from PIL import Image import torch import torch.nn.functional as F from transformers.modeling_outputs import BaseModelOutputWithPooling from transformers.models.clip.modeling_clip import _make_causal_mask, _expand_mask from torch import autograd import accelerate import torch.nn as nn from PIL import Image import numpy as np def set_requires_grad(nets, requires_grad=False): """Set requies_grad=Fasle for all the networks to avoid unnecessary computations Parameters: nets (network list) -- a list of networks requires_grad (bool) -- whether the networks require gradients or not """ if not isinstance(nets, list): nets = [nets] for net in nets: if net is not None: for param in net.parameters(): param.requires_grad = requires_grad def discriminator_r1_loss_accelerator(accelerator, real_pred, real_w): grad_real, = accelerate.gradient( outputs=real_pred.sum(), inputs=real_w, create_graph=True #, only_inputs=True ) grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean() return grad_penalty class GANLoss(nn.Module): def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0): super(GANLoss, self).__init__() self.register_buffer('real_label', torch.tensor(target_real_label)) self.register_buffer('fake_label', torch.tensor(target_fake_label)) if use_lsgan: self.loss = nn.MSELoss() else: self.loss = nn.BCEWithLogitsLoss() def get_target_tensor(self, input, target_is_real): if target_is_real: target_tensor = self.real_label else: target_tensor = self.fake_label return target_tensor.expand_as(input) def __call__(self, input, target_is_real): target_tensor = self.get_target_tensor(input, target_is_real) return self.loss(input, target_tensor) def discriminator_r1_loss(real_pred, real_w): grad_real, = autograd.grad( outputs=real_pred.sum(), inputs=real_w, create_graph=True #, only_inputs=True ) grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean() return grad_penalty def add_noise_return_paras( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples, sqrt_alpha_prod, sqrt_one_minus_alpha_prod def text_encoder_forward( text_encoder = None, input_ids = None, name_batch = None, attention_mask = None, position_ids = None, output_attentions = None, output_hidden_states = None, return_dict = None, embedding_manager = None, only_embedding=False, random_embeddings = None, timesteps = None, ): output_attentions = output_attentions if output_attentions is not None else text_encoder.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else text_encoder.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else text_encoder.config.use_return_dict if input_ids is None: raise ValueError("You have to specify either input_ids") input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) hidden_states, other_return_dict = text_encoder.text_model.embeddings(input_ids=input_ids, position_ids=position_ids, name_batch = name_batch, embedding_manager=embedding_manager, only_embedding=only_embedding, random_embeddings = random_embeddings, timesteps = timesteps, ) if only_embedding: return hidden_states causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device) if attention_mask is not None: attention_mask = _expand_mask(attention_mask, hidden_states.dtype) encoder_outputs = text_encoder.text_model.encoder( inputs_embeds=hidden_states, attention_mask=attention_mask, causal_attention_mask=causal_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = encoder_outputs[0] last_hidden_state = text_encoder.text_model.final_layer_norm(last_hidden_state) if text_encoder.text_model.eos_token_id == 2: pooled_output = last_hidden_state[ torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), ] else: pooled_output = last_hidden_state[ torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == text_encoder.text_model.eos_token_id) .int() .argmax(dim=-1), ] if not return_dict: return (last_hidden_state, pooled_output) + encoder_outputs[1:] return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, )[0], other_return_dict def downsampling(img: torch.tensor, w: int, h: int) -> torch.tensor: return F.interpolate( img.unsqueeze(0).unsqueeze(1), size=(w, h), mode="bilinear", align_corners=True, ).squeeze() def image_grid(images, rows=2, cols=2): w, h = images[0].size grid = Image.new('RGB', size=(cols * w, rows * h)) for i, img in enumerate(images): grid.paste(img, box=(i % cols * w, i // cols * h)) return grid def latents_to_images(vae, latents, scale_factor=0.18215): """ Decode latents to PIL images. """ scaled_latents = 1.0 / scale_factor * latents.clone() images = vae.decode(scaled_latents).sample images = (images / 2 + 0.5).clamp(0, 1) images = images.detach().cpu().permute(0, 2, 3, 1).numpy() if images.ndim == 3: images = images[None, ...] images = (images * 255).round().astype("uint8") pil_images = [Image.fromarray(image) for image in images] return pil_images def merge_and_save_images(output_images): image_size = output_images[0].size merged_width = len(output_images) * image_size[0] merged_height = image_size[1] merged_image = Image.new('RGB', (merged_width, merged_height), (255, 255, 255)) for i, image in enumerate(output_images): merged_image.paste(image, (i * image_size[0], 0)) return merged_image class GANLoss(nn.Module): def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0): super(GANLoss, self).__init__() self.register_buffer('real_label', torch.tensor(target_real_label)) self.register_buffer('fake_label', torch.tensor(target_fake_label)) if use_lsgan: self.loss = nn.MSELoss() else: self.loss = nn.BCELoss() def get_target_tensor(self, input, target_is_real): if target_is_real: target_tensor = self.real_label else: target_tensor = self.fake_label return target_tensor.expand_as(input) def __call__(self, input, target_is_real): target_tensor = self.get_target_tensor(input, target_is_real) return self.loss(input, target_tensor)