Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
class ConvBlock(nn.Module): | |
def __init__(self, in_channels: int, out_channels: int, downsample: bool = True, use_act: bool = True, | |
use_dropout: bool = False, **kwargs): | |
super(ConvBlock, self).__init__() | |
self.conv_block = nn.Sequential( | |
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, padding_mode="reflect", **kwargs) | |
if downsample | |
else nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, **kwargs), | |
nn.InstanceNorm2d(num_features=out_channels), | |
nn.ReLU(inplace=True) if use_act else nn.Identity() | |
) | |
if use_dropout: | |
self.conv_block = nn.Sequential(self.conv_block, nn.Dropout(p=0.5)) | |
def forward(self, x): | |
return self.conv_block(x) | |
class ResidualBlock(nn.Module): | |
def __init__(self, features: int): | |
super(ResidualBlock, self).__init__() | |
self.residual_block = nn.Sequential( | |
ConvBlock(in_channels=features, out_channels=features, kernel_size=3, padding=1), | |
ConvBlock(in_channels=features, out_channels=features, kernel_size=3, padding=1, use_act=False), | |
) | |
def forward(self, x): | |
return x + self.residual_block(x) | |
class CycleGenerator(nn.Module): | |
def __init__(self, img_channels: int = 3, latent_dim: int = 64, num_residuals: int = 9): | |
super(CycleGenerator, self).__init__() | |
self.base = nn.Sequential( | |
nn.Conv2d(in_channels=img_channels, out_channels=latent_dim, kernel_size=7, stride=1, padding=3, | |
padding_mode="reflect"), | |
nn.ReLU(inplace=True) | |
) | |
self.down_blocks = nn.ModuleList( | |
[ | |
ConvBlock(in_channels=latent_dim, out_channels=latent_dim * 2, kernel_size=3, stride=2, padding=1), | |
ConvBlock(in_channels=latent_dim * 2, out_channels=latent_dim * 4, kernel_size=3, stride=2, padding=1), | |
] | |
) | |
self.residual_blocks = nn.Sequential( | |
*[ResidualBlock(features=latent_dim * 4) for _ in range(num_residuals)] | |
) | |
self.up_blocks = nn.ModuleList( | |
[ | |
ConvBlock(in_channels=latent_dim * 4, out_channels=latent_dim * 2, kernel_size=3, stride=2, padding=1, | |
output_padding=1, | |
downsample=False), | |
ConvBlock(in_channels=latent_dim * 2, out_channels=latent_dim, kernel_size=3, stride=2, padding=1, | |
output_padding=1, | |
downsample=False), | |
] | |
) | |
self.head = nn.Conv2d(in_channels=latent_dim, out_channels=img_channels, kernel_size=7, stride=1, padding=3, | |
padding_mode="reflect") | |
def forward(self, x): | |
x = self.base(x) | |
for layer in self.down_blocks: | |
x = layer(x) | |
x = self.residual_blocks(x) | |
for layer in self.up_blocks: | |
x = layer(x) | |
x = self.head(x) | |
return torch.tanh(x) | |