Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from torch.nn import init | |
import functools | |
from torch.optim import lr_scheduler | |
from .c2pGen import * | |
from .p2cGen import * | |
from .c2pDis import * | |
class Identity(nn.Module): | |
def forward(self, x): | |
return x | |
def get_norm_layer(norm_type='instance'): | |
"""Return a normalization layer | |
Parameters: | |
norm_type (str) -- the name of the normalization layer: batch | instance | none | |
For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). | |
For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. | |
""" | |
if norm_type == 'batch': | |
norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) | |
elif norm_type == 'instance': | |
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) | |
elif norm_type == 'none': | |
def norm_layer(x): return Identity() | |
else: | |
raise NotImplementedError('normalization layer [%s] is not found' % norm_type) | |
return norm_layer | |
def get_scheduler(optimizer, opt): | |
"""Return a learning rate scheduler | |
Parameters: | |
optimizer -- the optimizer of the network | |
opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions. | |
opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine | |
For 'linear', we keep the same learning rate for the first <opt.n_epochs> epochs | |
and linearly decay the rate to zero over the next <opt.n_epochs_decay> epochs. | |
For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. | |
See https://pytorch.org/docs/stable/optim.html for more details. | |
""" | |
if opt.lr_policy == 'linear': | |
def lambda_rule(epoch): | |
lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1) | |
return lr_l | |
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) | |
elif opt.lr_policy == 'step': | |
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) | |
elif opt.lr_policy == 'plateau': | |
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) | |
elif opt.lr_policy == 'cosine': | |
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) | |
else: | |
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) | |
return scheduler | |
def init_weights(net, init_type='normal', init_gain=0.02): | |
"""Initialize network weights. | |
Parameters: | |
net (network) -- network to be initialized | |
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal | |
init_gain (float) -- scaling factor for normal, xavier and orthogonal. | |
""" | |
def init_func(m): # define the initialization function | |
classname = m.__class__.__name__ | |
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): | |
if init_type == 'normal': | |
init.normal_(m.weight.data, 0.0, init_gain) | |
elif init_type == 'xavier': | |
init.xavier_normal_(m.weight.data, gain=init_gain) | |
elif init_type == 'kaiming': | |
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') | |
elif init_type == 'orthogonal': | |
init.orthogonal_(m.weight.data, gain=init_gain) | |
else: | |
raise NotImplementedError('initialization method [%s] is not implemented' % init_type) | |
if hasattr(m, 'bias') and m.bias is not None: | |
init.constant_(m.bias.data, 0.0) | |
elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. | |
init.normal_(m.weight.data, 1.0, init_gain) | |
init.constant_(m.bias.data, 0.0) | |
#print('initialize network with %s' % init_type) | |
net.apply(init_func) # apply the initialization function <init_func> | |
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): | |
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights | |
Parameters: | |
net (network) -- the network to be initialized | |
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal | |
gain (float) -- scaling factor for normal, xavier and orthogonal. | |
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 | |
Return an initialized network. | |
""" | |
gpu_ids = [0] | |
if len(gpu_ids) > 0: | |
# assert(torch.cuda.is_available()) #uncomment this for using gpu | |
net.to(torch.device("cpu")) #change this for using gpu to gpu_ids[0] | |
net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs | |
init_weights(net, init_type, init_gain=init_gain) | |
return net | |
def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]): | |
"""Create a generator | |
Parameters: | |
input_nc (int) -- the number of channels in input images | |
output_nc (int) -- the number of channels in output images | |
ngf (int) -- the number of filters in the last conv layer | |
netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128 | |
norm (str) -- the name of normalization layers used in the network: batch | instance | none | |
use_dropout (bool) -- if use dropout layers. | |
init_type (str) -- the name of our initialization method. | |
init_gain (float) -- scaling factor for normal, xavier and orthogonal. | |
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 | |
Returns a generator | |
""" | |
net = None | |
norm_layer = get_norm_layer(norm_type=norm) | |
if netG == 'c2pGen': # style_dim mlp_dim | |
net = C2PGen(input_nc, output_nc, ngf, 2, 4, 256, 256, activ='relu', pad_type='reflect') | |
#print('c2pgen resblock is 8') | |
elif netG == 'p2cGen': | |
net = P2CGen(input_nc, output_nc, ngf, 2, 3, activ='relu', pad_type='reflect') | |
elif netG == 'antialias': | |
net = AliasNet(input_nc, output_nc, ngf, 2, 3, activ='relu', pad_type='reflect') | |
else: | |
raise NotImplementedError('Generator model name [%s] is not recognized' % netG) | |
return init_net(net, init_type, init_gain, gpu_ids) | |
def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]): | |
"""Create a discriminator | |
Parameters: | |
input_nc (int) -- the number of channels in input images | |
ndf (int) -- the number of filters in the first conv layer | |
netD (str) -- the architecture's name: basic | n_layers | pixel | |
n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers' | |
norm (str) -- the type of normalization layers used in the network. | |
init_type (str) -- the name of the initialization method. | |
init_gain (float) -- scaling factor for normal, xavier and orthogonal. | |
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 | |
Returns a discriminator | |
""" | |
net = None | |
norm_layer = get_norm_layer(norm_type=norm) | |
if netD == 'CPDis': | |
net = CPDis(image_size=256, conv_dim=64, repeat_num=3, norm='SN') | |
elif netD == 'CPDis_cls': | |
net = CPDis_cls(image_size=256, conv_dim=64, repeat_num=3, norm='SN') | |
else: | |
raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD) | |
return init_net(net, init_type, init_gain, gpu_ids) | |
class GANLoss(nn.Module): | |
"""Define different GAN objectives. | |
The GANLoss class abstracts away the need to create the target label tensor | |
that has the same size as the input. | |
""" | |
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): | |
""" Initialize the GANLoss class. | |
Parameters: | |
gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. | |
target_real_label (bool) - - label for a real image | |
target_fake_label (bool) - - label of a fake image | |
Note: Do not use sigmoid as the last layer of Discriminator. | |
LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. | |
""" | |
super(GANLoss, self).__init__() | |
self.register_buffer('real_label', torch.tensor(target_real_label)) | |
self.register_buffer('fake_label', torch.tensor(target_fake_label)) | |
self.gan_mode = gan_mode | |
if gan_mode == 'lsgan': | |
self.loss = nn.MSELoss() | |
elif gan_mode == 'vanilla': | |
self.loss = nn.BCEWithLogitsLoss() | |
elif gan_mode in ['wgangp']: | |
self.loss = None | |
else: | |
raise NotImplementedError('gan mode %s not implemented' % gan_mode) | |
def get_target_tensor(self, prediction, target_is_real): | |
"""Create label tensors with the same size as the input. | |
Parameters: | |
prediction (tensor) - - tpyically the prediction from a discriminator | |
target_is_real (bool) - - if the ground truth label is for real images or fake images | |
Returns: | |
A label tensor filled with ground truth label, and with the size of the input | |
""" | |
if target_is_real: | |
target_tensor = self.real_label | |
else: | |
target_tensor = self.fake_label | |
return target_tensor.expand_as(prediction) | |
def __call__(self, prediction, target_is_real): | |
"""Calculate loss given Discriminator's output and grount truth labels. | |
Parameters: | |
prediction (tensor) - - tpyically the prediction output from a discriminator | |
target_is_real (bool) - - if the ground truth label is for real images or fake images | |
Returns: | |
the calculated loss. | |
""" | |
if self.gan_mode in ['lsgan', 'vanilla']: | |
target_tensor = self.get_target_tensor(prediction, target_is_real) | |
loss = self.loss(prediction, target_tensor) | |
elif self.gan_mode == 'wgangp': | |
if target_is_real: | |
loss = -prediction.mean() | |
else: | |
loss = prediction.mean() | |
return loss | |