Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.utils.checkpoint as cp | |
from mmcv.cnn import build_conv_layer, build_norm_layer | |
from mmcv.runner import BaseModule | |
from torch.nn.modules.utils import _pair | |
from mmdet.models.backbones.resnet import Bottleneck, ResNet | |
from mmdet.models.builder import BACKBONES | |
class TridentConv(BaseModule): | |
"""Trident Convolution Module. | |
Args: | |
in_channels (int): Number of channels in input. | |
out_channels (int): Number of channels in output. | |
kernel_size (int): Size of convolution kernel. | |
stride (int, optional): Convolution stride. Default: 1. | |
trident_dilations (tuple[int, int, int], optional): Dilations of | |
different trident branch. Default: (1, 2, 3). | |
test_branch_idx (int, optional): In inference, all 3 branches will | |
be used if `test_branch_idx==-1`, otherwise only branch with | |
index `test_branch_idx` will be used. Default: 1. | |
bias (bool, optional): Whether to use bias in convolution or not. | |
Default: False. | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Default: None | |
""" | |
def __init__(self, | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride=1, | |
trident_dilations=(1, 2, 3), | |
test_branch_idx=1, | |
bias=False, | |
init_cfg=None): | |
super(TridentConv, self).__init__(init_cfg) | |
self.num_branch = len(trident_dilations) | |
self.with_bias = bias | |
self.test_branch_idx = test_branch_idx | |
self.stride = _pair(stride) | |
self.kernel_size = _pair(kernel_size) | |
self.paddings = _pair(trident_dilations) | |
self.dilations = trident_dilations | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.bias = bias | |
self.weight = nn.Parameter( | |
torch.Tensor(out_channels, in_channels, *self.kernel_size)) | |
if bias: | |
self.bias = nn.Parameter(torch.Tensor(out_channels)) | |
else: | |
self.bias = None | |
def extra_repr(self): | |
tmpstr = f'in_channels={self.in_channels}' | |
tmpstr += f', out_channels={self.out_channels}' | |
tmpstr += f', kernel_size={self.kernel_size}' | |
tmpstr += f', num_branch={self.num_branch}' | |
tmpstr += f', test_branch_idx={self.test_branch_idx}' | |
tmpstr += f', stride={self.stride}' | |
tmpstr += f', paddings={self.paddings}' | |
tmpstr += f', dilations={self.dilations}' | |
tmpstr += f', bias={self.bias}' | |
return tmpstr | |
def forward(self, inputs): | |
if self.training or self.test_branch_idx == -1: | |
outputs = [ | |
F.conv2d(input, self.weight, self.bias, self.stride, padding, | |
dilation) for input, dilation, padding in zip( | |
inputs, self.dilations, self.paddings) | |
] | |
else: | |
assert len(inputs) == 1 | |
outputs = [ | |
F.conv2d(inputs[0], self.weight, self.bias, self.stride, | |
self.paddings[self.test_branch_idx], | |
self.dilations[self.test_branch_idx]) | |
] | |
return outputs | |
# Since TridentNet is defined over ResNet50 and ResNet101, here we | |
# only support TridentBottleneckBlock. | |
class TridentBottleneck(Bottleneck): | |
"""BottleBlock for TridentResNet. | |
Args: | |
trident_dilations (tuple[int, int, int]): Dilations of different | |
trident branch. | |
test_branch_idx (int): In inference, all 3 branches will be used | |
if `test_branch_idx==-1`, otherwise only branch with index | |
`test_branch_idx` will be used. | |
concat_output (bool): Whether to concat the output list to a Tensor. | |
`True` only in the last Block. | |
""" | |
def __init__(self, trident_dilations, test_branch_idx, concat_output, | |
**kwargs): | |
super(TridentBottleneck, self).__init__(**kwargs) | |
self.trident_dilations = trident_dilations | |
self.num_branch = len(trident_dilations) | |
self.concat_output = concat_output | |
self.test_branch_idx = test_branch_idx | |
self.conv2 = TridentConv( | |
self.planes, | |
self.planes, | |
kernel_size=3, | |
stride=self.conv2_stride, | |
bias=False, | |
trident_dilations=self.trident_dilations, | |
test_branch_idx=test_branch_idx, | |
init_cfg=dict( | |
type='Kaiming', | |
distribution='uniform', | |
mode='fan_in', | |
override=dict(name='conv2'))) | |
def forward(self, x): | |
def _inner_forward(x): | |
num_branch = ( | |
self.num_branch | |
if self.training or self.test_branch_idx == -1 else 1) | |
identity = x | |
if not isinstance(x, list): | |
x = (x, ) * num_branch | |
identity = x | |
if self.downsample is not None: | |
identity = [self.downsample(b) for b in x] | |
out = [self.conv1(b) for b in x] | |
out = [self.norm1(b) for b in out] | |
out = [self.relu(b) for b in out] | |
if self.with_plugins: | |
for k in range(len(out)): | |
out[k] = self.forward_plugin(out[k], | |
self.after_conv1_plugin_names) | |
out = self.conv2(out) | |
out = [self.norm2(b) for b in out] | |
out = [self.relu(b) for b in out] | |
if self.with_plugins: | |
for k in range(len(out)): | |
out[k] = self.forward_plugin(out[k], | |
self.after_conv2_plugin_names) | |
out = [self.conv3(b) for b in out] | |
out = [self.norm3(b) for b in out] | |
if self.with_plugins: | |
for k in range(len(out)): | |
out[k] = self.forward_plugin(out[k], | |
self.after_conv3_plugin_names) | |
out = [ | |
out_b + identity_b for out_b, identity_b in zip(out, identity) | |
] | |
return out | |
if self.with_cp and x.requires_grad: | |
out = cp.checkpoint(_inner_forward, x) | |
else: | |
out = _inner_forward(x) | |
out = [self.relu(b) for b in out] | |
if self.concat_output: | |
out = torch.cat(out, dim=0) | |
return out | |
def make_trident_res_layer(block, | |
inplanes, | |
planes, | |
num_blocks, | |
stride=1, | |
trident_dilations=(1, 2, 3), | |
style='pytorch', | |
with_cp=False, | |
conv_cfg=None, | |
norm_cfg=dict(type='BN'), | |
dcn=None, | |
plugins=None, | |
test_branch_idx=-1): | |
"""Build Trident Res Layers.""" | |
downsample = None | |
if stride != 1 or inplanes != planes * block.expansion: | |
downsample = [] | |
conv_stride = stride | |
downsample.extend([ | |
build_conv_layer( | |
conv_cfg, | |
inplanes, | |
planes * block.expansion, | |
kernel_size=1, | |
stride=conv_stride, | |
bias=False), | |
build_norm_layer(norm_cfg, planes * block.expansion)[1] | |
]) | |
downsample = nn.Sequential(*downsample) | |
layers = [] | |
for i in range(num_blocks): | |
layers.append( | |
block( | |
inplanes=inplanes, | |
planes=planes, | |
stride=stride if i == 0 else 1, | |
trident_dilations=trident_dilations, | |
downsample=downsample if i == 0 else None, | |
style=style, | |
with_cp=with_cp, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
dcn=dcn, | |
plugins=plugins, | |
test_branch_idx=test_branch_idx, | |
concat_output=True if i == num_blocks - 1 else False)) | |
inplanes = planes * block.expansion | |
return nn.Sequential(*layers) | |
class TridentResNet(ResNet): | |
"""The stem layer, stage 1 and stage 2 in Trident ResNet are identical to | |
ResNet, while in stage 3, Trident BottleBlock is utilized to replace the | |
normal BottleBlock to yield trident output. Different branch shares the | |
convolution weight but uses different dilations to achieve multi-scale | |
output. | |
/ stage3(b0) \ | |
x - stem - stage1 - stage2 - stage3(b1) - output | |
\ stage3(b2) / | |
Args: | |
depth (int): Depth of resnet, from {50, 101, 152}. | |
num_branch (int): Number of branches in TridentNet. | |
test_branch_idx (int): In inference, all 3 branches will be used | |
if `test_branch_idx==-1`, otherwise only branch with index | |
`test_branch_idx` will be used. | |
trident_dilations (tuple[int]): Dilations of different trident branch. | |
len(trident_dilations) should be equal to num_branch. | |
""" # noqa | |
def __init__(self, depth, num_branch, test_branch_idx, trident_dilations, | |
**kwargs): | |
assert num_branch == len(trident_dilations) | |
assert depth in (50, 101, 152) | |
super(TridentResNet, self).__init__(depth, **kwargs) | |
assert self.num_stages == 3 | |
self.test_branch_idx = test_branch_idx | |
self.num_branch = num_branch | |
last_stage_idx = self.num_stages - 1 | |
stride = self.strides[last_stage_idx] | |
dilation = trident_dilations | |
dcn = self.dcn if self.stage_with_dcn[last_stage_idx] else None | |
if self.plugins is not None: | |
stage_plugins = self.make_stage_plugins(self.plugins, | |
last_stage_idx) | |
else: | |
stage_plugins = None | |
planes = self.base_channels * 2**last_stage_idx | |
res_layer = make_trident_res_layer( | |
TridentBottleneck, | |
inplanes=(self.block.expansion * self.base_channels * | |
2**(last_stage_idx - 1)), | |
planes=planes, | |
num_blocks=self.stage_blocks[last_stage_idx], | |
stride=stride, | |
trident_dilations=dilation, | |
style=self.style, | |
with_cp=self.with_cp, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
dcn=dcn, | |
plugins=stage_plugins, | |
test_branch_idx=self.test_branch_idx) | |
layer_name = f'layer{last_stage_idx + 1}' | |
self.__setattr__(layer_name, res_layer) | |
self.res_layers.pop(last_stage_idx) | |
self.res_layers.insert(last_stage_idx, layer_name) | |
self._freeze_stages() | |