import torch from torch import nn, einsum from einops import rearrange, repeat from ... import share from ..attentionpatch import painta use_grad = True def forward(self, x, context=None): # Todo: add batch inference support if use_grad: y, self_v, self_sim = self.attn1(self.norm1(x), None) # Self Attn. x_uncond, x_cond = x.chunk(2) context_uncond, context_cond = context.chunk(2) y_uncond, y_cond = y.chunk(2) self_sim_uncond, self_sim_cond = self_sim.chunk(2) self_v_uncond, self_v_cond = self_v.chunk(2) # Calculate CA similarities with conditional context cross_h = self.attn2.heads cross_q = self.attn2.to_q(self.norm2(x_cond+y_cond)) cross_k = self.attn2.to_k(context_cond) cross_v = self.attn2.to_v(context_cond) cross_q, cross_k, cross_v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=cross_h), (cross_q, cross_k, cross_v)) with torch.autocast(enabled=False, device_type = 'cuda'): cross_q, cross_k = cross_q.float(), cross_k.float() cross_sim = einsum('b i d, b j d -> b i j', cross_q, cross_k) * self.attn2.scale del cross_q, cross_k cross_sim = cross_sim.softmax(dim=-1) # Up to this point cross_sim is regular cross_sim in CA layer cross_sim = cross_sim.mean(dim=0) # Calculate mean across heads # PAIntA rescale y_cond = painta.painta_rescale( y_cond, self_v_cond, self_sim_cond, cross_sim, self.attn1.heads, self.attn1.to_out) # Rescale cond y_uncond = painta.painta_rescale( y_uncond, self_v_uncond, self_sim_uncond, cross_sim, self.attn1.heads, self.attn1.to_out) # Rescale uncond y = torch.cat([y_uncond, y_cond], dim=0) x = x + y x = x + self.attn2(self.norm2(x), context=context) # Cross Attn. x = x + self.ff(self.norm3(x)) return x