Spaces:
Running
on
Zero
Running
on
Zero
from PIL import Image | |
import torch | |
import torch.nn.functional as F | |
from transformers.modeling_outputs import BaseModelOutputWithPooling | |
from transformers.models.clip.modeling_clip import _make_causal_mask, _expand_mask | |
from torch import autograd | |
import accelerate | |
import torch.nn as nn | |
from PIL import Image | |
import numpy as np | |
def set_requires_grad(nets, requires_grad=False): | |
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations | |
Parameters: | |
nets (network list) -- a list of networks | |
requires_grad (bool) -- whether the networks require gradients or not | |
""" | |
if not isinstance(nets, list): | |
nets = [nets] | |
for net in nets: | |
if net is not None: | |
for param in net.parameters(): | |
param.requires_grad = requires_grad | |
def discriminator_r1_loss_accelerator(accelerator, real_pred, real_w): | |
grad_real, = accelerate.gradient( | |
outputs=real_pred.sum(), inputs=real_w, create_graph=True #, only_inputs=True | |
) | |
grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean() | |
return grad_penalty | |
class GANLoss(nn.Module): | |
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0): | |
super(GANLoss, self).__init__() | |
self.register_buffer('real_label', torch.tensor(target_real_label)) | |
self.register_buffer('fake_label', torch.tensor(target_fake_label)) | |
if use_lsgan: | |
self.loss = nn.MSELoss() | |
else: | |
self.loss = nn.BCEWithLogitsLoss() | |
def get_target_tensor(self, input, target_is_real): | |
if target_is_real: | |
target_tensor = self.real_label | |
else: | |
target_tensor = self.fake_label | |
return target_tensor.expand_as(input) | |
def __call__(self, input, target_is_real): | |
target_tensor = self.get_target_tensor(input, target_is_real) | |
return self.loss(input, target_tensor) | |
def discriminator_r1_loss(real_pred, real_w): | |
grad_real, = autograd.grad( | |
outputs=real_pred.sum(), inputs=real_w, create_graph=True #, only_inputs=True | |
) | |
grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean() | |
return grad_penalty | |
def add_noise_return_paras( | |
self, | |
original_samples: torch.FloatTensor, | |
noise: torch.FloatTensor, | |
timesteps: torch.IntTensor, | |
) -> torch.FloatTensor: | |
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples | |
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) | |
timesteps = timesteps.to(original_samples.device) | |
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 | |
sqrt_alpha_prod = sqrt_alpha_prod.flatten() | |
while len(sqrt_alpha_prod.shape) < len(original_samples.shape): | |
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) | |
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 | |
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() | |
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): | |
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) | |
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise | |
return noisy_samples, sqrt_alpha_prod, sqrt_one_minus_alpha_prod | |
def text_encoder_forward( | |
text_encoder = None, | |
input_ids = None, | |
name_batch = None, | |
attention_mask = None, | |
position_ids = None, | |
output_attentions = None, | |
output_hidden_states = None, | |
return_dict = None, | |
embedding_manager = None, | |
only_embedding=False, | |
random_embeddings = None, | |
timesteps = None, | |
): | |
output_attentions = output_attentions if output_attentions is not None else text_encoder.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else text_encoder.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else text_encoder.config.use_return_dict | |
if input_ids is None: | |
raise ValueError("You have to specify either input_ids") | |
input_shape = input_ids.size() | |
input_ids = input_ids.view(-1, input_shape[-1]) | |
hidden_states, other_return_dict = text_encoder.text_model.embeddings(input_ids=input_ids, | |
position_ids=position_ids, | |
name_batch = name_batch, | |
embedding_manager=embedding_manager, | |
only_embedding=only_embedding, | |
random_embeddings = random_embeddings, | |
timesteps = timesteps, | |
) | |
if only_embedding: | |
return hidden_states | |
causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device) | |
if attention_mask is not None: | |
attention_mask = _expand_mask(attention_mask, hidden_states.dtype) | |
encoder_outputs = text_encoder.text_model.encoder( | |
inputs_embeds=hidden_states, | |
attention_mask=attention_mask, | |
causal_attention_mask=causal_attention_mask, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
last_hidden_state = encoder_outputs[0] | |
last_hidden_state = text_encoder.text_model.final_layer_norm(last_hidden_state) | |
if text_encoder.text_model.eos_token_id == 2: | |
pooled_output = last_hidden_state[ | |
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), | |
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), | |
] | |
else: | |
pooled_output = last_hidden_state[ | |
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), | |
(input_ids.to(dtype=torch.int, device=last_hidden_state.device) == text_encoder.text_model.eos_token_id) | |
.int() | |
.argmax(dim=-1), | |
] | |
if not return_dict: | |
return (last_hidden_state, pooled_output) + encoder_outputs[1:] | |
return BaseModelOutputWithPooling( | |
last_hidden_state=last_hidden_state, | |
pooler_output=pooled_output, | |
hidden_states=encoder_outputs.hidden_states, | |
attentions=encoder_outputs.attentions, | |
)[0], other_return_dict | |
def downsampling(img: torch.tensor, w: int, h: int) -> torch.tensor: | |
return F.interpolate( | |
img.unsqueeze(0).unsqueeze(1), | |
size=(w, h), | |
mode="bilinear", | |
align_corners=True, | |
).squeeze() | |
def image_grid(images, rows=2, cols=2): | |
w, h = images[0].size | |
grid = Image.new('RGB', size=(cols * w, rows * h)) | |
for i, img in enumerate(images): | |
grid.paste(img, box=(i % cols * w, i // cols * h)) | |
return grid | |
def latents_to_images(vae, latents, scale_factor=0.18215): | |
""" | |
Decode latents to PIL images. | |
""" | |
scaled_latents = 1.0 / scale_factor * latents.clone() | |
images = vae.decode(scaled_latents).sample | |
images = (images / 2 + 0.5).clamp(0, 1) | |
images = images.detach().cpu().permute(0, 2, 3, 1).numpy() | |
if images.ndim == 3: | |
images = images[None, ...] | |
images = (images * 255).round().astype("uint8") | |
pil_images = [Image.fromarray(image) for image in images] | |
return pil_images | |
def merge_and_save_images(output_images): | |
image_size = output_images[0].size | |
merged_width = len(output_images) * image_size[0] | |
merged_height = image_size[1] | |
merged_image = Image.new('RGB', (merged_width, merged_height), (255, 255, 255)) | |
for i, image in enumerate(output_images): | |
merged_image.paste(image, (i * image_size[0], 0)) | |
return merged_image | |
class GANLoss(nn.Module): | |
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0): | |
super(GANLoss, self).__init__() | |
self.register_buffer('real_label', torch.tensor(target_real_label)) | |
self.register_buffer('fake_label', torch.tensor(target_fake_label)) | |
if use_lsgan: | |
self.loss = nn.MSELoss() | |
else: | |
self.loss = nn.BCELoss() | |
def get_target_tensor(self, input, target_is_real): | |
if target_is_real: | |
target_tensor = self.real_label | |
else: | |
target_tensor = self.fake_label | |
return target_tensor.expand_as(input) | |
def __call__(self, input, target_is_real): | |
target_tensor = self.get_target_tensor(input, target_is_real) | |
return self.loss(input, target_tensor) |