CharacterFactory / utils.py
wangqinghehe's picture
0515_first_upload
3ab16a9
raw
history blame
8.88 kB
from PIL import Image
import torch
import torch.nn.functional as F
from transformers.modeling_outputs import BaseModelOutputWithPooling
from transformers.models.clip.modeling_clip import _make_causal_mask, _expand_mask
from torch import autograd
import accelerate
import torch.nn as nn
from PIL import Image
import numpy as np
def set_requires_grad(nets, requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
Parameters:
nets (network list) -- a list of networks
requires_grad (bool) -- whether the networks require gradients or not
"""
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad
def discriminator_r1_loss_accelerator(accelerator, real_pred, real_w):
grad_real, = accelerate.gradient(
outputs=real_pred.sum(), inputs=real_w, create_graph=True #, only_inputs=True
)
grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
return grad_penalty
class GANLoss(nn.Module):
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0):
super(GANLoss, self).__init__()
self.register_buffer('real_label', torch.tensor(target_real_label))
self.register_buffer('fake_label', torch.tensor(target_fake_label))
if use_lsgan:
self.loss = nn.MSELoss()
else:
self.loss = nn.BCEWithLogitsLoss()
def get_target_tensor(self, input, target_is_real):
if target_is_real:
target_tensor = self.real_label
else:
target_tensor = self.fake_label
return target_tensor.expand_as(input)
def __call__(self, input, target_is_real):
target_tensor = self.get_target_tensor(input, target_is_real)
return self.loss(input, target_tensor)
def discriminator_r1_loss(real_pred, real_w):
grad_real, = autograd.grad(
outputs=real_pred.sum(), inputs=real_w, create_graph=True #, only_inputs=True
)
grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
return grad_penalty
def add_noise_return_paras(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples, sqrt_alpha_prod, sqrt_one_minus_alpha_prod
def text_encoder_forward(
text_encoder = None,
input_ids = None,
name_batch = None,
attention_mask = None,
position_ids = None,
output_attentions = None,
output_hidden_states = None,
return_dict = None,
embedding_manager = None,
only_embedding=False,
random_embeddings = None,
timesteps = None,
):
output_attentions = output_attentions if output_attentions is not None else text_encoder.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else text_encoder.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else text_encoder.config.use_return_dict
if input_ids is None:
raise ValueError("You have to specify either input_ids")
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
hidden_states, other_return_dict = text_encoder.text_model.embeddings(input_ids=input_ids,
position_ids=position_ids,
name_batch = name_batch,
embedding_manager=embedding_manager,
only_embedding=only_embedding,
random_embeddings = random_embeddings,
timesteps = timesteps,
)
if only_embedding:
return hidden_states
causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device)
if attention_mask is not None:
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
encoder_outputs = text_encoder.text_model.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = text_encoder.text_model.final_layer_norm(last_hidden_state)
if text_encoder.text_model.eos_token_id == 2:
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
]
else:
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
(input_ids.to(dtype=torch.int, device=last_hidden_state.device) == text_encoder.text_model.eos_token_id)
.int()
.argmax(dim=-1),
]
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)[0], other_return_dict
def downsampling(img: torch.tensor, w: int, h: int) -> torch.tensor:
return F.interpolate(
img.unsqueeze(0).unsqueeze(1),
size=(w, h),
mode="bilinear",
align_corners=True,
).squeeze()
def image_grid(images, rows=2, cols=2):
w, h = images[0].size
grid = Image.new('RGB', size=(cols * w, rows * h))
for i, img in enumerate(images):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
def latents_to_images(vae, latents, scale_factor=0.18215):
"""
Decode latents to PIL images.
"""
scaled_latents = 1.0 / scale_factor * latents.clone()
images = vae.decode(scaled_latents).sample
images = (images / 2 + 0.5).clamp(0, 1)
images = images.detach().cpu().permute(0, 2, 3, 1).numpy()
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return pil_images
def merge_and_save_images(output_images):
image_size = output_images[0].size
merged_width = len(output_images) * image_size[0]
merged_height = image_size[1]
merged_image = Image.new('RGB', (merged_width, merged_height), (255, 255, 255))
for i, image in enumerate(output_images):
merged_image.paste(image, (i * image_size[0], 0))
return merged_image
class GANLoss(nn.Module):
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0):
super(GANLoss, self).__init__()
self.register_buffer('real_label', torch.tensor(target_real_label))
self.register_buffer('fake_label', torch.tensor(target_fake_label))
if use_lsgan:
self.loss = nn.MSELoss()
else:
self.loss = nn.BCELoss()
def get_target_tensor(self, input, target_is_real):
if target_is_real:
target_tensor = self.real_label
else:
target_tensor = self.fake_label
return target_tensor.expand_as(input)
def __call__(self, input, target_is_real):
target_tensor = self.get_target_tensor(input, target_is_real)
return self.loss(input, target_tensor)