File size: 1,215 Bytes
f6ca457 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
def preprocess(img):
image = Image.fromarray(img).convert('RGB')
imsize = 196
transform = transforms.Compose([
transforms.Resize((imsize, imsize)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
image = transform(image)
image = image.unsqueeze(dim=0)
return image
def deprocess(image): # def show_image
image = image.clone()
image = image.squeeze(0)
image = image.permute(1,2,0)
image = image.detach().numpy()
image = image * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
image = image.clip(0,1)
return image
def get_features(image, model):
features = {}
layers = {
'0': 'layer_1',
'5': 'layer_2',
'10': 'layer_3',
'19': 'layer_4',
'28': 'layer_5'
}
x = image
for name, layer in model._modules.items():
x = layer(x)
if name in layers:
features[layers[name]] = x
return features
def gram_matrix(image):
b, c, h, w = image.size()
image = image.view(c, h*w)
gram = torch.mm(image, image.t())
return gram
|