neural-style-transfer / data_setup.py
georgescutelnicu's picture
Upload 11 files
f6ca457
raw
history blame contribute delete
No virus
1.22 kB
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
def preprocess(img):
image = Image.fromarray(img).convert('RGB')
imsize = 196
transform = transforms.Compose([
transforms.Resize((imsize, imsize)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
image = transform(image)
image = image.unsqueeze(dim=0)
return image
def deprocess(image): # def show_image
image = image.clone()
image = image.squeeze(0)
image = image.permute(1,2,0)
image = image.detach().numpy()
image = image * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
image = image.clip(0,1)
return image
def get_features(image, model):
features = {}
layers = {
'0': 'layer_1',
'5': 'layer_2',
'10': 'layer_3',
'19': 'layer_4',
'28': 'layer_5'
}
x = image
for name, layer in model._modules.items():
x = layer(x)
if name in layers:
features[layers[name]] = x
return features
def gram_matrix(image):
b, c, h, w = image.size()
image = image.view(c, h*w)
gram = torch.mm(image, image.t())
return gram