TextDiffuser / model /layout_transformer.py
imjunaidafzal's picture
Duplicate from JingyeChen22/TextDiffuser
bb90efe
# ------------------------------------------
# TextDiffuser: Diffusion Models as Text Painters
# Paper Link: https://arxiv.org/abs/2305.10855
# Code Link: https://github.com/microsoft/unilm/tree/master/textdiffuser
# Copyright (c) Microsoft Corporation.
# This file define the Layout Transformer for predicting the layout of keywords.
# ------------------------------------------
import torch
import torch.nn as nn
from transformers import CLIPTokenizer, CLIPTextModel
class TextConditioner(nn.Module):
def __init__(self):
super(TextConditioner, self).__init__()
self.transformer = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14')
self.tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14')
# fix
self.transformer.eval()
for param in self.transformer.parameters():
param.requires_grad = False
def forward(self, prompt_list):
batch_encoding = self.tokenizer(prompt_list, truncation=True, max_length=77, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
text_embedding = self.transformer(batch_encoding["input_ids"].cuda())
return text_embedding.last_hidden_state.cuda(), batch_encoding["attention_mask"].cuda() # 1, 77, 768 / 1, 768
class LayoutTransformer(nn.Module):
def __init__(self, layer_number=2):
super(LayoutTransformer, self).__init__()
self.encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
self.transformer = torch.nn.TransformerEncoder(
self.encoder_layer, num_layers=layer_number
)
self.decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
self.decoder_transformer = torch.nn.TransformerDecoder(
self.decoder_layer, num_layers=layer_number
)
self.mask_embedding = nn.Embedding(2,512)
self.length_embedding = nn.Embedding(256,512)
self.width_embedding = nn.Embedding(256,512)
self.position_embedding = nn.Embedding(256,512)
self.state_embedding = nn.Embedding(256,512)
self.match_embedding = nn.Embedding(256,512)
self.x_embedding = nn.Embedding(512,512)
self.y_embedding = nn.Embedding(512,512)
self.w_embedding = nn.Embedding(512,512)
self.h_embedding = nn.Embedding(512,512)
self.encoder_target_embedding = nn.Embedding(256,512)
self.input_layer = nn.Sequential(
nn.Linear(768, 512),
nn.ReLU(),
nn.Linear(512, 512),
)
self.output_layer = nn.Sequential(
nn.Linear(512, 128),
nn.ReLU(),
nn.Linear(128, 4),
)
def forward(self, x, length, width, mask, state, match, target, right_shifted_boxes, train=False, encoder_embedding=None):
# detect whether the encoder_embedding is cached
if encoder_embedding is None:
# augmentation
if train:
width = width + torch.randint(-3, 3, (width.shape[0], width.shape[1])).cuda()
x = self.input_layer(x) # (1, 77, 512)
width_embedding = self.width_embedding(torch.clamp(width, 0, 255).long()) # (1, 77, 512)
encoder_target_embedding = self.encoder_target_embedding(target[:,:,0].long()) # (1, 77, 512)
pe_embedding = self.position_embedding(torch.arange(77).cuda()).unsqueeze(0) # (1, 77, 512)
total_embedding = x + width_embedding + pe_embedding + encoder_target_embedding # combine all the embeddings (1, 77, 512)
total_embedding = total_embedding.permute(1,0,2) # (77, 1, 512)
encoder_embedding = self.transformer(total_embedding) # (77, 1, 512)
right_shifted_boxes_resize = (right_shifted_boxes * 512).long() # (1, 8, 4)
right_shifted_boxes_resize = torch.clamp(right_shifted_boxes_resize, 0, 511) # (1, 8, 4)
# decoder pe
pe_decoder = torch.arange(8).cuda() # (8, )
pe_embedding_decoder = self.position_embedding(pe_decoder).unsqueeze(0) # (1, 8, 512)
decoder_input = pe_embedding_decoder + self.x_embedding(right_shifted_boxes_resize[:,:,0]) + self.y_embedding(right_shifted_boxes_resize[:,:,1]) + self.w_embedding(right_shifted_boxes_resize[:,:,2]) + self.h_embedding(right_shifted_boxes_resize[:,:,3]) # (1, 8, 512)
decoder_input = decoder_input.permute(1,0,2) # (8, 1, 512)
# generate triangular mask
mask = nn.Transformer.generate_square_subsequent_mask(8) # (8, 8)
mask = mask.cuda() # (8, 8)
decoder_result = self.decoder_transformer(decoder_input, encoder_embedding, tgt_mask=mask) # (8, 1, 512)
decoder_result = decoder_result.permute(1,0,2) # (1, 8, 512)
box_prediction = self.output_layer(decoder_result) # (1, 8, 4)
return box_prediction, encoder_embedding