import ast import gc import torch from collections import OrderedDict from diffusers.models.attention_processor import AttnProcessor2_0 from diffusers.models.attention import BasicTransformerBlock import wandb def extract_into_tensor(a, t, x_shape): b, *_ = t.shape out = a.gather(-1, t) return out.reshape(b, *((1,) * (len(x_shape) - 1))) def is_attn(name): return "attn1" or "attn2" == name.split(".")[-1] def set_processors(attentions): for attn in attentions: attn.set_processor(AttnProcessor2_0()) def set_torch_2_attn(unet): optim_count = 0 for name, module in unet.named_modules(): if is_attn(name): if isinstance(module, torch.nn.ModuleList): for m in module: if isinstance(m, BasicTransformerBlock): set_processors([m.attn1, m.attn2]) optim_count += 1 if optim_count > 0: print(f"{optim_count} Attention layers using Scaled Dot Product Attention.") # From LatentConsistencyModel.get_guidance_scale_embedding def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32): """ See Args: timesteps (`torch.Tensor`): generate embedding vectors at these timesteps embedding_dim (`int`, *optional*, defaults to 512): dimension of the embeddings to generate dtype: data type of the generated embeddings Returns: `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` """ assert len(w.shape) == 1 w = w * 1000.0 half_dim = embedding_dim // 2 emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) emb =[:, None] * emb[None, :] emb =[torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb def append_dims(x, target_dims): """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" dims_to_append = target_dims - x.ndim if dims_to_append < 0: raise ValueError( f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" ) return x[(...,) + (None,) * dims_to_append] # From LCMScheduler.get_scalings_for_boundary_condition_discrete def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): scaled_timestep = timestep_scaling * timestep c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 return c_skip, c_out # Compare LCMScheduler.step, Step 4 def get_predicted_original_sample( model_output, timesteps, sample, prediction_type, alphas, sigmas ): alphas = extract_into_tensor(alphas, timesteps, sample.shape) sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) if prediction_type == "epsilon": pred_x_0 = (sample - sigmas * model_output) / alphas elif prediction_type == "sample": pred_x_0 = model_output elif prediction_type == "v_prediction": pred_x_0 = alphas * sample - sigmas * model_output else: raise ValueError( f"Prediction type {prediction_type} is not supported; # Based on step 4 in DDIMScheduler.step
def get_predicted_noise(
    model_output, timesteps, sample, prediction_type, alphas, sigmas
):
    alphas = extract_into_tensor(alphas, timesteps, sample.shape)
    sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
    if prediction_type == "epsilon":
        pred_epsilon = model_output
    elif prediction_type == "sample":
        pred_epsilon = (sample - alphas * model_output) / sigmas
    elif prediction_type == "v_prediction":
        pred_epsilon = alphas * model_output + sigmas * sample
    else:
        raise ValueError(
            f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
            f" are supported."
        )

    return pred_epsilon def param_optim(model, condition, extra_params=None, is_lora=False, negation=None):
    extra_params = extra_params if len(extra_params.keys()) > 0 else None
    return {
        "model": model,
        "condition": condition,
        "extra_params": extra_params,
        "is_lora": is_lora,
        "negation": negation,
    }


def create_optim_params(name="param", params=None, lr=5e-6, extra_params=None):
    params = {"name": name, "params": params, "lr": lr}
    if extra_params is not None:
        for k, v in extra_params.items():
            params[k] = v

    return params


def create_optimizer_params(model_list, lr):
    import itertools

    optimizer_params = []

    for optim in model_list:
        model, condition, extra_params, is_lora, negation = optim.values()
        # Check if we are doing LoRA training.
        if is_lora and condition and isinstance(model, list):
            params = create_optim_params(
                params=itertools.chain(*model), extra_params=extra_params
            )
            optimizer_params.append(params)
            continue

        if is_lora and condition and not isinstance(model, list):
            for n, p in model.named_parameters():
                if "lora" in n:
                    params = create_optim_params(n, p, lr, extra_params)
                    optimizer_params.append(params)
            continue

        # If this is true, we can train it.
        if condition:
            for n, p in model.named_parameters():
                should_negate = "lora" in n and not is_lora
                if should_negate:
                    continue

                params = create_optim_params(n, p, lr, extra_params)
                optimizer_params.append(params)

    return optimizer_params


def handle_trainable_modules(
    model, trainable_modules=None, is_enabled=True, negation=None
):
    acc = []
    unfrozen_params = 0

    if trainable_modules is not None:
        unlock_all = any([name == "all" for name in trainable_modules])
        if unlock_all:
            model.requires_grad_(True)
            unfrozen_params = len(list(model.parameters()))
        else:
            model.requires_grad_(False)
            for name, param in model.named_parameters():
                for tm in trainable_modules:
                    if all([tm in name, name not in acc, "lora" not in name]):
                        param.requires_grad_(is_enabled)
                        acc.append(name)
                        unfrozen_params += 1


def huber_loss(pred, target, huber_c=0.001):
    loss = torch.sqrt((pred.float() - target.float()) ** 2 + huber_c**2) - huber_c
    return loss.mean()


@torch.no_grad()
def update_ema(target_params, source_params, rate=0.99):
    """
    Update target parameters to be closer to those of source parameters using
    an exponential moving average.

    :param target_params: the target parameter sequence.
    :param source_params: the source parameter sequence.
    :param rate: the EMA rate (closer to 1 means slower). """ for targ, src in zip(target_params, source_params): targ.detach().mul_(rate).add_(src, alpha=1 - rate) def log_validation_video(pipeline, args, accelerator, save_fps): if args.seed is None: generator = None else: generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) validation_prompts = [ "An astronaut riding a horse.", "Darth vader surfing in waves.", "Robot dancing in times square.", "Clown fish swimming through the coral reef.", "A child excitedly swings on a rusty swing set, laughter filling the air.", "With the style of van gogh, A young couple dances under the moonlight by the lake.", "A young woman with glasses is jogging in the park wearing a pink headband.", "Impressionist style, a yellow rubber duck floating on the wave on the sunset", ] video_logs = [] for _, prompt in enumerate(validation_prompts): with torch.autocast("cuda"): videos = pipeline( prompt=prompt, frames=args.n_frames, num_inference_steps=4, num_videos_per_prompt=2, generator=generator, ) videos = (videos.clamp(-1.0, 1.0) + 1.0) / 2.0 videos = (videos * 255).to(torch.uint8).permute(0, 2, 1, 3, 4).cpu().numpy() video_logs.append({"validation_prompt": prompt, "videos": videos}) for tracker in accelerator.trackers: if == "wandb": formatted_videos = [] for log in video_logs: videos = log["videos"] validation_prompt = log["validation_prompt"] for video in videos: video = wandb.Video(video, caption=validation_prompt, fps=save_fps) formatted_videos.append(video) tracker.log({f"validation": formatted_videos}) del pipeline gc.collect() def tuple_type(s): if isinstance(s, tuple): return s value = ast.literal_eval(s) if isinstance(value, tuple): return value raise TypeError("Argument must be a tuple") def load_model_checkpoint(model, ckpt): def load_checkpoint(model, ckpt, full_strict): state_dict = torch.load(ckpt, map_location="cpu") if "state_dict" in list(state_dict.keys()): state_dict = state_dict["state_dict"] model.load_state_dict(state_dict, strict=full_strict) del state_dict gc.collect() return model load_checkpoint(model, ckpt, full_strict=True) print(">>> model checkpoint loaded.") return model