CharacterFactory / models /celeb_embeddings.py
wangqinghehe's picture
0515_first_upload
3ab16a9
raw
history blame
No virus
2.4 kB
import torch
import torch.nn as nn
from functools import partial
import clip
from einops import rearrange, repeat
from transformers import CLIPTokenizer, CLIPTextModel
import kornia
import numpy as np
import os
def embedding_forward(
self,
input_ids = None,
position_ids = None,
name_batch = None,
inputs_embeds = None,
embedding_manager = None,
only_embedding=True,
random_embeddings = None,
timesteps = None,
) -> torch.Tensor:
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
if inputs_embeds is None:
inputs_embeds = self.token_embedding(input_ids)
if only_embedding:
return inputs_embeds
if embedding_manager is not None:
inputs_embeds, other_return_dict = embedding_manager(input_ids, inputs_embeds, name_batch, random_embeddings, timesteps)
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
return embeddings, other_return_dict
@torch.no_grad()
def _get_celeb_embeddings_basis(tokenizer, text_encoder, good_names_txt):
device = text_encoder.device
max_length = 77
with open(good_names_txt, "r") as f:
celeb_names = f.read().splitlines()
''' get tokens and embeddings '''
all_embeddings = []
for name in celeb_names:
batch_encoding = tokenizer(name, truncation=True, return_tensors="pt")
tokens = batch_encoding["input_ids"].to(device)[:, 1:3]
embeddings = text_encoder.text_model.embeddings(input_ids=tokens, only_embedding=True)
all_embeddings.append(embeddings)
all_embeddings: torch.Tensor = torch.cat(all_embeddings, dim=0)
print('[all_embeddings loaded] shape =', all_embeddings.shape,
'max:', all_embeddings.max(),
'min={}', all_embeddings.min())
name_emb_mean = all_embeddings.mean(0)
name_emb_std = all_embeddings.std(0)
print('[name_emb_mean loaded] shape =', name_emb_mean.shape,
'max:', name_emb_mean.max(),
'min={}', name_emb_mean.min())
return name_emb_mean, name_emb_std