Spaces:
Runtime error
Runtime error
File size: 4,032 Bytes
bfd34e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
from ... import share
import xformers
import xformers.ops
import torch
from torch import nn, einsum
import torchvision.transforms.functional as TF
from einops import rearrange, repeat
_ATTN_PRECISION = None
def forward_sd2(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = x if context is None else context
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
if _ATTN_PRECISION =="fp32": # force cast to fp32 to avoid overflowing
with torch.autocast(enabled=False, device_type = 'cuda'):
q, k = q.float(), k.float()
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
del q, k
if mask is not None:
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', sim, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
def forward_xformers(self, x, context=None, mask=None):
q = self.to_q(x)
context = x if context is None else context
k = self.to_k(context)
v = self.to_v(context)
b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, t.shape[1], self.heads, self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b * self.heads, t.shape[1], self.dim_head)
.contiguous(),
(q, k, v),
)
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
if mask is not None:
raise NotImplementedError
out = (
out.unsqueeze(0)
.reshape(b, self.heads, out.shape[1], self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b, out.shape[1], self.heads * self.dim_head)
)
return self.to_out(out)
forward = forward_xformers
import traceback
def forward_and_save(self, x, context=None, mask=None):
att_type = "self" if context is None else "cross"
h = self.heads
q = self.to_q(x)
context = x if context is None else context
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
if hasattr(share, '_crossattn_similarity_res8') and x.shape[1] == share.input_shape.res8 and att_type == 'cross':
share._crossattn_similarity_res8.append(torch.stack(share.reshape(sim).chunk(2))) # Chunk into 2 parts to differentiate the unconditional and conditional parts
if hasattr(share, '_crossattn_similarity_res16') and x.shape[1] == share.input_shape.res16 and att_type == 'cross':
share._crossattn_similarity_res16.append(torch.stack(share.reshape(sim).chunk(2))) # Chunk into 2 parts to differentiate the unconditional and conditional parts
if hasattr(share, '_crossattn_similarity_res32') and x.shape[1] == share.input_shape.res32 and att_type == 'cross':
share._crossattn_similarity_res32.append(torch.stack(share.reshape(sim).chunk(2))) # Chunk into 2 parts to differentiate the unconditional and conditional parts
if hasattr(share, '_crossattn_similarity_res64') and x.shape[1] == share.input_shape.res64 and att_type == 'cross':
share._crossattn_similarity_res64.append(torch.stack(share.reshape(sim).chunk(2))) # Chunk into 2 parts to differentiate the unconditional and conditional parts
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
out = einsum("b i j, b j d -> b i d", sim, v)
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
return self.to_out(out) |