import torch.nn as nn import torch import torch.nn.functional as F import numpy as np class ModulationConvBlock(nn.Module): def __init__(self, input_dim, output_dim, kernel_size, stride=1, padding=0, norm='none', activation='relu', pad_type='zero'): super(ModulationConvBlock, self).__init__() self.in_c = input_dim self.out_c = output_dim self.ksize = kernel_size self.stride = 1 self.padding = kernel_size // 2 self.eps = 1e-8 weight_shape = (output_dim, input_dim, kernel_size, kernel_size) fan_in = kernel_size * kernel_size *input_dim wscale = 1.0/np.sqrt(fan_in) self.weight = nn.Parameter(torch.randn(*weight_shape)) self.wscale = wscale self.bias = nn.Parameter(torch.zeros(output_dim)) self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True) self.activate_scale = np.sqrt(2.0) def forward(self, x, code): batch,in_channel,height,width = x.shape weight = self.weight * self.wscale _weight = weight.view(1, self.ksize, self.ksize, self.in_c, self.out_c) _weight = _weight * code.view(batch, 1, 1, self.in_c, 1) # demodulation _weight_norm = torch.sqrt(torch.sum(_weight ** 2, dim=[1, 2, 3]) + self.eps) _weight = _weight / _weight_norm.view(batch, 1, 1, 1, self.out_c) # fused_modulate x = x.view(1, batch * self.in_c, x.shape[2], x.shape[3]) weight = _weight.permute(1, 2, 3, 0, 4).reshape( self.ksize, self.ksize, self.in_c, batch * self.out_c) # not use_conv2d_transpose weight = weight.permute(3, 2, 0, 1) x = F.conv2d(x, weight=weight, bias=None, stride=self.stride, padding=self.padding, groups=(batch if True else 1)) if True:#self.fused_modulate: x = x.view(batch, self.out_c, height, width) x = x+self.bias.view(1,-1,1,1) x = self.activate(x)*self.activate_scale return x class AliasConvBlock(nn.Module): def __init__(self, input_dim, output_dim, kernel_size, stride, padding=0, norm='none', activation='relu', pad_type='zero'): super(AliasConvBlock, self).__init__() self.use_bias = True # initialize padding if pad_type == 'reflect': self.pad = nn.ReflectionPad2d(padding) elif pad_type == 'replicate': self.pad = nn.ReplicationPad2d(padding) elif pad_type == 'zero': self.pad = nn.ZeroPad2d(padding) else: assert 0, "Unsupported padding type: {}".format(pad_type) # initialize normalization norm_dim = output_dim if norm == 'bn': self.norm = nn.BatchNorm2d(norm_dim) elif norm == 'in': # self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True) self.norm = nn.InstanceNorm2d(norm_dim) elif norm == 'ln': self.norm = LayerNorm(norm_dim) elif norm == 'adain': self.norm = AdaptiveInstanceNorm2d(norm_dim) elif norm == 'none' or norm == 'sn': self.norm = None else: assert 0, "Unsupported normalization: {}".format(norm) # initialize activation if activation == 'relu': self.activation = nn.ReLU(inplace=True) elif activation == 'lrelu': self.activation = nn.LeakyReLU(0.2, inplace=True) elif activation == 'prelu': self.activation = nn.PReLU() elif activation == 'selu': self.activation = nn.SELU(inplace=True) elif activation == 'tanh': self.activation = nn.Tanh() elif activation == 'none': self.activation = None else: assert 0, "Unsupported activation: {}".format(activation) # initialize convolution if norm == 'sn': self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) else: self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) def forward(self, x): x = self.conv(self.pad(x)) if self.norm: x = self.norm(x) if self.activation: x = self.activation(x) return x class AliasResBlocks(nn.Module): def __init__(self, num_blocks, dim, norm='in', activation='relu', pad_type='zero'): super(AliasResBlocks, self).__init__() self.model = [] for i in range(num_blocks): self.model += [AliasResBlock(dim, norm=norm, activation=activation, pad_type=pad_type)] self.model = nn.Sequential(*self.model) def forward(self, x): return self.model(x) class AliasResBlock(nn.Module): def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): super(AliasResBlock, self).__init__() model = [] model += [AliasConvBlock(dim, dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] model += [AliasConvBlock(dim, dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] self.model = nn.Sequential(*model) def forward(self, x): residual = x out = self.model(x) out += residual return out ################################################################################## # Sequential Models ################################################################################## class ResBlocks(nn.Module): def __init__(self, num_blocks, dim, norm='in', activation='relu', pad_type='zero'): super(ResBlocks, self).__init__() self.model = [] for i in range(num_blocks): self.model += [ResBlock(dim, norm=norm, activation=activation, pad_type=pad_type)] self.model = nn.Sequential(*self.model) def forward(self, x): return self.model(x) class MLP(nn.Module): def __init__(self, input_dim, output_dim, dim, n_blk, norm='none', activ='relu'): super(MLP, self).__init__() self.model = [] self.model += [linearBlock(input_dim, input_dim, norm=norm, activation=activ)] self.model += [linearBlock(input_dim, dim, norm=norm, activation=activ)] for i in range(n_blk - 2): self.model += [linearBlock(dim, dim, norm=norm, activation=activ)] self.model += [linearBlock(dim, output_dim, norm='none', activation='none')] # no output activations self.model = nn.Sequential(*self.model) # def forward(self, style0, style1, a=0): # return self.model[3]((1 - a) * self.model[0:3](style0.view(style0.size(0), -1)) + a * self.model[0:3]( # style1.view(style1.size(0), -1))) def forward(self, style0, style1=None, a=0): style1 = style0 return self.model[3]((1 - a) * self.model[0:3](style0.view(style0.size(0), -1)) + a * self.model[0:3]( style1.view(style1.size(0), -1))) ################################################################################## # Basic Blocks ################################################################################## class ResBlock(nn.Module): def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): super(ResBlock, self).__init__() model = [] model += [ConvBlock(dim, dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] model += [ConvBlock(dim, dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] self.model = nn.Sequential(*model) def forward(self, x): residual = x out = self.model(x) out += residual return out class ConvBlock(nn.Module): def __init__(self, input_dim, output_dim, kernel_size, stride, padding=0, norm='none', activation='relu', pad_type='zero'): super(ConvBlock, self).__init__() self.use_bias = True # initialize padding if pad_type == 'reflect': self.pad = nn.ReflectionPad2d(padding) elif pad_type == 'replicate': self.pad = nn.ReplicationPad2d(padding) elif pad_type == 'zero': self.pad = nn.ZeroPad2d(padding) else: assert 0, "Unsupported padding type: {}".format(pad_type) # initialize normalization norm_dim = output_dim if norm == 'bn': self.norm = nn.BatchNorm2d(norm_dim) elif norm == 'in': # self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True) self.norm = nn.InstanceNorm2d(norm_dim) elif norm == 'ln': self.norm = LayerNorm(norm_dim) elif norm == 'adain': self.norm = AdaptiveInstanceNorm2d(norm_dim) elif norm == 'none' or norm == 'sn': self.norm = None else: assert 0, "Unsupported normalization: {}".format(norm) # initialize activation if activation == 'relu': self.activation = nn.ReLU(inplace=True) elif activation == 'lrelu': self.activation = nn.LeakyReLU(0.2, inplace=True) elif activation == 'prelu': self.activation = nn.PReLU() elif activation == 'selu': self.activation = nn.SELU(inplace=True) elif activation == 'tanh': self.activation = nn.Tanh() elif activation == 'none': self.activation = None else: assert 0, "Unsupported activation: {}".format(activation) # initialize convolution if norm == 'sn': self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) else: self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) def forward(self, x): x = self.conv(self.pad(x)) if self.norm: x = self.norm(x) if self.activation: x = self.activation(x) return x class linearBlock(nn.Module): def __init__(self, input_dim, output_dim, norm='none', activation='relu'): super(linearBlock, self).__init__() use_bias = True # initialize fully connected layer if norm == 'sn': self.fc = SpectralNorm(nn.Linear(input_dim, output_dim, bias=use_bias)) else: self.fc = nn.Linear(input_dim, output_dim, bias=use_bias) # initialize normalization norm_dim = output_dim if norm == 'bn': self.norm = nn.BatchNorm1d(norm_dim) elif norm == 'in': self.norm = nn.InstanceNorm1d(norm_dim) elif norm == 'ln': self.norm = LayerNorm(norm_dim) elif norm == 'none' or norm == 'sn': self.norm = None else: assert 0, "Unsupported normalization: {}".format(norm) # initialize activation if activation == 'relu': self.activation = nn.ReLU(inplace=True) elif activation == 'lrelu': self.activation = nn.LeakyReLU(0.2, inplace=True) elif activation == 'prelu': self.activation = nn.PReLU() elif activation == 'selu': self.activation = nn.SELU(inplace=True) elif activation == 'tanh': self.activation = nn.Tanh() elif activation == 'none': self.activation = None else: assert 0, "Unsupported activation: {}".format(activation) def forward(self, x): out = self.fc(x) if self.norm: out = self.norm(out) if self.activation: out = self.activation(out) return out ################################################################################## # Normalization layers ################################################################################## class AdaptiveInstanceNorm2d(nn.Module): def __init__(self, num_features, eps=1e-5, momentum=0.1): super(AdaptiveInstanceNorm2d, self).__init__() self.num_features = num_features self.eps = eps self.momentum = momentum # weight and bias are dynamically assigned self.weight = None self.bias = None # just dummy buffers, not used self.register_buffer('running_mean', torch.zeros(num_features)) self.register_buffer('running_var', torch.ones(num_features)) def forward(self, x): assert self.weight is not None and self.bias is not None, "Please assign weight and bias before calling AdaIN!" b, c = x.size(0), x.size(1) running_mean = self.running_mean.repeat(b) running_var = self.running_var.repeat(b) # Apply instance norm x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:]) out = F.batch_norm( x_reshaped, running_mean, running_var, self.weight, self.bias, True, self.momentum, self.eps) return out.view(b, c, *x.size()[2:]) def __repr__(self): return self.__class__.__name__ + '(' + str(self.num_features) + ')' class LayerNorm(nn.Module): def __init__(self, num_features, eps=1e-5, affine=True): super(LayerNorm, self).__init__() self.num_features = num_features self.affine = affine self.eps = eps if self.affine: self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_()) self.beta = nn.Parameter(torch.zeros(num_features)) def forward(self, x): shape = [-1] + [1] * (x.dim() - 1) # print(x.size()) if x.size(0) == 1: # These two lines run much faster in pytorch 0.4 than the two lines listed below. mean = x.view(-1).mean().view(*shape) std = x.view(-1).std().view(*shape) else: mean = x.view(x.size(0), -1).mean(1).view(*shape) std = x.view(x.size(0), -1).std(1).view(*shape) x = (x - mean) / (std + self.eps) if self.affine: shape = [1, -1] + [1] * (x.dim() - 2) x = x * self.gamma.view(*shape) + self.beta.view(*shape) return x def l2normalize(v, eps=1e-12): return v / (v.norm() + eps) class SpectralNorm(nn.Module): """ Based on the paper "Spectral Normalization for Generative Adversarial Networks" by Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida and the Pytorch implementation https://github.com/christiancosgrove/pytorch-spectral-normalization-gan """ def __init__(self, module, name='weight', power_iterations=1): super(SpectralNorm, self).__init__() self.module = module self.name = name self.power_iterations = power_iterations if not self._made_params(): self._make_params() def _update_u_v(self): u = getattr(self.module, self.name + "_u") v = getattr(self.module, self.name + "_v") w = getattr(self.module, self.name + "_bar") height = w.data.shape[0] for _ in range(self.power_iterations): v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data)) u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data)) # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data)) sigma = u.dot(w.view(height, -1).mv(v)) setattr(self.module, self.name, w / sigma.expand_as(w)) def _made_params(self): try: u = getattr(self.module, self.name + "_u") v = getattr(self.module, self.name + "_v") w = getattr(self.module, self.name + "_bar") return True except AttributeError: return False def _make_params(self): w = getattr(self.module, self.name) height = w.data.shape[0] width = w.view(height, -1).data.shape[1] u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) u.data = l2normalize(u.data) v.data = l2normalize(v.data) w_bar = nn.Parameter(w.data) del self.module._parameters[self.name] self.module.register_parameter(self.name + "_u", u) self.module.register_parameter(self.name + "_v", v) self.module.register_parameter(self.name + "_bar", w_bar) def forward(self, *args): self._update_u_v() return self.module.forward(*args)