|
import math |
|
import torch.nn as nn |
|
import jukebox.utils.dist_adapter as dist |
|
from jukebox.utils.checkpoint import checkpoint |
|
|
|
class ResConvBlock(nn.Module): |
|
def __init__(self, n_in, n_state): |
|
super().__init__() |
|
self.model = nn.Sequential( |
|
nn.ReLU(), |
|
nn.Conv2d(n_in, n_state, 3, 1, 1), |
|
nn.ReLU(), |
|
nn.Conv2d(n_state, n_in, 1, 1, 0), |
|
) |
|
|
|
def forward(self, x): |
|
return x + self.model(x) |
|
|
|
class Resnet(nn.Module): |
|
def __init__(self, n_in, n_depth, m_conv=1.0): |
|
super().__init__() |
|
self.model = nn.Sequential(*[ResConvBlock(n_in, int(m_conv * n_in)) for _ in range(n_depth)]) |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
class ResConv1DBlock(nn.Module): |
|
def __init__(self, n_in, n_state, dilation=1, zero_out=False, res_scale=1.0): |
|
super().__init__() |
|
padding = dilation |
|
self.model = nn.Sequential( |
|
nn.ReLU(), |
|
nn.Conv1d(n_in, n_state, 3, 1, padding, dilation), |
|
nn.ReLU(), |
|
nn.Conv1d(n_state, n_in, 1, 1, 0), |
|
) |
|
if zero_out: |
|
out = self.model[-1] |
|
nn.init.zeros_(out.weight) |
|
nn.init.zeros_(out.bias) |
|
self.res_scale = res_scale |
|
|
|
def forward(self, x): |
|
return x + self.res_scale * self.model(x) |
|
|
|
class Resnet1D(nn.Module): |
|
def __init__(self, n_in, n_depth, m_conv=1.0, dilation_growth_rate=1, dilation_cycle=None, zero_out=False, res_scale=False, reverse_dilation=False, checkpoint_res=False): |
|
super().__init__() |
|
def _get_depth(depth): |
|
if dilation_cycle is None: |
|
return depth |
|
else: |
|
return depth % dilation_cycle |
|
blocks = [ResConv1DBlock(n_in, int(m_conv * n_in), |
|
dilation=dilation_growth_rate ** _get_depth(depth), |
|
zero_out=zero_out, |
|
res_scale=1.0 if not res_scale else 1.0 / math.sqrt(n_depth)) |
|
for depth in range(n_depth)] |
|
if reverse_dilation: |
|
blocks = blocks[::-1] |
|
self.checkpoint_res = checkpoint_res |
|
if self.checkpoint_res == 1: |
|
if dist.get_rank() == 0: |
|
print("Checkpointing convs") |
|
self.blocks = nn.ModuleList(blocks) |
|
else: |
|
self.model = nn.Sequential(*blocks) |
|
|
|
def forward(self, x): |
|
if self.checkpoint_res == 1: |
|
for block in self.blocks: |
|
x = checkpoint(block, (x, ), block.parameters(), True) |
|
return x |
|
else: |
|
return self.model(x) |
|
|