diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..f5ba256c2767b7b8dabdbeee17ff205b5dec74e3 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +input/transition/1/2-Wide[[:space:]]angle[[:space:]]shot[[:space:]]of[[:space:]]an[[:space:]]alien[[:space:]]planet[[:space:]]with[[:space:]]cherry[[:space:]]blossom[[:space:]]forest-2.png filter=lfs diff=lfs merge=lfs -text diff --git a/Close-up-essence-is-poured-from-bottleKodak-Vision3-50_slow-motion_0000_001.mp4 b/Close-up-essence-is-poured-from-bottleKodak-Vision3-50_slow-motion_0000_001.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..4597d28c2c40c9453bb7ad0d50e97c941256db2a Binary files /dev/null and b/Close-up-essence-is-poured-from-bottleKodak-Vision3-50_slow-motion_0000_001.mp4 differ diff --git a/Close-up_essence_is_poured_from_bottleKodak_Vision.png b/Close-up_essence_is_poured_from_bottleKodak_Vision.png new file mode 100644 index 0000000000000000000000000000000000000000..5b8b907e01d44e7384bd405fcc33738e03743c0a Binary files /dev/null and b/Close-up_essence_is_poured_from_bottleKodak_Vision.png differ diff --git a/README.md b/README.md index 2b1b9352da3370aa6a359f52027aaa0c1cbd42cf..3d44a9eb6f817c0400a464757222b47bb91e6382 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,65 @@ ---- -title: Seine Gradio -emoji: 🦀 -colorFrom: pink -colorTo: blue -sdk: gradio -sdk_version: 4.7.1 -app_file: app.py -pinned: false ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# SEINE +This repository is the official implementation of [SEINE](https://arxiv.org/abs/2310.20700). + +**[SEINE: Short-to-Long Video Diffusion Model for Generative Transition and Prediction](https://arxiv.org/abs/2310.20700)** + +[Arxiv Report](https://arxiv.org/abs/2310.20700) | [Project Page](https://vchitect.github.io/SEINE-project/) + + + + +## Setups for Inference + +### Prepare Environment +``` +conda env create -f env.yaml +conda activate seine +``` + +### Downlaod our model and T2I base model +Download our model checkpoint from [Google Drive](https://drive.google.com/drive/folders/1cWfeDzKJhpb0m6HA5DoMOH0_ItuUY95b?usp=sharing) and save to directory of ```pre-trained``` + + +Our model is based on Stable diffusion v1.4, you may download [Stable Diffusion v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) to the director of ``` pre-trained ``` + +Now under `./pretrained`, you should be able to see the following: +``` +├── pretrained_models +│ ├── seine.pt +│ ├── stable-diffusion-v1-4 +│ │ ├── ... +└── └── ├── ... + ├── ... +``` + +#### Inference for I2V +```python +python sample_scripts/with_mask_sample.py --config configs/sample_i2v.yaml +``` +The generated video will be saved in ```./results/i2v```. + +#### Inference for Transition +```python +python sample_scripts/with_mask_sample.py --config configs/sample_transition.yaml +``` +The generated video will be saved in ```./results/transition```. + + + +#### More Details +You can modify ```./configs/sample_mask.yaml``` to change the generation conditions. +For example, +```ckpt``` is used to specify a model checkpoint. +```text_prompt``` is used to describe the content of the video. +```input_path``` is used to specify the path to the image. + + +## BibTeX +```bibtex +@article{chen2023seine, +title={SEINE: Short-to-Long Video Diffusion Model for Generative Transition and Prediction}, +author={Chen, Xinyuan and Wang, Yaohui and Zhang, Lingjun and Zhuang, Shaobin and Ma, Xin and Yu, Jiashuo and Wang, Yali and Lin, Dahua and Qiao, Yu and Liu, Ziwei}, +journal={arXiv preprint arXiv:2310.20700}, +year={2023} +} +``` \ No newline at end of file diff --git a/The-picture-shows-the-beauty-of-the-sea-and-at-the-sam_slow-motion_0000_11301.mp4 b/The-picture-shows-the-beauty-of-the-sea-and-at-the-sam_slow-motion_0000_11301.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..26919637200af54d10dc81edc3aee61db78b975d Binary files /dev/null and b/The-picture-shows-the-beauty-of-the-sea-and-at-the-sam_slow-motion_0000_11301.mp4 differ diff --git a/The-picture-shows-the-beauty-of-the-sea-and-at-the-sam_slow-motion_0000_6600.mp4 b/The-picture-shows-the-beauty-of-the-sea-and-at-the-sam_slow-motion_0000_6600.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..465b1835c7ae07c6fe7d5b00b30860fdd115578e Binary files /dev/null and b/The-picture-shows-the-beauty-of-the-sea-and-at-the-sam_slow-motion_0000_6600.mp4 differ diff --git a/The_picture_shows_the_beauty_of_the_sea.png b/The_picture_shows_the_beauty_of_the_sea.png new file mode 100644 index 0000000000000000000000000000000000000000..6f5847286d1507fe4c66f218755eff2d5d154aca Binary files /dev/null and b/The_picture_shows_the_beauty_of_the_sea.png differ diff --git a/The_picture_shows_the_beauty_of_the_sea_.jpg b/The_picture_shows_the_beauty_of_the_sea_.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cd2b1e2b31adf8f5e150df283d2129f78a64308e Binary files /dev/null and b/The_picture_shows_the_beauty_of_the_sea_.jpg differ diff --git a/__pycache__/download.cpython-310.pyc b/__pycache__/download.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38c01e289a088357ff4d3badca50d3d4aa6b57ca Binary files /dev/null and b/__pycache__/download.cpython-310.pyc differ diff --git a/__pycache__/download.cpython-311.pyc b/__pycache__/download.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7eede41205a3440a8bfd431afbcdcf884908cf1 Binary files /dev/null and b/__pycache__/download.cpython-311.pyc differ diff --git a/__pycache__/download.cpython-39.pyc b/__pycache__/download.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c313425340f296f3ffcbc19e7d4d8f85324b0a02 Binary files /dev/null and b/__pycache__/download.cpython-39.pyc differ diff --git a/__pycache__/utils.cpython-310.pyc b/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26147e0d3fbc6bd37be6964535a3efb5bf018cd2 Binary files /dev/null and b/__pycache__/utils.cpython-310.pyc differ diff --git a/__pycache__/utils.cpython-311.pyc b/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..caa8846ff5a21ce40d0485a2dee58476f3a76b81 Binary files /dev/null and b/__pycache__/utils.cpython-311.pyc differ diff --git a/__pycache__/utils.cpython-39.pyc b/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..616aae63f41717e7b04df098aed77a775746acfa Binary files /dev/null and b/__pycache__/utils.cpython-39.pyc differ diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..70a69d821d6526af85a6cdd1b44f402c4b0b4c07 --- /dev/null +++ b/app.py @@ -0,0 +1,183 @@ +import gradio as gr +from image_to_video import model_i2v_fun, get_input, auto_inpainting, setup_seed +from omegaconf import OmegaConf +import torch +from diffusers.utils.import_utils import is_xformers_available +import torchvision +from utils import mask_generation_before +import os +import cv2 + +config_path = "/mnt/petrelfs/zhouyan/project/i2v/configs/sample_i2v.yaml" +args = OmegaConf.load(config_path) +device = "cuda" if torch.cuda.is_available() else "cpu" + +# ------- get model --------------- +# model_i2V = model_i2v_fun() +# model_i2V.to("cuda") + +# vae, model, text_encoder, diffusion = model_i2v_fun(args) +# vae.to(device) +# model.to(device) +# text_encoder.to(device) + +# if args.use_fp16: +# vae.to(dtype=torch.float16) +# model.to(dtype=torch.float16) +# text_encoder.to(dtype=torch.float16) + +# if args.enable_xformers_memory_efficient_attention and device=="cuda": +# if is_xformers_available(): +# model.enable_xformers_memory_efficient_attention() +# else: +# raise ValueError("xformers is not available. Make sure it is installed correctly") + + +css = """ +h1 { + text-align: center; +} +#component-0 { + max-width: 730px; + margin: auto; +} +""" + +def infer(prompt, image_inp, seed_inp, ddim_steps): + setup_seed(seed_inp) + args.num_sampling_steps = ddim_steps + ###先测试Image的返回类型 + print(prompt, seed_inp, ddim_steps, type(image_inp)) + img = cv2.imread(image_inp) + new_size = [img.shape[0],img.shape[1]] + # if(img.shape[0]==512 and img.shape[1]==512): + # args.image_size = [512,512] + # elif(img.shape[0]==320 and img.shape[1]==512): + # args.image_size = [320, 512] + # elif(img.shape[0]==292 and img.shape[1]==512): + # args.image_size = [292,512] + # else: + # raise ValueError("Please enter image of right size") + # print(args.image_size) + args.image_size = new_size + + vae, model, text_encoder, diffusion = model_i2v_fun(args) + vae.to(device) + model.to(device) + text_encoder.to(device) + + if args.use_fp16: + vae.to(dtype=torch.float16) + model.to(dtype=torch.float16) + text_encoder.to(dtype=torch.float16) + + if args.enable_xformers_memory_efficient_attention and device=="cuda": + if is_xformers_available(): + model.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + + video_input, reserve_frames = get_input(image_inp, args) + video_input = video_input.to(device).unsqueeze(0) + mask = mask_generation_before(args.mask_type, video_input.shape, video_input.dtype, device) + masked_video = video_input * (mask == 0) + prompt = "tilt up, high quality, stable " + prompt = prompt + args.additional_prompt + video_clip = auto_inpainting(args, video_input, masked_video, mask, prompt, vae, text_encoder, diffusion, model, device,) + video_ = ((video_clip * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1) + torchvision.io.write_video(os.path.join(args.save_img_path, prompt+ '.mp4'), video_, fps=8) + + + + # video = model_i2V(prompt, image_inp, seed_inp, ddim_steps) + + return os.path.join(args.save_img_path, prompt+ '.mp4') + + + +def clean(): + # return gr.Image.update(value=None, visible=False), gr.Video.update(value=None) + return gr.Video.update(value=None) + + +title = """ +
+
+

+ SEINE: Image-to-Video generation +

+
+

+ Apply SEINE to generate a video +

