pixelization / models /c2pDis.py
NoCrypt's picture
init
2c9c37b
from .basic_layer import *
import math
from torch.nn import Parameter
#from pytorch_metric_learning import losses
'''
Margin code is borrowed from https://github.com/MuggleWang/CosFace_pytorch and https://github.com/wujiyang/Face_Pytorch.
'''
def cosine_sim(x1, x2, dim=1, eps=1e-8):
ip = torch.mm(x1, x2.t()) # w 7*512
w1 = torch.norm(x1, 2, dim)
w2 = torch.norm(x2, 2, dim)
return ip / torch.ger(w1,w2).clamp(min=eps)
class MarginCosineProduct(nn.Module):
r"""Implement of large margin cosine distance: :
Args:
in_features: size of each input sample
out_features: size of each output sample
s: norm of input feature
m: margin
"""
def __init__(self, in_features, out_features, s=30.0, m=0.40):
super(MarginCosineProduct, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.s = s
self.m = m
self.weight = Parameter(torch.Tensor(out_features, in_features)) # 7 512
nn.init.xavier_uniform_(self.weight)
#stdv = 1. / math.sqrt(self.weight.size(1))
#self.weight.data.uniform_(-stdv, stdv)
def forward(self, input, label):
cosine = cosine_sim(input, self.weight) # 1*512 7*512
# cosine = F.linear(F.normalize(input), F.normalize(self.weight))
# --------------------------- convert label to one-hot ---------------------------
# https://discuss.pytorch.org/t/convert-int-into-one-hot-format/507
one_hot = torch.zeros_like(cosine)
one_hot.scatter_(1, label.view(-1, 1), 1.0)
# -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
output = self.s * (cosine - one_hot * self.m)
return output
def __repr__(self):
return self.__class__.__name__ + '(' \
+ 'in_features=' + str(self.in_features) \
+ ', out_features=' + str(self.out_features) \
+ ', s=' + str(self.s) \
+ ', m=' + str(self.m) + ')'
class ArcMarginProduct(nn.Module):
def __init__(self, in_feature=128, out_feature=10575, s=32.0, m=0.50, easy_margin=False):
super(ArcMarginProduct, self).__init__()
self.in_feature = in_feature
self.out_feature = out_feature
self.s = s
self.m = m
self.weight = Parameter(torch.Tensor(out_feature, in_feature))
nn.init.xavier_uniform_(self.weight)
self.easy_margin = easy_margin
self.cos_m = math.cos(m)
self.sin_m = math.sin(m)
# make the function cos(theta+m) monotonic decreasing while theta in [0°,180°]
self.th = math.cos(math.pi - m)
self.mm = math.sin(math.pi - m) * m
def forward(self, x, label):
# cos(theta)
cosine = F.linear(F.normalize(x), F.normalize(self.weight))
# cos(theta + m)
sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
phi = cosine * self.cos_m - sine * self.sin_m
if self.easy_margin:
phi = torch.where(cosine > 0, phi, cosine)
else:
phi = torch.where((cosine - self.th) > 0, phi, cosine - self.mm)
#one_hot = torch.zeros(cosine.size(), device='cuda' if torch.cuda.is_available() else 'cpu')
one_hot = torch.zeros_like(cosine)
one_hot.scatter_(1, label.view(-1, 1), 1)
output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
output = output * self.s
return output
class MultiMarginProduct(nn.Module):
def __init__(self, in_feature=128, out_feature=10575, s=32.0, m1=0.20, m2=0.35, easy_margin=False):
super(MultiMarginProduct, self).__init__()
self.in_feature = in_feature
self.out_feature = out_feature
self.s = s
self.m1 = m1
self.m2 = m2
self.weight = Parameter(torch.Tensor(out_feature, in_feature))
nn.init.xavier_uniform_(self.weight)
self.easy_margin = easy_margin
self.cos_m1 = math.cos(m1)
self.sin_m1 = math.sin(m1)
# make the function cos(theta+m) monotonic decreasing while theta in [0°,180°]
self.th = math.cos(math.pi - m1)
self.mm = math.sin(math.pi - m1) * m1
def forward(self, x, label):
# cos(theta)
cosine = F.linear(F.normalize(x), F.normalize(self.weight))
# cos(theta + m1)
sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
phi = cosine * self.cos_m1 - sine * self.sin_m1
if self.easy_margin:
phi = torch.where(cosine > 0, phi, cosine)
else:
phi = torch.where((cosine - self.th) > 0, phi, cosine - self.mm)
one_hot = torch.zeros_like(cosine)
one_hot.scatter_(1, label.view(-1, 1), 1)
output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # additive angular margin
output = output - one_hot * self.m2 # additive cosine margin
output = output * self.s
return output
class CPDis(nn.Module):
"""PatchGAN."""
def __init__(self, image_size=256, conv_dim=64, repeat_num=3, norm='SN'):
super(CPDis, self).__init__()
layers = []
if norm == 'SN':
layers.append(spectral_norm(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1)))
else:
layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
layers.append(nn.LeakyReLU(0.01, inplace=True))
curr_dim = conv_dim
for i in range(1, repeat_num):
if norm == 'SN':
layers.append(spectral_norm(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1)))
else:
layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1))
layers.append(nn.LeakyReLU(0.01, inplace=True))
curr_dim = curr_dim * 2
# k_size = int(image_size / np.power(2, repeat_num))
if norm == 'SN':
layers.append(spectral_norm(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=1, padding=1)))
else:
layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=1, padding=1))
layers.append(nn.LeakyReLU(0.01, inplace=True))
curr_dim = curr_dim * 2
self.main = nn.Sequential(*layers)
if norm == 'SN':
self.conv1 = spectral_norm(nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False))
else:
self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False)
def forward(self, x):
if x.ndim == 5:
x = x.squeeze(0)
assert x.ndim == 4, x.ndim
h = self.main(x)
# out_real = self.conv1(h)
out_makeup = self.conv1(h)
# return out_real.squeeze(), out_makeup.squeeze()
return out_makeup
class CPDis_cls(nn.Module):
"""PatchGAN."""
def __init__(self, image_size=256, conv_dim=64, repeat_num=3, norm='SN'):
super(CPDis_cls, self).__init__()
layers = []
if norm == 'SN':
layers.append(spectral_norm(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1)))
else:
layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
layers.append(nn.LeakyReLU(0.01, inplace=True))
curr_dim = conv_dim
for i in range(1, repeat_num):
if norm == 'SN':
layers.append(spectral_norm(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1)))
else:
layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1))
layers.append(nn.LeakyReLU(0.01, inplace=True))
curr_dim = curr_dim * 2
# k_size = int(image_size / np.power(2, repeat_num))
if norm == 'SN':
layers.append(spectral_norm(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=1, padding=1)))
else:
layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=1, padding=1))
layers.append(nn.LeakyReLU(0.01, inplace=True))
curr_dim = curr_dim * 2
self.main = nn.Sequential(*layers)
if norm == 'SN':
self.conv1 = spectral_norm(nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False))
self.classifier_pool = nn.AdaptiveAvgPool2d(1)
self.classifier_conv = nn.Conv2d(512, 512, 1, 1, 0)
self.classifier = MarginCosineProduct(512,7)#ArcMarginProduct(512, 7)
print("Using Large Margin Cosine Loss.")
else:
self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False)
def forward(self, x, label):
if x.ndim == 5:
x = x.squeeze(0)
assert x.ndim == 4, x.ndim
h = self.main(x) # ([1, 512, 31, 31])
#print(out_cls.shape)
out_cls = self.classifier_pool(h)
#print(out_cls.shape)
out_cls = self.classifier_conv(out_cls)
#print(out_cls.shape)
out_cls = torch.squeeze(out_cls, -1)
out_cls = torch.squeeze(out_cls, -1)
out_cls = self.classifier(out_cls, label)
out_makeup = self.conv1(h) # torch.Size([1, 1, 30, 30])
# return out_real.squeeze(), out_makeup.squeeze()
return out_makeup, out_cls
class SpectralNorm(object):
def __init__(self):
self.name = "weight"
# print(self.name)
self.power_iterations = 1
def compute_weight(self, module):
u = getattr(module, self.name + "_u")
v = getattr(module, self.name + "_v")
w = getattr(module, self.name + "_bar")
height = w.data.shape[0]
for _ in range(self.power_iterations):
v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data))
u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data))
# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
sigma = u.dot(w.view(height, -1).mv(v))
return w / sigma.expand_as(w)
@staticmethod
def apply(module):
name = "weight"
fn = SpectralNorm()
try:
u = getattr(module, name + "_u")
v = getattr(module, name + "_v")
w = getattr(module, name + "_bar")
except AttributeError:
w = getattr(module, name)
height = w.data.shape[0]
width = w.view(height, -1).data.shape[1]
u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
w_bar = Parameter(w.data)
# del module._parameters[name]
module.register_parameter(name + "_u", u)
module.register_parameter(name + "_v", v)
module.register_parameter(name + "_bar", w_bar)
# remove w from parameter list
del module._parameters[name]
setattr(module, name, fn.compute_weight(module))
# recompute weight before every forward()
module.register_forward_pre_hook(fn)
return fn
def remove(self, module):
weight = self.compute_weight(module)
delattr(module, self.name)
del module._parameters[self.name + '_u']
del module._parameters[self.name + '_v']
del module._parameters[self.name + '_bar']
module.register_parameter(self.name, Parameter(weight.data))
def __call__(self, module, inputs):
setattr(module, self.name, self.compute_weight(module))
def spectral_norm(module):
SpectralNorm.apply(module)
return module
def remove_spectral_norm(module):
name = 'weight'
for k, hook in module._forward_pre_hooks.items():
if isinstance(hook, SpectralNorm) and hook.name == name:
hook.remove(module)
del module._forward_pre_hooks[k]
return module
raise ValueError("spectral_norm of '{}' not found in {}"
.format(name, module))