MasonCrinr's picture
Upload 331 files
8026e91
import math
import torch.nn as nn
import jukebox.utils.dist_adapter as dist
from jukebox.utils.checkpoint import checkpoint
class ResConvBlock(nn.Module):
def __init__(self, n_in, n_state):
super().__init__()
self.model = nn.Sequential(
nn.ReLU(),
nn.Conv2d(n_in, n_state, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(n_state, n_in, 1, 1, 0),
)
def forward(self, x):
return x + self.model(x)
class Resnet(nn.Module):
def __init__(self, n_in, n_depth, m_conv=1.0):
super().__init__()
self.model = nn.Sequential(*[ResConvBlock(n_in, int(m_conv * n_in)) for _ in range(n_depth)])
def forward(self, x):
return self.model(x)
class ResConv1DBlock(nn.Module):
def __init__(self, n_in, n_state, dilation=1, zero_out=False, res_scale=1.0):
super().__init__()
padding = dilation
self.model = nn.Sequential(
nn.ReLU(),
nn.Conv1d(n_in, n_state, 3, 1, padding, dilation),
nn.ReLU(),
nn.Conv1d(n_state, n_in, 1, 1, 0),
)
if zero_out:
out = self.model[-1]
nn.init.zeros_(out.weight)
nn.init.zeros_(out.bias)
self.res_scale = res_scale
def forward(self, x):
return x + self.res_scale * self.model(x)
class Resnet1D(nn.Module):
def __init__(self, n_in, n_depth, m_conv=1.0, dilation_growth_rate=1, dilation_cycle=None, zero_out=False, res_scale=False, reverse_dilation=False, checkpoint_res=False):
super().__init__()
def _get_depth(depth):
if dilation_cycle is None:
return depth
else:
return depth % dilation_cycle
blocks = [ResConv1DBlock(n_in, int(m_conv * n_in),
dilation=dilation_growth_rate ** _get_depth(depth),
zero_out=zero_out,
res_scale=1.0 if not res_scale else 1.0 / math.sqrt(n_depth))
for depth in range(n_depth)]
if reverse_dilation:
blocks = blocks[::-1]
self.checkpoint_res = checkpoint_res
if self.checkpoint_res == 1:
if dist.get_rank() == 0:
print("Checkpointing convs")
self.blocks = nn.ModuleList(blocks)
else:
self.model = nn.Sequential(*blocks)
def forward(self, x):
if self.checkpoint_res == 1:
for block in self.blocks:
x = checkpoint(block, (x, ), block.parameters(), True)
return x
else:
return self.model(x)