Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
import math | |
from torch import einsum | |
from einops import rearrange, repeat | |
from .basic_transformer_block import PatchedBasicTransformerBlock as BasicTransformerBlock | |
def Normalize(in_channels): | |
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) | |
def init_(tensor): | |
dim = tensor.shape[-1] | |
std = 1 / math.sqrt(dim) | |
tensor.uniform_(-std, std) | |
return tensor | |
def zero_module(module): | |
for p in module.parameters(): | |
p.detach().zero_() | |
return module | |
class SpatialTransformer(nn.Module): | |
""" | |
Transformer block for image-like data. | |
First, project the input (aka embedding) | |
and reshape to b, t, d. | |
Then apply standard transformer action. | |
Finally, reshape to image | |
NEW: use_linear for more efficiency instead of the 1x1 convs | |
""" | |
def __init__(self, in_channels, n_heads, d_head, | |
depth=1, dropout=0., context_dim=None, | |
disable_self_attn=False, use_linear=False, | |
use_checkpoint=True): | |
super().__init__() | |
if context_dim is not None and not isinstance(context_dim, list): | |
context_dim = [context_dim] | |
self.in_channels = in_channels | |
inner_dim = n_heads * d_head | |
self.norm = Normalize(in_channels) | |
if not use_linear: | |
self.proj_in = nn.Conv2d(in_channels, | |
inner_dim, | |
kernel_size=1, | |
stride=1, | |
padding=0) | |
else: | |
self.proj_in = nn.Linear(in_channels, inner_dim) | |
self.transformer_blocks = nn.ModuleList( | |
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], | |
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) | |
for d in range(depth)] | |
) | |
if not use_linear: | |
self.proj_out = zero_module(nn.Conv2d(inner_dim, | |
in_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0)) | |
else: | |
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) | |
self.use_linear = use_linear | |
def forward(self, x, context=None): | |
# note: if no context is given, cross-attention defaults to self-attention | |
if not isinstance(context, list): | |
context = [context] | |
b, c, h, w = x.shape | |
x_in = x | |
x = self.norm(x) | |
if not self.use_linear: | |
x = self.proj_in(x) | |
x = rearrange(x, 'b c h w -> b (h w) c').contiguous() | |
if self.use_linear: | |
x = self.proj_in(x) | |
for i, block in enumerate(self.transformer_blocks): | |
x = block(x, context=context[i]) | |
if self.use_linear: | |
x = self.proj_out(x) | |
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() | |
if not self.use_linear: | |
x = self.proj_out(x) | |
return x + x_in | |