+
+""" + + + +with gr.Blocks(css='style.css') as demo: + gr.Markdown("
SEINE: Image-to-Video generation
") + with gr.Column(elem_id="col-container"): + # gr.HTML(title) + + with gr.Row(): + with gr.Column(): + image_inp = gr.Image(type='filepath') + + with gr.Column(): + + prompt = gr.Textbox(label="Prompt", placeholder="enter prompt", show_label=True, elem_id="prompt-in") + + with gr.Row(): + # control_task = gr.Dropdown(label="Task", choices=["Text-2-video", "Image-2-video"], value="Text-2-video", multiselect=False, elem_id="controltask-in") + ddim_steps = gr.Slider(label='Steps', minimum=50, maximum=300, value=250, step=1) + seed_inp = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, value=250, elem_id="seed-in") + + # ddim_steps = gr.Slider(label='Steps', minimum=50, maximum=300, value=250, step=1) + + + + submit_btn = gr.Button("Generate video") + clean_btn = gr.Button("Clean video") + + video_out = gr.Video(label="Video result", elem_id="video-output", width = 800) + inputs = [prompt,image_inp, seed_inp, ddim_steps] + outputs = [video_out] + ex = gr.Examples( + examples = [["/mnt/petrelfs/zhouyan/project/i2v/The_picture_shows_the_beauty_of_the_sea_.jpg","A video of the beauty of the sea",123,50], + ["/mnt/petrelfs/zhouyan/project/i2v/The_picture_shows_the_beauty_of_the_sea.png","A video of the beauty of the sea",123,50], + ["/mnt/petrelfs/zhouyan/project/i2v/Close-up_essence_is_poured_from_bottleKodak_Vision.png","A video of close-up essence is poured from bottleKodak Vision",123,50]], + fn = infer, + inputs = [image_inp, prompt, seed_inp, ddim_steps], + outputs=[video_out], + cache_examples=False + + + ) + ex.dataset.headers = [""] + # gr.Markdown("
some examples
") + # with gr.Row(): + # gr.Image(value="/mnt/petrelfs/zhouyan/project/i2v/The_picture_shows_the_beauty_of_the_sea_.jpg") + # gr.Image(value="/mnt/petrelfs/zhouyan/project/i2v/The_picture_shows_the_beauty_of_the_sea.png") + # gr.Image(value="/mnt/petrelfs/zhouyan/project/i2v/Close-up_essence_is_poured_from_bottleKodak_Vision.png") + # with gr.Row(): + # gr.Video(value="/mnt/petrelfs/zhouyan/project/i2v/The-picture-shows-the-beauty-of-the-sea-and-at-the-sam_slow-motion_0000_11301.mp4") + # gr.Video(value="/mnt/petrelfs/zhouyan/project/i2v/The-picture-shows-the-beauty-of-the-sea-and-at-the-sam_slow-motion_0000_6600.mp4") + # gr.Video(value="/mnt/petrelfs/zhouyan/project/i2v/Close-up-essence-is-poured-from-bottleKodak-Vision3-50_slow-motion_0000_001.mp4") + # control_task.change(change_task_options, inputs=[control_task], outputs=[canny_opt, hough_opt, normal_opt], queue=False) + clean_btn.click(clean, inputs=[], outputs=[video_out], queue=False) + submit_btn.click(infer, inputs, outputs) + # share_button.click(None, [], [], _js=share_js) + + +demo.queue(max_size=12).launch(server_name="0.0.0.0",server_port=7861) + + diff --git a/configs/sample_i2v.yaml b/configs/sample_i2v.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1adc88be3d46317bc3abdeba0f428d371e05319e --- /dev/null +++ b/configs/sample_i2v.yaml @@ -0,0 +1,36 @@ + +ckpt: "/mnt/petrelfs/share_data/chenxinyuan/code/SEINE-release/pre-trained/seine.pt" +# save_img_path: "./results/i2v/" +save_img_path: "/mnt/petrelfs/share_data/zhouyan/gradio_i2v/" +pretrained_model_path: "pre-trained/stable-diffusion-v1-4/" + +# model config: +model: TAVU +num_frames: 16 +frame_interval: 1 +image_size: [512, 512] +#image_size: [320, 512] +# image_size: [512, 512] + +# model speedup +use_compile: False +use_fp16: True +enable_xformers_memory_efficient_attention: True +img_path: "/mnt/petrelfs/zhouyan/tmp/last" +# sample config: +seed: +run_time: 13 +cfg_scale: 8.0 +sample_method: 'ddpm' +num_sampling_steps: 250 +text_prompt: ["slow motion"] +additional_prompt: ", slow motion." +negative_prompt: "" +do_classifier_free_guidance: True + +# autoregressive config: +# input_path: "/mnt/petrelfs/zhouyan/tmp/未来上海/WechatIMG9434.jpg" +input_path: "/mnt/petrelfs/zhouyan/tmp/last" +researve_frame: 1 +mask_type: "first1" +use_mask: True diff --git a/configs/sample_transition.yaml b/configs/sample_transition.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0f656444a90e63b08c908cb8736cca94774ccfd7 --- /dev/null +++ b/configs/sample_transition.yaml @@ -0,0 +1,33 @@ + +ckpt: "pre-trained/0020000.pt" +save_img_path: "./results/transition/" +pretrained_model_path: "pre-trained/stable-diffusion-v1-4/" + +# model config: +model: TAVU +num_frames: 16 +frame_interval: 1 +#image_size: [240, 560] +#image_size: [320, 512] +image_size: [512, 512] + +# model speedup +use_compile: False +use_fp16: True +enable_xformers_memory_efficient_attention: True + +# sample config: +seed: +run_time: 13 +cfg_scale: 8.0 +sample_method: 'ddpm' +num_sampling_steps: 250 +text_prompt: ['smooth transition'] +additional_prompt: "smooth transition." +negative_prompt: "" +do_classifier_free_guidance: True + +# autoregressive config: +input_path: 'input/transition/1' +mask_type: "onelast1" +use_mask: True diff --git a/datasets/__pycache__/video_transforms.cpython-311.pyc b/datasets/__pycache__/video_transforms.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e8defe15cd33423279052745be6019274e7ca01 Binary files /dev/null and b/datasets/__pycache__/video_transforms.cpython-311.pyc differ diff --git a/datasets/__pycache__/video_transforms.cpython-39.pyc b/datasets/__pycache__/video_transforms.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bfc7a93b58f4080c24a9bc2765d04386a75dbd6f Binary files /dev/null and b/datasets/__pycache__/video_transforms.cpython-39.pyc differ diff --git a/datasets/video_transforms.py b/datasets/video_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..30607a68859fb804370e4bfe8664eada2a840743 --- /dev/null +++ b/datasets/video_transforms.py @@ -0,0 +1,472 @@ +import torch +import random +import numbers +from torchvision.transforms import RandomCrop, RandomResizedCrop +from PIL import Image + +def _is_tensor_video_clip(clip): + if not torch.is_tensor(clip): + raise TypeError("clip should be Tensor. Got %s" % type(clip)) + + if not clip.ndimension() == 4: + raise ValueError("clip should be 4D. Got %dD" % clip.dim()) + + return True + + +def center_crop_arr(pil_image, image_size): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = image_size / min(*pil_image.size) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + arr = np.array(pil_image) + crop_y = (arr.shape[0] - image_size) // 2 + crop_x = (arr.shape[1] - image_size) // 2 + return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) + + +def crop(clip, i, j, h, w): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + """ + if len(clip.size()) != 4: + raise ValueError("clip should be a 4D tensor") + return clip[..., i : i + h, j : j + w] + + +def resize(clip, target_size, interpolation_mode): + if len(target_size) != 2: + raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") + return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False) + +def resize_scale(clip, target_size, interpolation_mode): + if len(target_size) != 2: + raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") + H, W = clip.size(-2), clip.size(-1) + scale_ = target_size[0] / min(H, W) + return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False) + +def resize_with_scale_factor(clip, scale_factor, interpolation_mode): + return torch.nn.functional.interpolate(clip, scale_factor=scale_factor, mode=interpolation_mode, align_corners=False) + +def resize_scale_with_height(clip, target_size, interpolation_mode): + H, W = clip.size(-2), clip.size(-1) + scale_ = target_size / H + return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False) + +def resize_scale_with_weight(clip, target_size, interpolation_mode): + H, W = clip.size(-2), clip.size(-1) + scale_ = target_size / W + return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False) + + +def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"): + """ + Do spatial cropping and resizing to the video clip + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + i (int): i in (i,j) i.e coordinates of the upper left corner. + j (int): j in (i,j) i.e coordinates of the upper left corner. + h (int): Height of the cropped region. + w (int): Width of the cropped region. + size (tuple(int, int)): height and width of resized clip + Returns: + clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W) + """ + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + clip = crop(clip, i, j, h, w) + clip = resize(clip, size, interpolation_mode) + return clip + + +def center_crop(clip, crop_size): + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + h, w = clip.size(-2), clip.size(-1) + # print(clip.shape) + th, tw = crop_size + if h < th or w < tw: + # print(h, w) + raise ValueError("height {} and width {} must be no smaller than crop_size".format(h, w)) + + i = int(round((h - th) / 2.0)) + j = int(round((w - tw) / 2.0)) + return crop(clip, i, j, th, tw) + + +def center_crop_using_short_edge(clip): + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + h, w = clip.size(-2), clip.size(-1) + if h < w: + th, tw = h, h + i = 0 + j = int(round((w - tw) / 2.0)) + else: + th, tw = w, w + i = int(round((h - th) / 2.0)) + j = 0 + return crop(clip, i, j, th, tw) + + +def random_shift_crop(clip): + ''' + Slide along the long edge, with the short edge as crop size + ''' + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + h, w = clip.size(-2), clip.size(-1) + + if h <= w: + long_edge = w + short_edge = h + else: + long_edge = h + short_edge =w + + th, tw = short_edge, short_edge + + i = torch.randint(0, h - th + 1, size=(1,)).item() + j = torch.randint(0, w - tw + 1, size=(1,)).item() + return crop(clip, i, j, th, tw) + + +def to_tensor(clip): + """ + Convert tensor data type from uint8 to float, divide value by 255.0 and + permute the dimensions of clip tensor + Args: + clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) + Return: + clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) + """ + _is_tensor_video_clip(clip) + if not clip.dtype == torch.uint8: + raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype)) + # return clip.float().permute(3, 0, 1, 2) / 255.0 + return clip.float() / 255.0 + + +def normalize(clip, mean, std, inplace=False): + """ + Args: + clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W) + mean (tuple): pixel RGB mean. Size is (3) + std (tuple): pixel standard deviation. Size is (3) + Returns: + normalized clip (torch.tensor): Size is (T, C, H, W) + """ + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + if not inplace: + clip = clip.clone() + mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device) + # print(mean) + std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device) + clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) + return clip + + +def hflip(clip): + """ + Args: + clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W) + Returns: + flipped clip (torch.tensor): Size is (T, C, H, W) + """ + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + return clip.flip(-1) + + +class RandomCropVideo: + def __init__(self, size): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + Returns: + torch.tensor: randomly cropped video clip. + size is (T, C, OH, OW) + """ + i, j, h, w = self.get_params(clip) + return crop(clip, i, j, h, w) + + def get_params(self, clip): + h, w = clip.shape[-2:] + th, tw = self.size + + if h < th or w < tw: + raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}") + + if w == tw and h == th: + return 0, 0, h, w + + i = torch.randint(0, h - th + 1, size=(1,)).item() + j = torch.randint(0, w - tw + 1, size=(1,)).item() + + return i, j, th, tw + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size})" + +class CenterCropResizeVideo: + ''' + First use the short side for cropping length, + center crop video, then resize to the specified size + ''' + def __init__( + self, + size, + interpolation_mode="bilinear", + ): + if isinstance(size, tuple): + if len(size) != 2: + raise ValueError(f"size should be tuple (height, width), instead got {size}") + self.size = size + else: + self.size = (size, size) + + self.interpolation_mode = interpolation_mode + + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + Returns: + torch.tensor: scale resized / center cropped video clip. + size is (T, C, crop_size, crop_size) + """ + # print(clip.shape) + clip_center_crop = center_crop_using_short_edge(clip) + # print(clip_center_crop.shape) 320 512 + clip_center_crop_resize = resize(clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode) + return clip_center_crop_resize + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" + +class WebVideo320512: + def __init__( + self, + size, + interpolation_mode="bilinear", + ): + if isinstance(size, tuple): + if len(size) != 2: + raise ValueError(f"size should be tuple (height, width), instead got {size}") + self.size = size + else: + self.size = (size, size) + + self.interpolation_mode = interpolation_mode + + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + Returns: + torch.tensor: scale resized / center cropped video clip. + size is (T, C, crop_size, crop_size) + """ + # add aditional one pixel for avoiding error in center crop + h, w = clip.size(-2), clip.size(-1) + # print('before resize', clip.shape) + if h < 320: + clip = resize_scale_with_height(clip=clip, target_size=321, interpolation_mode=self.interpolation_mode) + # print('after h resize', clip.shape) + if w < 512: + clip = resize_scale_with_weight(clip=clip, target_size=513, interpolation_mode=self.interpolation_mode) + # print('after w resize', clip.shape) + clip_center_crop = center_crop(clip, self.size) + # print(clip_center_crop.shape) + return clip_center_crop + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" + +class UCFCenterCropVideo: + ''' + First scale to the specified size in equal proportion to the short edge, + then center cropping + ''' + def __init__( + self, + size, + interpolation_mode="bilinear", + ): + if isinstance(size, tuple): + if len(size) != 2: + raise ValueError(f"size should be tuple (height, width), instead got {size}") + self.size = size + else: + self.size = (size, size) + + self.interpolation_mode = interpolation_mode + + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + Returns: + torch.tensor: scale resized / center cropped video clip. + size is (T, C, crop_size, crop_size) + """ + clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode) + clip_center_crop = center_crop(clip_resize, self.size) + return clip_center_crop + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" + + +class CenterCropVideo: + def __init__( + self, + size, + interpolation_mode="bilinear", + ): + if isinstance(size, tuple): + if len(size) != 2: + raise ValueError(f"size should be tuple (height, width), instead got {size}") + self.size = size + else: + self.size = (size, size) + + self.interpolation_mode = interpolation_mode + + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + Returns: + torch.tensor: center cropped video clip. + size is (T, C, crop_size, crop_size) + """ + clip_center_crop = center_crop(clip, self.size) + return clip_center_crop + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" + + +class NormalizeVideo: + """ + Normalize the video clip by mean subtraction and division by standard deviation + Args: + mean (3-tuple): pixel RGB mean + std (3-tuple): pixel RGB standard deviation + inplace (boolean): whether do in-place normalization + """ + + def __init__(self, mean, std, inplace=False): + self.mean = mean + self.std = std + self.inplace = inplace + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W) + """ + return normalize(clip, self.mean, self.std, self.inplace) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})" + + +class ToTensorVideo: + """ + Convert tensor data type from uint8 to float, divide value by 255.0 and + permute the dimensions of clip tensor + """ + + def __init__(self): + pass + + def __call__(self, clip): + """ + Args: + clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) + Return: + clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) + """ + return to_tensor(clip) + + def __repr__(self) -> str: + return self.__class__.__name__ + + +class ResizeVideo(): + ''' + First use the short side for cropping length, + center crop video, then resize to the specified size + ''' + def __init__( + self, + size, + interpolation_mode="bilinear", + ): + if isinstance(size, tuple): + if len(size) != 2: + raise ValueError(f"size should be tuple (height, width), instead got {size}") + self.size = size + else: + self.size = (size, size) + + self.interpolation_mode = interpolation_mode + + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + Returns: + torch.tensor: scale resized / center cropped video clip. + size is (T, C, crop_size, crop_size) + """ + clip_resize = resize(clip, target_size=self.size, interpolation_mode=self.interpolation_mode) + return clip_resize + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" + +# ------------------------------------------------------------ +# --------------------- Sampling --------------------------- +# ------------------------------------------------------------ +class TemporalRandomCrop(object): + """Temporally crop the given frame indices at a random location. + + Args: + size (int): Desired length of frames will be seen in the model. + """ + + def __init__(self, size): + self.size = size + + def __call__(self, total_frames): + rand_end = max(0, total_frames - self.size - 1) + begin_index = random.randint(0, rand_end) + end_index = min(begin_index + self.size, total_frames) + return begin_index, end_index diff --git a/diffusion/__init__.py b/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c9dbf6cf0bd6b9d1a8f65e0a31e9a84cacc03189 --- /dev/null +++ b/diffusion/__init__.py @@ -0,0 +1,47 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +from . import gaussian_diffusion as gd +from .respace import SpacedDiffusion, space_timesteps + + +def create_diffusion( + timestep_respacing, + noise_schedule="linear", + use_kl=False, + sigma_small=False, + predict_xstart=False, + # learn_sigma=True, + learn_sigma=False, # for unet + rescale_learned_sigmas=False, + diffusion_steps=1000 +): + betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) + if use_kl: + loss_type = gd.LossType.RESCALED_KL + elif rescale_learned_sigmas: + loss_type = gd.LossType.RESCALED_MSE + else: + loss_type = gd.LossType.MSE + if timestep_respacing is None or timestep_respacing == "": + timestep_respacing = [diffusion_steps] + return SpacedDiffusion( + use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), + betas=betas, + model_mean_type=( + gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X + ), + model_var_type=( + ( + gd.ModelVarType.FIXED_LARGE + if not sigma_small + else gd.ModelVarType.FIXED_SMALL + ) + if not learn_sigma + else gd.ModelVarType.LEARNED_RANGE + ), + loss_type=loss_type + # rescale_timesteps=rescale_timesteps, + ) diff --git a/diffusion/__pycache__/__init__.cpython-310.pyc b/diffusion/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbe0bf5ae45331b0b01b6e2be8d6bf98856616b9 Binary files /dev/null and b/diffusion/__pycache__/__init__.cpython-310.pyc differ diff --git a/diffusion/__pycache__/__init__.cpython-311.pyc b/diffusion/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71cfaa457c807a557188101f385fbd4897dd5e7d Binary files /dev/null and b/diffusion/__pycache__/__init__.cpython-311.pyc differ diff --git a/diffusion/__pycache__/__init__.cpython-38.pyc b/diffusion/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a48b1fe9b85aa42792e946f74fae965d2f650f8d Binary files /dev/null and b/diffusion/__pycache__/__init__.cpython-38.pyc differ diff --git a/diffusion/__pycache__/__init__.cpython-39.pyc b/diffusion/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22042b311c16bebb698b2bcea485fb1d8a764288 Binary files /dev/null and b/diffusion/__pycache__/__init__.cpython-39.pyc differ diff --git a/diffusion/__pycache__/diffusion_utils.cpython-310.pyc b/diffusion/__pycache__/diffusion_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f767c0cec4a07a8fc7514932f33cb098d0d50b06 Binary files /dev/null and b/diffusion/__pycache__/diffusion_utils.cpython-310.pyc differ diff --git a/diffusion/__pycache__/diffusion_utils.cpython-311.pyc b/diffusion/__pycache__/diffusion_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a027d079ff091e66ede1025bdd7f0b55c3ada2e Binary files /dev/null and b/diffusion/__pycache__/diffusion_utils.cpython-311.pyc differ diff --git a/diffusion/__pycache__/diffusion_utils.cpython-38.pyc b/diffusion/__pycache__/diffusion_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a34e84da7fa4e0d9308261b7d36086bc937e12c Binary files /dev/null and b/diffusion/__pycache__/diffusion_utils.cpython-38.pyc differ diff --git a/diffusion/__pycache__/diffusion_utils.cpython-39.pyc b/diffusion/__pycache__/diffusion_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91a512ba1145dcdf80d9657678b5863086d05530 Binary files /dev/null and b/diffusion/__pycache__/diffusion_utils.cpython-39.pyc differ diff --git a/diffusion/__pycache__/gaussian_diffusion.cpython-310.pyc b/diffusion/__pycache__/gaussian_diffusion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..498d9cef36b220c5a5e114f669526f63c9bd8c24 Binary files /dev/null and b/diffusion/__pycache__/gaussian_diffusion.cpython-310.pyc differ diff --git a/diffusion/__pycache__/gaussian_diffusion.cpython-311.pyc b/diffusion/__pycache__/gaussian_diffusion.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15b8678feb0ae06f08579a74a4f4a3a8c89ce9f4 Binary files /dev/null and b/diffusion/__pycache__/gaussian_diffusion.cpython-311.pyc differ diff --git a/diffusion/__pycache__/gaussian_diffusion.cpython-38.pyc b/diffusion/__pycache__/gaussian_diffusion.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9001e55c9a831fe3d90d37eddd22ff76638d8a91 Binary files /dev/null and b/diffusion/__pycache__/gaussian_diffusion.cpython-38.pyc differ diff --git a/diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc b/diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2b448aed2786f0987d1dece996873d25f4b513b Binary files /dev/null and b/diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc differ diff --git a/diffusion/__pycache__/respace.cpython-310.pyc b/diffusion/__pycache__/respace.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4dff3f1610d0085135fe532a76001f486031a8bf Binary files /dev/null and b/diffusion/__pycache__/respace.cpython-310.pyc differ diff --git a/diffusion/__pycache__/respace.cpython-311.pyc b/diffusion/__pycache__/respace.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99072dc324cfa6cc9cb0b4ad6f4ca2e388a0a6a2 Binary files /dev/null and b/diffusion/__pycache__/respace.cpython-311.pyc differ diff --git a/diffusion/__pycache__/respace.cpython-38.pyc b/diffusion/__pycache__/respace.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3fd2fd7fc04c8c2e92cb1c1be53610014f606bf6 Binary files /dev/null and b/diffusion/__pycache__/respace.cpython-38.pyc differ diff --git a/diffusion/__pycache__/respace.cpython-39.pyc b/diffusion/__pycache__/respace.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..901b269c5cd95268c3d3d06ff9a4a7835dbd8776 Binary files /dev/null and b/diffusion/__pycache__/respace.cpython-39.pyc differ diff --git a/diffusion/diffusion_utils.py b/diffusion/diffusion_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e493a6a3ecb91e553a53cc7eadee5cc0d1753060 --- /dev/null +++ b/diffusion/diffusion_utils.py @@ -0,0 +1,88 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +import torch as th +import numpy as np + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, th.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for th.exp(). + logvar1, logvar2 = [ + x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + th.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * th.exp(-logvar2) + ) + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) + + +def continuous_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a continuous Gaussian distribution. + :param x: the targets + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + centered_x = x - means + inv_stdv = th.exp(-log_scales) + normalized_x = centered_x * inv_stdv + log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) + return log_probs + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = th.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = th.where( + x < -0.999, + log_cdf_plus, + th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs diff --git a/diffusion/gaussian_diffusion.py b/diffusion/gaussian_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..fe1b571e088fa245a072d2bb4320eea3d240df02 --- /dev/null +++ b/diffusion/gaussian_diffusion.py @@ -0,0 +1,931 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + + +import math + +import numpy as np +import torch as th +import enum + +from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = enum.auto() # the model predicts x_{t-1} + START_X = enum.auto() # the model predicts x_0 + EPSILON = enum.auto() # the model predicts epsilon + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = enum.auto() + FIXED_SMALL = enum.auto() + FIXED_LARGE = enum.auto() + LEARNED_RANGE = enum.auto() + + +class LossType(enum.Enum): + MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = ( + enum.auto() + ) # use raw MSE loss (with RESCALED_KL when learning variances) + KL = enum.auto() # use the variational lower-bound + RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + warmup_time = int(num_diffusion_timesteps * warmup_frac) + betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) + return betas + + +def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): + """ + This is the deprecated API for creating beta schedules. + See get_named_beta_schedule() for the new library of schedules. + """ + if beta_schedule == "quad": + betas = ( + np.linspace( + beta_start ** 0.5, + beta_end ** 0.5, + num_diffusion_timesteps, + dtype=np.float64, + ) + ** 2 + ) + elif beta_schedule == "linear": + betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "warmup10": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) + elif beta_schedule == "warmup50": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) + elif beta_schedule == "const": + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 + betas = 1.0 / np.linspace( + num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 + ) + else: + raise NotImplementedError(beta_schedule) + assert betas.shape == (num_diffusion_timesteps,) + return betas + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + return get_beta_schedule( + "linear", + beta_start=scale * 0.0001, + beta_end=scale * 0.02, + # diffuser stable diffusion + # beta_start=scale * 0.00085, + # beta_end=scale * 0.012, + num_diffusion_timesteps=num_diffusion_timesteps, + ) + elif schedule_name == "squaredcos_cap_v2": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class GaussianDiffusion: + """ + Utilities for training and sampling diffusion models. + Original ported from this codebase: + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + """ + + def __init__( + self, + *, + betas, + model_mean_type, + model_var_type, + loss_type + ): + + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.loss_type = loss_type + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = ( + betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:]) + ) if len(self.posterior_variance) > 1 else np.array([]) + + self.posterior_mean_coef1 = ( + betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + self.posterior_mean_coef2 = ( + (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) + ) + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + In other words, sample from q(x_t | x_0). + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None, + mask=None, x_start=None, use_concat=False): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, F, C = x.shape[:3] + assert t.shape == (B,) + if use_concat: + model_output = model(th.concat([x, mask, x_start], dim=1), t, **model_kwargs) + else: + model_output = model(x, t, **model_kwargs) + try: + model_output = model_output.sample # for tav unet + except: + pass + # model_output = model(x, t, **model_kwargs) + if isinstance(model_output, tuple): + model_output, extra = model_output + else: + extra = None + + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert model_output.shape == (B, F, C * 2, *x.shape[3:]) + model_output, model_var_values = th.split(model_output, C, dim=2) + min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart( + self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) + ) + model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) + + assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + "extra": extra, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, t, **model_kwargs) + new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + See condition_mean() for details on cond_fn. + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) + return out + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + mask=None, + x_start=None, + use_concat=False + ): + """ + Sample x_{t-1} from the model at the given timestep. + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + mask=mask, + x_start=x_start, + use_concat=use_concat + ) + noise = th.randn_like(x) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + mask=None, + x_start=None, + use_concat=False, + ): + """ + Generate samples from the model. + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :return: a non-differentiable batch of samples. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + mask=mask, + x_start=x_start, + use_concat=use_concat + ): + final = sample + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + mask=None, + x_start=None, + use_concat=False + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + mask=mask, + x_start=x_start, + use_concat=use_concat + ) + yield out + img = out["sample"] + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + mask=None, + x_start=None, + use_concat=False + ): + """ + Sample x_{t-1} from the model using DDIM. + Same usage as p_sample(). + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + mask=mask, + x_start=x_start, + use_concat=use_concat + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = ( + eta + * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + * th.sqrt(1 - alpha_bar / alpha_bar_prev) + ) + # Equation 12. + noise = th.randn_like(x) + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps + ) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x + - out["pred_xstart"] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + # Equation 12. reversed + mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + mask=None, + x_start=None, + use_concat=False + ): + """ + Generate samples from the model using DDIM. + Same usage as p_sample_loop(). + """ + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + mask=mask, + x_start=x_start, + use_concat=use_concat + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + mask=None, + x_start=None, + use_concat=False + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.ddim_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + mask=mask, + x_start=x_start, + use_concat=use_concat + ) + yield out + img = out["sample"] + + def _vb_terms_bpd( + self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None + ): + """ + Get a term for the variational lower-bound. + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + ) + out = self.p_mean_variance( + model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs + ) + kl = normal_kl( + true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] + ) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + ) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = th.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"]} + + def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, use_mask=False): + """ + Compute training losses for a single timestep. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + if use_mask: + x_t = th.cat([x_t[:, :4], x_start[:, 4:]], dim=1) + terms = {} + + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + terms["loss"] = self._vb_terms_bpd( + model=model, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + model_kwargs=model_kwargs, + )["output"] + if self.loss_type == LossType.RESCALED_KL: + terms["loss"] *= self.num_timesteps + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + model_output = model(x_t, t, **model_kwargs) + try: + # model_output = model(x_t, t, **model_kwargs).sample + model_output = model_output.sample # for tav unet + except: + pass + # model_output = model(x_t, t, **model_kwargs) + + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, F, C = x_t.shape[:3] + assert model_output.shape == (B, F, C * 2, *x_t.shape[3:]) + model_output, model_var_values = th.split(model_output, C, dim=2) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = th.cat([model_output.detach(), model_var_values], dim=2) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + target = { + ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + )[0], + ModelMeanType.START_X: x_start, + ModelMeanType.EPSILON: noise, + }[self.model_mean_type] + # assert model_output.shape == target.shape == x_start.shape + if use_mask: + terms["mse"] = mean_flat((target[:,:4] - model_output) ** 2) + else: + terms["mse"] = mean_flat((target - model_output) ** 2) + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + else: + raise NotImplementedError(self.loss_type) + + return terms + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl( + mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 + ) + return mean_flat(kl_prior) / np.log(2.0) + + def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): + """ + Compute the entire variational lower-bound, measured in bits-per-dim, + as well as other related quantities. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param clip_denoised: if True, clip denoised samples. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - total_bpd: the total variational lower-bound, per batch element. + - prior_bpd: the prior term in the lower-bound. + - vb: an [N x T] tensor of terms in the lower-bound. + - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. + - mse: an [N x T] tensor of epsilon MSEs for each timestep. + """ + device = x_start.device + batch_size = x_start.shape[0] + + vb = [] + xstart_mse = [] + mse = [] + for t in list(range(self.num_timesteps))[::-1]: + t_batch = th.tensor([t] * batch_size, device=device) + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) + # Calculate VLB term at the current timestep + with th.no_grad(): + out = self._vb_terms_bpd( + model, + x_start=x_start, + x_t=x_t, + t=t_batch, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + ) + vb.append(out["output"]) + xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) + eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) + mse.append(mean_flat((eps - noise) ** 2)) + + vb = th.stack(vb, dim=1) + xstart_mse = th.stack(xstart_mse, dim=1) + mse = th.stack(mse, dim=1) + + prior_bpd = self._prior_bpd(x_start) + total_bpd = vb.sum(dim=1) + prior_bpd + return { + "total_bpd": total_bpd, + "prior_bpd": prior_bpd, + "vb": vb, + "xstart_mse": xstart_mse, + "mse": mse, + } + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res + th.zeros(broadcast_shape, device=timesteps.device) diff --git a/diffusion/respace.py b/diffusion/respace.py new file mode 100644 index 0000000000000000000000000000000000000000..31c0092380367db477f154a39bf172ff64712fc4 --- /dev/null +++ b/diffusion/respace.py @@ -0,0 +1,130 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +import torch +import numpy as np +import torch as th + +from .gaussian_diffusion import GaussianDiffusion + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError( + f"cannot create exactly {num_timesteps} steps with an integer stride" + ) + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError( + f"cannot divide section of {size} steps into {section_count}" + ) + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + # @torch.compile + def training_losses( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel( + model, self.timestep_map, self.original_num_steps + ) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +class _WrappedModel: + def __init__(self, model, timestep_map, original_num_steps): + self.model = model + self.timestep_map = timestep_map + # self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + # if self.rescale_timesteps: + # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, new_ts, **kwargs) diff --git a/diffusion/timestep_sampler.py b/diffusion/timestep_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..a3f369847677d8dbaaadb8297691b1be92cf189f --- /dev/null +++ b/diffusion/timestep_sampler.py @@ -0,0 +1,150 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +from abc import ABC, abstractmethod + +import numpy as np +import torch as th +import torch.distributed as dist + + +def create_named_schedule_sampler(name, diffusion): + """ + Create a ScheduleSampler from a library of pre-defined samplers. + :param name: the name of the sampler. + :param diffusion: the diffusion object to sample for. + """ + if name == "uniform": + return UniformSampler(diffusion) + elif name == "loss-second-moment": + return LossSecondMomentResampler(diffusion) + else: + raise NotImplementedError(f"unknown schedule sampler: {name}") + + +class ScheduleSampler(ABC): + """ + A distribution over timesteps in the diffusion process, intended to reduce + variance of the objective. + By default, samplers perform unbiased importance sampling, in which the + objective's mean is unchanged. + However, subclasses may override sample() to change how the resampled + terms are reweighted, allowing for actual changes in the objective. + """ + + @abstractmethod + def weights(self): + """ + Get a numpy array of weights, one per diffusion step. + The weights needn't be normalized, but must be positive. + """ + + def sample(self, batch_size, device): + """ + Importance-sample timesteps for a batch. + :param batch_size: the number of timesteps. + :param device: the torch device to save to. + :return: a tuple (timesteps, weights): + - timesteps: a tensor of timestep indices. + - weights: a tensor of weights to scale the resulting losses. + """ + w = self.weights() + p = w / np.sum(w) + indices_np = np.random.choice(len(p), size=(batch_size,), p=p) + indices = th.from_numpy(indices_np).long().to(device) + weights_np = 1 / (len(p) * p[indices_np]) + weights = th.from_numpy(weights_np).float().to(device) + return indices, weights + + +class UniformSampler(ScheduleSampler): + def __init__(self, diffusion): + self.diffusion = diffusion + self._weights = np.ones([diffusion.num_timesteps]) + + def weights(self): + return self._weights + + +class LossAwareSampler(ScheduleSampler): + def update_with_local_losses(self, local_ts, local_losses): + """ + Update the reweighting using losses from a model. + Call this method from each rank with a batch of timesteps and the + corresponding losses for each of those timesteps. + This method will perform synchronization to make sure all of the ranks + maintain the exact same reweighting. + :param local_ts: an integer Tensor of timesteps. + :param local_losses: a 1D Tensor of losses. + """ + batch_sizes = [ + th.tensor([0], dtype=th.int32, device=local_ts.device) + for _ in range(dist.get_world_size()) + ] + dist.all_gather( + batch_sizes, + th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), + ) + + # Pad all_gather batches to be the maximum batch size. + batch_sizes = [x.item() for x in batch_sizes] + max_bs = max(batch_sizes) + + timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] + loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] + dist.all_gather(timestep_batches, local_ts) + dist.all_gather(loss_batches, local_losses) + timesteps = [ + x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] + ] + losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] + self.update_with_all_losses(timesteps, losses) + + @abstractmethod + def update_with_all_losses(self, ts, losses): + """ + Update the reweighting using losses from a model. + Sub-classes should override this method to update the reweighting + using losses from the model. + This method directly updates the reweighting without synchronizing + between workers. It is called by update_with_local_losses from all + ranks with identical arguments. Thus, it should have deterministic + behavior to maintain state across workers. + :param ts: a list of int timesteps. + :param losses: a list of float losses, one per timestep. + """ + + +class LossSecondMomentResampler(LossAwareSampler): + def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): + self.diffusion = diffusion + self.history_per_term = history_per_term + self.uniform_prob = uniform_prob + self._loss_history = np.zeros( + [diffusion.num_timesteps, history_per_term], dtype=np.float64 + ) + self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) + + def weights(self): + if not self._warmed_up(): + return np.ones([self.diffusion.num_timesteps], dtype=np.float64) + weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) + weights /= np.sum(weights) + weights *= 1 - self.uniform_prob + weights += self.uniform_prob / len(weights) + return weights + + def update_with_all_losses(self, ts, losses): + for t, loss in zip(ts, losses): + if self._loss_counts[t] == self.history_per_term: + # Shift out the oldest loss term. + self._loss_history[t, :-1] = self._loss_history[t, 1:] + self._loss_history[t, -1] = loss + else: + self._loss_history[t, self._loss_counts[t]] = loss + self._loss_counts[t] += 1 + + def _warmed_up(self): + return (self._loss_counts == self.history_per_term).all() diff --git a/download.py b/download.py new file mode 100644 index 0000000000000000000000000000000000000000..56602603b15868f2533ca0d083f61ee10f82e72f --- /dev/null +++ b/download.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Functions for downloading pre-trained DiT models +""" +from torchvision.datasets.utils import download_url +import torch +import os + + + +def find_model(model_name): + + checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage) + + if "ema" in checkpoint: # supports checkpoints from train.py + print('Ema existing!') + checkpoint = checkpoint["ema"] + return checkpoint + + +def download_model(model_name): + """ + Downloads a pre-trained DiT model from the web. + """ + assert model_name in pretrained_models + local_path = f'pretrained_models/{model_name}' + if not os.path.isfile(local_path): + os.makedirs('pretrained_models', exist_ok=True) + web_path = f'https://dl.fbaipublicfiles.com/DiT/models/{model_name}' + download_url(web_path, 'pretrained_models') + model = torch.load(local_path, map_location=lambda storage, loc: storage) + return model + + +if __name__ == "__main__": + # Download all DiT checkpoints + for model in pretrained_models: + download_model(model) + print('Done.') diff --git a/env.yaml b/env.yaml new file mode 100644 index 0000000000000000000000000000000000000000..50ed00de7e2555d5b73d953adaad552c81d36639 --- /dev/null +++ b/env.yaml @@ -0,0 +1,20 @@ +name: seine +channels: + - pytorch + - nvidia + - conda-forge + - defaults +dependencies: + - python=3.9.16 + - pytorch=2.0.1 + - pytorch-cuda=11.7 + - torchvision=0.15.2 + - pip + - pip: + - decord==0.6.0 + - diffusers==0.15.0 + - imageio==2.29.0 + - transformers==4.29.2 + - xformers==0.0.20 + - einops + - omegaconf diff --git a/huggingface-i2v/__init__.py b/huggingface-i2v/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/huggingface-i2v/requirements.txt b/huggingface-i2v/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/image_to_video/__init__.py b/image_to_video/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1aaa4fe07d2c1869909ed060f6b5545490d832b5 --- /dev/null +++ b/image_to_video/__init__.py @@ -0,0 +1,221 @@ +import os +import sys +import math +import docx +try: + import utils + + from diffusion import create_diffusion + from download import find_model +except: + # sys.path.append(os.getcwd()) + sys.path.append(os.path.split(sys.path[0])[0]) + # sys.path[0] + # os.path.split(sys.path[0]) + + + import utils + + from diffusion import create_diffusion + from download import find_model + +import torch +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +import argparse +import torchvision + +from einops import rearrange +from models import get_models +from torchvision.utils import save_image +from diffusers.models import AutoencoderKL +from models.clip import TextEmbedder +from omegaconf import OmegaConf +from PIL import Image +import numpy as np +from torchvision import transforms +sys.path.append("..") +from datasets import video_transforms +from utils import mask_generation_before +from natsort import natsorted +from diffusers.utils.import_utils import is_xformers_available + +config_path = "/mnt/petrelfs/zhouyan/project/i2v/configs/sample_i2v.yaml" +args = OmegaConf.load(config_path) +device = "cuda" if torch.cuda.is_available() else "cpu" +print(args) + +def model_i2v_fun(args): + if args.seed: + torch.manual_seed(args.seed) + torch.set_grad_enabled(False) + if args.ckpt is None: + raise ValueError("Please specify a checkpoint path using --ckpt ") + latent_h = args.image_size[0] // 8 + latent_w = args.image_size[1] // 8 + args.image_h = args.image_size[0] + args.image_w = args.image_size[1] + args.latent_h = latent_h + args.latent_w = latent_w + print("loading model") + model = get_models(args).to(device) + + if args.use_compile: + model = torch.compile(model) + ckpt_path = args.ckpt + state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)['ema'] + model.load_state_dict(state_dict) + + print('loading success') + + model.eval() + pretrained_model_path = args.pretrained_model_path + diffusion = create_diffusion(str(args.num_sampling_steps)) + vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device) + text_encoder = TextEmbedder(pretrained_model_path).to(device) + # if args.use_fp16: + # print('Warning: using half precision for inference') + # vae.to(dtype=torch.float16) + # model.to(dtype=torch.float16) + # text_encoder.to(dtype=torch.float16) + + return vae, model, text_encoder, diffusion + + +def auto_inpainting(args, video_input, masked_video, mask, prompt, vae, text_encoder, diffusion, model, device,): + b,f,c,h,w=video_input.shape + latent_h = args.image_size[0] // 8 + latent_w = args.image_size[1] // 8 + + # prepare inputs + if args.use_fp16: + z = torch.randn(1, 4, args.num_frames, args.latent_h, args.latent_w, dtype=torch.float16, device=device) # b,c,f,h,w + masked_video = masked_video.to(dtype=torch.float16) + mask = mask.to(dtype=torch.float16) + else: + z = torch.randn(1, 4, args.num_frames, args.latent_h, args.latent_w, device=device) # b,c,f,h,w + + + masked_video = rearrange(masked_video, 'b f c h w -> (b f) c h w').contiguous() + masked_video = vae.encode(masked_video).latent_dist.sample().mul_(0.18215) + masked_video = rearrange(masked_video, '(b f) c h w -> b c f h w', b=b).contiguous() + mask = torch.nn.functional.interpolate(mask[:,:,0,:], size=(latent_h, latent_w)).unsqueeze(1) + + # classifier_free_guidance + if args.do_classifier_free_guidance: + masked_video = torch.cat([masked_video] * 2) + mask = torch.cat([mask] * 2) + z = torch.cat([z] * 2) + prompt_all = [prompt] + [args.negative_prompt] + + else: + masked_video = masked_video + mask = mask + z = z + prompt_all = [prompt] + + text_prompt = text_encoder(text_prompts=prompt_all, train=False) + model_kwargs = dict(encoder_hidden_states=text_prompt, + class_labels=None, + cfg_scale=args.cfg_scale, + use_fp16=args.use_fp16,) # tav unet + + # Sample images: + if args.sample_method == 'ddim': + samples = diffusion.ddim_sample_loop( + model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device, \ + mask=mask, x_start=masked_video, use_concat=args.use_mask + ) + elif args.sample_method == 'ddpm': + samples = diffusion.p_sample_loop( + model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device, \ + mask=mask, x_start=masked_video, use_concat=args.use_mask + ) + samples, _ = samples.chunk(2, dim=0) # [1, 4, 16, 32, 32] + if args.use_fp16: + samples = samples.to(dtype=torch.float16) + + video_clip = samples[0].permute(1, 0, 2, 3).contiguous() # [16, 4, 32, 32] + video_clip = vae.decode(video_clip / 0.18215).sample # [16, 3, 256, 256] + return video_clip + +def get_input(path,args): + input_path = path + # input_path = args.input_path + transform_video = transforms.Compose([ + video_transforms.ToTensorVideo(), # TCHW + video_transforms.ResizeVideo((args.image_h, args.image_w)), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) + ]) + temporal_sample_func = video_transforms.TemporalRandomCrop(args.num_frames * args.frame_interval) + if input_path is not None: + print(f'loading video from {input_path}') + if os.path.isdir(input_path): + file_list = os.listdir(input_path) + video_frames = [] + if args.mask_type.startswith('onelast'): + num = int(args.mask_type.split('onelast')[-1]) + # get first and last frame + first_frame_path = os.path.join(input_path, natsorted(file_list)[0]) + last_frame_path = os.path.join(input_path, natsorted(file_list)[-1]) + first_frame = torch.as_tensor(np.array(Image.open(first_frame_path), dtype=np.uint8, copy=True)).unsqueeze(0) + last_frame = torch.as_tensor(np.array(Image.open(last_frame_path), dtype=np.uint8, copy=True)).unsqueeze(0) + for i in range(num): + video_frames.append(first_frame) + # add zeros to frames + num_zeros = args.num_frames-2*num + for i in range(num_zeros): + zeros = torch.zeros_like(first_frame) + video_frames.append(zeros) + for i in range(num): + video_frames.append(last_frame) + n = 0 + video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w + video_frames = transform_video(video_frames) + else: + for file in file_list: + if file.endswith('jpg') or file.endswith('png'): + image = torch.as_tensor(np.array(Image.open(os.path.join(input_path,file)), dtype=np.uint8, copy=True)).unsqueeze(0) + video_frames.append(image) + else: + continue + n = 0 + video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w + video_frames = transform_video(video_frames) + return video_frames, n + elif os.path.isfile(input_path): + _, full_file_name = os.path.split(input_path) + file_name, extention = os.path.splitext(full_file_name) + if extention == '.jpg' or extention == '.png': + # raise TypeError('a single image is not supported yet!!') + print("reading video from a image") + video_frames = [] + num = int(args.mask_type.split('first')[-1]) + first_frame = torch.as_tensor(np.array(Image.open(input_path), dtype=np.uint8, copy=True)).unsqueeze(0) + for i in range(num): + video_frames.append(first_frame) + num_zeros = args.num_frames-num + for i in range(num_zeros): + zeros = torch.zeros_like(first_frame) + video_frames.append(zeros) + n = 0 + video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w + video_frames = transform_video(video_frames) + return video_frames, n + else: + raise TypeError(f'{extention} is not supported !!') + else: + raise ValueError('Please check your path input!!') + else: + # raise ValueError('Need to give a video or some images') + print('given video is None, using text to video') + video_frames = torch.zeros(16,3,args.latent_h,args.latent_w,dtype=torch.uint8) + args.mask_type = 'all' + video_frames = transform_video(video_frames) + n = 0 + return video_frames, n + +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + \ No newline at end of file diff --git a/image_to_video/__pycache__/__init__.cpython-311.pyc b/image_to_video/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46430872ea451064ec2f7ed3af75ef5b30bde4eb Binary files /dev/null and b/image_to_video/__pycache__/__init__.cpython-311.pyc differ diff --git a/input/i2v/Close-up_essence_is_poured_from_bottleKodak_Vision.png b/input/i2v/Close-up_essence_is_poured_from_bottleKodak_Vision.png new file mode 100644 index 0000000000000000000000000000000000000000..5b8b907e01d44e7384bd405fcc33738e03743c0a Binary files /dev/null and b/input/i2v/Close-up_essence_is_poured_from_bottleKodak_Vision.png differ diff --git a/input/i2v/The_picture_shows_the_beauty_of_the_sea.png b/input/i2v/The_picture_shows_the_beauty_of_the_sea.png new file mode 100644 index 0000000000000000000000000000000000000000..6f5847286d1507fe4c66f218755eff2d5d154aca Binary files /dev/null and b/input/i2v/The_picture_shows_the_beauty_of_the_sea.png differ diff --git a/input/i2v/The_picture_shows_the_beauty_of_the_sea_and_at_the_same.png b/input/i2v/The_picture_shows_the_beauty_of_the_sea_and_at_the_same.png new file mode 100644 index 0000000000000000000000000000000000000000..9fc7d9c55d24fd1c167ba02facb016014806bd60 Binary files /dev/null and b/input/i2v/The_picture_shows_the_beauty_of_the_sea_and_at_the_same.png differ diff --git a/input/transition/1/1-Close-up shot of a blooming cherry tree, realism-1.png b/input/transition/1/1-Close-up shot of a blooming cherry tree, realism-1.png new file mode 100644 index 0000000000000000000000000000000000000000..25f5a3408bf8d590347732bf0ee96d395aba4807 Binary files /dev/null and b/input/transition/1/1-Close-up shot of a blooming cherry tree, realism-1.png differ diff --git a/input/transition/1/2-Wide angle shot of an alien planet with cherry blossom forest-2.png b/input/transition/1/2-Wide angle shot of an alien planet with cherry blossom forest-2.png new file mode 100644 index 0000000000000000000000000000000000000000..d36bae5d3d0551aeacf73999769e8e87031a0540 --- /dev/null +++ b/input/transition/1/2-Wide angle shot of an alien planet with cherry blossom forest-2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e98a00408112aee3d726be231a2c1f44c3ac58e3f2f54b4bded2e0a183f66529 +size 1005564 diff --git a/input/transition/2/1-Overhead view of a bustling city street at night, realism-1.png b/input/transition/2/1-Overhead view of a bustling city street at night, realism-1.png new file mode 100644 index 0000000000000000000000000000000000000000..95e9e244c4d7422fefb14d6c1a5947e90a83c430 Binary files /dev/null and b/input/transition/2/1-Overhead view of a bustling city street at night, realism-1.png differ diff --git a/input/transition/2/2-Aerial view of a futuristic city bathed in neon lights-2.png b/input/transition/2/2-Aerial view of a futuristic city bathed in neon lights-2.png new file mode 100644 index 0000000000000000000000000000000000000000..27c3776798133990666d421e50c8b0139f514121 Binary files /dev/null and b/input/transition/2/2-Aerial view of a futuristic city bathed in neon lights-2.png differ diff --git a/input/transition/3/1-Close-up shot of a candle lit in the darkness, realism-1png.png b/input/transition/3/1-Close-up shot of a candle lit in the darkness, realism-1png.png new file mode 100644 index 0000000000000000000000000000000000000000..54bdd1da39818d5c1b2274d5662b7447bb4d900c Binary files /dev/null and b/input/transition/3/1-Close-up shot of a candle lit in the darkness, realism-1png.png differ diff --git a/input/transition/3/2-Wide shot of a mystical land illuminated by giant glowing flowers-2.png b/input/transition/3/2-Wide shot of a mystical land illuminated by giant glowing flowers-2.png new file mode 100644 index 0000000000000000000000000000000000000000..17ba096583432c67b2011afcbec5e13fec259a96 Binary files /dev/null and b/input/transition/3/2-Wide shot of a mystical land illuminated by giant glowing flowers-2.png differ diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..69577f55b92072e09420e545df6d82c07212408f --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,54 @@ +import os +import sys +sys.path.append(os.path.split(sys.path[0])[0]) + +from .dit import DiT_models +from .uvit import UViT_models +from .unet import UNet3DConditionModel +from torch.optim.lr_scheduler import LambdaLR + +def customized_lr_scheduler(optimizer, warmup_steps=5000): # 5000 from u-vit + from torch.optim.lr_scheduler import LambdaLR + def fn(step): + if warmup_steps > 0: + return min(step / warmup_steps, 1) + else: + return 1 + return LambdaLR(optimizer, fn) + + +def get_lr_scheduler(optimizer, name, **kwargs): + if name == 'warmup': + return customized_lr_scheduler(optimizer, **kwargs) + elif name == 'cosine': + from torch.optim.lr_scheduler import CosineAnnealingLR + return CosineAnnealingLR(optimizer, **kwargs) + else: + raise NotImplementedError(name) + +def get_models(args): + + if 'DiT' in args.model: + return DiT_models[args.model]( + input_size=args.latent_size, + num_classes=args.num_classes, + class_guided=args.class_guided, + num_frames=args.num_frames, + use_lora=args.use_lora, + attention_mode=args.attention_mode + ) + elif 'UViT' in args.model: + return UViT_models[args.model]( + input_size=args.latent_size, + num_classes=args.num_classes, + class_guided=args.class_guided, + num_frames=args.num_frames, + use_lora=args.use_lora, + attention_mode=args.attention_mode + ) + elif 'TAV' in args.model: + pretrained_model_path = args.pretrained_model_path + return UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", use_concat=args.use_mask) + else: + raise '{} Model Not Supported!'.format(args.model) + \ No newline at end of file diff --git a/models/__pycache__/__init__.cpython-310.pyc b/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e38774aa7011ed9e7cb03f3530bfe238eee0e3d Binary files /dev/null and b/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/models/__pycache__/__init__.cpython-311.pyc b/models/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5921f914f6b92591485258426a52502a98ec31b2 Binary files /dev/null and b/models/__pycache__/__init__.cpython-311.pyc differ diff --git a/models/__pycache__/__init__.cpython-38.pyc b/models/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28ca91347df8cc40ce3df1ce8358596454159992 Binary files /dev/null and b/models/__pycache__/__init__.cpython-38.pyc differ diff --git a/models/__pycache__/__init__.cpython-39.pyc b/models/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c853d1fc24036b6d7d63f1135188773934b60261 Binary files /dev/null and b/models/__pycache__/__init__.cpython-39.pyc differ diff --git a/models/__pycache__/attention.cpython-310.pyc b/models/__pycache__/attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0dae6ab0508a6a216cbcdfaf1e5dc9d59edd9400 Binary files /dev/null and b/models/__pycache__/attention.cpython-310.pyc differ diff --git a/models/__pycache__/attention.cpython-311.pyc b/models/__pycache__/attention.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28861d5cbdedd15d1e1eec858106a4662cf1b544 Binary files /dev/null and b/models/__pycache__/attention.cpython-311.pyc differ diff --git a/models/__pycache__/attention.cpython-38.pyc b/models/__pycache__/attention.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ca1ff5762230c5921687794783619636ac0f36c Binary files /dev/null and b/models/__pycache__/attention.cpython-38.pyc differ diff --git a/models/__pycache__/attention.cpython-39.pyc b/models/__pycache__/attention.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21888f95c7191e914acd8084b4a4ffbb1abc67f4 Binary files /dev/null and b/models/__pycache__/attention.cpython-39.pyc differ diff --git a/models/__pycache__/clip.cpython-310.pyc b/models/__pycache__/clip.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..620862b1a6897caeb8178546df6fcdbb22464447 Binary files /dev/null and b/models/__pycache__/clip.cpython-310.pyc differ diff --git a/models/__pycache__/clip.cpython-311.pyc b/models/__pycache__/clip.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cef805947a4de50152b8361dc68996458471e406 Binary files /dev/null and b/models/__pycache__/clip.cpython-311.pyc differ diff --git a/models/__pycache__/clip.cpython-38.pyc b/models/__pycache__/clip.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b951c75c84546c47051527873cc5b76cf4c18bd Binary files /dev/null and b/models/__pycache__/clip.cpython-38.pyc differ diff --git a/models/__pycache__/clip.cpython-39.pyc b/models/__pycache__/clip.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..139b4d17dd6b6e6f7cfd0d6a28276b9b95721e66 Binary files /dev/null and b/models/__pycache__/clip.cpython-39.pyc differ diff --git a/models/__pycache__/dit.cpython-310.pyc b/models/__pycache__/dit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5435bde751ef3af56297c7e7943baed748722a4d Binary files /dev/null and b/models/__pycache__/dit.cpython-310.pyc differ diff --git a/models/__pycache__/dit.cpython-311.pyc b/models/__pycache__/dit.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce2ee43d6f041740f11adf6828786e73b658afca Binary files /dev/null and b/models/__pycache__/dit.cpython-311.pyc differ diff --git a/models/__pycache__/dit.cpython-38.pyc b/models/__pycache__/dit.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7dfaaf0356dff9964c21916d5bb1576668e699c7 Binary files /dev/null and b/models/__pycache__/dit.cpython-38.pyc differ diff --git a/models/__pycache__/dit.cpython-39.pyc b/models/__pycache__/dit.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..083eabc53fa89168564d9ffa335998b839ee11ee Binary files /dev/null and b/models/__pycache__/dit.cpython-39.pyc differ diff --git a/models/__pycache__/resnet.cpython-310.pyc b/models/__pycache__/resnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a538bd768a726c9b895b450a23562eb80192570 Binary files /dev/null and b/models/__pycache__/resnet.cpython-310.pyc differ diff --git a/models/__pycache__/resnet.cpython-311.pyc b/models/__pycache__/resnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bae6d6fd9729b1a61c3dbd4fa73d0a9b2da2ddc Binary files /dev/null and b/models/__pycache__/resnet.cpython-311.pyc differ diff --git a/models/__pycache__/resnet.cpython-38.pyc b/models/__pycache__/resnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc7845c892b85f62b7b9b9788a5aacfe858f9f55 Binary files /dev/null and b/models/__pycache__/resnet.cpython-38.pyc differ diff --git a/models/__pycache__/resnet.cpython-39.pyc b/models/__pycache__/resnet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0b02b1579bbdec188f32ba8687e1fdcf7ab9acd Binary files /dev/null and b/models/__pycache__/resnet.cpython-39.pyc differ diff --git a/models/__pycache__/unet.cpython-310.pyc b/models/__pycache__/unet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25bbfb2f81f9166c59a3530ee0343739896c37df Binary files /dev/null and b/models/__pycache__/unet.cpython-310.pyc differ diff --git a/models/__pycache__/unet.cpython-311.pyc b/models/__pycache__/unet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0af6d539b3bef28575dd35bc33422b2f6c4cf896 Binary files /dev/null and b/models/__pycache__/unet.cpython-311.pyc differ diff --git a/models/__pycache__/unet.cpython-38.pyc b/models/__pycache__/unet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55ac6ce89e455d280789c3b2921ffe4e9d737301 Binary files /dev/null and b/models/__pycache__/unet.cpython-38.pyc differ diff --git a/models/__pycache__/unet.cpython-39.pyc b/models/__pycache__/unet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3ae31907c6dd8d167137f2d63282a1f06029ee7 Binary files /dev/null and b/models/__pycache__/unet.cpython-39.pyc differ diff --git a/models/__pycache__/unet_blocks.cpython-310.pyc b/models/__pycache__/unet_blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2aa1f327d934ad77e0ea205b6ab0587190acff7a Binary files /dev/null and b/models/__pycache__/unet_blocks.cpython-310.pyc differ diff --git a/models/__pycache__/unet_blocks.cpython-311.pyc b/models/__pycache__/unet_blocks.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ffef0cc9e7a3aef69839bed4899f52763d90e92 Binary files /dev/null and b/models/__pycache__/unet_blocks.cpython-311.pyc differ diff --git a/models/__pycache__/unet_blocks.cpython-38.pyc b/models/__pycache__/unet_blocks.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49be4a6652f1971787067f57bccff3720ae3477e Binary files /dev/null and b/models/__pycache__/unet_blocks.cpython-38.pyc differ diff --git a/models/__pycache__/unet_blocks.cpython-39.pyc b/models/__pycache__/unet_blocks.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1bb1a58be97d8841d1476fb68629d3a8cb4aaa6a Binary files /dev/null and b/models/__pycache__/unet_blocks.cpython-39.pyc differ diff --git a/models/__pycache__/uvit.cpython-310.pyc b/models/__pycache__/uvit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2fd6de89d77f345f7c2286c7f99979953b1b25b9 Binary files /dev/null and b/models/__pycache__/uvit.cpython-310.pyc differ diff --git a/models/__pycache__/uvit.cpython-311.pyc b/models/__pycache__/uvit.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83671c53187e3aec10eb42a62d68e8c3ce723439 Binary files /dev/null and b/models/__pycache__/uvit.cpython-311.pyc differ diff --git a/models/__pycache__/uvit.cpython-38.pyc b/models/__pycache__/uvit.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..502d94ca74801f895cc9d8d1f8daf33c12d67020 Binary files /dev/null and b/models/__pycache__/uvit.cpython-38.pyc differ diff --git a/models/__pycache__/uvit.cpython-39.pyc b/models/__pycache__/uvit.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49767875b4891eb4ddf002df80a2dee370879e26 Binary files /dev/null and b/models/__pycache__/uvit.cpython-39.pyc differ diff --git a/models/attention.py b/models/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..689f017d36b431f730257d3c94413a0eddc7287c --- /dev/null +++ b/models/attention.py @@ -0,0 +1,968 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py +import os +import sys +sys.path.append(os.path.split(sys.path[0])[0]) + +from dataclasses import dataclass +from typing import Optional + +import math +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput +from diffusers.utils.import_utils import is_xformers_available +from diffusers.models.attention import FeedForward, AdaLayerNorm +from rotary_embedding_torch import RotaryEmbedding +from typing import Callable, Optional +from einops import rearrange, repeat + +try: + from diffusers.models.modeling_utils import ModelMixin +except: + from diffusers.modeling_utils import ModelMixin # 0.11.1 + + +@dataclass +class Transformer3DModelOutput(BaseOutput): + sample: torch.FloatTensor + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + +def exists(x): + return x is not None + + +class CrossAttention(nn.Module): + r""" + copy from diffuser 0.11.1 + A cross attention layer. + Parameters: + query_dim (`int`): The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + use_relative_position: bool = False, + ): + super().__init__() + # print('num head', heads) + inner_dim = dim_head * heads + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + + self.scale = dim_head**-0.5 + + self.heads = heads + self.dim_head = dim_head + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + self._slice_size = None + self._use_memory_efficient_attention_xformers = False + self.added_kv_proj_dim = added_kv_proj_dim + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) + else: + self.group_norm = None + + self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim)) + self.to_out.append(nn.Dropout(dropout)) + + # print(use_relative_position) + self.use_relative_position = use_relative_position + if self.use_relative_position: + self.rotary_emb = RotaryEmbedding(min(32, dim_head)) + # # print(dim_head) + # # print(heads) + # # adopt https://github.com/huggingface/transformers/blob/8a817e1ecac6a420b1bdc701fcc33535a3b96ff5/src/transformers/models/bert/modeling_bert.py#L265 + # self.max_position_embeddings = 32 + # self.distance_embedding = nn.Embedding(2 * self.max_position_embeddings - 1, dim_head) + + # self.dropout = nn.Dropout(dropout) + + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def reshape_for_scores(self, tensor): + # split heads and dims + # tensor should be [b (h w)] f (d nd) + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).contiguous() + return tensor + + def same_batch_dim_to_heads(self, tensor): + batch_size, head_size, seq_len, dim = tensor.shape # [b (h w)] nd f d + tensor = tensor.reshape(batch_size, seq_len, dim * head_size) + return tensor + + def set_attention_slice(self, slice_size): + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + self._slice_size = slice_size + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, use_image_num=None): + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) # [b (h w)] f (nd * d) + + # print('before reshpape query shape', query.shape) + dim = query.shape[-1] + if not self.use_relative_position: + query = self.reshape_heads_to_batch_dim(query) # [b (h w) nd] f d + # print('after reshape query shape', query.shape) + + if self.added_kv_proj_dim is not None: + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) + + key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) + else: + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + if not self.use_relative_position: + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + + def _attention(self, query, key, value, attention_mask=None): + if self.upcast_attention: + query = query.float() + key = key.float() + + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + # print('query shape', query.shape) + # print('key shape', key.shape) + # print('value shape', value.shape) + + if attention_mask is not None: + # print('attention_mask', attention_mask.shape) + # print('attention_scores', attention_scores.shape) + # exit() + attention_scores = attention_scores + attention_mask + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + # print(attention_probs.shape) + + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + # print(attention_probs.shape) + + # compute attention output + hidden_states = torch.bmm(attention_probs, value) + # print(hidden_states.shape) + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + # print(hidden_states.shape) + # exit() + return hidden_states + + def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask): + batch_size_attention = query.shape[0] + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype + ) + slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] + for i in range(hidden_states.shape[0] // slice_size): + start_idx = i * slice_size + end_idx = (i + 1) * slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + + if self.upcast_attention: + query_slice = query_slice.float() + key_slice = key_slice.float() + + attn_slice = torch.baddbmm( + torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device), + query_slice, + key_slice.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attn_slice = attn_slice + attention_mask[start_idx:end_idx] + + if self.upcast_softmax: + attn_slice = attn_slice.float() + + attn_slice = attn_slice.softmax(dim=-1) + + # cast back to the original dtype + attn_slice = attn_slice.to(value.dtype) + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): + # TODO attention_mask + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + +class Transformer3DModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + use_first_frame: bool = False, + use_relative_position: bool = False, + rotary_emb: bool = None, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # Define input layers + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + + # Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + if use_linear_projection: + self.proj_out = nn.Linear(in_channels, inner_dim) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, use_image_num=None, return_dict: bool = True): + # Input + assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." + if self.training: + video_length = hidden_states.shape[2] - use_image_num + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous() + encoder_hidden_states_length = encoder_hidden_states.shape[1] + encoder_hidden_states_video = encoder_hidden_states[:, :encoder_hidden_states_length - use_image_num, ...] + encoder_hidden_states_video = repeat(encoder_hidden_states_video, 'b m n c -> b (m f) n c', f=video_length).contiguous() + encoder_hidden_states_image = encoder_hidden_states[:, encoder_hidden_states_length - use_image_num:, ...] + encoder_hidden_states = torch.cat([encoder_hidden_states_video, encoder_hidden_states_image], dim=1) + encoder_hidden_states = rearrange(encoder_hidden_states, 'b m n c -> (b m) n c').contiguous() + else: + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous() + encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length).contiguous() + + batch, channel, height, weight = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = self.proj_in(hidden_states) + + # Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + video_length=video_length, + use_image_num=use_image_num, + ) + + # Output + if not self.use_linear_projection: + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + + output = hidden_states + residual + + output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length + use_image_num).contiguous() + if not return_dict: + return (output,) + + return Transformer3DModelOutput(sample=output) + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + use_first_frame: bool = False, + use_relative_position: bool = False, + rotary_emb: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + # print(only_cross_attention) + self.use_ada_layer_norm = num_embeds_ada_norm is not None + # print(self.use_ada_layer_norm) + self.use_first_frame = use_first_frame + + # Spatial-Attn + self.attn1 = CrossAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=None, + upcast_attention=upcast_attention, + ) + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + # # SC-Attn + # self.attn1 = SparseCausalAttention( + # query_dim=dim, + # heads=num_attention_heads, + # dim_head=attention_head_dim, + # dropout=dropout, + # bias=attention_bias, + # cross_attention_dim=cross_attention_dim if only_cross_attention else None, + # upcast_attention=upcast_attention, + # ) + # self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + # Text Cross-Attn + if cross_attention_dim is not None: + self.attn2 = CrossAttention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + else: + self.attn2 = None + + if cross_attention_dim is not None: + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + else: + self.norm2 = None + + # # Temp Frame-Cross-Attn; add tahn scale factor + # self.attn_fcross = SparseCausalAttention( + # query_dim=dim, + # heads=num_attention_heads, + # dim_head=attention_head_dim, + # dropout=dropout, + # bias=attention_bias, + # cross_attention_dim=cross_attention_dim if only_cross_attention else None, + # upcast_attention=upcast_attention, + # ) + # self.norm_fcross = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + # nn.init.zeros_(self.attn_fcross.to_out[0].weight.data) + + # Temp + self.attn_temp = TemporalAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=None, + upcast_attention=upcast_attention, + rotary_emb=rotary_emb, + ) + self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + nn.init.zeros_(self.attn_temp.to_out[0].weight.data) + + + # Feed-forward + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) + self.norm3 = nn.LayerNorm(dim) + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, op=None): + + if not is_xformers_available(): + print("Here is how to install it") + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" + " available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + if self.attn2 is not None: + self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + # self.attn_fcross._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None, use_image_num=None): + # SparseCausal-Attention + norm_hidden_states = ( + self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) + ) + + if self.only_cross_attention: + hidden_states = ( + self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states + ) + else: + hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, use_image_num=use_image_num) + hidden_states + + # # SparseCausal-Attention + # norm_hidden_states = ( + # self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) + # ) + + # if self.only_cross_attention: + # hidden_states = ( + # self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states + # ) + # else: + # hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states + + # # Temporal FrameCross Attention + # norm_hidden_states = ( + # self.norm_fcross(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_fcross(hidden_states) + # ) + # hidden_states = self.attn_fcross( + # norm_hidden_states, attention_mask=attention_mask, video_length=video_length, use_image_num=use_image_num) + hidden_states + + if self.attn2 is not None: + # Cross-Attention + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + hidden_states = ( + self.attn2( + norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + ) + + hidden_states + ) + + # Temporal Attention + if self.training: + d = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length + use_image_num).contiguous() + hidden_states_video = hidden_states[:, :video_length, :] + hidden_states_image = hidden_states[:, video_length:, :] + # print(hidden_states_video.shape) + # print(hidden_states_image.shape) + # if self.training: + # # prepare attention mask; mask images in temporal attention + # attention_mask_shape = (video_length + use_image_num) // 8 + 1 + # video_image_length = video_length + use_image_num + # attention_mask = torch.zeros([8 * attention_mask_shape, 8 * attention_mask_shape], + # dtype=hidden_states.dtype, device=hidden_states.device)[:video_image_length, :video_image_length] + # attention_mask[:, video_length:] = -math.inf + norm_hidden_states_video = ( + self.norm_temp(hidden_states_video, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states_video) + ) + # print(norm_hidden_states.shape) + hidden_states_video = self.attn_temp(norm_hidden_states_video) + hidden_states_video + hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1) + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d).contiguous() + else: + d = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length + use_image_num).contiguous() + norm_hidden_states = ( + self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states) + ) + # print(norm_hidden_states.shape) + hidden_states = self.attn_temp(norm_hidden_states) + hidden_states + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d).contiguous() + + # Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + + return hidden_states + + +class SparseCausalAttention(CrossAttention): + def forward_video(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) + dim = query.shape[-1] + query = self.reshape_heads_to_batch_dim(query) + + if self.added_kv_proj_dim is not None: + raise NotImplementedError + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + former_frame_index = torch.arange(video_length) - 1 + former_frame_index[0] = 0 + + key = rearrange(key, "(b f) d c -> b f d c", f=video_length).contiguous() + key = torch.cat([key[:, [0] * video_length], key[:, former_frame_index]], dim=2) + key = rearrange(key, "b f d c -> (b f) d c").contiguous() + + value = rearrange(value, "(b f) d c -> b f d c", f=video_length).contiguous() + value = torch.cat([value[:, [0] * video_length], value[:, former_frame_index]], dim=2) + value = rearrange(value, "b f d c -> (b f) d c").contiguous() + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + def forward_image(self, hidden_states, encoder_hidden_states=None, attention_mask=None, use_image_num=None): + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) # [b (h w)] f (nd * d) + # if self.use_relative_position: + # print('before attention query shape', query.shape) + dim = query.shape[-1] + if not self.use_relative_position: + query = self.reshape_heads_to_batch_dim(query) # [b (h w) nd] f d + # if self.use_relative_position: + # print('before attention query shape', query.shape) + + if self.added_kv_proj_dim is not None: + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) + + key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) + else: + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + if not self.use_relative_position: + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, use_image_num=None): + if self.training: + # print(use_image_num) + hidden_states = rearrange(hidden_states, "(b f) d c -> b f d c", f=video_length + use_image_num).contiguous() + hidden_states_video = hidden_states[:, :video_length, ...] + hidden_states_image = hidden_states[:, video_length:, ...] + hidden_states_video = rearrange(hidden_states_video, 'b f d c -> (b f) d c').contiguous() + hidden_states_image = rearrange(hidden_states_image, 'b f d c -> (b f) d c').contiguous() + hidden_states_video = self.forward_video(hidden_states=hidden_states_video, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + video_length=video_length) + # print('hidden_states_video', hidden_states_video.shape) + hidden_states_image = self.forward_image(hidden_states=hidden_states_image, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask) + # print('hidden_states_image', hidden_states_image.shape) + hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=0) + return hidden_states + # exit() + else: + return self.forward_video(hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + video_length=video_length) + +class TemporalAttention(CrossAttention): + def __init__(self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + rotary_emb=None): + super().__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax, added_kv_proj_dim, norm_num_groups) + # relative time positional embeddings + self.time_rel_pos_bias = RelativePositionBias(heads=heads, max_distance=32) # realistically will not be able to generate that many frames of video... yet + self.rotary_emb = rotary_emb + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): + time_rel_pos_bias = self.time_rel_pos_bias(hidden_states.shape[1], device=hidden_states.device) + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) # [b (h w)] f (nd * d) + dim = query.shape[-1] + + if self.added_kv_proj_dim is not None: + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) + + key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) + else: + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask, time_rel_pos_bias) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + + def _attention(self, query, key, value, attention_mask=None, time_rel_pos_bias=None): + if self.upcast_attention: + query = query.float() + key = key.float() + + # print('query shape', query.shape) + # print('key shape', key.shape) + # print('value shape', value.shape) + # reshape for adding time positional bais + query = self.scale * rearrange(query, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads + key = rearrange(key, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads + value = rearrange(value, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads + # print('query shape', query.shape) + # print('key shape', key.shape) + # print('value shape', value.shape) + + # torch.baddbmm only accepte 3-D tensor + # https://runebook.dev/zh/docs/pytorch/generated/torch.baddbmm + # attention_scores = self.scale * torch.matmul(query, key.transpose(-1, -2)) + if exists(self.rotary_emb): + query = self.rotary_emb.rotate_queries_or_keys(query) + key = self.rotary_emb.rotate_queries_or_keys(key) + + attention_scores = torch.einsum('... h i d, ... h j d -> ... h i j', query, key) + # print('attention_scores shape', attention_scores.shape) + # print('time_rel_pos_bias shape', time_rel_pos_bias.shape) + # print('attention_mask shape', attention_mask.shape) + + attention_scores = attention_scores + time_rel_pos_bias + # print(attention_scores.shape) + + # bert from huggin face + # attention_scores = attention_scores / math.sqrt(self.dim_head) + + # # Normalize the attention scores to probabilities. + # attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + if attention_mask is not None: + # add attention mask + attention_scores = attention_scores + attention_mask + + # vdm + attention_scores = attention_scores - attention_scores.amax(dim = -1, keepdim = True).detach() + + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + # print(attention_probs[0][0]) + + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + + # compute attention output + # hidden_states = torch.matmul(attention_probs, value) + hidden_states = torch.einsum('... h i j, ... h j d -> ... h i d', attention_probs, value) + # print(hidden_states.shape) + # hidden_states = self.same_batch_dim_to_heads(hidden_states) + hidden_states = rearrange(hidden_states, 'b h f d -> b f (h d)') + # print(hidden_states.shape) + # exit() + return hidden_states + +class RelativePositionBias(nn.Module): + def __init__( + self, + heads=8, + num_buckets=32, + max_distance=128, + ): + super().__init__() + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128): + ret = 0 + n = -relative_position + + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) + ).long() + val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, n, device): + q_pos = torch.arange(n, dtype = torch.long, device = device) + k_pos = torch.arange(n, dtype = torch.long, device = device) + rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') + rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance) + values = self.relative_attention_bias(rp_bucket) + return rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames \ No newline at end of file diff --git a/models/clip.py b/models/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..249ff2a1680d580d94d4a18c1db5f538a81c043d --- /dev/null +++ b/models/clip.py @@ -0,0 +1,123 @@ +import numpy +import torch.nn as nn +from transformers import CLIPTokenizer, CLIPTextModel + +import transformers +transformers.logging.set_verbosity_error() + +""" +Will encounter following warning: +- This IS expected if you are initializing CLIPTextModel from the checkpoint of a model trained on another task +or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). +- This IS NOT expected if you are initializing CLIPTextModel from the checkpoint of a model +that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). + +https://github.com/CompVis/stable-diffusion/issues/97 +according to this issue, this warning is safe. + +This is expected since the vision backbone of the CLIP model is not needed to run Stable Diffusion. +You can safely ignore the warning, it is not an error. + +This clip usage is from U-ViT and same with Stable Diffusion. +""" + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from Hugging Face)""" + # def __init__(self, version="openai/clip-vit-huge-patch14", device="cuda", max_length=77): + def __init__(self, path, device="cuda", max_length=77): + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained(path, subfolder="tokenizer") + self.transformer = CLIPTextModel.from_pretrained(path, subfolder='text_encoder') + self.device = device + self.max_length = max_length + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + + +class TextEmbedder(nn.Module): + """ + Embeds text prompt into vector representations. Also handles text dropout for classifier-free guidance. + """ + def __init__(self, path, dropout_prob=0.1): + super().__init__() + self.text_encodder = FrozenCLIPEmbedder(path=path) + self.dropout_prob = dropout_prob + + def token_drop(self, text_prompts, force_drop_ids=None): + """ + Drops text to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = numpy.random.uniform(0, 1, len(text_prompts)) < self.dropout_prob + else: + # TODO + drop_ids = force_drop_ids == 1 + labels = list(numpy.where(drop_ids, "", text_prompts)) + # print(labels) + return labels + + def forward(self, text_prompts, train, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + text_prompts = self.token_drop(text_prompts, force_drop_ids) + embeddings = self.text_encodder(text_prompts) + return embeddings + + +if __name__ == '__main__': + + r""" + Returns: + + Examples from CLIPTextModel: + + ```python + >>> from transformers import AutoTokenizer, CLIPTextModel + + >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") + >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") + + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ```""" + + import torch + + device = "cuda" if torch.cuda.is_available() else "cpu" + + text_encoder = TextEmbedder(path='/mnt/petrelfs/maxin/work/pretrained/stable-diffusion-2-1-base', + dropout_prob=0.00001).to(device) + + text_prompt = [["a photo of a cat", "a photo of a cat"], ["a photo of a dog", "a photo of a cat"], ['a photo of a dog human', "a photo of a cat"]] + # text_prompt = ('None', 'None', 'None') + output = text_encoder(text_prompts=text_prompt, train=False) + # print(output) + print(output.shape) + # print(output.shape) \ No newline at end of file diff --git a/models/dit.py b/models/dit.py new file mode 100644 index 0000000000000000000000000000000000000000..5ffebb93dc4acd3da5b876fa75decb32fbc5b559 --- /dev/null +++ b/models/dit.py @@ -0,0 +1,617 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# GLIDE: https://github.com/openai/glide-text2im +# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py +# -------------------------------------------------------- +import math +import torch +import torch.nn as nn +import numpy as np + +from einops import rearrange, repeat +from timm.models.vision_transformer import Mlp, PatchEmbed + + + +import os +import sys +# sys.path.append(os.getcwd()) +sys.path.append(os.path.split(sys.path[0])[0]) +# 代码解释 +# sys.path[0] : 得到C:\Users\maxu\Desktop\blog_test\pakage2 +# os.path.split(sys.path[0]) : 得到['C:\Users\maxu\Desktop\blog_test',pakage2'] +# mmcls 里面跨包引用是因为安装了mmcls + + +# for i in sys.path: +# print(i) + +# the xformers lib allows less memory, faster training and inference +try: + import xformers + import xformers.ops +except: + XFORMERS_IS_AVAILBLE = False + +# from timm.models.layers.helpers import to_2tuple +# from timm.models.layers.trace_utils import _assert + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + +################################################################################# +# Attention Layers from TIMM # +################################################################################# + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_lora=False, attention_mode='math'): + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + self.attention_mode = attention_mode + self.use_lora = use_lora + + if self.use_lora: + self.qkv = lora.MergedLinear(dim, dim * 3, r=500, enable_lora=[True, False, True]) + else: + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + if self.attention_mode == 'xformers': # cause loss nan while using with amp + x = xformers.ops.memory_efficient_attention(q, k, v).reshape(B, N, C) + + elif self.attention_mode == 'flash': + # cause loss nan while using with amp + # Optionally use the context manager to ensure one of the fused kerenels is run + with torch.backends.cuda.sdp_kernel(enable_math=False): + x = torch.nn.functional.scaled_dot_product_attention(q, k, v).reshape(B, N, C) # require pytorch 2.0 + + elif self.attention_mode == 'math': + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + + else: + raise NotImplemented + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +################################################################################# +# Embedding Layers for Timesteps and Class Labels # +################################################################################# + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class LabelEmbedder(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + """ + def __init__(self, num_classes, hidden_size, dropout_prob): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) + self.num_classes = num_classes + self.dropout_prob = dropout_prob + + def token_drop(self, labels, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob + print(drop_ids) + else: + drop_ids = force_drop_ids == 1 + labels = torch.where(drop_ids, self.num_classes, labels) + print('******labels******', labels) + return labels + + def forward(self, labels, train, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + labels = self.token_drop(labels, force_drop_ids) + embeddings = self.embedding_table(labels) + return embeddings + + +################################################################################# +# Core DiT Model # +################################################################################# + +class DiTBlock(nn.Module): + """ + A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. + """ + def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) + x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) + x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FinalLayer(nn.Module): + """ + The final layer of DiT. + """ + def __init__(self, hidden_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class DiT(nn.Module): + """ + Diffusion model with a Transformer backbone. + """ + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + num_frames=16, + class_dropout_prob=0.1, + num_classes=1000, + learn_sigma=True, + class_guided=False, + use_lora=False, + attention_mode='math', + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + self.class_guided = class_guided + self.num_frames = num_frames + + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + + if self.class_guided: + self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) + + num_patches = self.x_embedder.num_patches + # Will use fixed sin-cos embedding: + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) + self.time_embed = nn.Parameter(torch.zeros(1, num_frames, hidden_size), requires_grad=False) + + if use_lora: + self.blocks = nn.ModuleList([ + DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attention_mode=attention_mode, use_lora=False if num % 2 ==0 else True) for num in range(depth) + ]) + else: + self.blocks = nn.ModuleList([ + DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attention_mode=attention_mode) for _ in range(depth) + ]) + + self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) + self.initialize_weights() + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Initialize (and freeze) pos_embed by sin-cos embedding: + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5)) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + time_embed = get_1d_sincos_time_embed(self.time_embed.shape[-1], self.time_embed.shape[-2]) + self.time_embed.data.copy_(torch.from_numpy(time_embed).float().unsqueeze(0)) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + if self.class_guided: + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def unpatchify(self, x): + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) + return imgs + + # @torch.cuda.amp.autocast() + # @torch.compile + def forward(self, x, t, y=None): + """ + Forward pass of DiT. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N,) tensor of class labels + """ + # print('label: {}'.format(y)) + batches, frames, channels, high, weight = x.shape # for example, 3, 16, 3, 32, 32 + # 这里rearrange后每隔f是同一个视频 + x = rearrange(x, 'b f c h w -> (b f) c h w') + x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 + t = self.t_embedder(t) # (N, D) + # timestep_spatial的repeat需要保证每f帧为同一个timesteps + timestep_spatial = repeat(t, 'n d -> (n c) d', c=self.time_embed.shape[1]) # 48, 1152; c=num_frames + timestep_time = repeat(t, 'n d -> (n c) d', c=self.pos_embed.shape[1]) # 768, 1152; c=num tokens + + if self.class_guided: + y = self.y_embedder(y, self.training) + y_spatial = repeat(y, 'n d -> (n c) d', c=self.time_embed.shape[1]) # 48, 1152; c=num_frames + y_time = repeat(y, 'n d -> (n c) d', c=self.pos_embed.shape[1]) # 768, 1152; c=num tokens + + # if self.class_guided: + # y = self.y_embedder(y, self.training) # (N, D) + # c = timestep_spatial + y + # else: + # c = timestep_spatial + + # for block in self.blocks: + # x = block(x, c) # (N, T, D) + + for i in range(0, len(self.blocks), 2): + # print('The {}-th run'.format(i)) + spatial_block, time_block = self.blocks[i:i+2] + # print(spatial_block) + # print(time_block) + # print(x.shape) + + if self.class_guided: + c = timestep_spatial + y_spatial + else: + c = timestep_spatial + x = spatial_block(x, c) + # print(c.shape) + + x = rearrange(x, '(b f) t d -> (b t) f d', b=batches) # t 代表单帧token数; 768, 16, 1152 + # Add Time Embedding + if i == 0: + x = x + self.time_embed # 768, 16, 1152 + + if self.class_guided: + c = timestep_time + y_time + else: + # timestep_time = repeat(t, 'n d -> (n c) d', c=x.shape[0] // batches) # 768, 1152 + # print(timestep_time.shape) + c = timestep_time + + x = time_block(x, c) + # print(x.shape) + x = rearrange(x, '(b t) f d -> (b f) t d', b=batches) + + # x = rearrange(x, '(b t) f d -> (b f) t d', b=batches) + if self.class_guided: + c = timestep_spatial + y_spatial + else: + c = timestep_spatial + x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) + x = self.unpatchify(x) # (N, out_channels, H, W) + x = rearrange(x, '(b f) c h w -> b f c h w', b=batches) + # print(x.shape) + return x + + def forward_motion(self, motions, t, base_frame, y=None): + """ + Forward pass of DiT. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N,) tensor of class labels + """ + # print('label: {}'.format(y)) + batches, frames, channels, high, weight = motions.shape # for example, 3, 16, 3, 32, 32 + # 这里rearrange后每隔f是同一个视频 + motions = rearrange(motions, 'b f c h w -> (b f) c h w') + motions = self.x_embedder(motions) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 + t = self.t_embedder(t) # (N, D) + # timestep_spatial的repeat需要保证每f帧为同一个timesteps + timestep_spatial = repeat(t, 'n d -> (n c) d', c=self.time_embed.shape[1]) # 48, 1152; c=num_frames + timestep_time = repeat(t, 'n d -> (n c) d', c=self.pos_embed.shape[1]) # 768, 1152; c=num tokens + + if self.class_guided: + y = self.y_embedder(y, self.training) + y_spatial = repeat(y, 'n d -> (n c) d', c=self.time_embed.shape[1]) # 48, 1152; c=num_frames + y_time = repeat(y, 'n d -> (n c) d', c=self.pos_embed.shape[1]) # 768, 1152; c=num tokens + + # if self.class_guided: + # y = self.y_embedder(y, self.training) # (N, D) + # c = timestep_spatial + y + # else: + # c = timestep_spatial + + # for block in self.blocks: + # x = block(x, c) # (N, T, D) + + for i in range(0, len(self.blocks), 2): + # print('The {}-th run'.format(i)) + spatial_block, time_block = self.blocks[i:i+2] + # print(spatial_block) + # print(time_block) + # print(x.shape) + + if self.class_guided: + c = timestep_spatial + y_spatial + else: + c = timestep_spatial + x = spatial_block(x, c) + # print(c.shape) + + x = rearrange(x, '(b f) t d -> (b t) f d', b=batches) # t 代表单帧token数; 768, 16, 1152 + # Add Time Embedding + if i == 0: + x = x + self.time_embed # 768, 16, 1152 + + if self.class_guided: + c = timestep_time + y_time + else: + # timestep_time = repeat(t, 'n d -> (n c) d', c=x.shape[0] // batches) # 768, 1152 + # print(timestep_time.shape) + c = timestep_time + + x = time_block(x, c) + # print(x.shape) + x = rearrange(x, '(b t) f d -> (b f) t d', b=batches) + + # x = rearrange(x, '(b t) f d -> (b f) t d', b=batches) + if self.class_guided: + c = timestep_spatial + y_spatial + else: + c = timestep_spatial + x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) + x = self.unpatchify(x) # (N, out_channels, H, W) + x = rearrange(x, '(b f) c h w -> b f c h w', b=batches) + # print(x.shape) + return x + + + def forward_with_cfg(self, x, t, y, cfg_scale): + """ + Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance. + """ + # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + half = x[: len(x) // 2] + combined = torch.cat([half, half], dim=0) + model_out = self.forward(combined, t, y) + # For exact reproducibility reasons, we apply classifier-free guidance on only + # three channels by default. The standard approach to cfg applies it to all channels. + # This can be done by uncommenting the following line and commenting-out the line following that. + # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] + eps, rest = model_out[:, :3], model_out[:, 3:] + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) + + +################################################################################# +# Sine/Cosine Positional Embedding Functions # +################################################################################# +# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py + +def get_1d_sincos_time_embed(embed_dim, length): + pos = torch.arange(0, length).unsqueeze(1) + return get_1d_sincos_pos_embed_from_grid(embed_dim, pos) + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +################################################################################# +# DiT Configs # +################################################################################# + +def DiT_XL_2(**kwargs): + return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) + +def DiT_XL_4(**kwargs): + return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs) + +def DiT_XL_8(**kwargs): + return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs) + +def DiT_L_2(**kwargs): + return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs) + +def DiT_L_4(**kwargs): + return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs) + +def DiT_L_8(**kwargs): + return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs) + +def DiT_B_2(**kwargs): + return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs) + +def DiT_B_4(**kwargs): + return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs) + +def DiT_B_8(**kwargs): + return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs) + +def DiT_S_2(**kwargs): + return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs) + +def DiT_S_4(**kwargs): + return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs) + +def DiT_S_8(**kwargs): + return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs) + + +DiT_models = { + 'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8, + 'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8, + 'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8, + 'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8, +} + +if __name__ == '__main__': + + import torch + + device = "cuda" if torch.cuda.is_available() else "cpu" + + img = torch.randn(3, 16, 4, 32, 32).to(device) + t = torch.tensor([1, 2, 3]).to(device) + y = torch.tensor([1, 2, 3]).to(device) + network = DiT_XL_2().to(device) + y_embeder = LabelEmbedder(num_classes=100, hidden_size=768, dropout_prob=0.5).to(device) + # lora.mark_only_lora_as_trainable(network) + out = y_embeder(y, True) + # out = network(img, t, y) + print(out.shape) \ No newline at end of file diff --git a/models/resnet.py b/models/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c1307ad9a0fe26b6b1eca913b024c188df2dd86e --- /dev/null +++ b/models/resnet.py @@ -0,0 +1,212 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py +import os +import sys +sys.path.append(os.path.split(sys.path[0])[0]) + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange + + +class InflatedConv3d(nn.Conv2d): + def forward(self, x): + video_length = x.shape[2] + + x = rearrange(x, "b c f h w -> (b f) c h w") + x = super().forward(x) + x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) + + return x + + +class Upsample3D(nn.Module): + def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + conv = None + if use_conv_transpose: + raise NotImplementedError + elif use_conv: + conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) + + if name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv + + def forward(self, hidden_states, output_size=None): + assert hidden_states.shape[1] == self.channels + + if self.use_conv_transpose: + raise NotImplementedError + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output + # size and do not make use of `scale_factor=2` + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + if self.use_conv: + if self.name == "conv": + hidden_states = self.conv(hidden_states) + else: + hidden_states = self.Conv2d_0(hidden_states) + + return hidden_states + + +class Downsample3D(nn.Module): + def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + raise NotImplementedError + + if name == "conv": + self.Conv2d_0 = conv + self.conv = conv + elif name == "Conv2d_0": + self.conv = conv + else: + self.conv = conv + + def forward(self, hidden_states): + assert hidden_states.shape[1] == self.channels + if self.use_conv and self.padding == 0: + raise NotImplementedError + + assert hidden_states.shape[1] == self.channels + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class ResnetBlock3D(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + non_linearity="swish", + time_embedding_norm="default", + output_scale_factor=1.0, + use_in_shortcut=None, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + + self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if temb_channels is not None: + if self.time_embedding_norm == "default": + time_emb_proj_out_channels = out_channels + elif self.time_embedding_norm == "scale_shift": + time_emb_proj_out_channels = out_channels * 2 + else: + raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") + + self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) + else: + self.time_emb_proj = None + + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + + self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, input_tensor, temb): + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] + + if temb is not None and self.time_embedding_norm == "default": + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + + if temb is not None and self.time_embedding_norm == "scale_shift": + scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + + +class Mish(torch.nn.Module): + def forward(self, hidden_states): + return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) \ No newline at end of file diff --git a/models/unet.py b/models/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..941b5156cce9775606c15388026a11443f1d9dfb --- /dev/null +++ b/models/unet.py @@ -0,0 +1,721 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import os +import sys +sys.path.append(os.path.split(sys.path[0])[0]) + +import math +import json +import torch +import einops +import torch.nn as nn +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, logging +from diffusers.models.embeddings import TimestepEmbedding, Timesteps + +try: + from diffusers.models.modeling_utils import ModelMixin +except: + from diffusers.modeling_utils import ModelMixin # 0.11.1 + +try: + from .unet_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, + ) + from .resnet import InflatedConv3d +except: + from unet_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, + ) + from resnet import InflatedConv3d + +from rotary_embedding_torch import RotaryEmbedding + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +class RelativePositionBias(nn.Module): + def __init__( + self, + heads=8, + num_buckets=32, + max_distance=128, + ): + super().__init__() + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128): + ret = 0 + n = -relative_position + + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) + ).long() + val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, n, device): + q_pos = torch.arange(n, dtype = torch.long, device = device) + k_pos = torch.arange(n, dtype = torch.long, device = device) + rel_pos = einops.rearrange(k_pos, 'j -> 1 j') - einops.rearrange(q_pos, 'i -> i 1') + rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance) + values = self.relative_attention_bias(rp_bucket) + return einops.rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames + +@dataclass +class UNet3DConditionOutput(BaseOutput): + sample: torch.FloatTensor + + +class UNet3DConditionModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, # 64 + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + mid_block_type: str = "UNetMidBlock3DCrossAttn", + up_block_types: Tuple[str] = ( + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D" + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: Union[int, Tuple[int]] = 8, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + use_first_frame: bool = False, + use_relative_position: bool = False, + ): + super().__init__() + + # print(use_first_frame) + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + + # time + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + else: + self.class_embedding = None + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + # print(only_cross_attention) + # print(type(only_cross_attention)) + # exit() + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + # print(only_cross_attention) + # exit() + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + # print(attention_head_dim) + # exit() + + rotary_emb = RotaryEmbedding(32) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock3DCrossAttn": + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the videos + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + only_cross_attention = list(reversed(only_cross_attention)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + self.conv_act = nn.SiLU() + self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + + # relative time positional embeddings + self.use_relative_position = use_relative_position + if self.use_relative_position: + self.time_rel_pos_bias = RelativePositionBias(heads=8, max_distance=32) # realistically will not be able to generate that many frames of video... yet + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + + num_slicable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_slicable_layers * [1] + + slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor = None, + class_labels: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + use_image_num: int = 0, + return_dict: bool = True, + ) -> Union[UNet3DConditionOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # time + timesteps = timestep + if not torch.is_tensor(timesteps): + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + # print(emb.shape) # torch.Size([3, 1280]) + # print(class_emb.shape) # torch.Size([3, 1280]) + emb = emb + class_emb + + if self.use_relative_position: + frame_rel_pos_bias = self.time_rel_pos_bias(sample.shape[2], device=sample.device) + else: + frame_rel_pos_bias = None + + # pre-process + sample = self.conv_in(sample) + + # down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + use_image_num=use_image_num, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # mid + sample = self.mid_block( + sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, use_image_num=use_image_num, + ) + + # up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + use_image_num=use_image_num, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + # print(sample.shape) + + if not return_dict: + return (sample,) + sample = UNet3DConditionOutput(sample=sample) + return sample + + def forward_with_cfg(self, + x, + t, + encoder_hidden_states = None, + class_labels: Optional[torch.Tensor] = None, + cfg_scale=4.0, + use_fp16=False): + """ + Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance. + """ + # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + half = x[: len(x) // 2] + combined = torch.cat([half, half], dim=0) + if use_fp16: + combined = combined.to(dtype=torch.float16) + model_out = self.forward(combined, t, encoder_hidden_states, class_labels).sample + # For exact reproducibility reasons, we apply classifier-free guidance on only + # three channels by default. The standard approach to cfg applies it to all channels. + # This can be done by uncommenting the following line and commenting-out the line following that. + eps, rest = model_out[:, :4], model_out[:, 4:] + # eps, rest = model_out[:, :3], model_out[:, 3:] # b c f h w + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) + + @classmethod + def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, use_concat=False): + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + + + # the content of the config file + # { + # "_class_name": "UNet2DConditionModel", + # "_diffusers_version": "0.2.2", + # "act_fn": "silu", + # "attention_head_dim": 8, + # "block_out_channels": [ + # 320, + # 640, + # 1280, + # 1280 + # ], + # "center_input_sample": false, + # "cross_attention_dim": 768, + # "down_block_types": [ + # "CrossAttnDownBlock2D", + # "CrossAttnDownBlock2D", + # "CrossAttnDownBlock2D", + # "DownBlock2D" + # ], + # "downsample_padding": 1, + # "flip_sin_to_cos": true, + # "freq_shift": 0, + # "in_channels": 4, + # "layers_per_block": 2, + # "mid_block_scale_factor": 1, + # "norm_eps": 1e-05, + # "norm_num_groups": 32, + # "out_channels": 4, + # "sample_size": 64, + # "up_block_types": [ + # "UpBlock2D", + # "CrossAttnUpBlock2D", + # "CrossAttnUpBlock2D", + # "CrossAttnUpBlock2D" + # ] + # } + config_file = os.path.join(pretrained_model_path, 'config.json') + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + config["_class_name"] = cls.__name__ + config["down_block_types"] = [ + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D" + ] + config["up_block_types"] = [ + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D" + ] + + # config["use_first_frame"] = True + + config["use_first_frame"] = False + if use_concat: + config["in_channels"] = 9 + # config["use_relative_position"] = True + + # # tmp + # config["class_embed_type"] = "timestep" + # config["num_class_embeds"] = 100 + + from diffusers.utils import WEIGHTS_NAME # diffusion_pytorch_model.bin + + # {'_class_name': 'UNet3DConditionModel', + # '_diffusers_version': '0.2.2', + # 'act_fn': 'silu', + # 'attention_head_dim': 8, + # 'block_out_channels': [320, 640, 1280, 1280], + # 'center_input_sample': False, + # 'cross_attention_dim': 768, + # 'down_block_types': + # ['CrossAttnDownBlock3D', + # 'CrossAttnDownBlock3D', + # 'CrossAttnDownBlock3D', + # 'DownBlock3D'], + # 'downsample_padding': 1, + # 'flip_sin_to_cos': True, + # 'freq_shift': 0, + # 'in_channels': 4, + # 'layers_per_block': 2, + # 'mid_block_scale_factor': 1, + # 'norm_eps': 1e-05, + # 'norm_num_groups': 32, + # 'out_channels': 4, + # 'sample_size': 64, + # 'up_block_types': + # ['UpBlock3D', + # 'CrossAttnUpBlock3D', + # 'CrossAttnUpBlock3D', + # 'CrossAttnUpBlock3D']} + + model = cls.from_config(config) + model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) + if not os.path.isfile(model_file): + raise RuntimeError(f"{model_file} does not exist") + state_dict = torch.load(model_file, map_location="cpu") + + if use_concat: + new_state_dict = {} + conv_in_weight = state_dict["conv_in.weight"] + new_conv_weight = torch.zeros((conv_in_weight.shape[0], 9, *conv_in_weight.shape[2:]), dtype=conv_in_weight.dtype) + + for i, j in zip([0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 6, 7, 8]): + new_conv_weight[:, j] = conv_in_weight[:, i] + new_state_dict["conv_in.weight"] = new_conv_weight + new_state_dict["conv_in.bias"] = state_dict["conv_in.bias"] + for k, v in model.state_dict().items(): + # print(k) + if '_temp.' in k: + new_state_dict.update({k: v}) + if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross + k = k.replace('attn_fcross', 'attn1') + state_dict.update({k: state_dict[k]}) + if 'norm_fcross' in k: + k = k.replace('norm_fcross', 'norm1') + state_dict.update({k: state_dict[k]}) + + if 'conv_in' in k: + continue + else: + new_state_dict[k] = v + # # tmp + # if 'class_embedding' in k: + # state_dict.update({k: v}) + # breakpoint() + model.load_state_dict(new_state_dict) + else: + for k, v in model.state_dict().items(): + # print(k) + if '_temp' in k: + state_dict.update({k: v}) + if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross + k = k.replace('attn_fcross', 'attn1') + state_dict.update({k: state_dict[k]}) + if 'norm_fcross' in k: + k = k.replace('norm_fcross', 'norm1') + state_dict.update({k: state_dict[k]}) + + model.load_state_dict(state_dict) + + return model + +if __name__ == '__main__': + import torch + # from xformers.ops import MemoryEfficientAttentionFlashAttentionOp + + device = "cuda" if torch.cuda.is_available() else "cpu" + + # pretrained_model_path = "/mnt/petrelfs/maxin/work/pretrained/stable-diffusion-2-1-base/" # p cluster + pretrained_model_path = "/mnt/petrelfs/share_data/zhanglingjun/stable-diffusion-v1-4/" # p cluster + unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet").to(device) + # unet.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp) + unet.enable_xformers_memory_efficient_attention() + unet.enable_gradient_checkpointing() + + unet.train() + + use_image_num = 5 + noisy_latents = torch.randn((2, 4, 16 + use_image_num, 32, 32)).to(device) + bsz = noisy_latents.shape[0] + timesteps = torch.randint(0, 1000, (bsz,)).to(device) + timesteps = timesteps.long() + encoder_hidden_states = torch.randn((bsz, 1 + use_image_num, 77, 768)).to(device) + # class_labels = torch.randn((bsz, )).to(device) + + + model_pred = unet(sample=noisy_latents, timestep=timesteps, + encoder_hidden_states=encoder_hidden_states, + class_labels=None, + use_image_num=use_image_num).sample + print(model_pred.shape) \ No newline at end of file diff --git a/models/unet_blocks.py b/models/unet_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..849c10539c7039840c93631c5201069119d3c306 --- /dev/null +++ b/models/unet_blocks.py @@ -0,0 +1,648 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py +import os +import sys +sys.path.append(os.path.split(sys.path[0])[0]) + +import torch +from torch import nn + +try: + from .attention import Transformer3DModel + from .resnet import Downsample3D, ResnetBlock3D, Upsample3D +except: + from attention import Transformer3DModel + from resnet import Downsample3D, ResnetBlock3D, Upsample3D + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + use_first_frame=False, + use_relative_position=False, + rotary_emb=False, +): + # print(down_block_type) + # print(use_first_frame) + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock3D": + return DownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "CrossAttnDownBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") + return CrossAttnDownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + use_first_frame=False, + use_relative_position=False, + rotary_emb=False, +): + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock3D": + return UpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "CrossAttnUpBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") + return CrossAttnUpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock3DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + use_first_frame=False, + use_relative_position=False, + rotary_emb=False, + ): + super().__init__() + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + ) + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, use_image_num=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, use_image_num=use_image_num).sample + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class CrossAttnDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + use_first_frame=False, + use_relative_position=False, + rotary_emb=False, + ): + super().__init__() + resnets = [] + attentions = [] + + # print(use_first_frame) + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, use_image_num=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + def create_custom_forward_attn(module, return_dict=None, use_image_num=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict, use_image_num=use_image_num) + else: + return module(*inputs, use_image_num=use_image_num) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward_attn(attn, return_dict=False, use_image_num=use_image_num), + hidden_states, + encoder_hidden_states, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, use_image_num=use_image_num).sample + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + use_first_frame=False, + use_relative_position=False, + rotary_emb=False + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + attention_mask=None, + use_image_num=None, + ): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + def create_custom_forward_attn(module, return_dict=None, use_image_num=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict, use_image_num=use_image_num) + else: + return module(*inputs, use_image_num=use_image_num) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward_attn(attn, return_dict=False, use_image_num=use_image_num), + hidden_states, + encoder_hidden_states, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, use_image_num=use_image_num).sample + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states diff --git a/models/utils.py b/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..76e205d1709568885e58b56888f62c8805ad4f91 --- /dev/null +++ b/models/utils.py @@ -0,0 +1,215 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch + +import numpy as np +import torch.nn as nn + +from einops import repeat + + +################################################################################# +# Unet Utils # +################################################################################# + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim).contiguous() + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +# class HybridConditioner(nn.Module): + +# def __init__(self, c_concat_config, c_crossattn_config): +# super().__init__() +# self.concat_conditioner = instantiate_from_config(c_concat_config) +# self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + +# def forward(self, c_concat, c_crossattn): +# c_concat = self.concat_conditioner(c_concat) +# c_crossattn = self.crossattn_conditioner(c_crossattn) +# return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += torch.DoubleTensor([matmul_ops]) + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params \ No newline at end of file diff --git a/models/uvit.py b/models/uvit.py new file mode 100644 index 0000000000000000000000000000000000000000..9fefd458d5f556ddeb7f3ed74c8ac15870c6c860 --- /dev/null +++ b/models/uvit.py @@ -0,0 +1,310 @@ +import torch +import torch.nn as nn +import math +import timm +from timm.models.layers import trunc_normal_ +from timm.models.vision_transformer import PatchEmbed, Mlp +# assert timm.__version__ == "0.3.2" # version checks +import einops +import torch.utils.checkpoint + +# the xformers lib allows less memory, faster training and inference +try: + import xformers + import xformers.ops +except: + XFORMERS_IS_AVAILBLE = False + # print('xformers disabled') + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +def patchify(imgs, patch_size): + x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size) + return x + + +def unpatchify(x, channels=3): + patch_size = int((x.shape[2] // channels) ** 0.5) + h = w = int(x.shape[1] ** .5) + assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2] + x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, p1=patch_size, p2=patch_size) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, L, C = x.shape + + qkv = self.qkv(x) + if XFORMERS_IS_AVAILBLE: # the xformers lib allows less memory, faster training and inference + qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads) + q, k, v = qkv[0], qkv[1], qkv[2] # B L H D + x = xformers.ops.memory_efficient_attention(q, k, v) + x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads) + else: + qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads) + q, k, v = qkv[0], qkv[1], qkv[2] # B H L D + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B, L, C) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, + act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) + self.skip_linear = nn.Linear(2 * dim, dim) if skip else None + self.use_checkpoint = use_checkpoint + + def forward(self, x, skip=None): + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, skip) + else: + return self._forward(x, skip) + + def _forward(self, x, skip=None): + if self.skip_linear is not None: + # print('x shape', x.shape) + # print('skip shape', skip.shape) + # exit() + x = self.skip_linear(torch.cat([x, skip], dim=-1)) + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + + +class UViT(nn.Module): + def __init__(self, input_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., + qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, mlp_time_embed=False, num_classes=-1, + use_checkpoint=False, conv=True, skip=True, num_frames=16, class_guided=False, use_lora=False): + super().__init__() + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_classes = num_classes + self.in_chans = in_chans + + self.patch_embed = PatchEmbed( + img_size=input_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.time_embed = nn.Sequential( + nn.Linear(embed_dim, 4 * embed_dim), + nn.SiLU(), + nn.Linear(4 * embed_dim, embed_dim), + ) if mlp_time_embed else nn.Identity() + + if self.num_classes > 0: + self.label_emb = nn.Embedding(self.num_classes, embed_dim) + self.extras = 2 + else: + self.extras = 1 + + self.pos_embed = nn.Parameter(torch.zeros(1, self.extras + num_patches, embed_dim)) + self.frame_embed = nn.Parameter(torch.zeros(1, num_frames, embed_dim)) + + self.in_blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + norm_layer=norm_layer, use_checkpoint=use_checkpoint) + for _ in range(depth // 2)]) + + self.mid_block = Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + norm_layer=norm_layer, use_checkpoint=use_checkpoint) + + self.out_blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + norm_layer=norm_layer, skip=skip, use_checkpoint=use_checkpoint) + for _ in range(depth // 2)]) + + self.norm = norm_layer(embed_dim) + self.patch_dim = patch_size ** 2 * in_chans + self.decoder_pred = nn.Linear(embed_dim, self.patch_dim, bias=True) + self.final_layer = nn.Conv2d(self.in_chans, self.in_chans * 2, 3, padding=1) if conv else nn.Identity() + + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.frame_embed, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed'} + + def forward_(self, x, timesteps, y=None): + x = self.patch_embed(x) # 48, 256, 1152 + # print(x.shape) + B, L, D = x.shape + + time_token = self.time_embed(timestep_embedding(timesteps, self.embed_dim)) # 3, 1152 + # print(time_token.shape) + time_token = time_token.unsqueeze(dim=1) # 3, 1, 1152 + x = torch.cat((time_token, x), dim=1) + + if y is not None: + label_emb = self.label_emb(y) + label_emb = label_emb.unsqueeze(dim=1) + x = torch.cat((label_emb, x), dim=1) + x = x + self.pos_embed + + skips = [] + for blk in self.in_blocks: + x = blk(x) + skips.append(x) + + x = self.mid_block(x) + + for blk in self.out_blocks: + x = blk(x, skips.pop()) + + x = self.norm(x) + x = self.decoder_pred(x) + assert x.size(1) == self.extras + L + x = x[:, self.extras:, :] + x = unpatchify(x, self.in_chans) + x = self.final_layer(x) + return x + + def forward(self, x, timesteps, y=None): + # print(x.shape) + batch, frame, _, _, _ = x.shape + # 这里rearrange后每隔f是同一个视频 + x = einops.rearrange(x, 'b f c h w -> (b f) c h w') # 3 16 4 256 256 + x = self.patch_embed(x) # 48, 256, 1152 + B, L, D = x.shape + + time_token = self.time_embed(timestep_embedding(timesteps, self.embed_dim)) # 3, 1152 + # timestep_spatial的repeat需要保证每f帧为同一个timesteps + time_token_spatial = einops.repeat(time_token, 'n d -> (n c) d', c=frame) # 48, 1152 + time_token_spatial = time_token_spatial.unsqueeze(dim=1) # 48, 1, 1152 + x = torch.cat((time_token_spatial, x), dim=1) # 48, 257, 1152 + + if y is not None: + label_emb = self.label_emb(y) + label_emb = label_emb.unsqueeze(dim=1) + x = torch.cat((label_emb, x), dim=1) + x = x + self.pos_embed + + skips = [] + for i in range(0, len(self.in_blocks), 2): + # print('The {}-th run'.format(i)) + spatial_block, time_block = self.in_blocks[i:i+2] + x = spatial_block(x) + + # add time embeddings and conduct attention as frame. + x = einops.rearrange(x, '(b f) t d -> (b t) f d', b=batch) # t 代表单帧token数; 771, 16, 1152; 771: 3 * 257 + skips.append(x) + # print(x.shape) + + if i == 0: + x = x + self.frame_embed # 771, 16, 1152 + + x = time_block(x) + + x = einops.rearrange(x, '(b t) f d -> (b f) t d', b=batch) # 48, 257, 1152 + skips.append(x) + + x = self.mid_block(x) + + for i in range(0, len(self.out_blocks), 2): + # print('The {}-th run'.format(i)) + spatial_block, time_block = self.out_blocks[i:i+2] + x = spatial_block(x, skips.pop()) + + # add time embeddings and conduct attention as frame. + x = einops.rearrange(x, '(b f) t d -> (b t) f d', b=batch) # t 代表单帧token数; 771, 16, 1152; 771: 3 * 257 + + x = time_block(x, skips.pop()) + + x = einops.rearrange(x, '(b t) f d -> (b f) t d', b=batch) # 48, 256, 1152 + + + x = self.norm(x) + x = self.decoder_pred(x) + assert x.size(1) == self.extras + L + x = x[:, self.extras:, :] + x = unpatchify(x, self.in_chans) + x = self.final_layer(x) + x = einops.rearrange(x, '(b f) c h w -> b f c h w', b=batch) + # print(x.shape) + return x + +def UViT_XL_2(**kwargs): + return UViT(patch_size=2, in_chans=4, embed_dim=1152, depth=28, + num_heads=16, mlp_ratio=4, qkv_bias=False, mlp_time_embed=4, + use_checkpoint=True, conv=False, **kwargs) + +def UViT_L_2(**kwargs): + return UViT(patch_size=2, in_chans=4, embed_dim=1024, depth=20, + num_heads=16, mlp_ratio=4, qkv_bias=False, mlp_time_embed=False, + use_checkpoint=True, **kwargs) + +# 没有L以下的,UViT中L以下的img_size为64 + +UViT_models = { + 'UViT-XL/2': UViT_XL_2, 'UViT-L/2': UViT_L_2 +} + + +if __name__ == '__main__': + + + nnet = UViT_XL_2().cuda() + + imgs = torch.randn(3, 16, 4, 32, 32).cuda() + timestpes = torch.tensor([1, 2, 3]).cuda() + + outputs = nnet(imgs, timestpes) + print(outputs.shape) + diff --git a/results/i2v/frame, high quality,2.png.mp4 b/results/i2v/frame, high quality,2.png.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..aa64d00b2c563a1101cb9399c7b61950c7fb4b40 Binary files /dev/null and b/results/i2v/frame, high quality,2.png.mp4 differ diff --git a/results/i2v/frame, high quality,23.png.mp4 b/results/i2v/frame, high quality,23.png.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..f84a82bc27ea0d45d7b74f84da7544d442af7753 Binary files /dev/null and b/results/i2v/frame, high quality,23.png.mp4 differ diff --git a/results/i2v/frame, high quality,35.png.mp4 b/results/i2v/frame, high quality,35.png.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..e25123fdfa1f252424da49556052278797677c92 Binary files /dev/null and b/results/i2v/frame, high quality,35.png.mp4 differ diff --git "a/results/i2v/frame, high quality,6.1.\346\234\272\345\231\250\344\272\272\350\276\205\345\212\251\344\272\272\347\261\273\344\277\256\350\267\257\344\270\255\346\231\257.png.mp4" "b/results/i2v/frame, high quality,6.1.\346\234\272\345\231\250\344\272\272\350\276\205\345\212\251\344\272\272\347\261\273\344\277\256\350\267\257\344\270\255\346\231\257.png.mp4" new file mode 100644 index 0000000000000000000000000000000000000000..d32b26d3a4e47bc0ceb79c31e4c4f26688123f84 Binary files /dev/null and "b/results/i2v/frame, high quality,6.1.\346\234\272\345\231\250\344\272\272\350\276\205\345\212\251\344\272\272\347\261\273\344\277\256\350\267\257\344\270\255\346\231\257.png.mp4" differ diff --git a/results/i2v/moving, high quality2.png.mp4 b/results/i2v/moving, high quality2.png.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..3da108d3189b011821afa69a398e412b20e770be Binary files /dev/null and b/results/i2v/moving, high quality2.png.mp4 differ diff --git a/results/i2v/moving, high quality23.png.mp4 b/results/i2v/moving, high quality23.png.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..0bf5b5e807d549c367a5c0ffdd90498c756a9c35 Binary files /dev/null and b/results/i2v/moving, high quality23.png.mp4 differ diff --git a/results/i2v/moving, high quality35.png.mp4 b/results/i2v/moving, high quality35.png.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..e2fa1dce476bc330c159d2546c9b78d0125b238b Binary files /dev/null and b/results/i2v/moving, high quality35.png.mp4 differ diff --git "a/results/i2v/moving, high quality6.1.\346\234\272\345\231\250\344\272\272\350\276\205\345\212\251\344\272\272\347\261\273\344\277\256\350\267\257\344\270\255\346\231\257.png.mp4" "b/results/i2v/moving, high quality6.1.\346\234\272\345\231\250\344\272\272\350\276\205\345\212\251\344\272\272\347\261\273\344\277\256\350\267\257\344\270\255\346\231\257.png.mp4" new file mode 100644 index 0000000000000000000000000000000000000000..98f4fc364cdfefabb8ba0b9f061fec024f83463d Binary files /dev/null and "b/results/i2v/moving, high quality6.1.\346\234\272\345\231\250\344\272\272\350\276\205\345\212\251\344\272\272\347\261\273\344\277\256\350\267\257\344\270\255\346\231\257.png.mp4" differ diff --git a/results/i2v/tiled up, high quali23.png.mp4 b/results/i2v/tiled up, high quali23.png.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..27ade24afc71b51a1128438465c913a1684754d0 Binary files /dev/null and b/results/i2v/tiled up, high quali23.png.mp4 differ diff --git "a/results/i2v/tiled up, high quali6.1.\346\234\272\345\231\250\344\272\272\350\276\205\345\212\251\344\272\272\347\261\273\344\277\256\350\267\257\344\270\255\346\231\257.png.mp4" "b/results/i2v/tiled up, high quali6.1.\346\234\272\345\231\250\344\272\272\350\276\205\345\212\251\344\272\272\347\261\273\344\277\256\350\267\257\344\270\255\346\231\257.png.mp4" new file mode 100644 index 0000000000000000000000000000000000000000..108582182bbb9e6acfacb9a6978d9e89dda8abca Binary files /dev/null and "b/results/i2v/tiled up, high quali6.1.\346\234\272\345\231\250\344\272\272\350\276\205\345\212\251\344\272\272\347\261\273\344\277\256\350\267\257\344\270\255\346\231\257.png.mp4" differ diff --git a/results/i2v/tilt up, high qualit2.png.mp4 b/results/i2v/tilt up, high qualit2.png.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..6c4874d02d8a1b5597b8ff120fb69252de448357 Binary files /dev/null and b/results/i2v/tilt up, high qualit2.png.mp4 differ diff --git a/results/i2v/tilt up, high qualit23.png.mp4 b/results/i2v/tilt up, high qualit23.png.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..3421e17df09ff125821d049532a6da4992d01b4e Binary files /dev/null and b/results/i2v/tilt up, high qualit23.png.mp4 differ diff --git a/results/i2v/tilt up, high qualit35.png.mp4 b/results/i2v/tilt up, high qualit35.png.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..b623ffa6779e98f142c44cb0f8fe66f06cd97982 Binary files /dev/null and b/results/i2v/tilt up, high qualit35.png.mp4 differ diff --git "a/results/i2v/tilt up, high qualit6.1.\346\234\272\345\231\250\344\272\272\350\276\205\345\212\251\344\272\272\347\261\273\344\277\256\350\267\257\344\270\255\346\231\257.png.mp4" "b/results/i2v/tilt up, high qualit6.1.\346\234\272\345\231\250\344\272\272\350\276\205\345\212\251\344\272\272\347\261\273\344\277\256\350\267\257\344\270\255\346\231\257.png.mp4" new file mode 100644 index 0000000000000000000000000000000000000000..7110df9f8ed8f373792ff578053e98bf8be89121 Binary files /dev/null and "b/results/i2v/tilt up, high qualit6.1.\346\234\272\345\231\250\344\272\272\350\276\205\345\212\251\344\272\272\347\261\273\344\277\256\350\267\257\344\270\255\346\231\257.png.mp4" differ diff --git a/results/i2v/tilt up, high quality, stable , slow motion..mp4 b/results/i2v/tilt up, high quality, stable , slow motion..mp4 new file mode 100644 index 0000000000000000000000000000000000000000..bb8029275cc447067ba8d126592b463699d088ec Binary files /dev/null and b/results/i2v/tilt up, high quality, stable , slow motion..mp4 differ diff --git a/results/i2v/zoom in, high qualit2.png.mp4 b/results/i2v/zoom in, high qualit2.png.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..a2102ade2bb7d4d6dafe46378f4768ec04d42bc6 Binary files /dev/null and b/results/i2v/zoom in, high qualit2.png.mp4 differ diff --git a/results/i2v/zoom in, high qualit23.png.mp4 b/results/i2v/zoom in, high qualit23.png.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..397cbeaded68fc0aeeb3ef7b8ba8b7cf3d56af13 Binary files /dev/null and b/results/i2v/zoom in, high qualit23.png.mp4 differ diff --git a/results/i2v/zoom in, high qualit35.png.mp4 b/results/i2v/zoom in, high qualit35.png.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..4a968875013b81899b1b4a83d1cce0871505bb86 Binary files /dev/null and b/results/i2v/zoom in, high qualit35.png.mp4 differ diff --git "a/results/i2v/zoom in, high qualit6.1.\346\234\272\345\231\250\344\272\272\350\276\205\345\212\251\344\272\272\347\261\273\344\277\256\350\267\257\344\270\255\346\231\257.png.mp4" "b/results/i2v/zoom in, high qualit6.1.\346\234\272\345\231\250\344\272\272\350\276\205\345\212\251\344\272\272\347\261\273\344\277\256\350\267\257\344\270\255\346\231\257.png.mp4" new file mode 100644 index 0000000000000000000000000000000000000000..7d6e4e5fabfdebc4aa1b996cb9b197c4ab8a3630 Binary files /dev/null and "b/results/i2v/zoom in, high qualit6.1.\346\234\272\345\231\250\344\272\272\350\276\205\345\212\251\344\272\272\347\261\273\344\277\256\350\267\257\344\270\255\346\231\257.png.mp4" differ diff --git a/results/transition/smooth_transition.mp4 b/results/transition/smooth_transition.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..5232357d98a9183c2e681958bf6d45301b2c90da Binary files /dev/null and b/results/transition/smooth_transition.mp4 differ diff --git a/sample_scripts/with_mask_sample.py b/sample_scripts/with_mask_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..bd50246da1c74242834534330e1324834ec0a0c7 --- /dev/null +++ b/sample_scripts/with_mask_sample.py @@ -0,0 +1,281 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Sample new images from a pre-trained DiT. +""" +import os +import sys +import math +import docx +try: + import utils + + from diffusion import create_diffusion + from download import find_model +except: + # sys.path.append(os.getcwd()) + sys.path.append(os.path.split(sys.path[0])[0]) + # sys.path[0] + # os.path.split(sys.path[0]) + + + import utils + + from diffusion import create_diffusion + from download import find_model + +import torch +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +import argparse +import torchvision + +from einops import rearrange +from models import get_models +from torchvision.utils import save_image +from diffusers.models import AutoencoderKL +from models.clip import TextEmbedder +from omegaconf import OmegaConf +from PIL import Image +import numpy as np +from torchvision import transforms +sys.path.append("..") +from datasets import video_transforms +from utils import mask_generation_before +from natsort import natsorted +from diffusers.utils.import_utils import is_xformers_available + + +doc = docx.Document("/mnt/petrelfs/zhouyan/tmp/星际旅行.docx") +start = 1 +p_dict = {} +for param in doc.paragraphs: + p_dict[start] = param.text + start = start+1 +# def get_input(args): +def get_input(path,args): + input_path = path + # input_path = args.input_path + transform_video = transforms.Compose([ + video_transforms.ToTensorVideo(), # TCHW + video_transforms.ResizeVideo((args.image_h, args.image_w)), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) + ]) + temporal_sample_func = video_transforms.TemporalRandomCrop(args.num_frames * args.frame_interval) + if input_path is not None: + print(f'loading video from {input_path}') + if os.path.isdir(input_path): + file_list = os.listdir(input_path) + video_frames = [] + if args.mask_type.startswith('onelast'): + num = int(args.mask_type.split('onelast')[-1]) + # get first and last frame + first_frame_path = os.path.join(input_path, natsorted(file_list)[0]) + last_frame_path = os.path.join(input_path, natsorted(file_list)[-1]) + first_frame = torch.as_tensor(np.array(Image.open(first_frame_path), dtype=np.uint8, copy=True)).unsqueeze(0) + last_frame = torch.as_tensor(np.array(Image.open(last_frame_path), dtype=np.uint8, copy=True)).unsqueeze(0) + for i in range(num): + video_frames.append(first_frame) + # add zeros to frames + num_zeros = args.num_frames-2*num + for i in range(num_zeros): + zeros = torch.zeros_like(first_frame) + video_frames.append(zeros) + for i in range(num): + video_frames.append(last_frame) + n = 0 + video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w + video_frames = transform_video(video_frames) + else: + for file in file_list: + if file.endswith('jpg') or file.endswith('png'): + image = torch.as_tensor(np.array(Image.open(os.path.join(input_path,file)), dtype=np.uint8, copy=True)).unsqueeze(0) + video_frames.append(image) + else: + continue + n = 0 + video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w + video_frames = transform_video(video_frames) + return video_frames, n + elif os.path.isfile(input_path): + _, full_file_name = os.path.split(input_path) + file_name, extention = os.path.splitext(full_file_name) + if extention == '.jpg' or extention == '.png': + # raise TypeError('a single image is not supported yet!!') + print("reading video from a image") + video_frames = [] + num = int(args.mask_type.split('first')[-1]) + first_frame = torch.as_tensor(np.array(Image.open(input_path), dtype=np.uint8, copy=True)).unsqueeze(0) + for i in range(num): + video_frames.append(first_frame) + num_zeros = args.num_frames-num + for i in range(num_zeros): + zeros = torch.zeros_like(first_frame) + video_frames.append(zeros) + n = 0 + video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w + video_frames = transform_video(video_frames) + return video_frames, n + else: + raise TypeError(f'{extention} is not supported !!') + else: + raise ValueError('Please check your path input!!') + else: + # raise ValueError('Need to give a video or some images') + print('given video is None, using text to video') + video_frames = torch.zeros(16,3,args.latent_h,args.latent_w,dtype=torch.uint8) + args.mask_type = 'all' + video_frames = transform_video(video_frames) + n = 0 + return video_frames, n + +def auto_inpainting(args, video_input, masked_video, mask, prompt, vae, text_encoder, diffusion, model, device,): + b,f,c,h,w=video_input.shape + latent_h = args.image_size[0] // 8 + latent_w = args.image_size[1] // 8 + + # prepare inputs + if args.use_fp16: + z = torch.randn(1, 4, args.num_frames, args.latent_h, args.latent_w, dtype=torch.float16, device=device) # b,c,f,h,w + masked_video = masked_video.to(dtype=torch.float16) + mask = mask.to(dtype=torch.float16) + else: + z = torch.randn(1, 4, args.num_frames, args.latent_h, args.latent_w, device=device) # b,c,f,h,w + + + masked_video = rearrange(masked_video, 'b f c h w -> (b f) c h w').contiguous() + masked_video = vae.encode(masked_video).latent_dist.sample().mul_(0.18215) + masked_video = rearrange(masked_video, '(b f) c h w -> b c f h w', b=b).contiguous() + mask = torch.nn.functional.interpolate(mask[:,:,0,:], size=(latent_h, latent_w)).unsqueeze(1) + + # classifier_free_guidance + if args.do_classifier_free_guidance: + masked_video = torch.cat([masked_video] * 2) + mask = torch.cat([mask] * 2) + z = torch.cat([z] * 2) + prompt_all = [prompt] + [args.negative_prompt] + + else: + masked_video = masked_video + mask = mask + z = z + prompt_all = [prompt] + + text_prompt = text_encoder(text_prompts=prompt_all, train=False) + model_kwargs = dict(encoder_hidden_states=text_prompt, + class_labels=None, + cfg_scale=args.cfg_scale, + use_fp16=args.use_fp16,) # tav unet + + # Sample images: + if args.sample_method == 'ddim': + samples = diffusion.ddim_sample_loop( + model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device, \ + mask=mask, x_start=masked_video, use_concat=args.use_mask + ) + elif args.sample_method == 'ddpm': + samples = diffusion.p_sample_loop( + model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device, \ + mask=mask, x_start=masked_video, use_concat=args.use_mask + ) + samples, _ = samples.chunk(2, dim=0) # [1, 4, 16, 32, 32] + if args.use_fp16: + samples = samples.to(dtype=torch.float16) + + video_clip = samples[0].permute(1, 0, 2, 3).contiguous() # [16, 4, 32, 32] + video_clip = vae.decode(video_clip / 0.18215).sample # [16, 3, 256, 256] + return video_clip + +def main(args): + # Setup PyTorch: + if args.seed: + torch.manual_seed(args.seed) + torch.set_grad_enabled(False) + device = "cuda" if torch.cuda.is_available() else "cpu" + # device = "cpu" + + if args.ckpt is None: + raise ValueError("Please specify a checkpoint path using --ckpt ") + + # Load model: + latent_h = args.image_size[0] // 8 + latent_w = args.image_size[1] // 8 + args.image_h = args.image_size[0] + args.image_w = args.image_size[1] + args.latent_h = latent_h + args.latent_w = latent_w + print('loading model') + model = get_models(args).to(device) + + if args.use_compile: + model = torch.compile(model) + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + model.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + # load model + ckpt_path = args.ckpt + state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)['ema'] + model.load_state_dict(state_dict) + print('loading succeed') + + model.eval() # important! + pretrained_model_path = args.pretrained_model_path + diffusion = create_diffusion(str(args.num_sampling_steps)) + vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device) + text_encoder = TextEmbedder(pretrained_model_path).to(device) + if args.use_fp16: + print('Warnning: using half percision for inferencing!') + vae.to(dtype=torch.float16) + model.to(dtype=torch.float16) + text_encoder.to(dtype=torch.float16) + + # Labels to condition the model with (feel free to change): + prompt = args.text_prompt + if prompt ==[]: + prompt = args.input_path.split('/')[-1].split('.')[0].replace('_', ' ') + else: + prompt = prompt[0] + prompt_base = prompt.replace(' ','_') + prompt = prompt + args.additional_prompt + + + + if not os.path.exists(os.path.join(args.save_img_path)): + os.makedirs(os.path.join(args.save_img_path)) + for file in os.listdir(args.img_path): + video_input, reserve_frames = get_input(os.path.join(args.img_path,file),args) + video_input = video_input.to(device).unsqueeze(0) + mask = mask_generation_before(args.mask_type, video_input.shape, video_input.dtype, device) + masked_video = video_input * (mask == 0) + prompt = "tilt up, high quality, stable " + prompt = prompt + args.additional_prompt + video_clip = auto_inpainting(args, video_input, masked_video, mask, prompt, vae, text_encoder, diffusion, model, device,) + video_ = ((video_clip * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1) + torchvision.io.write_video(os.path.join(args.save_img_path, prompt[0:20]+file+ '.mp4'), video_, fps=8) + # video_input, researve_frames = get_input(args) # f,c,h,w + # video_input = video_input.to(device).unsqueeze(0) # b,f,c,h,w + # mask = mask_generation_before(args.mask_type, video_input.shape, video_input.dtype, device) # b,f,c,h,w + # # TODO: change the first3 to last3 + # masked_video = video_input * (mask == 0) + + # video_clip = auto_inpainting(args, video_input, masked_video, mask, prompt, vae, text_encoder, diffusion, model, device,) + # video_ = ((video_clip * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1) + # torchvision.io.write_video(os.path.join(args.save_img_path, prompt_base+ '.mp4'), video_, fps=8) + print(f'save in {args.save_img_path}') + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, default="./configs/sample_mask.yaml") + parser.add_argument("--run-time", type=int, default=0) + args = parser.parse_args() + omega_conf = OmegaConf.load(args.config) + omega_conf.run_time = args.run_time + main(omega_conf) diff --git a/utils.py b/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c116b19946e3e53dbcd66e49b4d4aba98920187a --- /dev/null +++ b/utils.py @@ -0,0 +1,374 @@ +import os +import math +import torch +import logging +import subprocess +import numpy as np +import torch.distributed as dist + +# from torch._six import inf +from torch import inf +from PIL import Image +from typing import Union, Iterable +from collections import OrderedDict +from torch.utils.tensorboard import SummaryWriter +_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]] + +################################################################################# +# Training Helper Functions # +################################################################################# +def fetch_files_by_numbers(start_number, count, file_list): + file_numbers = range(start_number, start_number + count) + found_files = [] + for file_number in file_numbers: + file_number_padded = str(file_number).zfill(2) + for file_name in file_list: + if file_name.endswith(file_number_padded + '.csv'): + found_files.append(file_name) + break # Stop searching once a file is found for the current number + return found_files + +################################################################################# +# Training Clip Gradients # +################################################################################# + +def get_grad_norm( + parameters: _tensor_or_tensors, norm_type: float = 2.0) -> torch.Tensor: + r""" + Copy from torch.nn.utils.clip_grad_norm_ + + Clips gradient norm of an iterable of parameters. + + The norm is computed over all gradients together, as if they were + concatenated into a single vector. Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of the gradients from :attr:`parameters` is ``nan``, + ``inf``, or ``-inf``. Default: False (will switch to True in the future) + + Returns: + Total norm of the parameter gradients (viewed as a single vector). + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + grads = [p.grad for p in parameters if p.grad is not None] + norm_type = float(norm_type) + if len(grads) == 0: + return torch.tensor(0.) + device = grads[0].device + if norm_type == inf: + norms = [g.detach().abs().max().to(device) for g in grads] + total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms)) + else: + total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type) + return total_norm + +def clip_grad_norm_( + parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0, + error_if_nonfinite: bool = False, clip_grad = True) -> torch.Tensor: + r""" + Copy from torch.nn.utils.clip_grad_norm_ + + Clips gradient norm of an iterable of parameters. + + The norm is computed over all gradients together, as if they were + concatenated into a single vector. Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of the gradients from :attr:`parameters` is ``nan``, + ``inf``, or ``-inf``. Default: False (will switch to True in the future) + + Returns: + Total norm of the parameter gradients (viewed as a single vector). + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + grads = [p.grad for p in parameters if p.grad is not None] + max_norm = float(max_norm) + norm_type = float(norm_type) + if len(grads) == 0: + return torch.tensor(0.) + device = grads[0].device + if norm_type == inf: + norms = [g.detach().abs().max().to(device) for g in grads] + total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms)) + else: + total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type) + # print(total_norm) + + if clip_grad: + if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): + raise RuntimeError( + f'The total norm of order {norm_type} for gradients from ' + '`parameters` is non-finite, so it cannot be clipped. To disable ' + 'this error and scale the gradients by the non-finite norm anyway, ' + 'set `error_if_nonfinite=False`') + clip_coef = max_norm / (total_norm + 1e-6) + # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so + # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization + # when the gradients do not reside in CPU memory. + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + for g in grads: + g.detach().mul_(clip_coef_clamped.to(g.device)) + # gradient_cliped = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type) + # print(gradient_cliped) + return total_norm + +def separation_content_motion(video_clip): + """ + separate coontent and motion in a given video + Args: + video_clip, a give video clip, [B F C H W] + + Return: + base frame, [B, 1, C, H, W] + motions, [B, F-1, C, H, W], + the first is base frame, + the second is motions based on base frame + """ + total_frames = video_clip.shape[1] + base_frame = video_clip[0] + motions = [video_clip[i] - base_frame for i in range(1, total_frames)] + motions = torch.cat(motions, dim=1) + return base_frame, motions + +def get_experiment_dir(root_dir, args): + if args.use_compile: + root_dir += '-Compile' # speedup by torch compile + if args.fixed_spatial: + root_dir += '-FixedSpa' + if args.enable_xformers_memory_efficient_attention: + root_dir += '-Xfor' + if args.gradient_checkpointing: + root_dir += '-Gc' + if args.mixed_precision: + root_dir += '-Amp' + if args.image_size == 512: + root_dir += '-512' + return root_dir + +################################################################################# +# Training Logger # +################################################################################# + +def create_logger(logging_dir): + """ + Create a logger that writes to a log file and stdout. + """ + if dist.get_rank() == 0: # real logger + logging.basicConfig( + level=logging.INFO, + # format='[\033[34m%(asctime)s\033[0m] %(message)s', + format='[%(asctime)s] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] + ) + logger = logging.getLogger(__name__) + + else: # dummy logger (does nothing) + logger = logging.getLogger(__name__) + logger.addHandler(logging.NullHandler()) + return logger + +def create_accelerate_logger(logging_dir, is_main_process=False): + """ + Create a logger that writes to a log file and stdout. + """ + if is_main_process: # real logger + logging.basicConfig( + level=logging.INFO, + # format='[\033[34m%(asctime)s\033[0m] %(message)s', + format='[%(asctime)s] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] + ) + logger = logging.getLogger(__name__) + else: # dummy logger (does nothing) + logger = logging.getLogger(__name__) + logger.addHandler(logging.NullHandler()) + return logger + + +def create_tensorboard(tensorboard_dir): + """ + Create a tensorboard that saves losses. + """ + if dist.get_rank() == 0: # real tensorboard + # tensorboard + writer = SummaryWriter(tensorboard_dir) + + return writer + +def write_tensorboard(writer, *args): + ''' + write the loss information to a tensorboard file. + Only for pytorch DDP mode. + ''' + if dist.get_rank() == 0: # real tensorboard + writer.add_scalar(args[0], args[1], args[2]) + +################################################################################# +# EMA Update/ DDP Training Utils # +################################################################################# + +@torch.no_grad() +def update_ema(ema_model, model, decay=0.9999): + """ + Step the EMA model towards the current model. + """ + ema_params = OrderedDict(ema_model.named_parameters()) + model_params = OrderedDict(model.named_parameters()) + + for name, param in model_params.items(): + # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed + if param.requires_grad: + ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) + +def requires_grad(model, flag=True): + """ + Set requires_grad flag for all parameters in a model. + """ + for p in model.parameters(): + p.requires_grad = flag + +def cleanup(): + """ + End DDP training. + """ + dist.destroy_process_group() + + +def setup_distributed(backend="nccl", port=None): + """Initialize distributed training environment. + support both slurm and torch.distributed.launch + see torch.distributed.init_process_group() for more details + """ + num_gpus = torch.cuda.device_count() + + if "SLURM_JOB_ID" in os.environ: + rank = int(os.environ["SLURM_PROCID"]) + world_size = int(os.environ["SLURM_NTASKS"]) + node_list = os.environ["SLURM_NODELIST"] + addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1") + # specify master port + if port is not None: + os.environ["MASTER_PORT"] = str(port) + elif "MASTER_PORT" not in os.environ: + # os.environ["MASTER_PORT"] = "29566" + os.environ["MASTER_PORT"] = str(29566 + num_gpus) + if "MASTER_ADDR" not in os.environ: + os.environ["MASTER_ADDR"] = addr + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["LOCAL_RANK"] = str(rank % num_gpus) + os.environ["RANK"] = str(rank) + else: + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + # torch.cuda.set_device(rank % num_gpus) + + dist.init_process_group( + backend=backend, + world_size=world_size, + rank=rank, + ) + +################################################################################# +# Testing Utils # +################################################################################# + +def save_video_grid(video, nrow=None): + b, t, h, w, c = video.shape + + if nrow is None: + nrow = math.ceil(math.sqrt(b)) + ncol = math.ceil(b / nrow) + padding = 1 + video_grid = torch.zeros((t, (padding + h) * nrow + padding, + (padding + w) * ncol + padding, c), dtype=torch.uint8) + + print(video_grid.shape) + for i in range(b): + r = i // ncol + c = i % ncol + start_r = (padding + h) * r + start_c = (padding + w) * c + video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i] + + return video_grid + +def save_videos_grid_tav(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=8): + from einops import rearrange + import imageio + import torchvision + + videos = rearrange(videos, "b c t h w -> t b c h w") + outputs = [] + for x in videos: + x = torchvision.utils.make_grid(x, nrow=n_rows) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + if rescale: + x = (x + 1.0) / 2.0 # -1,1 -> 0,1 + x = (x * 255).numpy().astype(np.uint8) + outputs.append(x) + + # os.makedirs(os.path.dirname(path), exist_ok=True) + imageio.mimsave(path, outputs, fps=fps) + + +################################################################################# +# MMCV Utils # +################################################################################# + + +def collect_env(): + # Copyright (c) OpenMMLab. All rights reserved. + from mmcv.utils import collect_env as collect_base_env + from mmcv.utils import get_git_hash + """Collect the information of the running environments.""" + + env_info = collect_base_env() + env_info['MMClassification'] = get_git_hash()[:7] + + for name, val in env_info.items(): + print(f'{name}: {val}') + + print(torch.cuda.get_arch_list()) + print(torch.version.cuda) + + +################################################################################# +# Long video generation Utils # +################################################################################# + +def mask_generation_before(mask_type, shape, dtype, device, dropout_prob=0.0, use_image_num=0): + b, f, c, h, w = shape + if mask_type.startswith('first'): + num = int(mask_type.split('first')[-1]) + mask_f = torch.cat([torch.zeros(1, num, 1, 1, 1, dtype=dtype, device=device), + torch.ones(1, f-num, 1, 1, 1, dtype=dtype, device=device)], dim=1) + mask = mask_f.expand(b, -1, c, h, w) + elif mask_type.startswith('all'): + mask = torch.ones(b,f,c,h,w,dtype=dtype,device=device) + elif mask_type.startswith('onelast'): + num = int(mask_type.split('onelast')[-1]) + mask_one = torch.zeros(1,1,1,1,1, dtype=dtype, device=device) + mask_mid = torch.ones(1,f-2*num,1,1,1,dtype=dtype, device=device) + mask_last = torch.zeros_like(mask_one) + mask = torch.cat([mask_one]*num + [mask_mid] + [mask_last]*num, dim=1) + mask = mask.expand(b, -1, c, h, w) + else: + raise ValueError(f"Invalid mask type: {mask_type}") + return mask