AndranikSargsyan
add support for diffusers checkpoint loading
f1cc496
import torch
from torch import nn, einsum
from einops import rearrange, repeat
from ... import share
from ..attentionpatch import painta
use_grad = True
def forward(self, x, context=None):
# Todo: add batch inference support
if use_grad:
y, self_v, self_sim = self.attn1(self.norm1(x), None) # Self Attn.
x_uncond, x_cond = x.chunk(2)
context_uncond, context_cond = context.chunk(2)
y_uncond, y_cond = y.chunk(2)
self_sim_uncond, self_sim_cond = self_sim.chunk(2)
self_v_uncond, self_v_cond = self_v.chunk(2)
# Calculate CA similarities with conditional context
cross_h = self.attn2.heads
cross_q = self.attn2.to_q(self.norm2(x_cond+y_cond))
cross_k = self.attn2.to_k(context_cond)
cross_v = self.attn2.to_v(context_cond)
cross_q, cross_k, cross_v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=cross_h), (cross_q, cross_k, cross_v))
with torch.autocast(enabled=False, device_type = 'cuda'):
cross_q, cross_k = cross_q.float(), cross_k.float()
cross_sim = einsum('b i d, b j d -> b i j', cross_q, cross_k) * self.attn2.scale
del cross_q, cross_k
cross_sim = cross_sim.softmax(dim=-1) # Up to this point cross_sim is regular cross_sim in CA layer
cross_sim = cross_sim.mean(dim=0) # Calculate mean across heads
# PAIntA rescale
y_cond = painta.painta_rescale(
y_cond, self_v_cond, self_sim_cond, cross_sim, self.attn1.heads, self.attn1.to_out) # Rescale cond
y_uncond = painta.painta_rescale(
y_uncond, self_v_uncond, self_sim_uncond, cross_sim, self.attn1.heads, self.attn1.to_out) # Rescale uncond
y = torch.cat([y_uncond, y_cond], dim=0)
x = x + y
x = x + self.attn2(self.norm2(x), context=context) # Cross Attn.
x = x + self.ff(self.norm3(x))
return x