import torch | |
import numpy as np | |
import torchvision | |
from PIL import Image | |
from torchvision.transforms.functional import InterpolationMode | |
import torchvision.transforms as transforms | |
def padding_336(b): | |
width, height = b.size | |
tar = int(np.ceil(height / 336) * 336) | |
top_padding = int((tar - height)/2) | |
bottom_padding = tar - height - top_padding | |
left_padding = 0 | |
right_padding = 0 | |
b = transforms.functional.pad(b, [left_padding, top_padding, right_padding, bottom_padding], fill=[255,255,255]) | |
return b | |
def HD_transform(img, hd_num=16): | |
width, height = img.size | |
trans = False | |
if width < height: | |
img = img.transpose(Image.TRANSPOSE) | |
trans = True | |
width, height = img.size | |
ratio = (width/ height) | |
scale = 1 | |
while scale*np.ceil(scale/ratio) <= hd_num: | |
scale += 1 | |
scale -= 1 | |
new_w = int(scale * 336) | |
new_h = int(new_w / ratio) | |
img = transforms.functional.resize(img, [new_h, new_w],) | |
img = padding_336(img) | |
width, height = img.size | |
if trans: | |
img = img.transpose(Image.TRANSPOSE) | |
return img | |