Spaces:
Running
Running
from .basic_layer import * | |
import math | |
from torch.nn import Parameter | |
#from pytorch_metric_learning import losses | |
''' | |
Margin code is borrowed from https://github.com/MuggleWang/CosFace_pytorch and https://github.com/wujiyang/Face_Pytorch. | |
''' | |
def cosine_sim(x1, x2, dim=1, eps=1e-8): | |
ip = torch.mm(x1, x2.t()) # w 7*512 | |
w1 = torch.norm(x1, 2, dim) | |
w2 = torch.norm(x2, 2, dim) | |
return ip / torch.ger(w1,w2).clamp(min=eps) | |
class MarginCosineProduct(nn.Module): | |
r"""Implement of large margin cosine distance: : | |
Args: | |
in_features: size of each input sample | |
out_features: size of each output sample | |
s: norm of input feature | |
m: margin | |
""" | |
def __init__(self, in_features, out_features, s=30.0, m=0.40): | |
super(MarginCosineProduct, self).__init__() | |
self.in_features = in_features | |
self.out_features = out_features | |
self.s = s | |
self.m = m | |
self.weight = Parameter(torch.Tensor(out_features, in_features)) # 7 512 | |
nn.init.xavier_uniform_(self.weight) | |
#stdv = 1. / math.sqrt(self.weight.size(1)) | |
#self.weight.data.uniform_(-stdv, stdv) | |
def forward(self, input, label): | |
cosine = cosine_sim(input, self.weight) # 1*512 7*512 | |
# cosine = F.linear(F.normalize(input), F.normalize(self.weight)) | |
# --------------------------- convert label to one-hot --------------------------- | |
# https://discuss.pytorch.org/t/convert-int-into-one-hot-format/507 | |
one_hot = torch.zeros_like(cosine) | |
one_hot.scatter_(1, label.view(-1, 1), 1.0) | |
# -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- | |
output = self.s * (cosine - one_hot * self.m) | |
return output | |
def __repr__(self): | |
return self.__class__.__name__ + '(' \ | |
+ 'in_features=' + str(self.in_features) \ | |
+ ', out_features=' + str(self.out_features) \ | |
+ ', s=' + str(self.s) \ | |
+ ', m=' + str(self.m) + ')' | |
class ArcMarginProduct(nn.Module): | |
def __init__(self, in_feature=128, out_feature=10575, s=32.0, m=0.50, easy_margin=False): | |
super(ArcMarginProduct, self).__init__() | |
self.in_feature = in_feature | |
self.out_feature = out_feature | |
self.s = s | |
self.m = m | |
self.weight = Parameter(torch.Tensor(out_feature, in_feature)) | |
nn.init.xavier_uniform_(self.weight) | |
self.easy_margin = easy_margin | |
self.cos_m = math.cos(m) | |
self.sin_m = math.sin(m) | |
# make the function cos(theta+m) monotonic decreasing while theta in [0°,180°] | |
self.th = math.cos(math.pi - m) | |
self.mm = math.sin(math.pi - m) * m | |
def forward(self, x, label): | |
# cos(theta) | |
cosine = F.linear(F.normalize(x), F.normalize(self.weight)) | |
# cos(theta + m) | |
sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) | |
phi = cosine * self.cos_m - sine * self.sin_m | |
if self.easy_margin: | |
phi = torch.where(cosine > 0, phi, cosine) | |
else: | |
phi = torch.where((cosine - self.th) > 0, phi, cosine - self.mm) | |
#one_hot = torch.zeros(cosine.size(), device='cuda' if torch.cuda.is_available() else 'cpu') | |
one_hot = torch.zeros_like(cosine) | |
one_hot.scatter_(1, label.view(-1, 1), 1) | |
output = (one_hot * phi) + ((1.0 - one_hot) * cosine) | |
output = output * self.s | |
return output | |
class MultiMarginProduct(nn.Module): | |
def __init__(self, in_feature=128, out_feature=10575, s=32.0, m1=0.20, m2=0.35, easy_margin=False): | |
super(MultiMarginProduct, self).__init__() | |
self.in_feature = in_feature | |
self.out_feature = out_feature | |
self.s = s | |
self.m1 = m1 | |
self.m2 = m2 | |
self.weight = Parameter(torch.Tensor(out_feature, in_feature)) | |
nn.init.xavier_uniform_(self.weight) | |
self.easy_margin = easy_margin | |
self.cos_m1 = math.cos(m1) | |
self.sin_m1 = math.sin(m1) | |
# make the function cos(theta+m) monotonic decreasing while theta in [0°,180°] | |
self.th = math.cos(math.pi - m1) | |
self.mm = math.sin(math.pi - m1) * m1 | |
def forward(self, x, label): | |
# cos(theta) | |
cosine = F.linear(F.normalize(x), F.normalize(self.weight)) | |
# cos(theta + m1) | |
sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) | |
phi = cosine * self.cos_m1 - sine * self.sin_m1 | |
if self.easy_margin: | |
phi = torch.where(cosine > 0, phi, cosine) | |
else: | |
phi = torch.where((cosine - self.th) > 0, phi, cosine - self.mm) | |
one_hot = torch.zeros_like(cosine) | |
one_hot.scatter_(1, label.view(-1, 1), 1) | |
output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # additive angular margin | |
output = output - one_hot * self.m2 # additive cosine margin | |
output = output * self.s | |
return output | |
class CPDis(nn.Module): | |
"""PatchGAN.""" | |
def __init__(self, image_size=256, conv_dim=64, repeat_num=3, norm='SN'): | |
super(CPDis, self).__init__() | |
layers = [] | |
if norm == 'SN': | |
layers.append(spectral_norm(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))) | |
else: | |
layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1)) | |
layers.append(nn.LeakyReLU(0.01, inplace=True)) | |
curr_dim = conv_dim | |
for i in range(1, repeat_num): | |
if norm == 'SN': | |
layers.append(spectral_norm(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1))) | |
else: | |
layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1)) | |
layers.append(nn.LeakyReLU(0.01, inplace=True)) | |
curr_dim = curr_dim * 2 | |
# k_size = int(image_size / np.power(2, repeat_num)) | |
if norm == 'SN': | |
layers.append(spectral_norm(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=1, padding=1))) | |
else: | |
layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=1, padding=1)) | |
layers.append(nn.LeakyReLU(0.01, inplace=True)) | |
curr_dim = curr_dim * 2 | |
self.main = nn.Sequential(*layers) | |
if norm == 'SN': | |
self.conv1 = spectral_norm(nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False)) | |
else: | |
self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False) | |
def forward(self, x): | |
if x.ndim == 5: | |
x = x.squeeze(0) | |
assert x.ndim == 4, x.ndim | |
h = self.main(x) | |
# out_real = self.conv1(h) | |
out_makeup = self.conv1(h) | |
# return out_real.squeeze(), out_makeup.squeeze() | |
return out_makeup | |
class CPDis_cls(nn.Module): | |
"""PatchGAN.""" | |
def __init__(self, image_size=256, conv_dim=64, repeat_num=3, norm='SN'): | |
super(CPDis_cls, self).__init__() | |
layers = [] | |
if norm == 'SN': | |
layers.append(spectral_norm(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))) | |
else: | |
layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1)) | |
layers.append(nn.LeakyReLU(0.01, inplace=True)) | |
curr_dim = conv_dim | |
for i in range(1, repeat_num): | |
if norm == 'SN': | |
layers.append(spectral_norm(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1))) | |
else: | |
layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1)) | |
layers.append(nn.LeakyReLU(0.01, inplace=True)) | |
curr_dim = curr_dim * 2 | |
# k_size = int(image_size / np.power(2, repeat_num)) | |
if norm == 'SN': | |
layers.append(spectral_norm(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=1, padding=1))) | |
else: | |
layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=1, padding=1)) | |
layers.append(nn.LeakyReLU(0.01, inplace=True)) | |
curr_dim = curr_dim * 2 | |
self.main = nn.Sequential(*layers) | |
if norm == 'SN': | |
self.conv1 = spectral_norm(nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False)) | |
self.classifier_pool = nn.AdaptiveAvgPool2d(1) | |
self.classifier_conv = nn.Conv2d(512, 512, 1, 1, 0) | |
self.classifier = MarginCosineProduct(512,7)#ArcMarginProduct(512, 7) | |
print("Using Large Margin Cosine Loss.") | |
else: | |
self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False) | |
def forward(self, x, label): | |
if x.ndim == 5: | |
x = x.squeeze(0) | |
assert x.ndim == 4, x.ndim | |
h = self.main(x) # ([1, 512, 31, 31]) | |
#print(out_cls.shape) | |
out_cls = self.classifier_pool(h) | |
#print(out_cls.shape) | |
out_cls = self.classifier_conv(out_cls) | |
#print(out_cls.shape) | |
out_cls = torch.squeeze(out_cls, -1) | |
out_cls = torch.squeeze(out_cls, -1) | |
out_cls = self.classifier(out_cls, label) | |
out_makeup = self.conv1(h) # torch.Size([1, 1, 30, 30]) | |
# return out_real.squeeze(), out_makeup.squeeze() | |
return out_makeup, out_cls | |
class SpectralNorm(object): | |
def __init__(self): | |
self.name = "weight" | |
# print(self.name) | |
self.power_iterations = 1 | |
def compute_weight(self, module): | |
u = getattr(module, self.name + "_u") | |
v = getattr(module, self.name + "_v") | |
w = getattr(module, self.name + "_bar") | |
height = w.data.shape[0] | |
for _ in range(self.power_iterations): | |
v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data)) | |
u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data)) | |
# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data)) | |
sigma = u.dot(w.view(height, -1).mv(v)) | |
return w / sigma.expand_as(w) | |
def apply(module): | |
name = "weight" | |
fn = SpectralNorm() | |
try: | |
u = getattr(module, name + "_u") | |
v = getattr(module, name + "_v") | |
w = getattr(module, name + "_bar") | |
except AttributeError: | |
w = getattr(module, name) | |
height = w.data.shape[0] | |
width = w.view(height, -1).data.shape[1] | |
u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) | |
v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) | |
w_bar = Parameter(w.data) | |
# del module._parameters[name] | |
module.register_parameter(name + "_u", u) | |
module.register_parameter(name + "_v", v) | |
module.register_parameter(name + "_bar", w_bar) | |
# remove w from parameter list | |
del module._parameters[name] | |
setattr(module, name, fn.compute_weight(module)) | |
# recompute weight before every forward() | |
module.register_forward_pre_hook(fn) | |
return fn | |
def remove(self, module): | |
weight = self.compute_weight(module) | |
delattr(module, self.name) | |
del module._parameters[self.name + '_u'] | |
del module._parameters[self.name + '_v'] | |
del module._parameters[self.name + '_bar'] | |
module.register_parameter(self.name, Parameter(weight.data)) | |
def __call__(self, module, inputs): | |
setattr(module, self.name, self.compute_weight(module)) | |
def spectral_norm(module): | |
SpectralNorm.apply(module) | |
return module | |
def remove_spectral_norm(module): | |
name = 'weight' | |
for k, hook in module._forward_pre_hooks.items(): | |
if isinstance(hook, SpectralNorm) and hook.name == name: | |
hook.remove(module) | |
del module._forward_pre_hooks[k] | |
return module | |
raise ValueError("spectral_norm of '{}' not found in {}" | |
.format(name, module)) | |