neural-style-transfer / loss_functions.py
georgescutelnicu's picture
Upload 11 files
f6ca457
raw
history blame
No virus
618 Bytes
import torch
from data_setup import gram_matrix
def content_loss(target, content):
loss = torch.mean((target - content) ** 2)
return loss
def style_loss(target_features, style_grams):
loss = 0
for layer in target_features:
target_f = target_features[layer]
target_gram = gram_matrix(target_f)
style_gram = style_grams[layer]
b,c,h,w = target_f.shape
layer_loss = 0.2 * torch.mean((target_gram - style_gram) ** 2)
loss += layer_loss/(c*h*w)
return loss
def total_loss(content_loss, style_loss, alpha, beta):
loss = alpha * content_loss + beta * style_loss
return loss