diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..f5ba256c2767b7b8dabdbeee17ff205b5dec74e3 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+input/transition/1/2-Wide[[:space:]]angle[[:space:]]shot[[:space:]]of[[:space:]]an[[:space:]]alien[[:space:]]planet[[:space:]]with[[:space:]]cherry[[:space:]]blossom[[:space:]]forest-2.png filter=lfs diff=lfs merge=lfs -text
diff --git a/Close-up-essence-is-poured-from-bottleKodak-Vision3-50_slow-motion_0000_001.mp4 b/Close-up-essence-is-poured-from-bottleKodak-Vision3-50_slow-motion_0000_001.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..4597d28c2c40c9453bb7ad0d50e97c941256db2a
Binary files /dev/null and b/Close-up-essence-is-poured-from-bottleKodak-Vision3-50_slow-motion_0000_001.mp4 differ
diff --git a/Close-up_essence_is_poured_from_bottleKodak_Vision.png b/Close-up_essence_is_poured_from_bottleKodak_Vision.png
new file mode 100644
index 0000000000000000000000000000000000000000..5b8b907e01d44e7384bd405fcc33738e03743c0a
Binary files /dev/null and b/Close-up_essence_is_poured_from_bottleKodak_Vision.png differ
diff --git a/README.md b/README.md
index 2b1b9352da3370aa6a359f52027aaa0c1cbd42cf..3d44a9eb6f817c0400a464757222b47bb91e6382 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,65 @@
----
-title: Seine Gradio
-emoji: 🦀
-colorFrom: pink
-colorTo: blue
-sdk: gradio
-sdk_version: 4.7.1
-app_file: app.py
-pinned: false
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+# SEINE
+This repository is the official implementation of [SEINE](https://arxiv.org/abs/2310.20700).
+
+**[SEINE: Short-to-Long Video Diffusion Model for Generative Transition and Prediction](https://arxiv.org/abs/2310.20700)**
+
+[Arxiv Report](https://arxiv.org/abs/2310.20700) | [Project Page](https://vchitect.github.io/SEINE-project/)
+
+
+
+
+## Setups for Inference
+
+### Prepare Environment
+```
+conda env create -f env.yaml
+conda activate seine
+```
+
+### Downlaod our model and T2I base model
+Download our model checkpoint from [Google Drive](https://drive.google.com/drive/folders/1cWfeDzKJhpb0m6HA5DoMOH0_ItuUY95b?usp=sharing) and save to directory of ```pre-trained```
+
+
+Our model is based on Stable diffusion v1.4, you may download [Stable Diffusion v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) to the director of ``` pre-trained ```
+
+Now under `./pretrained`, you should be able to see the following:
+```
+├── pretrained_models
+│ ├── seine.pt
+│ ├── stable-diffusion-v1-4
+│ │ ├── ...
+└── └── ├── ...
+ ├── ...
+```
+
+#### Inference for I2V
+```python
+python sample_scripts/with_mask_sample.py --config configs/sample_i2v.yaml
+```
+The generated video will be saved in ```./results/i2v```.
+
+#### Inference for Transition
+```python
+python sample_scripts/with_mask_sample.py --config configs/sample_transition.yaml
+```
+The generated video will be saved in ```./results/transition```.
+
+
+
+#### More Details
+You can modify ```./configs/sample_mask.yaml``` to change the generation conditions.
+For example,
+```ckpt``` is used to specify a model checkpoint.
+```text_prompt``` is used to describe the content of the video.
+```input_path``` is used to specify the path to the image.
+
+
+## BibTeX
+```bibtex
+@article{chen2023seine,
+title={SEINE: Short-to-Long Video Diffusion Model for Generative Transition and Prediction},
+author={Chen, Xinyuan and Wang, Yaohui and Zhang, Lingjun and Zhuang, Shaobin and Ma, Xin and Yu, Jiashuo and Wang, Yali and Lin, Dahua and Qiao, Yu and Liu, Ziwei},
+journal={arXiv preprint arXiv:2310.20700},
+year={2023}
+}
+```
\ No newline at end of file
diff --git a/The-picture-shows-the-beauty-of-the-sea-and-at-the-sam_slow-motion_0000_11301.mp4 b/The-picture-shows-the-beauty-of-the-sea-and-at-the-sam_slow-motion_0000_11301.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..26919637200af54d10dc81edc3aee61db78b975d
Binary files /dev/null and b/The-picture-shows-the-beauty-of-the-sea-and-at-the-sam_slow-motion_0000_11301.mp4 differ
diff --git a/The-picture-shows-the-beauty-of-the-sea-and-at-the-sam_slow-motion_0000_6600.mp4 b/The-picture-shows-the-beauty-of-the-sea-and-at-the-sam_slow-motion_0000_6600.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..465b1835c7ae07c6fe7d5b00b30860fdd115578e
Binary files /dev/null and b/The-picture-shows-the-beauty-of-the-sea-and-at-the-sam_slow-motion_0000_6600.mp4 differ
diff --git a/The_picture_shows_the_beauty_of_the_sea.png b/The_picture_shows_the_beauty_of_the_sea.png
new file mode 100644
index 0000000000000000000000000000000000000000..6f5847286d1507fe4c66f218755eff2d5d154aca
Binary files /dev/null and b/The_picture_shows_the_beauty_of_the_sea.png differ
diff --git a/The_picture_shows_the_beauty_of_the_sea_.jpg b/The_picture_shows_the_beauty_of_the_sea_.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..cd2b1e2b31adf8f5e150df283d2129f78a64308e
Binary files /dev/null and b/The_picture_shows_the_beauty_of_the_sea_.jpg differ
diff --git a/__pycache__/download.cpython-310.pyc b/__pycache__/download.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..38c01e289a088357ff4d3badca50d3d4aa6b57ca
Binary files /dev/null and b/__pycache__/download.cpython-310.pyc differ
diff --git a/__pycache__/download.cpython-311.pyc b/__pycache__/download.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f7eede41205a3440a8bfd431afbcdcf884908cf1
Binary files /dev/null and b/__pycache__/download.cpython-311.pyc differ
diff --git a/__pycache__/download.cpython-39.pyc b/__pycache__/download.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c313425340f296f3ffcbc19e7d4d8f85324b0a02
Binary files /dev/null and b/__pycache__/download.cpython-39.pyc differ
diff --git a/__pycache__/utils.cpython-310.pyc b/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..26147e0d3fbc6bd37be6964535a3efb5bf018cd2
Binary files /dev/null and b/__pycache__/utils.cpython-310.pyc differ
diff --git a/__pycache__/utils.cpython-311.pyc b/__pycache__/utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..caa8846ff5a21ce40d0485a2dee58476f3a76b81
Binary files /dev/null and b/__pycache__/utils.cpython-311.pyc differ
diff --git a/__pycache__/utils.cpython-39.pyc b/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..616aae63f41717e7b04df098aed77a775746acfa
Binary files /dev/null and b/__pycache__/utils.cpython-39.pyc differ
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..70a69d821d6526af85a6cdd1b44f402c4b0b4c07
--- /dev/null
+++ b/app.py
@@ -0,0 +1,183 @@
+import gradio as gr
+from image_to_video import model_i2v_fun, get_input, auto_inpainting, setup_seed
+from omegaconf import OmegaConf
+import torch
+from diffusers.utils.import_utils import is_xformers_available
+import torchvision
+from utils import mask_generation_before
+import os
+import cv2
+
+config_path = "/mnt/petrelfs/zhouyan/project/i2v/configs/sample_i2v.yaml"
+args = OmegaConf.load(config_path)
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+# ------- get model ---------------
+# model_i2V = model_i2v_fun()
+# model_i2V.to("cuda")
+
+# vae, model, text_encoder, diffusion = model_i2v_fun(args)
+# vae.to(device)
+# model.to(device)
+# text_encoder.to(device)
+
+# if args.use_fp16:
+# vae.to(dtype=torch.float16)
+# model.to(dtype=torch.float16)
+# text_encoder.to(dtype=torch.float16)
+
+# if args.enable_xformers_memory_efficient_attention and device=="cuda":
+# if is_xformers_available():
+# model.enable_xformers_memory_efficient_attention()
+# else:
+# raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+
+css = """
+h1 {
+ text-align: center;
+}
+#component-0 {
+ max-width: 730px;
+ margin: auto;
+}
+"""
+
+def infer(prompt, image_inp, seed_inp, ddim_steps):
+ setup_seed(seed_inp)
+ args.num_sampling_steps = ddim_steps
+ ###先测试Image的返回类型
+ print(prompt, seed_inp, ddim_steps, type(image_inp))
+ img = cv2.imread(image_inp)
+ new_size = [img.shape[0],img.shape[1]]
+ # if(img.shape[0]==512 and img.shape[1]==512):
+ # args.image_size = [512,512]
+ # elif(img.shape[0]==320 and img.shape[1]==512):
+ # args.image_size = [320, 512]
+ # elif(img.shape[0]==292 and img.shape[1]==512):
+ # args.image_size = [292,512]
+ # else:
+ # raise ValueError("Please enter image of right size")
+ # print(args.image_size)
+ args.image_size = new_size
+
+ vae, model, text_encoder, diffusion = model_i2v_fun(args)
+ vae.to(device)
+ model.to(device)
+ text_encoder.to(device)
+
+ if args.use_fp16:
+ vae.to(dtype=torch.float16)
+ model.to(dtype=torch.float16)
+ text_encoder.to(dtype=torch.float16)
+
+ if args.enable_xformers_memory_efficient_attention and device=="cuda":
+ if is_xformers_available():
+ model.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+
+ video_input, reserve_frames = get_input(image_inp, args)
+ video_input = video_input.to(device).unsqueeze(0)
+ mask = mask_generation_before(args.mask_type, video_input.shape, video_input.dtype, device)
+ masked_video = video_input * (mask == 0)
+ prompt = "tilt up, high quality, stable "
+ prompt = prompt + args.additional_prompt
+ video_clip = auto_inpainting(args, video_input, masked_video, mask, prompt, vae, text_encoder, diffusion, model, device,)
+ video_ = ((video_clip * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1)
+ torchvision.io.write_video(os.path.join(args.save_img_path, prompt+ '.mp4'), video_, fps=8)
+
+
+
+ # video = model_i2V(prompt, image_inp, seed_inp, ddim_steps)
+
+ return os.path.join(args.save_img_path, prompt+ '.mp4')
+
+
+
+def clean():
+ # return gr.Image.update(value=None, visible=False), gr.Video.update(value=None)
+ return gr.Video.update(value=None)
+
+
+title = """
+
+
+
+ SEINE: Image-to-Video generation
+
+
+
+ Apply SEINE to generate a video
+
+
+"""
+
+
+
+with gr.Blocks(css='style.css') as demo:
+ gr.Markdown("SEINE: Image-to-Video generation")
+ with gr.Column(elem_id="col-container"):
+ # gr.HTML(title)
+
+ with gr.Row():
+ with gr.Column():
+ image_inp = gr.Image(type='filepath')
+
+ with gr.Column():
+
+ prompt = gr.Textbox(label="Prompt", placeholder="enter prompt", show_label=True, elem_id="prompt-in")
+
+ with gr.Row():
+ # control_task = gr.Dropdown(label="Task", choices=["Text-2-video", "Image-2-video"], value="Text-2-video", multiselect=False, elem_id="controltask-in")
+ ddim_steps = gr.Slider(label='Steps', minimum=50, maximum=300, value=250, step=1)
+ seed_inp = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, value=250, elem_id="seed-in")
+
+ # ddim_steps = gr.Slider(label='Steps', minimum=50, maximum=300, value=250, step=1)
+
+
+
+ submit_btn = gr.Button("Generate video")
+ clean_btn = gr.Button("Clean video")
+
+ video_out = gr.Video(label="Video result", elem_id="video-output", width = 800)
+ inputs = [prompt,image_inp, seed_inp, ddim_steps]
+ outputs = [video_out]
+ ex = gr.Examples(
+ examples = [["/mnt/petrelfs/zhouyan/project/i2v/The_picture_shows_the_beauty_of_the_sea_.jpg","A video of the beauty of the sea",123,50],
+ ["/mnt/petrelfs/zhouyan/project/i2v/The_picture_shows_the_beauty_of_the_sea.png","A video of the beauty of the sea",123,50],
+ ["/mnt/petrelfs/zhouyan/project/i2v/Close-up_essence_is_poured_from_bottleKodak_Vision.png","A video of close-up essence is poured from bottleKodak Vision",123,50]],
+ fn = infer,
+ inputs = [image_inp, prompt, seed_inp, ddim_steps],
+ outputs=[video_out],
+ cache_examples=False
+
+
+ )
+ ex.dataset.headers = [""]
+ # gr.Markdown("some examples")
+ # with gr.Row():
+ # gr.Image(value="/mnt/petrelfs/zhouyan/project/i2v/The_picture_shows_the_beauty_of_the_sea_.jpg")
+ # gr.Image(value="/mnt/petrelfs/zhouyan/project/i2v/The_picture_shows_the_beauty_of_the_sea.png")
+ # gr.Image(value="/mnt/petrelfs/zhouyan/project/i2v/Close-up_essence_is_poured_from_bottleKodak_Vision.png")
+ # with gr.Row():
+ # gr.Video(value="/mnt/petrelfs/zhouyan/project/i2v/The-picture-shows-the-beauty-of-the-sea-and-at-the-sam_slow-motion_0000_11301.mp4")
+ # gr.Video(value="/mnt/petrelfs/zhouyan/project/i2v/The-picture-shows-the-beauty-of-the-sea-and-at-the-sam_slow-motion_0000_6600.mp4")
+ # gr.Video(value="/mnt/petrelfs/zhouyan/project/i2v/Close-up-essence-is-poured-from-bottleKodak-Vision3-50_slow-motion_0000_001.mp4")
+ # control_task.change(change_task_options, inputs=[control_task], outputs=[canny_opt, hough_opt, normal_opt], queue=False)
+ clean_btn.click(clean, inputs=[], outputs=[video_out], queue=False)
+ submit_btn.click(infer, inputs, outputs)
+ # share_button.click(None, [], [], _js=share_js)
+
+
+demo.queue(max_size=12).launch(server_name="0.0.0.0",server_port=7861)
+
+
diff --git a/configs/sample_i2v.yaml b/configs/sample_i2v.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1adc88be3d46317bc3abdeba0f428d371e05319e
--- /dev/null
+++ b/configs/sample_i2v.yaml
@@ -0,0 +1,36 @@
+
+ckpt: "/mnt/petrelfs/share_data/chenxinyuan/code/SEINE-release/pre-trained/seine.pt"
+# save_img_path: "./results/i2v/"
+save_img_path: "/mnt/petrelfs/share_data/zhouyan/gradio_i2v/"
+pretrained_model_path: "pre-trained/stable-diffusion-v1-4/"
+
+# model config:
+model: TAVU
+num_frames: 16
+frame_interval: 1
+image_size: [512, 512]
+#image_size: [320, 512]
+# image_size: [512, 512]
+
+# model speedup
+use_compile: False
+use_fp16: True
+enable_xformers_memory_efficient_attention: True
+img_path: "/mnt/petrelfs/zhouyan/tmp/last"
+# sample config:
+seed:
+run_time: 13
+cfg_scale: 8.0
+sample_method: 'ddpm'
+num_sampling_steps: 250
+text_prompt: ["slow motion"]
+additional_prompt: ", slow motion."
+negative_prompt: ""
+do_classifier_free_guidance: True
+
+# autoregressive config:
+# input_path: "/mnt/petrelfs/zhouyan/tmp/未来上海/WechatIMG9434.jpg"
+input_path: "/mnt/petrelfs/zhouyan/tmp/last"
+researve_frame: 1
+mask_type: "first1"
+use_mask: True
diff --git a/configs/sample_transition.yaml b/configs/sample_transition.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0f656444a90e63b08c908cb8736cca94774ccfd7
--- /dev/null
+++ b/configs/sample_transition.yaml
@@ -0,0 +1,33 @@
+
+ckpt: "pre-trained/0020000.pt"
+save_img_path: "./results/transition/"
+pretrained_model_path: "pre-trained/stable-diffusion-v1-4/"
+
+# model config:
+model: TAVU
+num_frames: 16
+frame_interval: 1
+#image_size: [240, 560]
+#image_size: [320, 512]
+image_size: [512, 512]
+
+# model speedup
+use_compile: False
+use_fp16: True
+enable_xformers_memory_efficient_attention: True
+
+# sample config:
+seed:
+run_time: 13
+cfg_scale: 8.0
+sample_method: 'ddpm'
+num_sampling_steps: 250
+text_prompt: ['smooth transition']
+additional_prompt: "smooth transition."
+negative_prompt: ""
+do_classifier_free_guidance: True
+
+# autoregressive config:
+input_path: 'input/transition/1'
+mask_type: "onelast1"
+use_mask: True
diff --git a/datasets/__pycache__/video_transforms.cpython-311.pyc b/datasets/__pycache__/video_transforms.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9e8defe15cd33423279052745be6019274e7ca01
Binary files /dev/null and b/datasets/__pycache__/video_transforms.cpython-311.pyc differ
diff --git a/datasets/__pycache__/video_transforms.cpython-39.pyc b/datasets/__pycache__/video_transforms.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bfc7a93b58f4080c24a9bc2765d04386a75dbd6f
Binary files /dev/null and b/datasets/__pycache__/video_transforms.cpython-39.pyc differ
diff --git a/datasets/video_transforms.py b/datasets/video_transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..30607a68859fb804370e4bfe8664eada2a840743
--- /dev/null
+++ b/datasets/video_transforms.py
@@ -0,0 +1,472 @@
+import torch
+import random
+import numbers
+from torchvision.transforms import RandomCrop, RandomResizedCrop
+from PIL import Image
+
+def _is_tensor_video_clip(clip):
+ if not torch.is_tensor(clip):
+ raise TypeError("clip should be Tensor. Got %s" % type(clip))
+
+ if not clip.ndimension() == 4:
+ raise ValueError("clip should be 4D. Got %dD" % clip.dim())
+
+ return True
+
+
+def center_crop_arr(pil_image, image_size):
+ """
+ Center cropping implementation from ADM.
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
+ """
+ while min(*pil_image.size) >= 2 * image_size:
+ pil_image = pil_image.resize(
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
+ )
+
+ scale = image_size / min(*pil_image.size)
+ pil_image = pil_image.resize(
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
+ )
+
+ arr = np.array(pil_image)
+ crop_y = (arr.shape[0] - image_size) // 2
+ crop_x = (arr.shape[1] - image_size) // 2
+ return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
+
+
+def crop(clip, i, j, h, w):
+ """
+ Args:
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
+ """
+ if len(clip.size()) != 4:
+ raise ValueError("clip should be a 4D tensor")
+ return clip[..., i : i + h, j : j + w]
+
+
+def resize(clip, target_size, interpolation_mode):
+ if len(target_size) != 2:
+ raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
+ return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
+
+def resize_scale(clip, target_size, interpolation_mode):
+ if len(target_size) != 2:
+ raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
+ H, W = clip.size(-2), clip.size(-1)
+ scale_ = target_size[0] / min(H, W)
+ return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
+
+def resize_with_scale_factor(clip, scale_factor, interpolation_mode):
+ return torch.nn.functional.interpolate(clip, scale_factor=scale_factor, mode=interpolation_mode, align_corners=False)
+
+def resize_scale_with_height(clip, target_size, interpolation_mode):
+ H, W = clip.size(-2), clip.size(-1)
+ scale_ = target_size / H
+ return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
+
+def resize_scale_with_weight(clip, target_size, interpolation_mode):
+ H, W = clip.size(-2), clip.size(-1)
+ scale_ = target_size / W
+ return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
+
+
+def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
+ """
+ Do spatial cropping and resizing to the video clip
+ Args:
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
+ i (int): i in (i,j) i.e coordinates of the upper left corner.
+ j (int): j in (i,j) i.e coordinates of the upper left corner.
+ h (int): Height of the cropped region.
+ w (int): Width of the cropped region.
+ size (tuple(int, int)): height and width of resized clip
+ Returns:
+ clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
+ """
+ if not _is_tensor_video_clip(clip):
+ raise ValueError("clip should be a 4D torch.tensor")
+ clip = crop(clip, i, j, h, w)
+ clip = resize(clip, size, interpolation_mode)
+ return clip
+
+
+def center_crop(clip, crop_size):
+ if not _is_tensor_video_clip(clip):
+ raise ValueError("clip should be a 4D torch.tensor")
+ h, w = clip.size(-2), clip.size(-1)
+ # print(clip.shape)
+ th, tw = crop_size
+ if h < th or w < tw:
+ # print(h, w)
+ raise ValueError("height {} and width {} must be no smaller than crop_size".format(h, w))
+
+ i = int(round((h - th) / 2.0))
+ j = int(round((w - tw) / 2.0))
+ return crop(clip, i, j, th, tw)
+
+
+def center_crop_using_short_edge(clip):
+ if not _is_tensor_video_clip(clip):
+ raise ValueError("clip should be a 4D torch.tensor")
+ h, w = clip.size(-2), clip.size(-1)
+ if h < w:
+ th, tw = h, h
+ i = 0
+ j = int(round((w - tw) / 2.0))
+ else:
+ th, tw = w, w
+ i = int(round((h - th) / 2.0))
+ j = 0
+ return crop(clip, i, j, th, tw)
+
+
+def random_shift_crop(clip):
+ '''
+ Slide along the long edge, with the short edge as crop size
+ '''
+ if not _is_tensor_video_clip(clip):
+ raise ValueError("clip should be a 4D torch.tensor")
+ h, w = clip.size(-2), clip.size(-1)
+
+ if h <= w:
+ long_edge = w
+ short_edge = h
+ else:
+ long_edge = h
+ short_edge =w
+
+ th, tw = short_edge, short_edge
+
+ i = torch.randint(0, h - th + 1, size=(1,)).item()
+ j = torch.randint(0, w - tw + 1, size=(1,)).item()
+ return crop(clip, i, j, th, tw)
+
+
+def to_tensor(clip):
+ """
+ Convert tensor data type from uint8 to float, divide value by 255.0 and
+ permute the dimensions of clip tensor
+ Args:
+ clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
+ Return:
+ clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
+ """
+ _is_tensor_video_clip(clip)
+ if not clip.dtype == torch.uint8:
+ raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
+ # return clip.float().permute(3, 0, 1, 2) / 255.0
+ return clip.float() / 255.0
+
+
+def normalize(clip, mean, std, inplace=False):
+ """
+ Args:
+ clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
+ mean (tuple): pixel RGB mean. Size is (3)
+ std (tuple): pixel standard deviation. Size is (3)
+ Returns:
+ normalized clip (torch.tensor): Size is (T, C, H, W)
+ """
+ if not _is_tensor_video_clip(clip):
+ raise ValueError("clip should be a 4D torch.tensor")
+ if not inplace:
+ clip = clip.clone()
+ mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
+ # print(mean)
+ std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
+ clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
+ return clip
+
+
+def hflip(clip):
+ """
+ Args:
+ clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
+ Returns:
+ flipped clip (torch.tensor): Size is (T, C, H, W)
+ """
+ if not _is_tensor_video_clip(clip):
+ raise ValueError("clip should be a 4D torch.tensor")
+ return clip.flip(-1)
+
+
+class RandomCropVideo:
+ def __init__(self, size):
+ if isinstance(size, numbers.Number):
+ self.size = (int(size), int(size))
+ else:
+ self.size = size
+
+ def __call__(self, clip):
+ """
+ Args:
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
+ Returns:
+ torch.tensor: randomly cropped video clip.
+ size is (T, C, OH, OW)
+ """
+ i, j, h, w = self.get_params(clip)
+ return crop(clip, i, j, h, w)
+
+ def get_params(self, clip):
+ h, w = clip.shape[-2:]
+ th, tw = self.size
+
+ if h < th or w < tw:
+ raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
+
+ if w == tw and h == th:
+ return 0, 0, h, w
+
+ i = torch.randint(0, h - th + 1, size=(1,)).item()
+ j = torch.randint(0, w - tw + 1, size=(1,)).item()
+
+ return i, j, th, tw
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(size={self.size})"
+
+class CenterCropResizeVideo:
+ '''
+ First use the short side for cropping length,
+ center crop video, then resize to the specified size
+ '''
+ def __init__(
+ self,
+ size,
+ interpolation_mode="bilinear",
+ ):
+ if isinstance(size, tuple):
+ if len(size) != 2:
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
+ self.size = size
+ else:
+ self.size = (size, size)
+
+ self.interpolation_mode = interpolation_mode
+
+
+ def __call__(self, clip):
+ """
+ Args:
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
+ Returns:
+ torch.tensor: scale resized / center cropped video clip.
+ size is (T, C, crop_size, crop_size)
+ """
+ # print(clip.shape)
+ clip_center_crop = center_crop_using_short_edge(clip)
+ # print(clip_center_crop.shape) 320 512
+ clip_center_crop_resize = resize(clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode)
+ return clip_center_crop_resize
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
+
+class WebVideo320512:
+ def __init__(
+ self,
+ size,
+ interpolation_mode="bilinear",
+ ):
+ if isinstance(size, tuple):
+ if len(size) != 2:
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
+ self.size = size
+ else:
+ self.size = (size, size)
+
+ self.interpolation_mode = interpolation_mode
+
+
+ def __call__(self, clip):
+ """
+ Args:
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
+ Returns:
+ torch.tensor: scale resized / center cropped video clip.
+ size is (T, C, crop_size, crop_size)
+ """
+ # add aditional one pixel for avoiding error in center crop
+ h, w = clip.size(-2), clip.size(-1)
+ # print('before resize', clip.shape)
+ if h < 320:
+ clip = resize_scale_with_height(clip=clip, target_size=321, interpolation_mode=self.interpolation_mode)
+ # print('after h resize', clip.shape)
+ if w < 512:
+ clip = resize_scale_with_weight(clip=clip, target_size=513, interpolation_mode=self.interpolation_mode)
+ # print('after w resize', clip.shape)
+ clip_center_crop = center_crop(clip, self.size)
+ # print(clip_center_crop.shape)
+ return clip_center_crop
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
+
+class UCFCenterCropVideo:
+ '''
+ First scale to the specified size in equal proportion to the short edge,
+ then center cropping
+ '''
+ def __init__(
+ self,
+ size,
+ interpolation_mode="bilinear",
+ ):
+ if isinstance(size, tuple):
+ if len(size) != 2:
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
+ self.size = size
+ else:
+ self.size = (size, size)
+
+ self.interpolation_mode = interpolation_mode
+
+
+ def __call__(self, clip):
+ """
+ Args:
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
+ Returns:
+ torch.tensor: scale resized / center cropped video clip.
+ size is (T, C, crop_size, crop_size)
+ """
+ clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
+ clip_center_crop = center_crop(clip_resize, self.size)
+ return clip_center_crop
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
+
+
+class CenterCropVideo:
+ def __init__(
+ self,
+ size,
+ interpolation_mode="bilinear",
+ ):
+ if isinstance(size, tuple):
+ if len(size) != 2:
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
+ self.size = size
+ else:
+ self.size = (size, size)
+
+ self.interpolation_mode = interpolation_mode
+
+
+ def __call__(self, clip):
+ """
+ Args:
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
+ Returns:
+ torch.tensor: center cropped video clip.
+ size is (T, C, crop_size, crop_size)
+ """
+ clip_center_crop = center_crop(clip, self.size)
+ return clip_center_crop
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
+
+
+class NormalizeVideo:
+ """
+ Normalize the video clip by mean subtraction and division by standard deviation
+ Args:
+ mean (3-tuple): pixel RGB mean
+ std (3-tuple): pixel RGB standard deviation
+ inplace (boolean): whether do in-place normalization
+ """
+
+ def __init__(self, mean, std, inplace=False):
+ self.mean = mean
+ self.std = std
+ self.inplace = inplace
+
+ def __call__(self, clip):
+ """
+ Args:
+ clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
+ """
+ return normalize(clip, self.mean, self.std, self.inplace)
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
+
+
+class ToTensorVideo:
+ """
+ Convert tensor data type from uint8 to float, divide value by 255.0 and
+ permute the dimensions of clip tensor
+ """
+
+ def __init__(self):
+ pass
+
+ def __call__(self, clip):
+ """
+ Args:
+ clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
+ Return:
+ clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
+ """
+ return to_tensor(clip)
+
+ def __repr__(self) -> str:
+ return self.__class__.__name__
+
+
+class ResizeVideo():
+ '''
+ First use the short side for cropping length,
+ center crop video, then resize to the specified size
+ '''
+ def __init__(
+ self,
+ size,
+ interpolation_mode="bilinear",
+ ):
+ if isinstance(size, tuple):
+ if len(size) != 2:
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
+ self.size = size
+ else:
+ self.size = (size, size)
+
+ self.interpolation_mode = interpolation_mode
+
+
+ def __call__(self, clip):
+ """
+ Args:
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
+ Returns:
+ torch.tensor: scale resized / center cropped video clip.
+ size is (T, C, crop_size, crop_size)
+ """
+ clip_resize = resize(clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
+ return clip_resize
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
+
+# ------------------------------------------------------------
+# --------------------- Sampling ---------------------------
+# ------------------------------------------------------------
+class TemporalRandomCrop(object):
+ """Temporally crop the given frame indices at a random location.
+
+ Args:
+ size (int): Desired length of frames will be seen in the model.
+ """
+
+ def __init__(self, size):
+ self.size = size
+
+ def __call__(self, total_frames):
+ rand_end = max(0, total_frames - self.size - 1)
+ begin_index = random.randint(0, rand_end)
+ end_index = min(begin_index + self.size, total_frames)
+ return begin_index, end_index
diff --git a/diffusion/__init__.py b/diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9dbf6cf0bd6b9d1a8f65e0a31e9a84cacc03189
--- /dev/null
+++ b/diffusion/__init__.py
@@ -0,0 +1,47 @@
+# Modified from OpenAI's diffusion repos
+# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
+# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
+# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+
+from . import gaussian_diffusion as gd
+from .respace import SpacedDiffusion, space_timesteps
+
+
+def create_diffusion(
+ timestep_respacing,
+ noise_schedule="linear",
+ use_kl=False,
+ sigma_small=False,
+ predict_xstart=False,
+ # learn_sigma=True,
+ learn_sigma=False, # for unet
+ rescale_learned_sigmas=False,
+ diffusion_steps=1000
+):
+ betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
+ if use_kl:
+ loss_type = gd.LossType.RESCALED_KL
+ elif rescale_learned_sigmas:
+ loss_type = gd.LossType.RESCALED_MSE
+ else:
+ loss_type = gd.LossType.MSE
+ if timestep_respacing is None or timestep_respacing == "":
+ timestep_respacing = [diffusion_steps]
+ return SpacedDiffusion(
+ use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
+ betas=betas,
+ model_mean_type=(
+ gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
+ ),
+ model_var_type=(
+ (
+ gd.ModelVarType.FIXED_LARGE
+ if not sigma_small
+ else gd.ModelVarType.FIXED_SMALL
+ )
+ if not learn_sigma
+ else gd.ModelVarType.LEARNED_RANGE
+ ),
+ loss_type=loss_type
+ # rescale_timesteps=rescale_timesteps,
+ )
diff --git a/diffusion/__pycache__/__init__.cpython-310.pyc b/diffusion/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fbe0bf5ae45331b0b01b6e2be8d6bf98856616b9
Binary files /dev/null and b/diffusion/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusion/__pycache__/__init__.cpython-311.pyc b/diffusion/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..71cfaa457c807a557188101f385fbd4897dd5e7d
Binary files /dev/null and b/diffusion/__pycache__/__init__.cpython-311.pyc differ
diff --git a/diffusion/__pycache__/__init__.cpython-38.pyc b/diffusion/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a48b1fe9b85aa42792e946f74fae965d2f650f8d
Binary files /dev/null and b/diffusion/__pycache__/__init__.cpython-38.pyc differ
diff --git a/diffusion/__pycache__/__init__.cpython-39.pyc b/diffusion/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..22042b311c16bebb698b2bcea485fb1d8a764288
Binary files /dev/null and b/diffusion/__pycache__/__init__.cpython-39.pyc differ
diff --git a/diffusion/__pycache__/diffusion_utils.cpython-310.pyc b/diffusion/__pycache__/diffusion_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f767c0cec4a07a8fc7514932f33cb098d0d50b06
Binary files /dev/null and b/diffusion/__pycache__/diffusion_utils.cpython-310.pyc differ
diff --git a/diffusion/__pycache__/diffusion_utils.cpython-311.pyc b/diffusion/__pycache__/diffusion_utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9a027d079ff091e66ede1025bdd7f0b55c3ada2e
Binary files /dev/null and b/diffusion/__pycache__/diffusion_utils.cpython-311.pyc differ
diff --git a/diffusion/__pycache__/diffusion_utils.cpython-38.pyc b/diffusion/__pycache__/diffusion_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9a34e84da7fa4e0d9308261b7d36086bc937e12c
Binary files /dev/null and b/diffusion/__pycache__/diffusion_utils.cpython-38.pyc differ
diff --git a/diffusion/__pycache__/diffusion_utils.cpython-39.pyc b/diffusion/__pycache__/diffusion_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..91a512ba1145dcdf80d9657678b5863086d05530
Binary files /dev/null and b/diffusion/__pycache__/diffusion_utils.cpython-39.pyc differ
diff --git a/diffusion/__pycache__/gaussian_diffusion.cpython-310.pyc b/diffusion/__pycache__/gaussian_diffusion.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..498d9cef36b220c5a5e114f669526f63c9bd8c24
Binary files /dev/null and b/diffusion/__pycache__/gaussian_diffusion.cpython-310.pyc differ
diff --git a/diffusion/__pycache__/gaussian_diffusion.cpython-311.pyc b/diffusion/__pycache__/gaussian_diffusion.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..15b8678feb0ae06f08579a74a4f4a3a8c89ce9f4
Binary files /dev/null and b/diffusion/__pycache__/gaussian_diffusion.cpython-311.pyc differ
diff --git a/diffusion/__pycache__/gaussian_diffusion.cpython-38.pyc b/diffusion/__pycache__/gaussian_diffusion.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9001e55c9a831fe3d90d37eddd22ff76638d8a91
Binary files /dev/null and b/diffusion/__pycache__/gaussian_diffusion.cpython-38.pyc differ
diff --git a/diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc b/diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a2b448aed2786f0987d1dece996873d25f4b513b
Binary files /dev/null and b/diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc differ
diff --git a/diffusion/__pycache__/respace.cpython-310.pyc b/diffusion/__pycache__/respace.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4dff3f1610d0085135fe532a76001f486031a8bf
Binary files /dev/null and b/diffusion/__pycache__/respace.cpython-310.pyc differ
diff --git a/diffusion/__pycache__/respace.cpython-311.pyc b/diffusion/__pycache__/respace.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..99072dc324cfa6cc9cb0b4ad6f4ca2e388a0a6a2
Binary files /dev/null and b/diffusion/__pycache__/respace.cpython-311.pyc differ
diff --git a/diffusion/__pycache__/respace.cpython-38.pyc b/diffusion/__pycache__/respace.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3fd2fd7fc04c8c2e92cb1c1be53610014f606bf6
Binary files /dev/null and b/diffusion/__pycache__/respace.cpython-38.pyc differ
diff --git a/diffusion/__pycache__/respace.cpython-39.pyc b/diffusion/__pycache__/respace.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..901b269c5cd95268c3d3d06ff9a4a7835dbd8776
Binary files /dev/null and b/diffusion/__pycache__/respace.cpython-39.pyc differ
diff --git a/diffusion/diffusion_utils.py b/diffusion/diffusion_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e493a6a3ecb91e553a53cc7eadee5cc0d1753060
--- /dev/null
+++ b/diffusion/diffusion_utils.py
@@ -0,0 +1,88 @@
+# Modified from OpenAI's diffusion repos
+# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
+# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
+# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+
+import torch as th
+import numpy as np
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ Compute the KL divergence between two gaussians.
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, th.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for th.exp().
+ logvar1, logvar2 = [
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
+ for x in (logvar1, logvar2)
+ ]
+
+ return 0.5 * (
+ -1.0
+ + logvar2
+ - logvar1
+ + th.exp(logvar1 - logvar2)
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
+ )
+
+
+def approx_standard_normal_cdf(x):
+ """
+ A fast approximation of the cumulative distribution function of the
+ standard normal.
+ """
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
+
+
+def continuous_gaussian_log_likelihood(x, *, means, log_scales):
+ """
+ Compute the log-likelihood of a continuous Gaussian distribution.
+ :param x: the targets
+ :param means: the Gaussian mean Tensor.
+ :param log_scales: the Gaussian log stddev Tensor.
+ :return: a tensor like x of log probabilities (in nats).
+ """
+ centered_x = x - means
+ inv_stdv = th.exp(-log_scales)
+ normalized_x = centered_x * inv_stdv
+ log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
+ return log_probs
+
+
+def discretized_gaussian_log_likelihood(x, *, means, log_scales):
+ """
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
+ given image.
+ :param x: the target images. It is assumed that this was uint8 values,
+ rescaled to the range [-1, 1].
+ :param means: the Gaussian mean Tensor.
+ :param log_scales: the Gaussian log stddev Tensor.
+ :return: a tensor like x of log probabilities (in nats).
+ """
+ assert x.shape == means.shape == log_scales.shape
+ centered_x = x - means
+ inv_stdv = th.exp(-log_scales)
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
+ cdf_plus = approx_standard_normal_cdf(plus_in)
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
+ cdf_min = approx_standard_normal_cdf(min_in)
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
+ cdf_delta = cdf_plus - cdf_min
+ log_probs = th.where(
+ x < -0.999,
+ log_cdf_plus,
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
+ )
+ assert log_probs.shape == x.shape
+ return log_probs
diff --git a/diffusion/gaussian_diffusion.py b/diffusion/gaussian_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe1b571e088fa245a072d2bb4320eea3d240df02
--- /dev/null
+++ b/diffusion/gaussian_diffusion.py
@@ -0,0 +1,931 @@
+# Modified from OpenAI's diffusion repos
+# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
+# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
+# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+
+
+import math
+
+import numpy as np
+import torch as th
+import enum
+
+from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+class ModelMeanType(enum.Enum):
+ """
+ Which type of output the model predicts.
+ """
+
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
+ START_X = enum.auto() # the model predicts x_0
+ EPSILON = enum.auto() # the model predicts epsilon
+
+
+class ModelVarType(enum.Enum):
+ """
+ What is used as the model's output variance.
+ The LEARNED_RANGE option has been added to allow the model to predict
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
+ """
+
+ LEARNED = enum.auto()
+ FIXED_SMALL = enum.auto()
+ FIXED_LARGE = enum.auto()
+ LEARNED_RANGE = enum.auto()
+
+
+class LossType(enum.Enum):
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
+ RESCALED_MSE = (
+ enum.auto()
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
+ KL = enum.auto() # use the variational lower-bound
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
+
+ def is_vb(self):
+ return self == LossType.KL or self == LossType.RESCALED_KL
+
+
+def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
+ warmup_time = int(num_diffusion_timesteps * warmup_frac)
+ betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
+ return betas
+
+
+def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
+ """
+ This is the deprecated API for creating beta schedules.
+ See get_named_beta_schedule() for the new library of schedules.
+ """
+ if beta_schedule == "quad":
+ betas = (
+ np.linspace(
+ beta_start ** 0.5,
+ beta_end ** 0.5,
+ num_diffusion_timesteps,
+ dtype=np.float64,
+ )
+ ** 2
+ )
+ elif beta_schedule == "linear":
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
+ elif beta_schedule == "warmup10":
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
+ elif beta_schedule == "warmup50":
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
+ elif beta_schedule == "const":
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
+ elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
+ betas = 1.0 / np.linspace(
+ num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
+ )
+ else:
+ raise NotImplementedError(beta_schedule)
+ assert betas.shape == (num_diffusion_timesteps,)
+ return betas
+
+
+def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
+ """
+ Get a pre-defined beta schedule for the given name.
+ The beta schedule library consists of beta schedules which remain similar
+ in the limit of num_diffusion_timesteps.
+ Beta schedules may be added, but should not be removed or changed once
+ they are committed to maintain backwards compatibility.
+ """
+ if schedule_name == "linear":
+ # Linear schedule from Ho et al, extended to work for any number of
+ # diffusion steps.
+ scale = 1000 / num_diffusion_timesteps
+ return get_beta_schedule(
+ "linear",
+ beta_start=scale * 0.0001,
+ beta_end=scale * 0.02,
+ # diffuser stable diffusion
+ # beta_start=scale * 0.00085,
+ # beta_end=scale * 0.012,
+ num_diffusion_timesteps=num_diffusion_timesteps,
+ )
+ elif schedule_name == "squaredcos_cap_v2":
+ return betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
+ )
+ else:
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function,
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
+ :param num_diffusion_timesteps: the number of betas to produce.
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+ produces the cumulative product of (1-beta) up to that
+ part of the diffusion process.
+ :param max_beta: the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ """
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas)
+
+
+class GaussianDiffusion:
+ """
+ Utilities for training and sampling diffusion models.
+ Original ported from this codebase:
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
+ starting at T and going to 1.
+ """
+
+ def __init__(
+ self,
+ *,
+ betas,
+ model_mean_type,
+ model_var_type,
+ loss_type
+ ):
+
+ self.model_mean_type = model_mean_type
+ self.model_var_type = model_var_type
+ self.loss_type = loss_type
+
+ # Use float64 for accuracy.
+ betas = np.array(betas, dtype=np.float64)
+ self.betas = betas
+ assert len(betas.shape) == 1, "betas must be 1-D"
+ assert (betas > 0).all() and (betas <= 1).all()
+
+ self.num_timesteps = int(betas.shape[0])
+
+ alphas = 1.0 - betas
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ self.posterior_variance = (
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
+ )
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+ self.posterior_log_variance_clipped = np.log(
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
+ ) if len(self.posterior_variance) > 1 else np.array([])
+
+ self.posterior_mean_coef1 = (
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
+ )
+ self.posterior_mean_coef2 = (
+ (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
+ )
+
+ def q_mean_variance(self, x_start, t):
+ """
+ Get the distribution q(x_t | x_0).
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+ """
+ mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
+ log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
+ return mean, variance, log_variance
+
+ def q_sample(self, x_start, t, noise=None):
+ """
+ Diffuse the data for a given number of diffusion steps.
+ In other words, sample from q(x_t | x_0).
+ :param x_start: the initial data batch.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :param noise: if specified, the split-out normal noise.
+ :return: A noisy version of x_start.
+ """
+ if noise is None:
+ noise = th.randn_like(x_start)
+ assert noise.shape == x_start.shape
+ return (
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
+ )
+
+ def q_posterior_mean_variance(self, x_start, x_t, t):
+ """
+ Compute the mean and variance of the diffusion posterior:
+ q(x_{t-1} | x_t, x_0)
+ """
+ assert x_start.shape == x_t.shape
+ posterior_mean = (
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ )
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
+ posterior_log_variance_clipped = _extract_into_tensor(
+ self.posterior_log_variance_clipped, t, x_t.shape
+ )
+ assert (
+ posterior_mean.shape[0]
+ == posterior_variance.shape[0]
+ == posterior_log_variance_clipped.shape[0]
+ == x_start.shape[0]
+ )
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None,
+ mask=None, x_start=None, use_concat=False):
+ """
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
+ the initial x, x_0.
+ :param model: the model, which takes a signal and a batch of timesteps
+ as input.
+ :param x: the [N x C x ...] tensor at time t.
+ :param t: a 1-D Tensor of timesteps.
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample. Applies before
+ clip_denoised.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict with the following keys:
+ - 'mean': the model mean output.
+ - 'variance': the model variance output.
+ - 'log_variance': the log of 'variance'.
+ - 'pred_xstart': the prediction for x_0.
+ """
+ if model_kwargs is None:
+ model_kwargs = {}
+
+ B, F, C = x.shape[:3]
+ assert t.shape == (B,)
+ if use_concat:
+ model_output = model(th.concat([x, mask, x_start], dim=1), t, **model_kwargs)
+ else:
+ model_output = model(x, t, **model_kwargs)
+ try:
+ model_output = model_output.sample # for tav unet
+ except:
+ pass
+ # model_output = model(x, t, **model_kwargs)
+ if isinstance(model_output, tuple):
+ model_output, extra = model_output
+ else:
+ extra = None
+
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
+ assert model_output.shape == (B, F, C * 2, *x.shape[3:])
+ model_output, model_var_values = th.split(model_output, C, dim=2)
+ min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
+ # The model_var_values is [-1, 1] for [min_var, max_var].
+ frac = (model_var_values + 1) / 2
+ model_log_variance = frac * max_log + (1 - frac) * min_log
+ model_variance = th.exp(model_log_variance)
+ else:
+ model_variance, model_log_variance = {
+ # for fixedlarge, we set the initial (log-)variance like so
+ # to get a better decoder log likelihood.
+ ModelVarType.FIXED_LARGE: (
+ np.append(self.posterior_variance[1], self.betas[1:]),
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
+ ),
+ ModelVarType.FIXED_SMALL: (
+ self.posterior_variance,
+ self.posterior_log_variance_clipped,
+ ),
+ }[self.model_var_type]
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
+
+ def process_xstart(x):
+ if denoised_fn is not None:
+ x = denoised_fn(x)
+ if clip_denoised:
+ return x.clamp(-1, 1)
+ return x
+
+ if self.model_mean_type == ModelMeanType.START_X:
+ pred_xstart = process_xstart(model_output)
+ else:
+ pred_xstart = process_xstart(
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
+ )
+ model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
+
+ assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
+ return {
+ "mean": model_mean,
+ "variance": model_variance,
+ "log_variance": model_log_variance,
+ "pred_xstart": pred_xstart,
+ "extra": extra,
+ }
+
+ def _predict_xstart_from_eps(self, x_t, t, eps):
+ assert x_t.shape == eps.shape
+ return (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
+ )
+
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
+ return (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
+ """
+ Compute the mean for the previous step, given a function cond_fn that
+ computes the gradient of a conditional log probability with respect to
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
+ condition on y.
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
+ """
+ gradient = cond_fn(x, t, **model_kwargs)
+ new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
+ return new_mean
+
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
+ """
+ Compute what the p_mean_variance output would have been, should the
+ model's score function be conditioned by cond_fn.
+ See condition_mean() for details on cond_fn.
+ Unlike condition_mean(), this instead uses the conditioning strategy
+ from Song et al (2020).
+ """
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
+
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
+
+ out = p_mean_var.copy()
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
+ out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
+ return out
+
+ def p_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ mask=None,
+ x_start=None,
+ use_concat=False
+ ):
+ """
+ Sample x_{t-1} from the model at the given timestep.
+ :param model: the model to sample from.
+ :param x: the current tensor at x_{t-1}.
+ :param t: the value of t, starting at 0 for the first diffusion step.
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample.
+ :param cond_fn: if not None, this is a gradient function that acts
+ similarly to the model.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict containing the following keys:
+ - 'sample': a random sample from the model.
+ - 'pred_xstart': a prediction of x_0.
+ """
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ mask=mask,
+ x_start=x_start,
+ use_concat=use_concat
+ )
+ noise = th.randn_like(x)
+ nonzero_mask = (
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
+ ) # no noise when t == 0
+ if cond_fn is not None:
+ out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
+
+ def p_sample_loop(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ mask=None,
+ x_start=None,
+ use_concat=False,
+ ):
+ """
+ Generate samples from the model.
+ :param model: the model module.
+ :param shape: the shape of the samples, (N, C, H, W).
+ :param noise: if specified, the noise from the encoder to sample.
+ Should be of the same shape as `shape`.
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample.
+ :param cond_fn: if not None, this is a gradient function that acts
+ similarly to the model.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :param device: if specified, the device to create the samples on.
+ If not specified, use a model parameter's device.
+ :param progress: if True, show a tqdm progress bar.
+ :return: a non-differentiable batch of samples.
+ """
+ final = None
+ for sample in self.p_sample_loop_progressive(
+ model,
+ shape,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ device=device,
+ progress=progress,
+ mask=mask,
+ x_start=x_start,
+ use_concat=use_concat
+ ):
+ final = sample
+ return final["sample"]
+
+ def p_sample_loop_progressive(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ mask=None,
+ x_start=None,
+ use_concat=False
+ ):
+ """
+ Generate samples from the model and yield intermediate samples from
+ each timestep of diffusion.
+ Arguments are the same as p_sample_loop().
+ Returns a generator over dicts, where each dict is the return value of
+ p_sample().
+ """
+ if device is None:
+ device = next(model.parameters()).device
+ assert isinstance(shape, (tuple, list))
+ if noise is not None:
+ img = noise
+ else:
+ img = th.randn(*shape, device=device)
+ indices = list(range(self.num_timesteps))[::-1]
+
+ if progress:
+ # Lazy import so that we don't depend on tqdm.
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for i in indices:
+ t = th.tensor([i] * shape[0], device=device)
+ with th.no_grad():
+ out = self.p_sample(
+ model,
+ img,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ mask=mask,
+ x_start=x_start,
+ use_concat=use_concat
+ )
+ yield out
+ img = out["sample"]
+
+ def ddim_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ eta=0.0,
+ mask=None,
+ x_start=None,
+ use_concat=False
+ ):
+ """
+ Sample x_{t-1} from the model using DDIM.
+ Same usage as p_sample().
+ """
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ mask=mask,
+ x_start=x_start,
+ use_concat=use_concat
+ )
+ if cond_fn is not None:
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
+
+ # Usually our model outputs epsilon, but we re-derive it
+ # in case we used x_start or x_prev prediction.
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
+
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
+ sigma = (
+ eta
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
+ )
+ # Equation 12.
+ noise = th.randn_like(x)
+ mean_pred = (
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
+ )
+ nonzero_mask = (
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
+ ) # no noise when t == 0
+ sample = mean_pred + nonzero_mask * sigma * noise
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
+
+ def ddim_reverse_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ eta=0.0,
+ ):
+ """
+ Sample x_{t+1} from the model using DDIM reverse ODE.
+ """
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ if cond_fn is not None:
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
+ # Usually our model outputs epsilon, but we re-derive it
+ # in case we used x_start or x_prev prediction.
+ eps = (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
+ - out["pred_xstart"]
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
+
+ # Equation 12. reversed
+ mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
+
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
+
+ def ddim_sample_loop(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ eta=0.0,
+ mask=None,
+ x_start=None,
+ use_concat=False
+ ):
+ """
+ Generate samples from the model using DDIM.
+ Same usage as p_sample_loop().
+ """
+ final = None
+ for sample in self.ddim_sample_loop_progressive(
+ model,
+ shape,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ device=device,
+ progress=progress,
+ eta=eta,
+ mask=mask,
+ x_start=x_start,
+ use_concat=use_concat
+ ):
+ final = sample
+ return final["sample"]
+
+ def ddim_sample_loop_progressive(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ eta=0.0,
+ mask=None,
+ x_start=None,
+ use_concat=False
+ ):
+ """
+ Use DDIM to sample from the model and yield intermediate samples from
+ each timestep of DDIM.
+ Same usage as p_sample_loop_progressive().
+ """
+ if device is None:
+ device = next(model.parameters()).device
+ assert isinstance(shape, (tuple, list))
+ if noise is not None:
+ img = noise
+ else:
+ img = th.randn(*shape, device=device)
+ indices = list(range(self.num_timesteps))[::-1]
+
+ if progress:
+ # Lazy import so that we don't depend on tqdm.
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for i in indices:
+ t = th.tensor([i] * shape[0], device=device)
+ with th.no_grad():
+ out = self.ddim_sample(
+ model,
+ img,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ eta=eta,
+ mask=mask,
+ x_start=x_start,
+ use_concat=use_concat
+ )
+ yield out
+ img = out["sample"]
+
+ def _vb_terms_bpd(
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
+ ):
+ """
+ Get a term for the variational lower-bound.
+ The resulting units are bits (rather than nats, as one might expect).
+ This allows for comparison to other papers.
+ :return: a dict with the following keys:
+ - 'output': a shape [N] tensor of NLLs or KLs.
+ - 'pred_xstart': the x_0 predictions.
+ """
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
+ x_start=x_start, x_t=x_t, t=t
+ )
+ out = self.p_mean_variance(
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
+ )
+ kl = normal_kl(
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
+ )
+ kl = mean_flat(kl) / np.log(2.0)
+
+ decoder_nll = -discretized_gaussian_log_likelihood(
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
+ )
+ assert decoder_nll.shape == x_start.shape
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
+
+ # At the first timestep return the decoder NLL,
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
+ output = th.where((t == 0), decoder_nll, kl)
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
+
+ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, use_mask=False):
+ """
+ Compute training losses for a single timestep.
+ :param model: the model to evaluate loss on.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :param t: a batch of timestep indices.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :param noise: if specified, the specific Gaussian noise to try to remove.
+ :return: a dict with the key "loss" containing a tensor of shape [N].
+ Some mean or variance settings may also have other keys.
+ """
+ if model_kwargs is None:
+ model_kwargs = {}
+ if noise is None:
+ noise = th.randn_like(x_start)
+ x_t = self.q_sample(x_start, t, noise=noise)
+ if use_mask:
+ x_t = th.cat([x_t[:, :4], x_start[:, 4:]], dim=1)
+ terms = {}
+
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
+ terms["loss"] = self._vb_terms_bpd(
+ model=model,
+ x_start=x_start,
+ x_t=x_t,
+ t=t,
+ clip_denoised=False,
+ model_kwargs=model_kwargs,
+ )["output"]
+ if self.loss_type == LossType.RESCALED_KL:
+ terms["loss"] *= self.num_timesteps
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
+ model_output = model(x_t, t, **model_kwargs)
+ try:
+ # model_output = model(x_t, t, **model_kwargs).sample
+ model_output = model_output.sample # for tav unet
+ except:
+ pass
+ # model_output = model(x_t, t, **model_kwargs)
+
+ if self.model_var_type in [
+ ModelVarType.LEARNED,
+ ModelVarType.LEARNED_RANGE,
+ ]:
+ B, F, C = x_t.shape[:3]
+ assert model_output.shape == (B, F, C * 2, *x_t.shape[3:])
+ model_output, model_var_values = th.split(model_output, C, dim=2)
+ # Learn the variance using the variational bound, but don't let
+ # it affect our mean prediction.
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=2)
+ terms["vb"] = self._vb_terms_bpd(
+ model=lambda *args, r=frozen_out: r,
+ x_start=x_start,
+ x_t=x_t,
+ t=t,
+ clip_denoised=False,
+ )["output"]
+ if self.loss_type == LossType.RESCALED_MSE:
+ # Divide by 1000 for equivalence with initial implementation.
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
+ terms["vb"] *= self.num_timesteps / 1000.0
+
+ target = {
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
+ x_start=x_start, x_t=x_t, t=t
+ )[0],
+ ModelMeanType.START_X: x_start,
+ ModelMeanType.EPSILON: noise,
+ }[self.model_mean_type]
+ # assert model_output.shape == target.shape == x_start.shape
+ if use_mask:
+ terms["mse"] = mean_flat((target[:,:4] - model_output) ** 2)
+ else:
+ terms["mse"] = mean_flat((target - model_output) ** 2)
+ if "vb" in terms:
+ terms["loss"] = terms["mse"] + terms["vb"]
+ else:
+ terms["loss"] = terms["mse"]
+ else:
+ raise NotImplementedError(self.loss_type)
+
+ return terms
+
+ def _prior_bpd(self, x_start):
+ """
+ Get the prior KL term for the variational lower-bound, measured in
+ bits-per-dim.
+ This term can't be optimized, as it only depends on the encoder.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :return: a batch of [N] KL values (in bits), one per batch element.
+ """
+ batch_size = x_start.shape[0]
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
+ kl_prior = normal_kl(
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
+ )
+ return mean_flat(kl_prior) / np.log(2.0)
+
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
+ """
+ Compute the entire variational lower-bound, measured in bits-per-dim,
+ as well as other related quantities.
+ :param model: the model to evaluate loss on.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :param clip_denoised: if True, clip denoised samples.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict containing the following keys:
+ - total_bpd: the total variational lower-bound, per batch element.
+ - prior_bpd: the prior term in the lower-bound.
+ - vb: an [N x T] tensor of terms in the lower-bound.
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
+ """
+ device = x_start.device
+ batch_size = x_start.shape[0]
+
+ vb = []
+ xstart_mse = []
+ mse = []
+ for t in list(range(self.num_timesteps))[::-1]:
+ t_batch = th.tensor([t] * batch_size, device=device)
+ noise = th.randn_like(x_start)
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
+ # Calculate VLB term at the current timestep
+ with th.no_grad():
+ out = self._vb_terms_bpd(
+ model,
+ x_start=x_start,
+ x_t=x_t,
+ t=t_batch,
+ clip_denoised=clip_denoised,
+ model_kwargs=model_kwargs,
+ )
+ vb.append(out["output"])
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
+ mse.append(mean_flat((eps - noise) ** 2))
+
+ vb = th.stack(vb, dim=1)
+ xstart_mse = th.stack(xstart_mse, dim=1)
+ mse = th.stack(mse, dim=1)
+
+ prior_bpd = self._prior_bpd(x_start)
+ total_bpd = vb.sum(dim=1) + prior_bpd
+ return {
+ "total_bpd": total_bpd,
+ "prior_bpd": prior_bpd,
+ "vb": vb,
+ "xstart_mse": xstart_mse,
+ "mse": mse,
+ }
+
+
+def _extract_into_tensor(arr, timesteps, broadcast_shape):
+ """
+ Extract values from a 1-D numpy array for a batch of indices.
+ :param arr: the 1-D numpy array.
+ :param timesteps: a tensor of indices into the array to extract.
+ :param broadcast_shape: a larger shape of K dimensions with the batch
+ dimension equal to the length of timesteps.
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
+ """
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
+ while len(res.shape) < len(broadcast_shape):
+ res = res[..., None]
+ return res + th.zeros(broadcast_shape, device=timesteps.device)
diff --git a/diffusion/respace.py b/diffusion/respace.py
new file mode 100644
index 0000000000000000000000000000000000000000..31c0092380367db477f154a39bf172ff64712fc4
--- /dev/null
+++ b/diffusion/respace.py
@@ -0,0 +1,130 @@
+# Modified from OpenAI's diffusion repos
+# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
+# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
+# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+import torch
+import numpy as np
+import torch as th
+
+from .gaussian_diffusion import GaussianDiffusion
+
+
+def space_timesteps(num_timesteps, section_counts):
+ """
+ Create a list of timesteps to use from an original diffusion process,
+ given the number of timesteps we want to take from equally-sized portions
+ of the original process.
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
+ If the stride is a string starting with "ddim", then the fixed striding
+ from the DDIM paper is used, and only one section is allowed.
+ :param num_timesteps: the number of diffusion steps in the original
+ process to divide up.
+ :param section_counts: either a list of numbers, or a string containing
+ comma-separated numbers, indicating the step count
+ per section. As a special case, use "ddimN" where N
+ is a number of steps to use the striding from the
+ DDIM paper.
+ :return: a set of diffusion steps from the original process to use.
+ """
+ if isinstance(section_counts, str):
+ if section_counts.startswith("ddim"):
+ desired_count = int(section_counts[len("ddim") :])
+ for i in range(1, num_timesteps):
+ if len(range(0, num_timesteps, i)) == desired_count:
+ return set(range(0, num_timesteps, i))
+ raise ValueError(
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
+ )
+ section_counts = [int(x) for x in section_counts.split(",")]
+ size_per = num_timesteps // len(section_counts)
+ extra = num_timesteps % len(section_counts)
+ start_idx = 0
+ all_steps = []
+ for i, section_count in enumerate(section_counts):
+ size = size_per + (1 if i < extra else 0)
+ if size < section_count:
+ raise ValueError(
+ f"cannot divide section of {size} steps into {section_count}"
+ )
+ if section_count <= 1:
+ frac_stride = 1
+ else:
+ frac_stride = (size - 1) / (section_count - 1)
+ cur_idx = 0.0
+ taken_steps = []
+ for _ in range(section_count):
+ taken_steps.append(start_idx + round(cur_idx))
+ cur_idx += frac_stride
+ all_steps += taken_steps
+ start_idx += size
+ return set(all_steps)
+
+
+class SpacedDiffusion(GaussianDiffusion):
+ """
+ A diffusion process which can skip steps in a base diffusion process.
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
+ original diffusion process to retain.
+ :param kwargs: the kwargs to create the base diffusion process.
+ """
+
+ def __init__(self, use_timesteps, **kwargs):
+ self.use_timesteps = set(use_timesteps)
+ self.timestep_map = []
+ self.original_num_steps = len(kwargs["betas"])
+
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
+ last_alpha_cumprod = 1.0
+ new_betas = []
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
+ if i in self.use_timesteps:
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
+ last_alpha_cumprod = alpha_cumprod
+ self.timestep_map.append(i)
+ kwargs["betas"] = np.array(new_betas)
+ super().__init__(**kwargs)
+
+ def p_mean_variance(
+ self, model, *args, **kwargs
+ ): # pylint: disable=signature-differs
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
+
+ # @torch.compile
+ def training_losses(
+ self, model, *args, **kwargs
+ ): # pylint: disable=signature-differs
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
+
+ def condition_mean(self, cond_fn, *args, **kwargs):
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
+
+ def condition_score(self, cond_fn, *args, **kwargs):
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
+
+ def _wrap_model(self, model):
+ if isinstance(model, _WrappedModel):
+ return model
+ return _WrappedModel(
+ model, self.timestep_map, self.original_num_steps
+ )
+
+ def _scale_timesteps(self, t):
+ # Scaling is done by the wrapped model.
+ return t
+
+
+class _WrappedModel:
+ def __init__(self, model, timestep_map, original_num_steps):
+ self.model = model
+ self.timestep_map = timestep_map
+ # self.rescale_timesteps = rescale_timesteps
+ self.original_num_steps = original_num_steps
+
+ def __call__(self, x, ts, **kwargs):
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
+ new_ts = map_tensor[ts]
+ # if self.rescale_timesteps:
+ # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
+ return self.model(x, new_ts, **kwargs)
diff --git a/diffusion/timestep_sampler.py b/diffusion/timestep_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3f369847677d8dbaaadb8297691b1be92cf189f
--- /dev/null
+++ b/diffusion/timestep_sampler.py
@@ -0,0 +1,150 @@
+# Modified from OpenAI's diffusion repos
+# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
+# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
+# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+
+from abc import ABC, abstractmethod
+
+import numpy as np
+import torch as th
+import torch.distributed as dist
+
+
+def create_named_schedule_sampler(name, diffusion):
+ """
+ Create a ScheduleSampler from a library of pre-defined samplers.
+ :param name: the name of the sampler.
+ :param diffusion: the diffusion object to sample for.
+ """
+ if name == "uniform":
+ return UniformSampler(diffusion)
+ elif name == "loss-second-moment":
+ return LossSecondMomentResampler(diffusion)
+ else:
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
+
+
+class ScheduleSampler(ABC):
+ """
+ A distribution over timesteps in the diffusion process, intended to reduce
+ variance of the objective.
+ By default, samplers perform unbiased importance sampling, in which the
+ objective's mean is unchanged.
+ However, subclasses may override sample() to change how the resampled
+ terms are reweighted, allowing for actual changes in the objective.
+ """
+
+ @abstractmethod
+ def weights(self):
+ """
+ Get a numpy array of weights, one per diffusion step.
+ The weights needn't be normalized, but must be positive.
+ """
+
+ def sample(self, batch_size, device):
+ """
+ Importance-sample timesteps for a batch.
+ :param batch_size: the number of timesteps.
+ :param device: the torch device to save to.
+ :return: a tuple (timesteps, weights):
+ - timesteps: a tensor of timestep indices.
+ - weights: a tensor of weights to scale the resulting losses.
+ """
+ w = self.weights()
+ p = w / np.sum(w)
+ indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
+ indices = th.from_numpy(indices_np).long().to(device)
+ weights_np = 1 / (len(p) * p[indices_np])
+ weights = th.from_numpy(weights_np).float().to(device)
+ return indices, weights
+
+
+class UniformSampler(ScheduleSampler):
+ def __init__(self, diffusion):
+ self.diffusion = diffusion
+ self._weights = np.ones([diffusion.num_timesteps])
+
+ def weights(self):
+ return self._weights
+
+
+class LossAwareSampler(ScheduleSampler):
+ def update_with_local_losses(self, local_ts, local_losses):
+ """
+ Update the reweighting using losses from a model.
+ Call this method from each rank with a batch of timesteps and the
+ corresponding losses for each of those timesteps.
+ This method will perform synchronization to make sure all of the ranks
+ maintain the exact same reweighting.
+ :param local_ts: an integer Tensor of timesteps.
+ :param local_losses: a 1D Tensor of losses.
+ """
+ batch_sizes = [
+ th.tensor([0], dtype=th.int32, device=local_ts.device)
+ for _ in range(dist.get_world_size())
+ ]
+ dist.all_gather(
+ batch_sizes,
+ th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
+ )
+
+ # Pad all_gather batches to be the maximum batch size.
+ batch_sizes = [x.item() for x in batch_sizes]
+ max_bs = max(batch_sizes)
+
+ timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
+ loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
+ dist.all_gather(timestep_batches, local_ts)
+ dist.all_gather(loss_batches, local_losses)
+ timesteps = [
+ x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
+ ]
+ losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
+ self.update_with_all_losses(timesteps, losses)
+
+ @abstractmethod
+ def update_with_all_losses(self, ts, losses):
+ """
+ Update the reweighting using losses from a model.
+ Sub-classes should override this method to update the reweighting
+ using losses from the model.
+ This method directly updates the reweighting without synchronizing
+ between workers. It is called by update_with_local_losses from all
+ ranks with identical arguments. Thus, it should have deterministic
+ behavior to maintain state across workers.
+ :param ts: a list of int timesteps.
+ :param losses: a list of float losses, one per timestep.
+ """
+
+
+class LossSecondMomentResampler(LossAwareSampler):
+ def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
+ self.diffusion = diffusion
+ self.history_per_term = history_per_term
+ self.uniform_prob = uniform_prob
+ self._loss_history = np.zeros(
+ [diffusion.num_timesteps, history_per_term], dtype=np.float64
+ )
+ self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
+
+ def weights(self):
+ if not self._warmed_up():
+ return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
+ weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
+ weights /= np.sum(weights)
+ weights *= 1 - self.uniform_prob
+ weights += self.uniform_prob / len(weights)
+ return weights
+
+ def update_with_all_losses(self, ts, losses):
+ for t, loss in zip(ts, losses):
+ if self._loss_counts[t] == self.history_per_term:
+ # Shift out the oldest loss term.
+ self._loss_history[t, :-1] = self._loss_history[t, 1:]
+ self._loss_history[t, -1] = loss
+ else:
+ self._loss_history[t, self._loss_counts[t]] = loss
+ self._loss_counts[t] += 1
+
+ def _warmed_up(self):
+ return (self._loss_counts == self.history_per_term).all()
diff --git a/download.py b/download.py
new file mode 100644
index 0000000000000000000000000000000000000000..56602603b15868f2533ca0d083f61ee10f82e72f
--- /dev/null
+++ b/download.py
@@ -0,0 +1,44 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Functions for downloading pre-trained DiT models
+"""
+from torchvision.datasets.utils import download_url
+import torch
+import os
+
+
+
+def find_model(model_name):
+
+ checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage)
+
+ if "ema" in checkpoint: # supports checkpoints from train.py
+ print('Ema existing!')
+ checkpoint = checkpoint["ema"]
+ return checkpoint
+
+
+def download_model(model_name):
+ """
+ Downloads a pre-trained DiT model from the web.
+ """
+ assert model_name in pretrained_models
+ local_path = f'pretrained_models/{model_name}'
+ if not os.path.isfile(local_path):
+ os.makedirs('pretrained_models', exist_ok=True)
+ web_path = f'https://dl.fbaipublicfiles.com/DiT/models/{model_name}'
+ download_url(web_path, 'pretrained_models')
+ model = torch.load(local_path, map_location=lambda storage, loc: storage)
+ return model
+
+
+if __name__ == "__main__":
+ # Download all DiT checkpoints
+ for model in pretrained_models:
+ download_model(model)
+ print('Done.')
diff --git a/env.yaml b/env.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..50ed00de7e2555d5b73d953adaad552c81d36639
--- /dev/null
+++ b/env.yaml
@@ -0,0 +1,20 @@
+name: seine
+channels:
+ - pytorch
+ - nvidia
+ - conda-forge
+ - defaults
+dependencies:
+ - python=3.9.16
+ - pytorch=2.0.1
+ - pytorch-cuda=11.7
+ - torchvision=0.15.2
+ - pip
+ - pip:
+ - decord==0.6.0
+ - diffusers==0.15.0
+ - imageio==2.29.0
+ - transformers==4.29.2
+ - xformers==0.0.20
+ - einops
+ - omegaconf
diff --git a/huggingface-i2v/__init__.py b/huggingface-i2v/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/huggingface-i2v/requirements.txt b/huggingface-i2v/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/image_to_video/__init__.py b/image_to_video/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1aaa4fe07d2c1869909ed060f6b5545490d832b5
--- /dev/null
+++ b/image_to_video/__init__.py
@@ -0,0 +1,221 @@
+import os
+import sys
+import math
+import docx
+try:
+ import utils
+
+ from diffusion import create_diffusion
+ from download import find_model
+except:
+ # sys.path.append(os.getcwd())
+ sys.path.append(os.path.split(sys.path[0])[0])
+ # sys.path[0]
+ # os.path.split(sys.path[0])
+
+
+ import utils
+
+ from diffusion import create_diffusion
+ from download import find_model
+
+import torch
+torch.backends.cuda.matmul.allow_tf32 = True
+torch.backends.cudnn.allow_tf32 = True
+import argparse
+import torchvision
+
+from einops import rearrange
+from models import get_models
+from torchvision.utils import save_image
+from diffusers.models import AutoencoderKL
+from models.clip import TextEmbedder
+from omegaconf import OmegaConf
+from PIL import Image
+import numpy as np
+from torchvision import transforms
+sys.path.append("..")
+from datasets import video_transforms
+from utils import mask_generation_before
+from natsort import natsorted
+from diffusers.utils.import_utils import is_xformers_available
+
+config_path = "/mnt/petrelfs/zhouyan/project/i2v/configs/sample_i2v.yaml"
+args = OmegaConf.load(config_path)
+device = "cuda" if torch.cuda.is_available() else "cpu"
+print(args)
+
+def model_i2v_fun(args):
+ if args.seed:
+ torch.manual_seed(args.seed)
+ torch.set_grad_enabled(False)
+ if args.ckpt is None:
+ raise ValueError("Please specify a checkpoint path using --ckpt ")
+ latent_h = args.image_size[0] // 8
+ latent_w = args.image_size[1] // 8
+ args.image_h = args.image_size[0]
+ args.image_w = args.image_size[1]
+ args.latent_h = latent_h
+ args.latent_w = latent_w
+ print("loading model")
+ model = get_models(args).to(device)
+
+ if args.use_compile:
+ model = torch.compile(model)
+ ckpt_path = args.ckpt
+ state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)['ema']
+ model.load_state_dict(state_dict)
+
+ print('loading success')
+
+ model.eval()
+ pretrained_model_path = args.pretrained_model_path
+ diffusion = create_diffusion(str(args.num_sampling_steps))
+ vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device)
+ text_encoder = TextEmbedder(pretrained_model_path).to(device)
+ # if args.use_fp16:
+ # print('Warning: using half precision for inference')
+ # vae.to(dtype=torch.float16)
+ # model.to(dtype=torch.float16)
+ # text_encoder.to(dtype=torch.float16)
+
+ return vae, model, text_encoder, diffusion
+
+
+def auto_inpainting(args, video_input, masked_video, mask, prompt, vae, text_encoder, diffusion, model, device,):
+ b,f,c,h,w=video_input.shape
+ latent_h = args.image_size[0] // 8
+ latent_w = args.image_size[1] // 8
+
+ # prepare inputs
+ if args.use_fp16:
+ z = torch.randn(1, 4, args.num_frames, args.latent_h, args.latent_w, dtype=torch.float16, device=device) # b,c,f,h,w
+ masked_video = masked_video.to(dtype=torch.float16)
+ mask = mask.to(dtype=torch.float16)
+ else:
+ z = torch.randn(1, 4, args.num_frames, args.latent_h, args.latent_w, device=device) # b,c,f,h,w
+
+
+ masked_video = rearrange(masked_video, 'b f c h w -> (b f) c h w').contiguous()
+ masked_video = vae.encode(masked_video).latent_dist.sample().mul_(0.18215)
+ masked_video = rearrange(masked_video, '(b f) c h w -> b c f h w', b=b).contiguous()
+ mask = torch.nn.functional.interpolate(mask[:,:,0,:], size=(latent_h, latent_w)).unsqueeze(1)
+
+ # classifier_free_guidance
+ if args.do_classifier_free_guidance:
+ masked_video = torch.cat([masked_video] * 2)
+ mask = torch.cat([mask] * 2)
+ z = torch.cat([z] * 2)
+ prompt_all = [prompt] + [args.negative_prompt]
+
+ else:
+ masked_video = masked_video
+ mask = mask
+ z = z
+ prompt_all = [prompt]
+
+ text_prompt = text_encoder(text_prompts=prompt_all, train=False)
+ model_kwargs = dict(encoder_hidden_states=text_prompt,
+ class_labels=None,
+ cfg_scale=args.cfg_scale,
+ use_fp16=args.use_fp16,) # tav unet
+
+ # Sample images:
+ if args.sample_method == 'ddim':
+ samples = diffusion.ddim_sample_loop(
+ model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device, \
+ mask=mask, x_start=masked_video, use_concat=args.use_mask
+ )
+ elif args.sample_method == 'ddpm':
+ samples = diffusion.p_sample_loop(
+ model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device, \
+ mask=mask, x_start=masked_video, use_concat=args.use_mask
+ )
+ samples, _ = samples.chunk(2, dim=0) # [1, 4, 16, 32, 32]
+ if args.use_fp16:
+ samples = samples.to(dtype=torch.float16)
+
+ video_clip = samples[0].permute(1, 0, 2, 3).contiguous() # [16, 4, 32, 32]
+ video_clip = vae.decode(video_clip / 0.18215).sample # [16, 3, 256, 256]
+ return video_clip
+
+def get_input(path,args):
+ input_path = path
+ # input_path = args.input_path
+ transform_video = transforms.Compose([
+ video_transforms.ToTensorVideo(), # TCHW
+ video_transforms.ResizeVideo((args.image_h, args.image_w)),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
+ ])
+ temporal_sample_func = video_transforms.TemporalRandomCrop(args.num_frames * args.frame_interval)
+ if input_path is not None:
+ print(f'loading video from {input_path}')
+ if os.path.isdir(input_path):
+ file_list = os.listdir(input_path)
+ video_frames = []
+ if args.mask_type.startswith('onelast'):
+ num = int(args.mask_type.split('onelast')[-1])
+ # get first and last frame
+ first_frame_path = os.path.join(input_path, natsorted(file_list)[0])
+ last_frame_path = os.path.join(input_path, natsorted(file_list)[-1])
+ first_frame = torch.as_tensor(np.array(Image.open(first_frame_path), dtype=np.uint8, copy=True)).unsqueeze(0)
+ last_frame = torch.as_tensor(np.array(Image.open(last_frame_path), dtype=np.uint8, copy=True)).unsqueeze(0)
+ for i in range(num):
+ video_frames.append(first_frame)
+ # add zeros to frames
+ num_zeros = args.num_frames-2*num
+ for i in range(num_zeros):
+ zeros = torch.zeros_like(first_frame)
+ video_frames.append(zeros)
+ for i in range(num):
+ video_frames.append(last_frame)
+ n = 0
+ video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w
+ video_frames = transform_video(video_frames)
+ else:
+ for file in file_list:
+ if file.endswith('jpg') or file.endswith('png'):
+ image = torch.as_tensor(np.array(Image.open(os.path.join(input_path,file)), dtype=np.uint8, copy=True)).unsqueeze(0)
+ video_frames.append(image)
+ else:
+ continue
+ n = 0
+ video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w
+ video_frames = transform_video(video_frames)
+ return video_frames, n
+ elif os.path.isfile(input_path):
+ _, full_file_name = os.path.split(input_path)
+ file_name, extention = os.path.splitext(full_file_name)
+ if extention == '.jpg' or extention == '.png':
+ # raise TypeError('a single image is not supported yet!!')
+ print("reading video from a image")
+ video_frames = []
+ num = int(args.mask_type.split('first')[-1])
+ first_frame = torch.as_tensor(np.array(Image.open(input_path), dtype=np.uint8, copy=True)).unsqueeze(0)
+ for i in range(num):
+ video_frames.append(first_frame)
+ num_zeros = args.num_frames-num
+ for i in range(num_zeros):
+ zeros = torch.zeros_like(first_frame)
+ video_frames.append(zeros)
+ n = 0
+ video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w
+ video_frames = transform_video(video_frames)
+ return video_frames, n
+ else:
+ raise TypeError(f'{extention} is not supported !!')
+ else:
+ raise ValueError('Please check your path input!!')
+ else:
+ # raise ValueError('Need to give a video or some images')
+ print('given video is None, using text to video')
+ video_frames = torch.zeros(16,3,args.latent_h,args.latent_w,dtype=torch.uint8)
+ args.mask_type = 'all'
+ video_frames = transform_video(video_frames)
+ n = 0
+ return video_frames, n
+
+def setup_seed(seed):
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
\ No newline at end of file
diff --git a/image_to_video/__pycache__/__init__.cpython-311.pyc b/image_to_video/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..46430872ea451064ec2f7ed3af75ef5b30bde4eb
Binary files /dev/null and b/image_to_video/__pycache__/__init__.cpython-311.pyc differ
diff --git a/input/i2v/Close-up_essence_is_poured_from_bottleKodak_Vision.png b/input/i2v/Close-up_essence_is_poured_from_bottleKodak_Vision.png
new file mode 100644
index 0000000000000000000000000000000000000000..5b8b907e01d44e7384bd405fcc33738e03743c0a
Binary files /dev/null and b/input/i2v/Close-up_essence_is_poured_from_bottleKodak_Vision.png differ
diff --git a/input/i2v/The_picture_shows_the_beauty_of_the_sea.png b/input/i2v/The_picture_shows_the_beauty_of_the_sea.png
new file mode 100644
index 0000000000000000000000000000000000000000..6f5847286d1507fe4c66f218755eff2d5d154aca
Binary files /dev/null and b/input/i2v/The_picture_shows_the_beauty_of_the_sea.png differ
diff --git a/input/i2v/The_picture_shows_the_beauty_of_the_sea_and_at_the_same.png b/input/i2v/The_picture_shows_the_beauty_of_the_sea_and_at_the_same.png
new file mode 100644
index 0000000000000000000000000000000000000000..9fc7d9c55d24fd1c167ba02facb016014806bd60
Binary files /dev/null and b/input/i2v/The_picture_shows_the_beauty_of_the_sea_and_at_the_same.png differ
diff --git a/input/transition/1/1-Close-up shot of a blooming cherry tree, realism-1.png b/input/transition/1/1-Close-up shot of a blooming cherry tree, realism-1.png
new file mode 100644
index 0000000000000000000000000000000000000000..25f5a3408bf8d590347732bf0ee96d395aba4807
Binary files /dev/null and b/input/transition/1/1-Close-up shot of a blooming cherry tree, realism-1.png differ
diff --git a/input/transition/1/2-Wide angle shot of an alien planet with cherry blossom forest-2.png b/input/transition/1/2-Wide angle shot of an alien planet with cherry blossom forest-2.png
new file mode 100644
index 0000000000000000000000000000000000000000..d36bae5d3d0551aeacf73999769e8e87031a0540
--- /dev/null
+++ b/input/transition/1/2-Wide angle shot of an alien planet with cherry blossom forest-2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e98a00408112aee3d726be231a2c1f44c3ac58e3f2f54b4bded2e0a183f66529
+size 1005564
diff --git a/input/transition/2/1-Overhead view of a bustling city street at night, realism-1.png b/input/transition/2/1-Overhead view of a bustling city street at night, realism-1.png
new file mode 100644
index 0000000000000000000000000000000000000000..95e9e244c4d7422fefb14d6c1a5947e90a83c430
Binary files /dev/null and b/input/transition/2/1-Overhead view of a bustling city street at night, realism-1.png differ
diff --git a/input/transition/2/2-Aerial view of a futuristic city bathed in neon lights-2.png b/input/transition/2/2-Aerial view of a futuristic city bathed in neon lights-2.png
new file mode 100644
index 0000000000000000000000000000000000000000..27c3776798133990666d421e50c8b0139f514121
Binary files /dev/null and b/input/transition/2/2-Aerial view of a futuristic city bathed in neon lights-2.png differ
diff --git a/input/transition/3/1-Close-up shot of a candle lit in the darkness, realism-1png.png b/input/transition/3/1-Close-up shot of a candle lit in the darkness, realism-1png.png
new file mode 100644
index 0000000000000000000000000000000000000000..54bdd1da39818d5c1b2274d5662b7447bb4d900c
Binary files /dev/null and b/input/transition/3/1-Close-up shot of a candle lit in the darkness, realism-1png.png differ
diff --git a/input/transition/3/2-Wide shot of a mystical land illuminated by giant glowing flowers-2.png b/input/transition/3/2-Wide shot of a mystical land illuminated by giant glowing flowers-2.png
new file mode 100644
index 0000000000000000000000000000000000000000..17ba096583432c67b2011afcbec5e13fec259a96
Binary files /dev/null and b/input/transition/3/2-Wide shot of a mystical land illuminated by giant glowing flowers-2.png differ
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..69577f55b92072e09420e545df6d82c07212408f
--- /dev/null
+++ b/models/__init__.py
@@ -0,0 +1,54 @@
+import os
+import sys
+sys.path.append(os.path.split(sys.path[0])[0])
+
+from .dit import DiT_models
+from .uvit import UViT_models
+from .unet import UNet3DConditionModel
+from torch.optim.lr_scheduler import LambdaLR
+
+def customized_lr_scheduler(optimizer, warmup_steps=5000): # 5000 from u-vit
+ from torch.optim.lr_scheduler import LambdaLR
+ def fn(step):
+ if warmup_steps > 0:
+ return min(step / warmup_steps, 1)
+ else:
+ return 1
+ return LambdaLR(optimizer, fn)
+
+
+def get_lr_scheduler(optimizer, name, **kwargs):
+ if name == 'warmup':
+ return customized_lr_scheduler(optimizer, **kwargs)
+ elif name == 'cosine':
+ from torch.optim.lr_scheduler import CosineAnnealingLR
+ return CosineAnnealingLR(optimizer, **kwargs)
+ else:
+ raise NotImplementedError(name)
+
+def get_models(args):
+
+ if 'DiT' in args.model:
+ return DiT_models[args.model](
+ input_size=args.latent_size,
+ num_classes=args.num_classes,
+ class_guided=args.class_guided,
+ num_frames=args.num_frames,
+ use_lora=args.use_lora,
+ attention_mode=args.attention_mode
+ )
+ elif 'UViT' in args.model:
+ return UViT_models[args.model](
+ input_size=args.latent_size,
+ num_classes=args.num_classes,
+ class_guided=args.class_guided,
+ num_frames=args.num_frames,
+ use_lora=args.use_lora,
+ attention_mode=args.attention_mode
+ )
+ elif 'TAV' in args.model:
+ pretrained_model_path = args.pretrained_model_path
+ return UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", use_concat=args.use_mask)
+ else:
+ raise '{} Model Not Supported!'.format(args.model)
+
\ No newline at end of file
diff --git a/models/__pycache__/__init__.cpython-310.pyc b/models/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8e38774aa7011ed9e7cb03f3530bfe238eee0e3d
Binary files /dev/null and b/models/__pycache__/__init__.cpython-310.pyc differ
diff --git a/models/__pycache__/__init__.cpython-311.pyc b/models/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5921f914f6b92591485258426a52502a98ec31b2
Binary files /dev/null and b/models/__pycache__/__init__.cpython-311.pyc differ
diff --git a/models/__pycache__/__init__.cpython-38.pyc b/models/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..28ca91347df8cc40ce3df1ce8358596454159992
Binary files /dev/null and b/models/__pycache__/__init__.cpython-38.pyc differ
diff --git a/models/__pycache__/__init__.cpython-39.pyc b/models/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c853d1fc24036b6d7d63f1135188773934b60261
Binary files /dev/null and b/models/__pycache__/__init__.cpython-39.pyc differ
diff --git a/models/__pycache__/attention.cpython-310.pyc b/models/__pycache__/attention.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0dae6ab0508a6a216cbcdfaf1e5dc9d59edd9400
Binary files /dev/null and b/models/__pycache__/attention.cpython-310.pyc differ
diff --git a/models/__pycache__/attention.cpython-311.pyc b/models/__pycache__/attention.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..28861d5cbdedd15d1e1eec858106a4662cf1b544
Binary files /dev/null and b/models/__pycache__/attention.cpython-311.pyc differ
diff --git a/models/__pycache__/attention.cpython-38.pyc b/models/__pycache__/attention.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2ca1ff5762230c5921687794783619636ac0f36c
Binary files /dev/null and b/models/__pycache__/attention.cpython-38.pyc differ
diff --git a/models/__pycache__/attention.cpython-39.pyc b/models/__pycache__/attention.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..21888f95c7191e914acd8084b4a4ffbb1abc67f4
Binary files /dev/null and b/models/__pycache__/attention.cpython-39.pyc differ
diff --git a/models/__pycache__/clip.cpython-310.pyc b/models/__pycache__/clip.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..620862b1a6897caeb8178546df6fcdbb22464447
Binary files /dev/null and b/models/__pycache__/clip.cpython-310.pyc differ
diff --git a/models/__pycache__/clip.cpython-311.pyc b/models/__pycache__/clip.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cef805947a4de50152b8361dc68996458471e406
Binary files /dev/null and b/models/__pycache__/clip.cpython-311.pyc differ
diff --git a/models/__pycache__/clip.cpython-38.pyc b/models/__pycache__/clip.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7b951c75c84546c47051527873cc5b76cf4c18bd
Binary files /dev/null and b/models/__pycache__/clip.cpython-38.pyc differ
diff --git a/models/__pycache__/clip.cpython-39.pyc b/models/__pycache__/clip.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..139b4d17dd6b6e6f7cfd0d6a28276b9b95721e66
Binary files /dev/null and b/models/__pycache__/clip.cpython-39.pyc differ
diff --git a/models/__pycache__/dit.cpython-310.pyc b/models/__pycache__/dit.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5435bde751ef3af56297c7e7943baed748722a4d
Binary files /dev/null and b/models/__pycache__/dit.cpython-310.pyc differ
diff --git a/models/__pycache__/dit.cpython-311.pyc b/models/__pycache__/dit.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ce2ee43d6f041740f11adf6828786e73b658afca
Binary files /dev/null and b/models/__pycache__/dit.cpython-311.pyc differ
diff --git a/models/__pycache__/dit.cpython-38.pyc b/models/__pycache__/dit.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7dfaaf0356dff9964c21916d5bb1576668e699c7
Binary files /dev/null and b/models/__pycache__/dit.cpython-38.pyc differ
diff --git a/models/__pycache__/dit.cpython-39.pyc b/models/__pycache__/dit.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..083eabc53fa89168564d9ffa335998b839ee11ee
Binary files /dev/null and b/models/__pycache__/dit.cpython-39.pyc differ
diff --git a/models/__pycache__/resnet.cpython-310.pyc b/models/__pycache__/resnet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0a538bd768a726c9b895b450a23562eb80192570
Binary files /dev/null and b/models/__pycache__/resnet.cpython-310.pyc differ
diff --git a/models/__pycache__/resnet.cpython-311.pyc b/models/__pycache__/resnet.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8bae6d6fd9729b1a61c3dbd4fa73d0a9b2da2ddc
Binary files /dev/null and b/models/__pycache__/resnet.cpython-311.pyc differ
diff --git a/models/__pycache__/resnet.cpython-38.pyc b/models/__pycache__/resnet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cc7845c892b85f62b7b9b9788a5aacfe858f9f55
Binary files /dev/null and b/models/__pycache__/resnet.cpython-38.pyc differ
diff --git a/models/__pycache__/resnet.cpython-39.pyc b/models/__pycache__/resnet.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c0b02b1579bbdec188f32ba8687e1fdcf7ab9acd
Binary files /dev/null and b/models/__pycache__/resnet.cpython-39.pyc differ
diff --git a/models/__pycache__/unet.cpython-310.pyc b/models/__pycache__/unet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..25bbfb2f81f9166c59a3530ee0343739896c37df
Binary files /dev/null and b/models/__pycache__/unet.cpython-310.pyc differ
diff --git a/models/__pycache__/unet.cpython-311.pyc b/models/__pycache__/unet.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0af6d539b3bef28575dd35bc33422b2f6c4cf896
Binary files /dev/null and b/models/__pycache__/unet.cpython-311.pyc differ
diff --git a/models/__pycache__/unet.cpython-38.pyc b/models/__pycache__/unet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..55ac6ce89e455d280789c3b2921ffe4e9d737301
Binary files /dev/null and b/models/__pycache__/unet.cpython-38.pyc differ
diff --git a/models/__pycache__/unet.cpython-39.pyc b/models/__pycache__/unet.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a3ae31907c6dd8d167137f2d63282a1f06029ee7
Binary files /dev/null and b/models/__pycache__/unet.cpython-39.pyc differ
diff --git a/models/__pycache__/unet_blocks.cpython-310.pyc b/models/__pycache__/unet_blocks.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2aa1f327d934ad77e0ea205b6ab0587190acff7a
Binary files /dev/null and b/models/__pycache__/unet_blocks.cpython-310.pyc differ
diff --git a/models/__pycache__/unet_blocks.cpython-311.pyc b/models/__pycache__/unet_blocks.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4ffef0cc9e7a3aef69839bed4899f52763d90e92
Binary files /dev/null and b/models/__pycache__/unet_blocks.cpython-311.pyc differ
diff --git a/models/__pycache__/unet_blocks.cpython-38.pyc b/models/__pycache__/unet_blocks.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..49be4a6652f1971787067f57bccff3720ae3477e
Binary files /dev/null and b/models/__pycache__/unet_blocks.cpython-38.pyc differ
diff --git a/models/__pycache__/unet_blocks.cpython-39.pyc b/models/__pycache__/unet_blocks.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1bb1a58be97d8841d1476fb68629d3a8cb4aaa6a
Binary files /dev/null and b/models/__pycache__/unet_blocks.cpython-39.pyc differ
diff --git a/models/__pycache__/uvit.cpython-310.pyc b/models/__pycache__/uvit.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2fd6de89d77f345f7c2286c7f99979953b1b25b9
Binary files /dev/null and b/models/__pycache__/uvit.cpython-310.pyc differ
diff --git a/models/__pycache__/uvit.cpython-311.pyc b/models/__pycache__/uvit.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..83671c53187e3aec10eb42a62d68e8c3ce723439
Binary files /dev/null and b/models/__pycache__/uvit.cpython-311.pyc differ
diff --git a/models/__pycache__/uvit.cpython-38.pyc b/models/__pycache__/uvit.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..502d94ca74801f895cc9d8d1f8daf33c12d67020
Binary files /dev/null and b/models/__pycache__/uvit.cpython-38.pyc differ
diff --git a/models/__pycache__/uvit.cpython-39.pyc b/models/__pycache__/uvit.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..49767875b4891eb4ddf002df80a2dee370879e26
Binary files /dev/null and b/models/__pycache__/uvit.cpython-39.pyc differ
diff --git a/models/attention.py b/models/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..689f017d36b431f730257d3c94413a0eddc7287c
--- /dev/null
+++ b/models/attention.py
@@ -0,0 +1,968 @@
+# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
+import os
+import sys
+sys.path.append(os.path.split(sys.path[0])[0])
+
+from dataclasses import dataclass
+from typing import Optional
+
+import math
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.utils import BaseOutput
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.models.attention import FeedForward, AdaLayerNorm
+from rotary_embedding_torch import RotaryEmbedding
+from typing import Callable, Optional
+from einops import rearrange, repeat
+
+try:
+ from diffusers.models.modeling_utils import ModelMixin
+except:
+ from diffusers.modeling_utils import ModelMixin # 0.11.1
+
+
+@dataclass
+class Transformer3DModelOutput(BaseOutput):
+ sample: torch.FloatTensor
+
+
+if is_xformers_available():
+ import xformers
+ import xformers.ops
+else:
+ xformers = None
+
+def exists(x):
+ return x is not None
+
+
+class CrossAttention(nn.Module):
+ r"""
+ copy from diffuser 0.11.1
+ A cross attention layer.
+ Parameters:
+ query_dim (`int`): The number of channels in the query.
+ cross_attention_dim (`int`, *optional*):
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ bias (`bool`, *optional*, defaults to False):
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
+ """
+
+ def __init__(
+ self,
+ query_dim: int,
+ cross_attention_dim: Optional[int] = None,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias=False,
+ upcast_attention: bool = False,
+ upcast_softmax: bool = False,
+ added_kv_proj_dim: Optional[int] = None,
+ norm_num_groups: Optional[int] = None,
+ use_relative_position: bool = False,
+ ):
+ super().__init__()
+ # print('num head', heads)
+ inner_dim = dim_head * heads
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
+ self.upcast_attention = upcast_attention
+ self.upcast_softmax = upcast_softmax
+
+ self.scale = dim_head**-0.5
+
+ self.heads = heads
+ self.dim_head = dim_head
+ # for slice_size > 0 the attention score computation
+ # is split across the batch axis to save memory
+ # You can set slice_size with `set_attention_slice`
+ self.sliceable_head_dim = heads
+ self._slice_size = None
+ self._use_memory_efficient_attention_xformers = False
+ self.added_kv_proj_dim = added_kv_proj_dim
+
+ if norm_num_groups is not None:
+ self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
+ else:
+ self.group_norm = None
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
+
+ if self.added_kv_proj_dim is not None:
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
+
+ self.to_out = nn.ModuleList([])
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
+ self.to_out.append(nn.Dropout(dropout))
+
+ # print(use_relative_position)
+ self.use_relative_position = use_relative_position
+ if self.use_relative_position:
+ self.rotary_emb = RotaryEmbedding(min(32, dim_head))
+ # # print(dim_head)
+ # # print(heads)
+ # # adopt https://github.com/huggingface/transformers/blob/8a817e1ecac6a420b1bdc701fcc33535a3b96ff5/src/transformers/models/bert/modeling_bert.py#L265
+ # self.max_position_embeddings = 32
+ # self.distance_embedding = nn.Embedding(2 * self.max_position_embeddings - 1, dim_head)
+
+ # self.dropout = nn.Dropout(dropout)
+
+
+ def reshape_heads_to_batch_dim(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
+ return tensor
+
+ def reshape_batch_dim_to_heads(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
+ return tensor
+
+ def reshape_for_scores(self, tensor):
+ # split heads and dims
+ # tensor should be [b (h w)] f (d nd)
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
+ return tensor
+
+ def same_batch_dim_to_heads(self, tensor):
+ batch_size, head_size, seq_len, dim = tensor.shape # [b (h w)] nd f d
+ tensor = tensor.reshape(batch_size, seq_len, dim * head_size)
+ return tensor
+
+ def set_attention_slice(self, slice_size):
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
+
+ self._slice_size = slice_size
+
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, use_image_num=None):
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ encoder_hidden_states = encoder_hidden_states
+
+ if self.group_norm is not None:
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = self.to_q(hidden_states) # [b (h w)] f (nd * d)
+
+ # print('before reshpape query shape', query.shape)
+ dim = query.shape[-1]
+ if not self.use_relative_position:
+ query = self.reshape_heads_to_batch_dim(query) # [b (h w) nd] f d
+ # print('after reshape query shape', query.shape)
+
+ if self.added_kv_proj_dim is not None:
+ key = self.to_k(hidden_states)
+ value = self.to_v(hidden_states)
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
+
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
+
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
+ else:
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+ key = self.to_k(encoder_hidden_states)
+ value = self.to_v(encoder_hidden_states)
+
+ if not self.use_relative_position:
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+
+ if attention_mask is not None:
+ if attention_mask.shape[-1] != query.shape[1]:
+ target_length = query.shape[1]
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
+
+ # attention, what we cannot get enough of
+ if self._use_memory_efficient_attention_xformers:
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
+ hidden_states = hidden_states.to(query.dtype)
+ else:
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
+ hidden_states = self._attention(query, key, value, attention_mask)
+ else:
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
+
+ # linear proj
+ hidden_states = self.to_out[0](hidden_states)
+
+ # dropout
+ hidden_states = self.to_out[1](hidden_states)
+ return hidden_states
+
+
+ def _attention(self, query, key, value, attention_mask=None):
+ if self.upcast_attention:
+ query = query.float()
+ key = key.float()
+
+ attention_scores = torch.baddbmm(
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
+ query,
+ key.transpose(-1, -2),
+ beta=0,
+ alpha=self.scale,
+ )
+
+ # print('query shape', query.shape)
+ # print('key shape', key.shape)
+ # print('value shape', value.shape)
+
+ if attention_mask is not None:
+ # print('attention_mask', attention_mask.shape)
+ # print('attention_scores', attention_scores.shape)
+ # exit()
+ attention_scores = attention_scores + attention_mask
+
+ if self.upcast_softmax:
+ attention_scores = attention_scores.float()
+
+ attention_probs = attention_scores.softmax(dim=-1)
+ # print(attention_probs.shape)
+
+ # cast back to the original dtype
+ attention_probs = attention_probs.to(value.dtype)
+ # print(attention_probs.shape)
+
+ # compute attention output
+ hidden_states = torch.bmm(attention_probs, value)
+ # print(hidden_states.shape)
+
+ # reshape hidden_states
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ # print(hidden_states.shape)
+ # exit()
+ return hidden_states
+
+ def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
+ batch_size_attention = query.shape[0]
+ hidden_states = torch.zeros(
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
+ )
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
+ for i in range(hidden_states.shape[0] // slice_size):
+ start_idx = i * slice_size
+ end_idx = (i + 1) * slice_size
+
+ query_slice = query[start_idx:end_idx]
+ key_slice = key[start_idx:end_idx]
+
+ if self.upcast_attention:
+ query_slice = query_slice.float()
+ key_slice = key_slice.float()
+
+ attn_slice = torch.baddbmm(
+ torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
+ query_slice,
+ key_slice.transpose(-1, -2),
+ beta=0,
+ alpha=self.scale,
+ )
+
+ if attention_mask is not None:
+ attn_slice = attn_slice + attention_mask[start_idx:end_idx]
+
+ if self.upcast_softmax:
+ attn_slice = attn_slice.float()
+
+ attn_slice = attn_slice.softmax(dim=-1)
+
+ # cast back to the original dtype
+ attn_slice = attn_slice.to(value.dtype)
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
+
+ hidden_states[start_idx:end_idx] = attn_slice
+
+ # reshape hidden_states
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
+ # TODO attention_mask
+ query = query.contiguous()
+ key = key.contiguous()
+ value = value.contiguous()
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+
+class Transformer3DModel(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 88,
+ in_channels: Optional[int] = None,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ norm_num_groups: int = 32,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ use_first_frame: bool = False,
+ use_relative_position: bool = False,
+ rotary_emb: bool = None,
+ ):
+ super().__init__()
+ self.use_linear_projection = use_linear_projection
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+ inner_dim = num_attention_heads * attention_head_dim
+
+ # Define input layers
+ self.in_channels = in_channels
+
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+ if use_linear_projection:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+ else:
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
+
+ # Define transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ attention_bias=attention_bias,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ use_first_frame=use_first_frame,
+ use_relative_position=use_relative_position,
+ rotary_emb=rotary_emb,
+ )
+ for d in range(num_layers)
+ ]
+ )
+
+ # 4. Define output layers
+ if use_linear_projection:
+ self.proj_out = nn.Linear(in_channels, inner_dim)
+ else:
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, use_image_num=None, return_dict: bool = True):
+ # Input
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
+ if self.training:
+ video_length = hidden_states.shape[2] - use_image_num
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous()
+ encoder_hidden_states_length = encoder_hidden_states.shape[1]
+ encoder_hidden_states_video = encoder_hidden_states[:, :encoder_hidden_states_length - use_image_num, ...]
+ encoder_hidden_states_video = repeat(encoder_hidden_states_video, 'b m n c -> b (m f) n c', f=video_length).contiguous()
+ encoder_hidden_states_image = encoder_hidden_states[:, encoder_hidden_states_length - use_image_num:, ...]
+ encoder_hidden_states = torch.cat([encoder_hidden_states_video, encoder_hidden_states_image], dim=1)
+ encoder_hidden_states = rearrange(encoder_hidden_states, 'b m n c -> (b m) n c').contiguous()
+ else:
+ video_length = hidden_states.shape[2]
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous()
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length).contiguous()
+
+ batch, channel, height, weight = hidden_states.shape
+ residual = hidden_states
+
+ hidden_states = self.norm(hidden_states)
+ if not self.use_linear_projection:
+ hidden_states = self.proj_in(hidden_states)
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
+ else:
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
+ hidden_states = self.proj_in(hidden_states)
+
+ # Blocks
+ for block in self.transformer_blocks:
+ hidden_states = block(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ timestep=timestep,
+ video_length=video_length,
+ use_image_num=use_image_num,
+ )
+
+ # Output
+ if not self.use_linear_projection:
+ hidden_states = (
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
+ )
+ hidden_states = self.proj_out(hidden_states)
+ else:
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = (
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
+ )
+
+ output = hidden_states + residual
+
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length + use_image_num).contiguous()
+ if not return_dict:
+ return (output,)
+
+ return Transformer3DModelOutput(sample=output)
+
+
+class BasicTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ attention_bias: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ use_first_frame: bool = False,
+ use_relative_position: bool = False,
+ rotary_emb: bool = False,
+ ):
+ super().__init__()
+ self.only_cross_attention = only_cross_attention
+ # print(only_cross_attention)
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
+ # print(self.use_ada_layer_norm)
+ self.use_first_frame = use_first_frame
+
+ # Spatial-Attn
+ self.attn1 = CrossAttention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=None,
+ upcast_attention=upcast_attention,
+ )
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
+
+ # # SC-Attn
+ # self.attn1 = SparseCausalAttention(
+ # query_dim=dim,
+ # heads=num_attention_heads,
+ # dim_head=attention_head_dim,
+ # dropout=dropout,
+ # bias=attention_bias,
+ # cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ # upcast_attention=upcast_attention,
+ # )
+ # self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
+
+ # Text Cross-Attn
+ if cross_attention_dim is not None:
+ self.attn2 = CrossAttention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ )
+ else:
+ self.attn2 = None
+
+ if cross_attention_dim is not None:
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
+ else:
+ self.norm2 = None
+
+ # # Temp Frame-Cross-Attn; add tahn scale factor
+ # self.attn_fcross = SparseCausalAttention(
+ # query_dim=dim,
+ # heads=num_attention_heads,
+ # dim_head=attention_head_dim,
+ # dropout=dropout,
+ # bias=attention_bias,
+ # cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ # upcast_attention=upcast_attention,
+ # )
+ # self.norm_fcross = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
+ # nn.init.zeros_(self.attn_fcross.to_out[0].weight.data)
+
+ # Temp
+ self.attn_temp = TemporalAttention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=None,
+ upcast_attention=upcast_attention,
+ rotary_emb=rotary_emb,
+ )
+ self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
+
+
+ # Feed-forward
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
+ self.norm3 = nn.LayerNorm(dim)
+
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, op=None):
+
+ if not is_xformers_available():
+ print("Here is how to install it")
+ raise ModuleNotFoundError(
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
+ " xformers",
+ name="xformers",
+ )
+ elif not torch.cuda.is_available():
+ raise ValueError(
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
+ " available for GPU "
+ )
+ else:
+ try:
+ # Make sure we can run the memory efficient attention
+ _ = xformers.ops.memory_efficient_attention(
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ )
+ except Exception as e:
+ raise e
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
+ if self.attn2 is not None:
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
+ # self.attn_fcross._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
+ # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
+
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None, use_image_num=None):
+ # SparseCausal-Attention
+ norm_hidden_states = (
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
+ )
+
+ if self.only_cross_attention:
+ hidden_states = (
+ self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
+ )
+ else:
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, use_image_num=use_image_num) + hidden_states
+
+ # # SparseCausal-Attention
+ # norm_hidden_states = (
+ # self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
+ # )
+
+ # if self.only_cross_attention:
+ # hidden_states = (
+ # self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
+ # )
+ # else:
+ # hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
+
+ # # Temporal FrameCross Attention
+ # norm_hidden_states = (
+ # self.norm_fcross(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_fcross(hidden_states)
+ # )
+ # hidden_states = self.attn_fcross(
+ # norm_hidden_states, attention_mask=attention_mask, video_length=video_length, use_image_num=use_image_num) + hidden_states
+
+ if self.attn2 is not None:
+ # Cross-Attention
+ norm_hidden_states = (
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
+ )
+ hidden_states = (
+ self.attn2(
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
+ )
+ + hidden_states
+ )
+
+ # Temporal Attention
+ if self.training:
+ d = hidden_states.shape[1]
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length + use_image_num).contiguous()
+ hidden_states_video = hidden_states[:, :video_length, :]
+ hidden_states_image = hidden_states[:, video_length:, :]
+ # print(hidden_states_video.shape)
+ # print(hidden_states_image.shape)
+ # if self.training:
+ # # prepare attention mask; mask images in temporal attention
+ # attention_mask_shape = (video_length + use_image_num) // 8 + 1
+ # video_image_length = video_length + use_image_num
+ # attention_mask = torch.zeros([8 * attention_mask_shape, 8 * attention_mask_shape],
+ # dtype=hidden_states.dtype, device=hidden_states.device)[:video_image_length, :video_image_length]
+ # attention_mask[:, video_length:] = -math.inf
+ norm_hidden_states_video = (
+ self.norm_temp(hidden_states_video, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states_video)
+ )
+ # print(norm_hidden_states.shape)
+ hidden_states_video = self.attn_temp(norm_hidden_states_video) + hidden_states_video
+ hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d).contiguous()
+ else:
+ d = hidden_states.shape[1]
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length + use_image_num).contiguous()
+ norm_hidden_states = (
+ self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
+ )
+ # print(norm_hidden_states.shape)
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d).contiguous()
+
+ # Feed-forward
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
+
+ return hidden_states
+
+
+class SparseCausalAttention(CrossAttention):
+ def forward_video(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ encoder_hidden_states = encoder_hidden_states
+
+ if self.group_norm is not None:
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = self.to_q(hidden_states)
+ dim = query.shape[-1]
+ query = self.reshape_heads_to_batch_dim(query)
+
+ if self.added_kv_proj_dim is not None:
+ raise NotImplementedError
+
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+ key = self.to_k(encoder_hidden_states)
+ value = self.to_v(encoder_hidden_states)
+
+ former_frame_index = torch.arange(video_length) - 1
+ former_frame_index[0] = 0
+
+ key = rearrange(key, "(b f) d c -> b f d c", f=video_length).contiguous()
+ key = torch.cat([key[:, [0] * video_length], key[:, former_frame_index]], dim=2)
+ key = rearrange(key, "b f d c -> (b f) d c").contiguous()
+
+ value = rearrange(value, "(b f) d c -> b f d c", f=video_length).contiguous()
+ value = torch.cat([value[:, [0] * video_length], value[:, former_frame_index]], dim=2)
+ value = rearrange(value, "b f d c -> (b f) d c").contiguous()
+
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+
+ if attention_mask is not None:
+ if attention_mask.shape[-1] != query.shape[1]:
+ target_length = query.shape[1]
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
+
+ # attention, what we cannot get enough of
+ if self._use_memory_efficient_attention_xformers:
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
+ hidden_states = hidden_states.to(query.dtype)
+ else:
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
+ hidden_states = self._attention(query, key, value, attention_mask)
+ else:
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
+
+ # linear proj
+ hidden_states = self.to_out[0](hidden_states)
+
+ # dropout
+ hidden_states = self.to_out[1](hidden_states)
+ return hidden_states
+
+ def forward_image(self, hidden_states, encoder_hidden_states=None, attention_mask=None, use_image_num=None):
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ encoder_hidden_states = encoder_hidden_states
+
+ if self.group_norm is not None:
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = self.to_q(hidden_states) # [b (h w)] f (nd * d)
+ # if self.use_relative_position:
+ # print('before attention query shape', query.shape)
+ dim = query.shape[-1]
+ if not self.use_relative_position:
+ query = self.reshape_heads_to_batch_dim(query) # [b (h w) nd] f d
+ # if self.use_relative_position:
+ # print('before attention query shape', query.shape)
+
+ if self.added_kv_proj_dim is not None:
+ key = self.to_k(hidden_states)
+ value = self.to_v(hidden_states)
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
+
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
+
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
+ else:
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+ key = self.to_k(encoder_hidden_states)
+ value = self.to_v(encoder_hidden_states)
+
+ if not self.use_relative_position:
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+
+ if attention_mask is not None:
+ if attention_mask.shape[-1] != query.shape[1]:
+ target_length = query.shape[1]
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
+
+ # attention, what we cannot get enough of
+ if self._use_memory_efficient_attention_xformers:
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
+ hidden_states = hidden_states.to(query.dtype)
+ else:
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
+ hidden_states = self._attention(query, key, value, attention_mask)
+ else:
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
+
+ # linear proj
+ hidden_states = self.to_out[0](hidden_states)
+
+ # dropout
+ hidden_states = self.to_out[1](hidden_states)
+ return hidden_states
+
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, use_image_num=None):
+ if self.training:
+ # print(use_image_num)
+ hidden_states = rearrange(hidden_states, "(b f) d c -> b f d c", f=video_length + use_image_num).contiguous()
+ hidden_states_video = hidden_states[:, :video_length, ...]
+ hidden_states_image = hidden_states[:, video_length:, ...]
+ hidden_states_video = rearrange(hidden_states_video, 'b f d c -> (b f) d c').contiguous()
+ hidden_states_image = rearrange(hidden_states_image, 'b f d c -> (b f) d c').contiguous()
+ hidden_states_video = self.forward_video(hidden_states=hidden_states_video,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ video_length=video_length)
+ # print('hidden_states_video', hidden_states_video.shape)
+ hidden_states_image = self.forward_image(hidden_states=hidden_states_image,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask)
+ # print('hidden_states_image', hidden_states_image.shape)
+ hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=0)
+ return hidden_states
+ # exit()
+ else:
+ return self.forward_video(hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ video_length=video_length)
+
+class TemporalAttention(CrossAttention):
+ def __init__(self,
+ query_dim: int,
+ cross_attention_dim: Optional[int] = None,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias=False,
+ upcast_attention: bool = False,
+ upcast_softmax: bool = False,
+ added_kv_proj_dim: Optional[int] = None,
+ norm_num_groups: Optional[int] = None,
+ rotary_emb=None):
+ super().__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax, added_kv_proj_dim, norm_num_groups)
+ # relative time positional embeddings
+ self.time_rel_pos_bias = RelativePositionBias(heads=heads, max_distance=32) # realistically will not be able to generate that many frames of video... yet
+ self.rotary_emb = rotary_emb
+
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
+ time_rel_pos_bias = self.time_rel_pos_bias(hidden_states.shape[1], device=hidden_states.device)
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ encoder_hidden_states = encoder_hidden_states
+
+ if self.group_norm is not None:
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = self.to_q(hidden_states) # [b (h w)] f (nd * d)
+ dim = query.shape[-1]
+
+ if self.added_kv_proj_dim is not None:
+ key = self.to_k(hidden_states)
+ value = self.to_v(hidden_states)
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
+
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
+
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
+ else:
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+ key = self.to_k(encoder_hidden_states)
+ value = self.to_v(encoder_hidden_states)
+
+ if attention_mask is not None:
+ if attention_mask.shape[-1] != query.shape[1]:
+ target_length = query.shape[1]
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
+
+ # attention, what we cannot get enough of
+ if self._use_memory_efficient_attention_xformers:
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
+ hidden_states = hidden_states.to(query.dtype)
+ else:
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
+ hidden_states = self._attention(query, key, value, attention_mask, time_rel_pos_bias)
+ else:
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
+
+ # linear proj
+ hidden_states = self.to_out[0](hidden_states)
+
+ # dropout
+ hidden_states = self.to_out[1](hidden_states)
+ return hidden_states
+
+
+ def _attention(self, query, key, value, attention_mask=None, time_rel_pos_bias=None):
+ if self.upcast_attention:
+ query = query.float()
+ key = key.float()
+
+ # print('query shape', query.shape)
+ # print('key shape', key.shape)
+ # print('value shape', value.shape)
+ # reshape for adding time positional bais
+ query = self.scale * rearrange(query, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads
+ key = rearrange(key, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads
+ value = rearrange(value, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads
+ # print('query shape', query.shape)
+ # print('key shape', key.shape)
+ # print('value shape', value.shape)
+
+ # torch.baddbmm only accepte 3-D tensor
+ # https://runebook.dev/zh/docs/pytorch/generated/torch.baddbmm
+ # attention_scores = self.scale * torch.matmul(query, key.transpose(-1, -2))
+ if exists(self.rotary_emb):
+ query = self.rotary_emb.rotate_queries_or_keys(query)
+ key = self.rotary_emb.rotate_queries_or_keys(key)
+
+ attention_scores = torch.einsum('... h i d, ... h j d -> ... h i j', query, key)
+ # print('attention_scores shape', attention_scores.shape)
+ # print('time_rel_pos_bias shape', time_rel_pos_bias.shape)
+ # print('attention_mask shape', attention_mask.shape)
+
+ attention_scores = attention_scores + time_rel_pos_bias
+ # print(attention_scores.shape)
+
+ # bert from huggin face
+ # attention_scores = attention_scores / math.sqrt(self.dim_head)
+
+ # # Normalize the attention scores to probabilities.
+ # attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ if attention_mask is not None:
+ # add attention mask
+ attention_scores = attention_scores + attention_mask
+
+ # vdm
+ attention_scores = attention_scores - attention_scores.amax(dim = -1, keepdim = True).detach()
+
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+ # print(attention_probs[0][0])
+
+ # cast back to the original dtype
+ attention_probs = attention_probs.to(value.dtype)
+
+ # compute attention output
+ # hidden_states = torch.matmul(attention_probs, value)
+ hidden_states = torch.einsum('... h i j, ... h j d -> ... h i d', attention_probs, value)
+ # print(hidden_states.shape)
+ # hidden_states = self.same_batch_dim_to_heads(hidden_states)
+ hidden_states = rearrange(hidden_states, 'b h f d -> b f (h d)')
+ # print(hidden_states.shape)
+ # exit()
+ return hidden_states
+
+class RelativePositionBias(nn.Module):
+ def __init__(
+ self,
+ heads=8,
+ num_buckets=32,
+ max_distance=128,
+ ):
+ super().__init__()
+ self.num_buckets = num_buckets
+ self.max_distance = max_distance
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
+
+ @staticmethod
+ def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
+ ret = 0
+ n = -relative_position
+
+ num_buckets //= 2
+ ret += (n < 0).long() * num_buckets
+ n = torch.abs(n)
+
+ max_exact = num_buckets // 2
+ is_small = n < max_exact
+
+ val_if_large = max_exact + (
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
+ ).long()
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
+
+ ret += torch.where(is_small, n, val_if_large)
+ return ret
+
+ def forward(self, n, device):
+ q_pos = torch.arange(n, dtype = torch.long, device = device)
+ k_pos = torch.arange(n, dtype = torch.long, device = device)
+ rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
+ rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
+ values = self.relative_attention_bias(rp_bucket)
+ return rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames
\ No newline at end of file
diff --git a/models/clip.py b/models/clip.py
new file mode 100644
index 0000000000000000000000000000000000000000..249ff2a1680d580d94d4a18c1db5f538a81c043d
--- /dev/null
+++ b/models/clip.py
@@ -0,0 +1,123 @@
+import numpy
+import torch.nn as nn
+from transformers import CLIPTokenizer, CLIPTextModel
+
+import transformers
+transformers.logging.set_verbosity_error()
+
+"""
+Will encounter following warning:
+- This IS expected if you are initializing CLIPTextModel from the checkpoint of a model trained on another task
+or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
+- This IS NOT expected if you are initializing CLIPTextModel from the checkpoint of a model
+that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
+
+https://github.com/CompVis/stable-diffusion/issues/97
+according to this issue, this warning is safe.
+
+This is expected since the vision backbone of the CLIP model is not needed to run Stable Diffusion.
+You can safely ignore the warning, it is not an error.
+
+This clip usage is from U-ViT and same with Stable Diffusion.
+"""
+
+class AbstractEncoder(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def encode(self, *args, **kwargs):
+ raise NotImplementedError
+
+
+class FrozenCLIPEmbedder(AbstractEncoder):
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
+ # def __init__(self, version="openai/clip-vit-huge-patch14", device="cuda", max_length=77):
+ def __init__(self, path, device="cuda", max_length=77):
+ super().__init__()
+ self.tokenizer = CLIPTokenizer.from_pretrained(path, subfolder="tokenizer")
+ self.transformer = CLIPTextModel.from_pretrained(path, subfolder='text_encoder')
+ self.device = device
+ self.max_length = max_length
+ self.freeze()
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+ tokens = batch_encoding["input_ids"].to(self.device)
+ outputs = self.transformer(input_ids=tokens)
+
+ z = outputs.last_hidden_state
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+
+class TextEmbedder(nn.Module):
+ """
+ Embeds text prompt into vector representations. Also handles text dropout for classifier-free guidance.
+ """
+ def __init__(self, path, dropout_prob=0.1):
+ super().__init__()
+ self.text_encodder = FrozenCLIPEmbedder(path=path)
+ self.dropout_prob = dropout_prob
+
+ def token_drop(self, text_prompts, force_drop_ids=None):
+ """
+ Drops text to enable classifier-free guidance.
+ """
+ if force_drop_ids is None:
+ drop_ids = numpy.random.uniform(0, 1, len(text_prompts)) < self.dropout_prob
+ else:
+ # TODO
+ drop_ids = force_drop_ids == 1
+ labels = list(numpy.where(drop_ids, "", text_prompts))
+ # print(labels)
+ return labels
+
+ def forward(self, text_prompts, train, force_drop_ids=None):
+ use_dropout = self.dropout_prob > 0
+ if (train and use_dropout) or (force_drop_ids is not None):
+ text_prompts = self.token_drop(text_prompts, force_drop_ids)
+ embeddings = self.text_encodder(text_prompts)
+ return embeddings
+
+
+if __name__ == '__main__':
+
+ r"""
+ Returns:
+
+ Examples from CLIPTextModel:
+
+ ```python
+ >>> from transformers import AutoTokenizer, CLIPTextModel
+
+ >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
+
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> last_hidden_state = outputs.last_hidden_state
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
+ ```"""
+
+ import torch
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ text_encoder = TextEmbedder(path='/mnt/petrelfs/maxin/work/pretrained/stable-diffusion-2-1-base',
+ dropout_prob=0.00001).to(device)
+
+ text_prompt = [["a photo of a cat", "a photo of a cat"], ["a photo of a dog", "a photo of a cat"], ['a photo of a dog human', "a photo of a cat"]]
+ # text_prompt = ('None', 'None', 'None')
+ output = text_encoder(text_prompts=text_prompt, train=False)
+ # print(output)
+ print(output.shape)
+ # print(output.shape)
\ No newline at end of file
diff --git a/models/dit.py b/models/dit.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ffebb93dc4acd3da5b876fa75decb32fbc5b559
--- /dev/null
+++ b/models/dit.py
@@ -0,0 +1,617 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# GLIDE: https://github.com/openai/glide-text2im
+# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
+# --------------------------------------------------------
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+
+from einops import rearrange, repeat
+from timm.models.vision_transformer import Mlp, PatchEmbed
+
+
+
+import os
+import sys
+# sys.path.append(os.getcwd())
+sys.path.append(os.path.split(sys.path[0])[0])
+# 代码解释
+# sys.path[0] : 得到C:\Users\maxu\Desktop\blog_test\pakage2
+# os.path.split(sys.path[0]) : 得到['C:\Users\maxu\Desktop\blog_test',pakage2']
+# mmcls 里面跨包引用是因为安装了mmcls
+
+
+# for i in sys.path:
+# print(i)
+
+# the xformers lib allows less memory, faster training and inference
+try:
+ import xformers
+ import xformers.ops
+except:
+ XFORMERS_IS_AVAILBLE = False
+
+# from timm.models.layers.helpers import to_2tuple
+# from timm.models.layers.trace_utils import _assert
+
+def modulate(x, shift, scale):
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+
+#################################################################################
+# Attention Layers from TIMM #
+#################################################################################
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_lora=False, attention_mode='math'):
+ super().__init__()
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim ** -0.5
+ self.attention_mode = attention_mode
+ self.use_lora = use_lora
+
+ if self.use_lora:
+ self.qkv = lora.MergedLinear(dim, dim * 3, r=500, enable_lora=[True, False, True])
+ else:
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous()
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
+
+ if self.attention_mode == 'xformers': # cause loss nan while using with amp
+ x = xformers.ops.memory_efficient_attention(q, k, v).reshape(B, N, C)
+
+ elif self.attention_mode == 'flash':
+ # cause loss nan while using with amp
+ # Optionally use the context manager to ensure one of the fused kerenels is run
+ with torch.backends.cuda.sdp_kernel(enable_math=False):
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v).reshape(B, N, C) # require pytorch 2.0
+
+ elif self.attention_mode == 'math':
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+
+ else:
+ raise NotImplemented
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+#################################################################################
+# Embedding Layers for Timesteps and Class Labels #
+#################################################################################
+
+class TimestepEmbedder(nn.Module):
+ """
+ Embeds scalar timesteps into vector representations.
+ """
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+ :param t: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an (N, D) Tensor of positional embeddings.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=t.device)
+ args = t[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+
+class LabelEmbedder(nn.Module):
+ """
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
+ """
+ def __init__(self, num_classes, hidden_size, dropout_prob):
+ super().__init__()
+ use_cfg_embedding = dropout_prob > 0
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
+ self.num_classes = num_classes
+ self.dropout_prob = dropout_prob
+
+ def token_drop(self, labels, force_drop_ids=None):
+ """
+ Drops labels to enable classifier-free guidance.
+ """
+ if force_drop_ids is None:
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
+ print(drop_ids)
+ else:
+ drop_ids = force_drop_ids == 1
+ labels = torch.where(drop_ids, self.num_classes, labels)
+ print('******labels******', labels)
+ return labels
+
+ def forward(self, labels, train, force_drop_ids=None):
+ use_dropout = self.dropout_prob > 0
+ if (train and use_dropout) or (force_drop_ids is not None):
+ labels = self.token_drop(labels, force_drop_ids)
+ embeddings = self.embedding_table(labels)
+ return embeddings
+
+
+#################################################################################
+# Core DiT Model #
+#################################################################################
+
+class DiTBlock(nn.Module):
+ """
+ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
+ """
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
+ super().__init__()
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
+ )
+
+ def forward(self, x, c):
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
+ return x
+
+
+class FinalLayer(nn.Module):
+ """
+ The final layer of DiT.
+ """
+ def __init__(self, hidden_size, patch_size, out_channels):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
+ )
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
+ x = modulate(self.norm_final(x), shift, scale)
+ x = self.linear(x)
+ return x
+
+
+class DiT(nn.Module):
+ """
+ Diffusion model with a Transformer backbone.
+ """
+ def __init__(
+ self,
+ input_size=32,
+ patch_size=2,
+ in_channels=4,
+ hidden_size=1152,
+ depth=28,
+ num_heads=16,
+ mlp_ratio=4.0,
+ num_frames=16,
+ class_dropout_prob=0.1,
+ num_classes=1000,
+ learn_sigma=True,
+ class_guided=False,
+ use_lora=False,
+ attention_mode='math',
+ ):
+ super().__init__()
+ self.learn_sigma = learn_sigma
+ self.in_channels = in_channels
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
+ self.patch_size = patch_size
+ self.num_heads = num_heads
+ self.class_guided = class_guided
+ self.num_frames = num_frames
+
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
+ self.t_embedder = TimestepEmbedder(hidden_size)
+
+ if self.class_guided:
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
+
+ num_patches = self.x_embedder.num_patches
+ # Will use fixed sin-cos embedding:
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
+ self.time_embed = nn.Parameter(torch.zeros(1, num_frames, hidden_size), requires_grad=False)
+
+ if use_lora:
+ self.blocks = nn.ModuleList([
+ DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attention_mode=attention_mode, use_lora=False if num % 2 ==0 else True) for num in range(depth)
+ ])
+ else:
+ self.blocks = nn.ModuleList([
+ DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attention_mode=attention_mode) for _ in range(depth)
+ ])
+
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ # Initialize transformer layers:
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+ self.apply(_basic_init)
+
+ # Initialize (and freeze) pos_embed by sin-cos embedding:
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
+
+ time_embed = get_1d_sincos_time_embed(self.time_embed.shape[-1], self.time_embed.shape[-2])
+ self.time_embed.data.copy_(torch.from_numpy(time_embed).float().unsqueeze(0))
+
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.x_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
+
+ if self.class_guided:
+ # Initialize label embedding table:
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ # Zero-out adaLN modulation layers in DiT blocks:
+ for block in self.blocks:
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+
+ # Zero-out output layers:
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
+ nn.init.constant_(self.final_layer.linear.weight, 0)
+ nn.init.constant_(self.final_layer.linear.bias, 0)
+
+ def unpatchify(self, x):
+ """
+ x: (N, T, patch_size**2 * C)
+ imgs: (N, H, W, C)
+ """
+ c = self.out_channels
+ p = self.x_embedder.patch_size[0]
+ h = w = int(x.shape[1] ** 0.5)
+ assert h * w == x.shape[1]
+
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
+ x = torch.einsum('nhwpqc->nchpwq', x)
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
+ return imgs
+
+ # @torch.cuda.amp.autocast()
+ # @torch.compile
+ def forward(self, x, t, y=None):
+ """
+ Forward pass of DiT.
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
+ t: (N,) tensor of diffusion timesteps
+ y: (N,) tensor of class labels
+ """
+ # print('label: {}'.format(y))
+ batches, frames, channels, high, weight = x.shape # for example, 3, 16, 3, 32, 32
+ # 这里rearrange后每隔f是同一个视频
+ x = rearrange(x, 'b f c h w -> (b f) c h w')
+ x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
+ t = self.t_embedder(t) # (N, D)
+ # timestep_spatial的repeat需要保证每f帧为同一个timesteps
+ timestep_spatial = repeat(t, 'n d -> (n c) d', c=self.time_embed.shape[1]) # 48, 1152; c=num_frames
+ timestep_time = repeat(t, 'n d -> (n c) d', c=self.pos_embed.shape[1]) # 768, 1152; c=num tokens
+
+ if self.class_guided:
+ y = self.y_embedder(y, self.training)
+ y_spatial = repeat(y, 'n d -> (n c) d', c=self.time_embed.shape[1]) # 48, 1152; c=num_frames
+ y_time = repeat(y, 'n d -> (n c) d', c=self.pos_embed.shape[1]) # 768, 1152; c=num tokens
+
+ # if self.class_guided:
+ # y = self.y_embedder(y, self.training) # (N, D)
+ # c = timestep_spatial + y
+ # else:
+ # c = timestep_spatial
+
+ # for block in self.blocks:
+ # x = block(x, c) # (N, T, D)
+
+ for i in range(0, len(self.blocks), 2):
+ # print('The {}-th run'.format(i))
+ spatial_block, time_block = self.blocks[i:i+2]
+ # print(spatial_block)
+ # print(time_block)
+ # print(x.shape)
+
+ if self.class_guided:
+ c = timestep_spatial + y_spatial
+ else:
+ c = timestep_spatial
+ x = spatial_block(x, c)
+ # print(c.shape)
+
+ x = rearrange(x, '(b f) t d -> (b t) f d', b=batches) # t 代表单帧token数; 768, 16, 1152
+ # Add Time Embedding
+ if i == 0:
+ x = x + self.time_embed # 768, 16, 1152
+
+ if self.class_guided:
+ c = timestep_time + y_time
+ else:
+ # timestep_time = repeat(t, 'n d -> (n c) d', c=x.shape[0] // batches) # 768, 1152
+ # print(timestep_time.shape)
+ c = timestep_time
+
+ x = time_block(x, c)
+ # print(x.shape)
+ x = rearrange(x, '(b t) f d -> (b f) t d', b=batches)
+
+ # x = rearrange(x, '(b t) f d -> (b f) t d', b=batches)
+ if self.class_guided:
+ c = timestep_spatial + y_spatial
+ else:
+ c = timestep_spatial
+ x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
+ x = self.unpatchify(x) # (N, out_channels, H, W)
+ x = rearrange(x, '(b f) c h w -> b f c h w', b=batches)
+ # print(x.shape)
+ return x
+
+ def forward_motion(self, motions, t, base_frame, y=None):
+ """
+ Forward pass of DiT.
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
+ t: (N,) tensor of diffusion timesteps
+ y: (N,) tensor of class labels
+ """
+ # print('label: {}'.format(y))
+ batches, frames, channels, high, weight = motions.shape # for example, 3, 16, 3, 32, 32
+ # 这里rearrange后每隔f是同一个视频
+ motions = rearrange(motions, 'b f c h w -> (b f) c h w')
+ motions = self.x_embedder(motions) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
+ t = self.t_embedder(t) # (N, D)
+ # timestep_spatial的repeat需要保证每f帧为同一个timesteps
+ timestep_spatial = repeat(t, 'n d -> (n c) d', c=self.time_embed.shape[1]) # 48, 1152; c=num_frames
+ timestep_time = repeat(t, 'n d -> (n c) d', c=self.pos_embed.shape[1]) # 768, 1152; c=num tokens
+
+ if self.class_guided:
+ y = self.y_embedder(y, self.training)
+ y_spatial = repeat(y, 'n d -> (n c) d', c=self.time_embed.shape[1]) # 48, 1152; c=num_frames
+ y_time = repeat(y, 'n d -> (n c) d', c=self.pos_embed.shape[1]) # 768, 1152; c=num tokens
+
+ # if self.class_guided:
+ # y = self.y_embedder(y, self.training) # (N, D)
+ # c = timestep_spatial + y
+ # else:
+ # c = timestep_spatial
+
+ # for block in self.blocks:
+ # x = block(x, c) # (N, T, D)
+
+ for i in range(0, len(self.blocks), 2):
+ # print('The {}-th run'.format(i))
+ spatial_block, time_block = self.blocks[i:i+2]
+ # print(spatial_block)
+ # print(time_block)
+ # print(x.shape)
+
+ if self.class_guided:
+ c = timestep_spatial + y_spatial
+ else:
+ c = timestep_spatial
+ x = spatial_block(x, c)
+ # print(c.shape)
+
+ x = rearrange(x, '(b f) t d -> (b t) f d', b=batches) # t 代表单帧token数; 768, 16, 1152
+ # Add Time Embedding
+ if i == 0:
+ x = x + self.time_embed # 768, 16, 1152
+
+ if self.class_guided:
+ c = timestep_time + y_time
+ else:
+ # timestep_time = repeat(t, 'n d -> (n c) d', c=x.shape[0] // batches) # 768, 1152
+ # print(timestep_time.shape)
+ c = timestep_time
+
+ x = time_block(x, c)
+ # print(x.shape)
+ x = rearrange(x, '(b t) f d -> (b f) t d', b=batches)
+
+ # x = rearrange(x, '(b t) f d -> (b f) t d', b=batches)
+ if self.class_guided:
+ c = timestep_spatial + y_spatial
+ else:
+ c = timestep_spatial
+ x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
+ x = self.unpatchify(x) # (N, out_channels, H, W)
+ x = rearrange(x, '(b f) c h w -> b f c h w', b=batches)
+ # print(x.shape)
+ return x
+
+
+ def forward_with_cfg(self, x, t, y, cfg_scale):
+ """
+ Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
+ half = x[: len(x) // 2]
+ combined = torch.cat([half, half], dim=0)
+ model_out = self.forward(combined, t, y)
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
+ # three channels by default. The standard approach to cfg applies it to all channels.
+ # This can be done by uncommenting the following line and commenting-out the line following that.
+ # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
+ eps, rest = model_out[:, :3], model_out[:, 3:]
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
+ eps = torch.cat([half_eps, half_eps], dim=0)
+ return torch.cat([eps, rest], dim=1)
+
+
+#################################################################################
+# Sine/Cosine Positional Embedding Functions #
+#################################################################################
+# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
+
+def get_1d_sincos_time_embed(embed_dim, length):
+ pos = torch.arange(0, length).unsqueeze(1)
+ return get_1d_sincos_pos_embed_from_grid(embed_dim, pos)
+
+def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
+ """
+ grid_size: int of the grid height and width
+ return:
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ grid_h = np.arange(grid_size, dtype=np.float32)
+ grid_w = np.arange(grid_size, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size, grid_size])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token and extra_tokens > 0:
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ out: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
+ omega /= embed_dim / 2.
+ omega = 1. / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+#################################################################################
+# DiT Configs #
+#################################################################################
+
+def DiT_XL_2(**kwargs):
+ return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
+
+def DiT_XL_4(**kwargs):
+ return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
+
+def DiT_XL_8(**kwargs):
+ return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
+
+def DiT_L_2(**kwargs):
+ return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
+
+def DiT_L_4(**kwargs):
+ return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
+
+def DiT_L_8(**kwargs):
+ return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
+
+def DiT_B_2(**kwargs):
+ return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
+
+def DiT_B_4(**kwargs):
+ return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
+
+def DiT_B_8(**kwargs):
+ return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
+
+def DiT_S_2(**kwargs):
+ return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
+
+def DiT_S_4(**kwargs):
+ return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
+
+def DiT_S_8(**kwargs):
+ return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
+
+
+DiT_models = {
+ 'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8,
+ 'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8,
+ 'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8,
+ 'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8,
+}
+
+if __name__ == '__main__':
+
+ import torch
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ img = torch.randn(3, 16, 4, 32, 32).to(device)
+ t = torch.tensor([1, 2, 3]).to(device)
+ y = torch.tensor([1, 2, 3]).to(device)
+ network = DiT_XL_2().to(device)
+ y_embeder = LabelEmbedder(num_classes=100, hidden_size=768, dropout_prob=0.5).to(device)
+ # lora.mark_only_lora_as_trainable(network)
+ out = y_embeder(y, True)
+ # out = network(img, t, y)
+ print(out.shape)
\ No newline at end of file
diff --git a/models/resnet.py b/models/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1307ad9a0fe26b6b1eca913b024c188df2dd86e
--- /dev/null
+++ b/models/resnet.py
@@ -0,0 +1,212 @@
+# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
+import os
+import sys
+sys.path.append(os.path.split(sys.path[0])[0])
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from einops import rearrange
+
+
+class InflatedConv3d(nn.Conv2d):
+ def forward(self, x):
+ video_length = x.shape[2]
+
+ x = rearrange(x, "b c f h w -> (b f) c h w")
+ x = super().forward(x)
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
+
+ return x
+
+
+class Upsample3D(nn.Module):
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_conv_transpose = use_conv_transpose
+ self.name = name
+
+ conv = None
+ if use_conv_transpose:
+ raise NotImplementedError
+ elif use_conv:
+ conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
+
+ if name == "conv":
+ self.conv = conv
+ else:
+ self.Conv2d_0 = conv
+
+ def forward(self, hidden_states, output_size=None):
+ assert hidden_states.shape[1] == self.channels
+
+ if self.use_conv_transpose:
+ raise NotImplementedError
+
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
+ dtype = hidden_states.dtype
+ if dtype == torch.bfloat16:
+ hidden_states = hidden_states.to(torch.float32)
+
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
+ if hidden_states.shape[0] >= 64:
+ hidden_states = hidden_states.contiguous()
+
+ # if `output_size` is passed we force the interpolation output
+ # size and do not make use of `scale_factor=2`
+ if output_size is None:
+ hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
+ else:
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
+
+ # If the input is bfloat16, we cast back to bfloat16
+ if dtype == torch.bfloat16:
+ hidden_states = hidden_states.to(dtype)
+
+ if self.use_conv:
+ if self.name == "conv":
+ hidden_states = self.conv(hidden_states)
+ else:
+ hidden_states = self.Conv2d_0(hidden_states)
+
+ return hidden_states
+
+
+class Downsample3D(nn.Module):
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.padding = padding
+ stride = 2
+ self.name = name
+
+ if use_conv:
+ conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
+ else:
+ raise NotImplementedError
+
+ if name == "conv":
+ self.Conv2d_0 = conv
+ self.conv = conv
+ elif name == "Conv2d_0":
+ self.conv = conv
+ else:
+ self.conv = conv
+
+ def forward(self, hidden_states):
+ assert hidden_states.shape[1] == self.channels
+ if self.use_conv and self.padding == 0:
+ raise NotImplementedError
+
+ assert hidden_states.shape[1] == self.channels
+ hidden_states = self.conv(hidden_states)
+
+ return hidden_states
+
+
+class ResnetBlock3D(nn.Module):
+ def __init__(
+ self,
+ *,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ dropout=0.0,
+ temb_channels=512,
+ groups=32,
+ groups_out=None,
+ pre_norm=True,
+ eps=1e-6,
+ non_linearity="swish",
+ time_embedding_norm="default",
+ output_scale_factor=1.0,
+ use_in_shortcut=None,
+ ):
+ super().__init__()
+ self.pre_norm = pre_norm
+ self.pre_norm = True
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+ self.time_embedding_norm = time_embedding_norm
+ self.output_scale_factor = output_scale_factor
+
+ if groups_out is None:
+ groups_out = groups
+
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
+
+ self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ if temb_channels is not None:
+ if self.time_embedding_norm == "default":
+ time_emb_proj_out_channels = out_channels
+ elif self.time_embedding_norm == "scale_shift":
+ time_emb_proj_out_channels = out_channels * 2
+ else:
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
+
+ self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
+ else:
+ self.time_emb_proj = None
+
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ if non_linearity == "swish":
+ self.nonlinearity = lambda x: F.silu(x)
+ elif non_linearity == "mish":
+ self.nonlinearity = Mish()
+ elif non_linearity == "silu":
+ self.nonlinearity = nn.SiLU()
+
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
+
+ self.conv_shortcut = None
+ if self.use_in_shortcut:
+ self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, input_tensor, temb):
+ hidden_states = input_tensor
+
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.conv1(hidden_states)
+
+ if temb is not None:
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
+
+ if temb is not None and self.time_embedding_norm == "default":
+ hidden_states = hidden_states + temb
+
+ hidden_states = self.norm2(hidden_states)
+
+ if temb is not None and self.time_embedding_norm == "scale_shift":
+ scale, shift = torch.chunk(temb, 2, dim=1)
+ hidden_states = hidden_states * (1 + scale) + shift
+
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.conv_shortcut is not None:
+ input_tensor = self.conv_shortcut(input_tensor)
+
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
+
+ return output_tensor
+
+
+class Mish(torch.nn.Module):
+ def forward(self, hidden_states):
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
\ No newline at end of file
diff --git a/models/unet.py b/models/unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..941b5156cce9775606c15388026a11443f1d9dfb
--- /dev/null
+++ b/models/unet.py
@@ -0,0 +1,721 @@
+# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
+
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import os
+import sys
+sys.path.append(os.path.split(sys.path[0])[0])
+
+import math
+import json
+import torch
+import einops
+import torch.nn as nn
+import torch.utils.checkpoint
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.utils import BaseOutput, logging
+from diffusers.models.embeddings import TimestepEmbedding, Timesteps
+
+try:
+ from diffusers.models.modeling_utils import ModelMixin
+except:
+ from diffusers.modeling_utils import ModelMixin # 0.11.1
+
+try:
+ from .unet_blocks import (
+ CrossAttnDownBlock3D,
+ CrossAttnUpBlock3D,
+ DownBlock3D,
+ UNetMidBlock3DCrossAttn,
+ UpBlock3D,
+ get_down_block,
+ get_up_block,
+ )
+ from .resnet import InflatedConv3d
+except:
+ from unet_blocks import (
+ CrossAttnDownBlock3D,
+ CrossAttnUpBlock3D,
+ DownBlock3D,
+ UNetMidBlock3DCrossAttn,
+ UpBlock3D,
+ get_down_block,
+ get_up_block,
+ )
+ from resnet import InflatedConv3d
+
+from rotary_embedding_torch import RotaryEmbedding
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+class RelativePositionBias(nn.Module):
+ def __init__(
+ self,
+ heads=8,
+ num_buckets=32,
+ max_distance=128,
+ ):
+ super().__init__()
+ self.num_buckets = num_buckets
+ self.max_distance = max_distance
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
+
+ @staticmethod
+ def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
+ ret = 0
+ n = -relative_position
+
+ num_buckets //= 2
+ ret += (n < 0).long() * num_buckets
+ n = torch.abs(n)
+
+ max_exact = num_buckets // 2
+ is_small = n < max_exact
+
+ val_if_large = max_exact + (
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
+ ).long()
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
+
+ ret += torch.where(is_small, n, val_if_large)
+ return ret
+
+ def forward(self, n, device):
+ q_pos = torch.arange(n, dtype = torch.long, device = device)
+ k_pos = torch.arange(n, dtype = torch.long, device = device)
+ rel_pos = einops.rearrange(k_pos, 'j -> 1 j') - einops.rearrange(q_pos, 'i -> i 1')
+ rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
+ values = self.relative_attention_bias(rp_bucket)
+ return einops.rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames
+
+@dataclass
+class UNet3DConditionOutput(BaseOutput):
+ sample: torch.FloatTensor
+
+
+class UNet3DConditionModel(ModelMixin, ConfigMixin):
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None, # 64
+ in_channels: int = 4,
+ out_channels: int = 4,
+ center_input_sample: bool = False,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "DownBlock3D",
+ ),
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
+ up_block_types: Tuple[str] = (
+ "UpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D"
+ ),
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ layers_per_block: int = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: int = 1280,
+ attention_head_dim: Union[int, Tuple[int]] = 8,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ use_first_frame: bool = False,
+ use_relative_position: bool = False,
+ ):
+ super().__init__()
+
+ # print(use_first_frame)
+
+ self.sample_size = sample_size
+ time_embed_dim = block_out_channels[0] * 4
+
+ # input
+ self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
+
+ # time
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+
+ # class embedding
+ if class_embed_type is None and num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+ elif class_embed_type == "timestep":
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+ elif class_embed_type == "identity":
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
+ else:
+ self.class_embedding = None
+
+ self.down_blocks = nn.ModuleList([])
+ self.mid_block = None
+ self.up_blocks = nn.ModuleList([])
+
+ # print(only_cross_attention)
+ # print(type(only_cross_attention))
+ # exit()
+ if isinstance(only_cross_attention, bool):
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+ # print(only_cross_attention)
+ # exit()
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+ # print(attention_head_dim)
+ # exit()
+
+ rotary_emb = RotaryEmbedding(32)
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attention_head_dim[i],
+ downsample_padding=downsample_padding,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ use_first_frame=use_first_frame,
+ use_relative_position=use_relative_position,
+ rotary_emb=rotary_emb,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
+ self.mid_block = UNetMidBlock3DCrossAttn(
+ in_channels=block_out_channels[-1],
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attention_head_dim[-1],
+ resnet_groups=norm_num_groups,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ use_first_frame=use_first_frame,
+ use_relative_position=use_relative_position,
+ rotary_emb=rotary_emb,
+ )
+ else:
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
+
+ # count how many layers upsample the videos
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
+ only_cross_attention = list(reversed(only_cross_attention))
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=layers_per_block + 1,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=reversed_attention_head_dim[i],
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ use_first_frame=use_first_frame,
+ use_relative_position=use_relative_position,
+ rotary_emb=rotary_emb,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
+ self.conv_act = nn.SiLU()
+ self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
+
+ # relative time positional embeddings
+ self.use_relative_position = use_relative_position
+ if self.use_relative_position:
+ self.time_rel_pos_bias = RelativePositionBias(heads=8, max_distance=32) # realistically will not be able to generate that many frames of video... yet
+
+ def set_attention_slice(self, slice_size):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_slicable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_slicable_dims(module)
+
+ num_slicable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_slicable_layers * [1]
+
+ slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
+ module.gradient_checkpointing = value
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor = None,
+ class_labels: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ use_image_num: int = 0,
+ return_dict: bool = True,
+ ) -> Union[UNet3DConditionOutput, Tuple]:
+ r"""
+ Args:
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+ """
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
+ # on the fly if necessary.
+ default_overall_up_factor = 2**self.num_upsamplers
+
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+ forward_upsample_size = False
+ upsample_size = None
+
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
+ logger.info("Forward upsample size to force interpolation output size.")
+ forward_upsample_size = True
+
+ # prepare attention_mask
+ if attention_mask is not None:
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=self.dtype)
+ emb = self.time_embedding(t_emb)
+ if self.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
+ # print(emb.shape) # torch.Size([3, 1280])
+ # print(class_emb.shape) # torch.Size([3, 1280])
+ emb = emb + class_emb
+
+ if self.use_relative_position:
+ frame_rel_pos_bias = self.time_rel_pos_bias(sample.shape[2], device=sample.device)
+ else:
+ frame_rel_pos_bias = None
+
+ # pre-process
+ sample = self.conv_in(sample)
+
+ # down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ use_image_num=use_image_num,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ down_block_res_samples += res_samples
+
+ # mid
+ sample = self.mid_block(
+ sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, use_image_num=use_image_num,
+ )
+
+ # up
+ for i, upsample_block in enumerate(self.up_blocks):
+ is_final_block = i == len(self.up_blocks) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ use_image_num=use_image_num,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
+ )
+ # post-process
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+ # print(sample.shape)
+
+ if not return_dict:
+ return (sample,)
+ sample = UNet3DConditionOutput(sample=sample)
+ return sample
+
+ def forward_with_cfg(self,
+ x,
+ t,
+ encoder_hidden_states = None,
+ class_labels: Optional[torch.Tensor] = None,
+ cfg_scale=4.0,
+ use_fp16=False):
+ """
+ Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
+ half = x[: len(x) // 2]
+ combined = torch.cat([half, half], dim=0)
+ if use_fp16:
+ combined = combined.to(dtype=torch.float16)
+ model_out = self.forward(combined, t, encoder_hidden_states, class_labels).sample
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
+ # three channels by default. The standard approach to cfg applies it to all channels.
+ # This can be done by uncommenting the following line and commenting-out the line following that.
+ eps, rest = model_out[:, :4], model_out[:, 4:]
+ # eps, rest = model_out[:, :3], model_out[:, 3:] # b c f h w
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
+ eps = torch.cat([half_eps, half_eps], dim=0)
+ return torch.cat([eps, rest], dim=1)
+
+ @classmethod
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, use_concat=False):
+ if subfolder is not None:
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
+
+
+ # the content of the config file
+ # {
+ # "_class_name": "UNet2DConditionModel",
+ # "_diffusers_version": "0.2.2",
+ # "act_fn": "silu",
+ # "attention_head_dim": 8,
+ # "block_out_channels": [
+ # 320,
+ # 640,
+ # 1280,
+ # 1280
+ # ],
+ # "center_input_sample": false,
+ # "cross_attention_dim": 768,
+ # "down_block_types": [
+ # "CrossAttnDownBlock2D",
+ # "CrossAttnDownBlock2D",
+ # "CrossAttnDownBlock2D",
+ # "DownBlock2D"
+ # ],
+ # "downsample_padding": 1,
+ # "flip_sin_to_cos": true,
+ # "freq_shift": 0,
+ # "in_channels": 4,
+ # "layers_per_block": 2,
+ # "mid_block_scale_factor": 1,
+ # "norm_eps": 1e-05,
+ # "norm_num_groups": 32,
+ # "out_channels": 4,
+ # "sample_size": 64,
+ # "up_block_types": [
+ # "UpBlock2D",
+ # "CrossAttnUpBlock2D",
+ # "CrossAttnUpBlock2D",
+ # "CrossAttnUpBlock2D"
+ # ]
+ # }
+ config_file = os.path.join(pretrained_model_path, 'config.json')
+ if not os.path.isfile(config_file):
+ raise RuntimeError(f"{config_file} does not exist")
+ with open(config_file, "r") as f:
+ config = json.load(f)
+ config["_class_name"] = cls.__name__
+ config["down_block_types"] = [
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "DownBlock3D"
+ ]
+ config["up_block_types"] = [
+ "UpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D"
+ ]
+
+ # config["use_first_frame"] = True
+
+ config["use_first_frame"] = False
+ if use_concat:
+ config["in_channels"] = 9
+ # config["use_relative_position"] = True
+
+ # # tmp
+ # config["class_embed_type"] = "timestep"
+ # config["num_class_embeds"] = 100
+
+ from diffusers.utils import WEIGHTS_NAME # diffusion_pytorch_model.bin
+
+ # {'_class_name': 'UNet3DConditionModel',
+ # '_diffusers_version': '0.2.2',
+ # 'act_fn': 'silu',
+ # 'attention_head_dim': 8,
+ # 'block_out_channels': [320, 640, 1280, 1280],
+ # 'center_input_sample': False,
+ # 'cross_attention_dim': 768,
+ # 'down_block_types':
+ # ['CrossAttnDownBlock3D',
+ # 'CrossAttnDownBlock3D',
+ # 'CrossAttnDownBlock3D',
+ # 'DownBlock3D'],
+ # 'downsample_padding': 1,
+ # 'flip_sin_to_cos': True,
+ # 'freq_shift': 0,
+ # 'in_channels': 4,
+ # 'layers_per_block': 2,
+ # 'mid_block_scale_factor': 1,
+ # 'norm_eps': 1e-05,
+ # 'norm_num_groups': 32,
+ # 'out_channels': 4,
+ # 'sample_size': 64,
+ # 'up_block_types':
+ # ['UpBlock3D',
+ # 'CrossAttnUpBlock3D',
+ # 'CrossAttnUpBlock3D',
+ # 'CrossAttnUpBlock3D']}
+
+ model = cls.from_config(config)
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
+ if not os.path.isfile(model_file):
+ raise RuntimeError(f"{model_file} does not exist")
+ state_dict = torch.load(model_file, map_location="cpu")
+
+ if use_concat:
+ new_state_dict = {}
+ conv_in_weight = state_dict["conv_in.weight"]
+ new_conv_weight = torch.zeros((conv_in_weight.shape[0], 9, *conv_in_weight.shape[2:]), dtype=conv_in_weight.dtype)
+
+ for i, j in zip([0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 6, 7, 8]):
+ new_conv_weight[:, j] = conv_in_weight[:, i]
+ new_state_dict["conv_in.weight"] = new_conv_weight
+ new_state_dict["conv_in.bias"] = state_dict["conv_in.bias"]
+ for k, v in model.state_dict().items():
+ # print(k)
+ if '_temp.' in k:
+ new_state_dict.update({k: v})
+ if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross
+ k = k.replace('attn_fcross', 'attn1')
+ state_dict.update({k: state_dict[k]})
+ if 'norm_fcross' in k:
+ k = k.replace('norm_fcross', 'norm1')
+ state_dict.update({k: state_dict[k]})
+
+ if 'conv_in' in k:
+ continue
+ else:
+ new_state_dict[k] = v
+ # # tmp
+ # if 'class_embedding' in k:
+ # state_dict.update({k: v})
+ # breakpoint()
+ model.load_state_dict(new_state_dict)
+ else:
+ for k, v in model.state_dict().items():
+ # print(k)
+ if '_temp' in k:
+ state_dict.update({k: v})
+ if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross
+ k = k.replace('attn_fcross', 'attn1')
+ state_dict.update({k: state_dict[k]})
+ if 'norm_fcross' in k:
+ k = k.replace('norm_fcross', 'norm1')
+ state_dict.update({k: state_dict[k]})
+
+ model.load_state_dict(state_dict)
+
+ return model
+
+if __name__ == '__main__':
+ import torch
+ # from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ # pretrained_model_path = "/mnt/petrelfs/maxin/work/pretrained/stable-diffusion-2-1-base/" # p cluster
+ pretrained_model_path = "/mnt/petrelfs/share_data/zhanglingjun/stable-diffusion-v1-4/" # p cluster
+ unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet").to(device)
+ # unet.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
+ unet.enable_xformers_memory_efficient_attention()
+ unet.enable_gradient_checkpointing()
+
+ unet.train()
+
+ use_image_num = 5
+ noisy_latents = torch.randn((2, 4, 16 + use_image_num, 32, 32)).to(device)
+ bsz = noisy_latents.shape[0]
+ timesteps = torch.randint(0, 1000, (bsz,)).to(device)
+ timesteps = timesteps.long()
+ encoder_hidden_states = torch.randn((bsz, 1 + use_image_num, 77, 768)).to(device)
+ # class_labels = torch.randn((bsz, )).to(device)
+
+
+ model_pred = unet(sample=noisy_latents, timestep=timesteps,
+ encoder_hidden_states=encoder_hidden_states,
+ class_labels=None,
+ use_image_num=use_image_num).sample
+ print(model_pred.shape)
\ No newline at end of file
diff --git a/models/unet_blocks.py b/models/unet_blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..849c10539c7039840c93631c5201069119d3c306
--- /dev/null
+++ b/models/unet_blocks.py
@@ -0,0 +1,648 @@
+# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
+import os
+import sys
+sys.path.append(os.path.split(sys.path[0])[0])
+
+import torch
+from torch import nn
+
+try:
+ from .attention import Transformer3DModel
+ from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
+except:
+ from attention import Transformer3DModel
+ from resnet import Downsample3D, ResnetBlock3D, Upsample3D
+
+
+def get_down_block(
+ down_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ temb_channels,
+ add_downsample,
+ resnet_eps,
+ resnet_act_fn,
+ attn_num_head_channels,
+ resnet_groups=None,
+ cross_attention_dim=None,
+ downsample_padding=None,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ resnet_time_scale_shift="default",
+ use_first_frame=False,
+ use_relative_position=False,
+ rotary_emb=False,
+):
+ # print(down_block_type)
+ # print(use_first_frame)
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
+ if down_block_type == "DownBlock3D":
+ return DownBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "CrossAttnDownBlock3D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
+ return CrossAttnDownBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attn_num_head_channels,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ use_first_frame=use_first_frame,
+ use_relative_position=use_relative_position,
+ rotary_emb=rotary_emb,
+ )
+ raise ValueError(f"{down_block_type} does not exist.")
+
+
+def get_up_block(
+ up_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ prev_output_channel,
+ temb_channels,
+ add_upsample,
+ resnet_eps,
+ resnet_act_fn,
+ attn_num_head_channels,
+ resnet_groups=None,
+ cross_attention_dim=None,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ resnet_time_scale_shift="default",
+ use_first_frame=False,
+ use_relative_position=False,
+ rotary_emb=False,
+):
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
+ if up_block_type == "UpBlock3D":
+ return UpBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif up_block_type == "CrossAttnUpBlock3D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
+ return CrossAttnUpBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attn_num_head_channels,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ use_first_frame=use_first_frame,
+ use_relative_position=use_relative_position,
+ rotary_emb=rotary_emb,
+ )
+ raise ValueError(f"{up_block_type} does not exist.")
+
+
+class UNetMidBlock3DCrossAttn(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ output_scale_factor=1.0,
+ cross_attention_dim=1280,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ upcast_attention=False,
+ use_first_frame=False,
+ use_relative_position=False,
+ rotary_emb=False,
+ ):
+ super().__init__()
+
+ self.has_cross_attention = True
+ self.attn_num_head_channels = attn_num_head_channels
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+ attentions = []
+
+ for _ in range(num_layers):
+ if dual_cross_attention:
+ raise NotImplementedError
+ attentions.append(
+ Transformer3DModel(
+ attn_num_head_channels,
+ in_channels // attn_num_head_channels,
+ in_channels=in_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ use_first_frame=use_first_frame,
+ use_relative_position=use_relative_position,
+ rotary_emb=rotary_emb,
+ )
+ )
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, use_image_num=None):
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, use_image_num=use_image_num).sample
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
+
+
+class CrossAttnDownBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ add_downsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ use_first_frame=False,
+ use_relative_position=False,
+ rotary_emb=False,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ # print(use_first_frame)
+
+ self.has_cross_attention = True
+ self.attn_num_head_channels = attn_num_head_channels
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ if dual_cross_attention:
+ raise NotImplementedError
+ attentions.append(
+ Transformer3DModel(
+ attn_num_head_channels,
+ out_channels // attn_num_head_channels,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ use_first_frame=use_first_frame,
+ use_relative_position=use_relative_position,
+ rotary_emb=rotary_emb,
+ )
+ )
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample3D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, use_image_num=None):
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ def create_custom_forward_attn(module, return_dict=None, use_image_num=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict, use_image_num=use_image_num)
+ else:
+ return module(*inputs, use_image_num=use_image_num)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward_attn(attn, return_dict=False, use_image_num=use_image_num),
+ hidden_states,
+ encoder_hidden_states,
+ )[0]
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, use_image_num=use_image_num).sample
+
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class DownBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample3D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, temb=None):
+ output_states = ()
+
+ for resnet in self.resnets:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class CrossAttnUpBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ use_first_frame=False,
+ use_relative_position=False,
+ rotary_emb=False
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.has_cross_attention = True
+ self.attn_num_head_channels = attn_num_head_channels
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ if dual_cross_attention:
+ raise NotImplementedError
+ attentions.append(
+ Transformer3DModel(
+ attn_num_head_channels,
+ out_channels // attn_num_head_channels,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ use_first_frame=use_first_frame,
+ use_relative_position=use_relative_position,
+ rotary_emb=rotary_emb,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ res_hidden_states_tuple,
+ temb=None,
+ encoder_hidden_states=None,
+ upsample_size=None,
+ attention_mask=None,
+ use_image_num=None,
+ ):
+ for resnet, attn in zip(self.resnets, self.attentions):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ def create_custom_forward_attn(module, return_dict=None, use_image_num=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict, use_image_num=use_image_num)
+ else:
+ return module(*inputs, use_image_num=use_image_num)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward_attn(attn, return_dict=False, use_image_num=use_image_num),
+ hidden_states,
+ encoder_hidden_states,
+ )[0]
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, use_image_num=use_image_num).sample
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+
+class UpBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
diff --git a/models/utils.py b/models/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..76e205d1709568885e58b56888f62c8805ad4f91
--- /dev/null
+++ b/models/utils.py
@@ -0,0 +1,215 @@
+# adopted from
+# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+# and
+# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+# and
+# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
+#
+# thanks!
+
+
+import os
+import math
+import torch
+
+import numpy as np
+import torch.nn as nn
+
+from einops import repeat
+
+
+#################################################################################
+# Unet Utils #
+#################################################################################
+
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+
+class CheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+
+ with torch.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with torch.enable_grad():
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
+
+
+def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
+ """
+ Create sinusoidal timestep embeddings.
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ if not repeat_only:
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ else:
+ embedding = repeat(timesteps, 'b -> b d', d=dim).contiguous()
+ return embedding
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def normalization(channels):
+ """
+ Make a standard normalization layer.
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(32, channels)
+
+
+# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+class SiLU(nn.Module):
+ def forward(self, x):
+ return x * torch.sigmoid(x)
+
+
+class GroupNorm32(nn.GroupNorm):
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+ """
+ Create a linear module.
+ """
+ return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+# class HybridConditioner(nn.Module):
+
+# def __init__(self, c_concat_config, c_crossattn_config):
+# super().__init__()
+# self.concat_conditioner = instantiate_from_config(c_concat_config)
+# self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
+
+# def forward(self, c_concat, c_crossattn):
+# c_concat = self.concat_conditioner(c_concat)
+# c_crossattn = self.crossattn_conditioner(c_crossattn)
+# return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
+
+
+def noise_like(shape, device, repeat=False):
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
+ noise = lambda: torch.randn(shape, device=device)
+ return repeat_noise() if repeat else noise()
+
+def count_flops_attn(model, _x, y):
+ """
+ A counter for the `thop` package to count the operations in an
+ attention operation.
+ Meant to be used like:
+ macs, params = thop.profile(
+ model,
+ inputs=(inputs, timestamps),
+ custom_ops={QKVAttention: QKVAttention.count_flops},
+ )
+ """
+ b, c, *spatial = y[0].shape
+ num_spatial = int(np.prod(spatial))
+ # We perform two matmuls with the same number of ops.
+ # The first computes the weight matrix, the second computes
+ # the combination of the value vectors.
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
+ model.total_ops += torch.DoubleTensor([matmul_ops])
+
+def count_params(model, verbose=False):
+ total_params = sum(p.numel() for p in model.parameters())
+ if verbose:
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
+ return total_params
\ No newline at end of file
diff --git a/models/uvit.py b/models/uvit.py
new file mode 100644
index 0000000000000000000000000000000000000000..9fefd458d5f556ddeb7f3ed74c8ac15870c6c860
--- /dev/null
+++ b/models/uvit.py
@@ -0,0 +1,310 @@
+import torch
+import torch.nn as nn
+import math
+import timm
+from timm.models.layers import trunc_normal_
+from timm.models.vision_transformer import PatchEmbed, Mlp
+# assert timm.__version__ == "0.3.2" # version checks
+import einops
+import torch.utils.checkpoint
+
+# the xformers lib allows less memory, faster training and inference
+try:
+ import xformers
+ import xformers.ops
+except:
+ XFORMERS_IS_AVAILBLE = False
+ # print('xformers disabled')
+
+
+def timestep_embedding(timesteps, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+
+def patchify(imgs, patch_size):
+ x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size)
+ return x
+
+
+def unpatchify(x, channels=3):
+ patch_size = int((x.shape[2] // channels) ** 0.5)
+ h = w = int(x.shape[1] ** .5)
+ assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2]
+ x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, p1=patch_size, p2=patch_size)
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ B, L, C = x.shape
+
+ qkv = self.qkv(x)
+ if XFORMERS_IS_AVAILBLE: # the xformers lib allows less memory, faster training and inference
+ qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads)
+ q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
+ x = xformers.ops.memory_efficient_attention(q, k, v)
+ x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads)
+ else:
+ qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads)
+ q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = (attn @ v).transpose(1, 2).reshape(B, L, C)
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale)
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer)
+ self.skip_linear = nn.Linear(2 * dim, dim) if skip else None
+ self.use_checkpoint = use_checkpoint
+
+ def forward(self, x, skip=None):
+ if self.use_checkpoint:
+ return torch.utils.checkpoint.checkpoint(self._forward, x, skip)
+ else:
+ return self._forward(x, skip)
+
+ def _forward(self, x, skip=None):
+ if self.skip_linear is not None:
+ # print('x shape', x.shape)
+ # print('skip shape', skip.shape)
+ # exit()
+ x = self.skip_linear(torch.cat([x, skip], dim=-1))
+ x = x + self.attn(self.norm1(x))
+ x = x + self.mlp(self.norm2(x))
+ return x
+
+
+class UViT(nn.Module):
+ def __init__(self, input_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.,
+ qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, mlp_time_embed=False, num_classes=-1,
+ use_checkpoint=False, conv=True, skip=True, num_frames=16, class_guided=False, use_lora=False):
+ super().__init__()
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_classes = num_classes
+ self.in_chans = in_chans
+
+ self.patch_embed = PatchEmbed(
+ img_size=input_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.time_embed = nn.Sequential(
+ nn.Linear(embed_dim, 4 * embed_dim),
+ nn.SiLU(),
+ nn.Linear(4 * embed_dim, embed_dim),
+ ) if mlp_time_embed else nn.Identity()
+
+ if self.num_classes > 0:
+ self.label_emb = nn.Embedding(self.num_classes, embed_dim)
+ self.extras = 2
+ else:
+ self.extras = 1
+
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.extras + num_patches, embed_dim))
+ self.frame_embed = nn.Parameter(torch.zeros(1, num_frames, embed_dim))
+
+ self.in_blocks = nn.ModuleList([
+ Block(
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ norm_layer=norm_layer, use_checkpoint=use_checkpoint)
+ for _ in range(depth // 2)])
+
+ self.mid_block = Block(
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ norm_layer=norm_layer, use_checkpoint=use_checkpoint)
+
+ self.out_blocks = nn.ModuleList([
+ Block(
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ norm_layer=norm_layer, skip=skip, use_checkpoint=use_checkpoint)
+ for _ in range(depth // 2)])
+
+ self.norm = norm_layer(embed_dim)
+ self.patch_dim = patch_size ** 2 * in_chans
+ self.decoder_pred = nn.Linear(embed_dim, self.patch_dim, bias=True)
+ self.final_layer = nn.Conv2d(self.in_chans, self.in_chans * 2, 3, padding=1) if conv else nn.Identity()
+
+ trunc_normal_(self.pos_embed, std=.02)
+ trunc_normal_(self.frame_embed, std=.02)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'pos_embed'}
+
+ def forward_(self, x, timesteps, y=None):
+ x = self.patch_embed(x) # 48, 256, 1152
+ # print(x.shape)
+ B, L, D = x.shape
+
+ time_token = self.time_embed(timestep_embedding(timesteps, self.embed_dim)) # 3, 1152
+ # print(time_token.shape)
+ time_token = time_token.unsqueeze(dim=1) # 3, 1, 1152
+ x = torch.cat((time_token, x), dim=1)
+
+ if y is not None:
+ label_emb = self.label_emb(y)
+ label_emb = label_emb.unsqueeze(dim=1)
+ x = torch.cat((label_emb, x), dim=1)
+ x = x + self.pos_embed
+
+ skips = []
+ for blk in self.in_blocks:
+ x = blk(x)
+ skips.append(x)
+
+ x = self.mid_block(x)
+
+ for blk in self.out_blocks:
+ x = blk(x, skips.pop())
+
+ x = self.norm(x)
+ x = self.decoder_pred(x)
+ assert x.size(1) == self.extras + L
+ x = x[:, self.extras:, :]
+ x = unpatchify(x, self.in_chans)
+ x = self.final_layer(x)
+ return x
+
+ def forward(self, x, timesteps, y=None):
+ # print(x.shape)
+ batch, frame, _, _, _ = x.shape
+ # 这里rearrange后每隔f是同一个视频
+ x = einops.rearrange(x, 'b f c h w -> (b f) c h w') # 3 16 4 256 256
+ x = self.patch_embed(x) # 48, 256, 1152
+ B, L, D = x.shape
+
+ time_token = self.time_embed(timestep_embedding(timesteps, self.embed_dim)) # 3, 1152
+ # timestep_spatial的repeat需要保证每f帧为同一个timesteps
+ time_token_spatial = einops.repeat(time_token, 'n d -> (n c) d', c=frame) # 48, 1152
+ time_token_spatial = time_token_spatial.unsqueeze(dim=1) # 48, 1, 1152
+ x = torch.cat((time_token_spatial, x), dim=1) # 48, 257, 1152
+
+ if y is not None:
+ label_emb = self.label_emb(y)
+ label_emb = label_emb.unsqueeze(dim=1)
+ x = torch.cat((label_emb, x), dim=1)
+ x = x + self.pos_embed
+
+ skips = []
+ for i in range(0, len(self.in_blocks), 2):
+ # print('The {}-th run'.format(i))
+ spatial_block, time_block = self.in_blocks[i:i+2]
+ x = spatial_block(x)
+
+ # add time embeddings and conduct attention as frame.
+ x = einops.rearrange(x, '(b f) t d -> (b t) f d', b=batch) # t 代表单帧token数; 771, 16, 1152; 771: 3 * 257
+ skips.append(x)
+ # print(x.shape)
+
+ if i == 0:
+ x = x + self.frame_embed # 771, 16, 1152
+
+ x = time_block(x)
+
+ x = einops.rearrange(x, '(b t) f d -> (b f) t d', b=batch) # 48, 257, 1152
+ skips.append(x)
+
+ x = self.mid_block(x)
+
+ for i in range(0, len(self.out_blocks), 2):
+ # print('The {}-th run'.format(i))
+ spatial_block, time_block = self.out_blocks[i:i+2]
+ x = spatial_block(x, skips.pop())
+
+ # add time embeddings and conduct attention as frame.
+ x = einops.rearrange(x, '(b f) t d -> (b t) f d', b=batch) # t 代表单帧token数; 771, 16, 1152; 771: 3 * 257
+
+ x = time_block(x, skips.pop())
+
+ x = einops.rearrange(x, '(b t) f d -> (b f) t d', b=batch) # 48, 256, 1152
+
+
+ x = self.norm(x)
+ x = self.decoder_pred(x)
+ assert x.size(1) == self.extras + L
+ x = x[:, self.extras:, :]
+ x = unpatchify(x, self.in_chans)
+ x = self.final_layer(x)
+ x = einops.rearrange(x, '(b f) c h w -> b f c h w', b=batch)
+ # print(x.shape)
+ return x
+
+def UViT_XL_2(**kwargs):
+ return UViT(patch_size=2, in_chans=4, embed_dim=1152, depth=28,
+ num_heads=16, mlp_ratio=4, qkv_bias=False, mlp_time_embed=4,
+ use_checkpoint=True, conv=False, **kwargs)
+
+def UViT_L_2(**kwargs):
+ return UViT(patch_size=2, in_chans=4, embed_dim=1024, depth=20,
+ num_heads=16, mlp_ratio=4, qkv_bias=False, mlp_time_embed=False,
+ use_checkpoint=True, **kwargs)
+
+# 没有L以下的,UViT中L以下的img_size为64
+
+UViT_models = {
+ 'UViT-XL/2': UViT_XL_2, 'UViT-L/2': UViT_L_2
+}
+
+
+if __name__ == '__main__':
+
+
+ nnet = UViT_XL_2().cuda()
+
+ imgs = torch.randn(3, 16, 4, 32, 32).cuda()
+ timestpes = torch.tensor([1, 2, 3]).cuda()
+
+ outputs = nnet(imgs, timestpes)
+ print(outputs.shape)
+
diff --git a/results/i2v/frame, high quality,2.png.mp4 b/results/i2v/frame, high quality,2.png.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..aa64d00b2c563a1101cb9399c7b61950c7fb4b40
Binary files /dev/null and b/results/i2v/frame, high quality,2.png.mp4 differ
diff --git a/results/i2v/frame, high quality,23.png.mp4 b/results/i2v/frame, high quality,23.png.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..f84a82bc27ea0d45d7b74f84da7544d442af7753
Binary files /dev/null and b/results/i2v/frame, high quality,23.png.mp4 differ
diff --git a/results/i2v/frame, high quality,35.png.mp4 b/results/i2v/frame, high quality,35.png.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..e25123fdfa1f252424da49556052278797677c92
Binary files /dev/null and b/results/i2v/frame, high quality,35.png.mp4 differ
diff --git "a/results/i2v/frame, high quality,6.1.\346\234\272\345\231\250\344\272\272\350\276\205\345\212\251\344\272\272\347\261\273\344\277\256\350\267\257\344\270\255\346\231\257.png.mp4" "b/results/i2v/frame, high quality,6.1.\346\234\272\345\231\250\344\272\272\350\276\205\345\212\251\344\272\272\347\261\273\344\277\256\350\267\257\344\270\255\346\231\257.png.mp4"
new file mode 100644
index 0000000000000000000000000000000000000000..d32b26d3a4e47bc0ceb79c31e4c4f26688123f84
Binary files /dev/null and "b/results/i2v/frame, high quality,6.1.\346\234\272\345\231\250\344\272\272\350\276\205\345\212\251\344\272\272\347\261\273\344\277\256\350\267\257\344\270\255\346\231\257.png.mp4" differ
diff --git a/results/i2v/moving, high quality2.png.mp4 b/results/i2v/moving, high quality2.png.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..3da108d3189b011821afa69a398e412b20e770be
Binary files /dev/null and b/results/i2v/moving, high quality2.png.mp4 differ
diff --git a/results/i2v/moving, high quality23.png.mp4 b/results/i2v/moving, high quality23.png.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..0bf5b5e807d549c367a5c0ffdd90498c756a9c35
Binary files /dev/null and b/results/i2v/moving, high quality23.png.mp4 differ
diff --git a/results/i2v/moving, high quality35.png.mp4 b/results/i2v/moving, high quality35.png.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..e2fa1dce476bc330c159d2546c9b78d0125b238b
Binary files /dev/null and b/results/i2v/moving, high quality35.png.mp4 differ
diff --git "a/results/i2v/moving, high quality6.1.\346\234\272\345\231\250\344\272\272\350\276\205\345\212\251\344\272\272\347\261\273\344\277\256\350\267\257\344\270\255\346\231\257.png.mp4" "b/results/i2v/moving, high quality6.1.\346\234\272\345\231\250\344\272\272\350\276\205\345\212\251\344\272\272\347\261\273\344\277\256\350\267\257\344\270\255\346\231\257.png.mp4"
new file mode 100644
index 0000000000000000000000000000000000000000..98f4fc364cdfefabb8ba0b9f061fec024f83463d
Binary files /dev/null and "b/results/i2v/moving, high quality6.1.\346\234\272\345\231\250\344\272\272\350\276\205\345\212\251\344\272\272\347\261\273\344\277\256\350\267\257\344\270\255\346\231\257.png.mp4" differ
diff --git a/results/i2v/tiled up, high quali23.png.mp4 b/results/i2v/tiled up, high quali23.png.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..27ade24afc71b51a1128438465c913a1684754d0
Binary files /dev/null and b/results/i2v/tiled up, high quali23.png.mp4 differ
diff --git "a/results/i2v/tiled up, high quali6.1.\346\234\272\345\231\250\344\272\272\350\276\205\345\212\251\344\272\272\347\261\273\344\277\256\350\267\257\344\270\255\346\231\257.png.mp4" "b/results/i2v/tiled up, high quali6.1.\346\234\272\345\231\250\344\272\272\350\276\205\345\212\251\344\272\272\347\261\273\344\277\256\350\267\257\344\270\255\346\231\257.png.mp4"
new file mode 100644
index 0000000000000000000000000000000000000000..108582182bbb9e6acfacb9a6978d9e89dda8abca
Binary files /dev/null and "b/results/i2v/tiled up, high quali6.1.\346\234\272\345\231\250\344\272\272\350\276\205\345\212\251\344\272\272\347\261\273\344\277\256\350\267\257\344\270\255\346\231\257.png.mp4" differ
diff --git a/results/i2v/tilt up, high qualit2.png.mp4 b/results/i2v/tilt up, high qualit2.png.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..6c4874d02d8a1b5597b8ff120fb69252de448357
Binary files /dev/null and b/results/i2v/tilt up, high qualit2.png.mp4 differ
diff --git a/results/i2v/tilt up, high qualit23.png.mp4 b/results/i2v/tilt up, high qualit23.png.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..3421e17df09ff125821d049532a6da4992d01b4e
Binary files /dev/null and b/results/i2v/tilt up, high qualit23.png.mp4 differ
diff --git a/results/i2v/tilt up, high qualit35.png.mp4 b/results/i2v/tilt up, high qualit35.png.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..b623ffa6779e98f142c44cb0f8fe66f06cd97982
Binary files /dev/null and b/results/i2v/tilt up, high qualit35.png.mp4 differ
diff --git "a/results/i2v/tilt up, high qualit6.1.\346\234\272\345\231\250\344\272\272\350\276\205\345\212\251\344\272\272\347\261\273\344\277\256\350\267\257\344\270\255\346\231\257.png.mp4" "b/results/i2v/tilt up, high qualit6.1.\346\234\272\345\231\250\344\272\272\350\276\205\345\212\251\344\272\272\347\261\273\344\277\256\350\267\257\344\270\255\346\231\257.png.mp4"
new file mode 100644
index 0000000000000000000000000000000000000000..7110df9f8ed8f373792ff578053e98bf8be89121
Binary files /dev/null and "b/results/i2v/tilt up, high qualit6.1.\346\234\272\345\231\250\344\272\272\350\276\205\345\212\251\344\272\272\347\261\273\344\277\256\350\267\257\344\270\255\346\231\257.png.mp4" differ
diff --git a/results/i2v/tilt up, high quality, stable , slow motion..mp4 b/results/i2v/tilt up, high quality, stable , slow motion..mp4
new file mode 100644
index 0000000000000000000000000000000000000000..bb8029275cc447067ba8d126592b463699d088ec
Binary files /dev/null and b/results/i2v/tilt up, high quality, stable , slow motion..mp4 differ
diff --git a/results/i2v/zoom in, high qualit2.png.mp4 b/results/i2v/zoom in, high qualit2.png.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..a2102ade2bb7d4d6dafe46378f4768ec04d42bc6
Binary files /dev/null and b/results/i2v/zoom in, high qualit2.png.mp4 differ
diff --git a/results/i2v/zoom in, high qualit23.png.mp4 b/results/i2v/zoom in, high qualit23.png.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..397cbeaded68fc0aeeb3ef7b8ba8b7cf3d56af13
Binary files /dev/null and b/results/i2v/zoom in, high qualit23.png.mp4 differ
diff --git a/results/i2v/zoom in, high qualit35.png.mp4 b/results/i2v/zoom in, high qualit35.png.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..4a968875013b81899b1b4a83d1cce0871505bb86
Binary files /dev/null and b/results/i2v/zoom in, high qualit35.png.mp4 differ
diff --git "a/results/i2v/zoom in, high qualit6.1.\346\234\272\345\231\250\344\272\272\350\276\205\345\212\251\344\272\272\347\261\273\344\277\256\350\267\257\344\270\255\346\231\257.png.mp4" "b/results/i2v/zoom in, high qualit6.1.\346\234\272\345\231\250\344\272\272\350\276\205\345\212\251\344\272\272\347\261\273\344\277\256\350\267\257\344\270\255\346\231\257.png.mp4"
new file mode 100644
index 0000000000000000000000000000000000000000..7d6e4e5fabfdebc4aa1b996cb9b197c4ab8a3630
Binary files /dev/null and "b/results/i2v/zoom in, high qualit6.1.\346\234\272\345\231\250\344\272\272\350\276\205\345\212\251\344\272\272\347\261\273\344\277\256\350\267\257\344\270\255\346\231\257.png.mp4" differ
diff --git a/results/transition/smooth_transition.mp4 b/results/transition/smooth_transition.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..5232357d98a9183c2e681958bf6d45301b2c90da
Binary files /dev/null and b/results/transition/smooth_transition.mp4 differ
diff --git a/sample_scripts/with_mask_sample.py b/sample_scripts/with_mask_sample.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd50246da1c74242834534330e1324834ec0a0c7
--- /dev/null
+++ b/sample_scripts/with_mask_sample.py
@@ -0,0 +1,281 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Sample new images from a pre-trained DiT.
+"""
+import os
+import sys
+import math
+import docx
+try:
+ import utils
+
+ from diffusion import create_diffusion
+ from download import find_model
+except:
+ # sys.path.append(os.getcwd())
+ sys.path.append(os.path.split(sys.path[0])[0])
+ # sys.path[0]
+ # os.path.split(sys.path[0])
+
+
+ import utils
+
+ from diffusion import create_diffusion
+ from download import find_model
+
+import torch
+torch.backends.cuda.matmul.allow_tf32 = True
+torch.backends.cudnn.allow_tf32 = True
+import argparse
+import torchvision
+
+from einops import rearrange
+from models import get_models
+from torchvision.utils import save_image
+from diffusers.models import AutoencoderKL
+from models.clip import TextEmbedder
+from omegaconf import OmegaConf
+from PIL import Image
+import numpy as np
+from torchvision import transforms
+sys.path.append("..")
+from datasets import video_transforms
+from utils import mask_generation_before
+from natsort import natsorted
+from diffusers.utils.import_utils import is_xformers_available
+
+
+doc = docx.Document("/mnt/petrelfs/zhouyan/tmp/星际旅行.docx")
+start = 1
+p_dict = {}
+for param in doc.paragraphs:
+ p_dict[start] = param.text
+ start = start+1
+# def get_input(args):
+def get_input(path,args):
+ input_path = path
+ # input_path = args.input_path
+ transform_video = transforms.Compose([
+ video_transforms.ToTensorVideo(), # TCHW
+ video_transforms.ResizeVideo((args.image_h, args.image_w)),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
+ ])
+ temporal_sample_func = video_transforms.TemporalRandomCrop(args.num_frames * args.frame_interval)
+ if input_path is not None:
+ print(f'loading video from {input_path}')
+ if os.path.isdir(input_path):
+ file_list = os.listdir(input_path)
+ video_frames = []
+ if args.mask_type.startswith('onelast'):
+ num = int(args.mask_type.split('onelast')[-1])
+ # get first and last frame
+ first_frame_path = os.path.join(input_path, natsorted(file_list)[0])
+ last_frame_path = os.path.join(input_path, natsorted(file_list)[-1])
+ first_frame = torch.as_tensor(np.array(Image.open(first_frame_path), dtype=np.uint8, copy=True)).unsqueeze(0)
+ last_frame = torch.as_tensor(np.array(Image.open(last_frame_path), dtype=np.uint8, copy=True)).unsqueeze(0)
+ for i in range(num):
+ video_frames.append(first_frame)
+ # add zeros to frames
+ num_zeros = args.num_frames-2*num
+ for i in range(num_zeros):
+ zeros = torch.zeros_like(first_frame)
+ video_frames.append(zeros)
+ for i in range(num):
+ video_frames.append(last_frame)
+ n = 0
+ video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w
+ video_frames = transform_video(video_frames)
+ else:
+ for file in file_list:
+ if file.endswith('jpg') or file.endswith('png'):
+ image = torch.as_tensor(np.array(Image.open(os.path.join(input_path,file)), dtype=np.uint8, copy=True)).unsqueeze(0)
+ video_frames.append(image)
+ else:
+ continue
+ n = 0
+ video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w
+ video_frames = transform_video(video_frames)
+ return video_frames, n
+ elif os.path.isfile(input_path):
+ _, full_file_name = os.path.split(input_path)
+ file_name, extention = os.path.splitext(full_file_name)
+ if extention == '.jpg' or extention == '.png':
+ # raise TypeError('a single image is not supported yet!!')
+ print("reading video from a image")
+ video_frames = []
+ num = int(args.mask_type.split('first')[-1])
+ first_frame = torch.as_tensor(np.array(Image.open(input_path), dtype=np.uint8, copy=True)).unsqueeze(0)
+ for i in range(num):
+ video_frames.append(first_frame)
+ num_zeros = args.num_frames-num
+ for i in range(num_zeros):
+ zeros = torch.zeros_like(first_frame)
+ video_frames.append(zeros)
+ n = 0
+ video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w
+ video_frames = transform_video(video_frames)
+ return video_frames, n
+ else:
+ raise TypeError(f'{extention} is not supported !!')
+ else:
+ raise ValueError('Please check your path input!!')
+ else:
+ # raise ValueError('Need to give a video or some images')
+ print('given video is None, using text to video')
+ video_frames = torch.zeros(16,3,args.latent_h,args.latent_w,dtype=torch.uint8)
+ args.mask_type = 'all'
+ video_frames = transform_video(video_frames)
+ n = 0
+ return video_frames, n
+
+def auto_inpainting(args, video_input, masked_video, mask, prompt, vae, text_encoder, diffusion, model, device,):
+ b,f,c,h,w=video_input.shape
+ latent_h = args.image_size[0] // 8
+ latent_w = args.image_size[1] // 8
+
+ # prepare inputs
+ if args.use_fp16:
+ z = torch.randn(1, 4, args.num_frames, args.latent_h, args.latent_w, dtype=torch.float16, device=device) # b,c,f,h,w
+ masked_video = masked_video.to(dtype=torch.float16)
+ mask = mask.to(dtype=torch.float16)
+ else:
+ z = torch.randn(1, 4, args.num_frames, args.latent_h, args.latent_w, device=device) # b,c,f,h,w
+
+
+ masked_video = rearrange(masked_video, 'b f c h w -> (b f) c h w').contiguous()
+ masked_video = vae.encode(masked_video).latent_dist.sample().mul_(0.18215)
+ masked_video = rearrange(masked_video, '(b f) c h w -> b c f h w', b=b).contiguous()
+ mask = torch.nn.functional.interpolate(mask[:,:,0,:], size=(latent_h, latent_w)).unsqueeze(1)
+
+ # classifier_free_guidance
+ if args.do_classifier_free_guidance:
+ masked_video = torch.cat([masked_video] * 2)
+ mask = torch.cat([mask] * 2)
+ z = torch.cat([z] * 2)
+ prompt_all = [prompt] + [args.negative_prompt]
+
+ else:
+ masked_video = masked_video
+ mask = mask
+ z = z
+ prompt_all = [prompt]
+
+ text_prompt = text_encoder(text_prompts=prompt_all, train=False)
+ model_kwargs = dict(encoder_hidden_states=text_prompt,
+ class_labels=None,
+ cfg_scale=args.cfg_scale,
+ use_fp16=args.use_fp16,) # tav unet
+
+ # Sample images:
+ if args.sample_method == 'ddim':
+ samples = diffusion.ddim_sample_loop(
+ model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device, \
+ mask=mask, x_start=masked_video, use_concat=args.use_mask
+ )
+ elif args.sample_method == 'ddpm':
+ samples = diffusion.p_sample_loop(
+ model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device, \
+ mask=mask, x_start=masked_video, use_concat=args.use_mask
+ )
+ samples, _ = samples.chunk(2, dim=0) # [1, 4, 16, 32, 32]
+ if args.use_fp16:
+ samples = samples.to(dtype=torch.float16)
+
+ video_clip = samples[0].permute(1, 0, 2, 3).contiguous() # [16, 4, 32, 32]
+ video_clip = vae.decode(video_clip / 0.18215).sample # [16, 3, 256, 256]
+ return video_clip
+
+def main(args):
+ # Setup PyTorch:
+ if args.seed:
+ torch.manual_seed(args.seed)
+ torch.set_grad_enabled(False)
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ # device = "cpu"
+
+ if args.ckpt is None:
+ raise ValueError("Please specify a checkpoint path using --ckpt ")
+
+ # Load model:
+ latent_h = args.image_size[0] // 8
+ latent_w = args.image_size[1] // 8
+ args.image_h = args.image_size[0]
+ args.image_w = args.image_size[1]
+ args.latent_h = latent_h
+ args.latent_w = latent_w
+ print('loading model')
+ model = get_models(args).to(device)
+
+ if args.use_compile:
+ model = torch.compile(model)
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ model.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # load model
+ ckpt_path = args.ckpt
+ state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)['ema']
+ model.load_state_dict(state_dict)
+ print('loading succeed')
+
+ model.eval() # important!
+ pretrained_model_path = args.pretrained_model_path
+ diffusion = create_diffusion(str(args.num_sampling_steps))
+ vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device)
+ text_encoder = TextEmbedder(pretrained_model_path).to(device)
+ if args.use_fp16:
+ print('Warnning: using half percision for inferencing!')
+ vae.to(dtype=torch.float16)
+ model.to(dtype=torch.float16)
+ text_encoder.to(dtype=torch.float16)
+
+ # Labels to condition the model with (feel free to change):
+ prompt = args.text_prompt
+ if prompt ==[]:
+ prompt = args.input_path.split('/')[-1].split('.')[0].replace('_', ' ')
+ else:
+ prompt = prompt[0]
+ prompt_base = prompt.replace(' ','_')
+ prompt = prompt + args.additional_prompt
+
+
+
+ if not os.path.exists(os.path.join(args.save_img_path)):
+ os.makedirs(os.path.join(args.save_img_path))
+ for file in os.listdir(args.img_path):
+ video_input, reserve_frames = get_input(os.path.join(args.img_path,file),args)
+ video_input = video_input.to(device).unsqueeze(0)
+ mask = mask_generation_before(args.mask_type, video_input.shape, video_input.dtype, device)
+ masked_video = video_input * (mask == 0)
+ prompt = "tilt up, high quality, stable "
+ prompt = prompt + args.additional_prompt
+ video_clip = auto_inpainting(args, video_input, masked_video, mask, prompt, vae, text_encoder, diffusion, model, device,)
+ video_ = ((video_clip * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1)
+ torchvision.io.write_video(os.path.join(args.save_img_path, prompt[0:20]+file+ '.mp4'), video_, fps=8)
+ # video_input, researve_frames = get_input(args) # f,c,h,w
+ # video_input = video_input.to(device).unsqueeze(0) # b,f,c,h,w
+ # mask = mask_generation_before(args.mask_type, video_input.shape, video_input.dtype, device) # b,f,c,h,w
+ # # TODO: change the first3 to last3
+ # masked_video = video_input * (mask == 0)
+
+ # video_clip = auto_inpainting(args, video_input, masked_video, mask, prompt, vae, text_encoder, diffusion, model, device,)
+ # video_ = ((video_clip * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1)
+ # torchvision.io.write_video(os.path.join(args.save_img_path, prompt_base+ '.mp4'), video_, fps=8)
+ print(f'save in {args.save_img_path}')
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config", type=str, default="./configs/sample_mask.yaml")
+ parser.add_argument("--run-time", type=int, default=0)
+ args = parser.parse_args()
+ omega_conf = OmegaConf.load(args.config)
+ omega_conf.run_time = args.run_time
+ main(omega_conf)
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c116b19946e3e53dbcd66e49b4d4aba98920187a
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,374 @@
+import os
+import math
+import torch
+import logging
+import subprocess
+import numpy as np
+import torch.distributed as dist
+
+# from torch._six import inf
+from torch import inf
+from PIL import Image
+from typing import Union, Iterable
+from collections import OrderedDict
+from torch.utils.tensorboard import SummaryWriter
+_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]
+
+#################################################################################
+# Training Helper Functions #
+#################################################################################
+def fetch_files_by_numbers(start_number, count, file_list):
+ file_numbers = range(start_number, start_number + count)
+ found_files = []
+ for file_number in file_numbers:
+ file_number_padded = str(file_number).zfill(2)
+ for file_name in file_list:
+ if file_name.endswith(file_number_padded + '.csv'):
+ found_files.append(file_name)
+ break # Stop searching once a file is found for the current number
+ return found_files
+
+#################################################################################
+# Training Clip Gradients #
+#################################################################################
+
+def get_grad_norm(
+ parameters: _tensor_or_tensors, norm_type: float = 2.0) -> torch.Tensor:
+ r"""
+ Copy from torch.nn.utils.clip_grad_norm_
+
+ Clips gradient norm of an iterable of parameters.
+
+ The norm is computed over all gradients together, as if they were
+ concatenated into a single vector. Gradients are modified in-place.
+
+ Args:
+ parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
+ single Tensor that will have gradients normalized
+ max_norm (float or int): max norm of the gradients
+ norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
+ infinity norm.
+ error_if_nonfinite (bool): if True, an error is thrown if the total
+ norm of the gradients from :attr:`parameters` is ``nan``,
+ ``inf``, or ``-inf``. Default: False (will switch to True in the future)
+
+ Returns:
+ Total norm of the parameter gradients (viewed as a single vector).
+ """
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ grads = [p.grad for p in parameters if p.grad is not None]
+ norm_type = float(norm_type)
+ if len(grads) == 0:
+ return torch.tensor(0.)
+ device = grads[0].device
+ if norm_type == inf:
+ norms = [g.detach().abs().max().to(device) for g in grads]
+ total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
+ else:
+ total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type)
+ return total_norm
+
+def clip_grad_norm_(
+ parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0,
+ error_if_nonfinite: bool = False, clip_grad = True) -> torch.Tensor:
+ r"""
+ Copy from torch.nn.utils.clip_grad_norm_
+
+ Clips gradient norm of an iterable of parameters.
+
+ The norm is computed over all gradients together, as if they were
+ concatenated into a single vector. Gradients are modified in-place.
+
+ Args:
+ parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
+ single Tensor that will have gradients normalized
+ max_norm (float or int): max norm of the gradients
+ norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
+ infinity norm.
+ error_if_nonfinite (bool): if True, an error is thrown if the total
+ norm of the gradients from :attr:`parameters` is ``nan``,
+ ``inf``, or ``-inf``. Default: False (will switch to True in the future)
+
+ Returns:
+ Total norm of the parameter gradients (viewed as a single vector).
+ """
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ grads = [p.grad for p in parameters if p.grad is not None]
+ max_norm = float(max_norm)
+ norm_type = float(norm_type)
+ if len(grads) == 0:
+ return torch.tensor(0.)
+ device = grads[0].device
+ if norm_type == inf:
+ norms = [g.detach().abs().max().to(device) for g in grads]
+ total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
+ else:
+ total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type)
+ # print(total_norm)
+
+ if clip_grad:
+ if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
+ raise RuntimeError(
+ f'The total norm of order {norm_type} for gradients from '
+ '`parameters` is non-finite, so it cannot be clipped. To disable '
+ 'this error and scale the gradients by the non-finite norm anyway, '
+ 'set `error_if_nonfinite=False`')
+ clip_coef = max_norm / (total_norm + 1e-6)
+ # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
+ # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
+ # when the gradients do not reside in CPU memory.
+ clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
+ for g in grads:
+ g.detach().mul_(clip_coef_clamped.to(g.device))
+ # gradient_cliped = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type)
+ # print(gradient_cliped)
+ return total_norm
+
+def separation_content_motion(video_clip):
+ """
+ separate coontent and motion in a given video
+ Args:
+ video_clip, a give video clip, [B F C H W]
+
+ Return:
+ base frame, [B, 1, C, H, W]
+ motions, [B, F-1, C, H, W],
+ the first is base frame,
+ the second is motions based on base frame
+ """
+ total_frames = video_clip.shape[1]
+ base_frame = video_clip[0]
+ motions = [video_clip[i] - base_frame for i in range(1, total_frames)]
+ motions = torch.cat(motions, dim=1)
+ return base_frame, motions
+
+def get_experiment_dir(root_dir, args):
+ if args.use_compile:
+ root_dir += '-Compile' # speedup by torch compile
+ if args.fixed_spatial:
+ root_dir += '-FixedSpa'
+ if args.enable_xformers_memory_efficient_attention:
+ root_dir += '-Xfor'
+ if args.gradient_checkpointing:
+ root_dir += '-Gc'
+ if args.mixed_precision:
+ root_dir += '-Amp'
+ if args.image_size == 512:
+ root_dir += '-512'
+ return root_dir
+
+#################################################################################
+# Training Logger #
+#################################################################################
+
+def create_logger(logging_dir):
+ """
+ Create a logger that writes to a log file and stdout.
+ """
+ if dist.get_rank() == 0: # real logger
+ logging.basicConfig(
+ level=logging.INFO,
+ # format='[\033[34m%(asctime)s\033[0m] %(message)s',
+ format='[%(asctime)s] %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
+ )
+ logger = logging.getLogger(__name__)
+
+ else: # dummy logger (does nothing)
+ logger = logging.getLogger(__name__)
+ logger.addHandler(logging.NullHandler())
+ return logger
+
+def create_accelerate_logger(logging_dir, is_main_process=False):
+ """
+ Create a logger that writes to a log file and stdout.
+ """
+ if is_main_process: # real logger
+ logging.basicConfig(
+ level=logging.INFO,
+ # format='[\033[34m%(asctime)s\033[0m] %(message)s',
+ format='[%(asctime)s] %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
+ )
+ logger = logging.getLogger(__name__)
+ else: # dummy logger (does nothing)
+ logger = logging.getLogger(__name__)
+ logger.addHandler(logging.NullHandler())
+ return logger
+
+
+def create_tensorboard(tensorboard_dir):
+ """
+ Create a tensorboard that saves losses.
+ """
+ if dist.get_rank() == 0: # real tensorboard
+ # tensorboard
+ writer = SummaryWriter(tensorboard_dir)
+
+ return writer
+
+def write_tensorboard(writer, *args):
+ '''
+ write the loss information to a tensorboard file.
+ Only for pytorch DDP mode.
+ '''
+ if dist.get_rank() == 0: # real tensorboard
+ writer.add_scalar(args[0], args[1], args[2])
+
+#################################################################################
+# EMA Update/ DDP Training Utils #
+#################################################################################
+
+@torch.no_grad()
+def update_ema(ema_model, model, decay=0.9999):
+ """
+ Step the EMA model towards the current model.
+ """
+ ema_params = OrderedDict(ema_model.named_parameters())
+ model_params = OrderedDict(model.named_parameters())
+
+ for name, param in model_params.items():
+ # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
+ if param.requires_grad:
+ ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
+
+def requires_grad(model, flag=True):
+ """
+ Set requires_grad flag for all parameters in a model.
+ """
+ for p in model.parameters():
+ p.requires_grad = flag
+
+def cleanup():
+ """
+ End DDP training.
+ """
+ dist.destroy_process_group()
+
+
+def setup_distributed(backend="nccl", port=None):
+ """Initialize distributed training environment.
+ support both slurm and torch.distributed.launch
+ see torch.distributed.init_process_group() for more details
+ """
+ num_gpus = torch.cuda.device_count()
+
+ if "SLURM_JOB_ID" in os.environ:
+ rank = int(os.environ["SLURM_PROCID"])
+ world_size = int(os.environ["SLURM_NTASKS"])
+ node_list = os.environ["SLURM_NODELIST"]
+ addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1")
+ # specify master port
+ if port is not None:
+ os.environ["MASTER_PORT"] = str(port)
+ elif "MASTER_PORT" not in os.environ:
+ # os.environ["MASTER_PORT"] = "29566"
+ os.environ["MASTER_PORT"] = str(29566 + num_gpus)
+ if "MASTER_ADDR" not in os.environ:
+ os.environ["MASTER_ADDR"] = addr
+ os.environ["WORLD_SIZE"] = str(world_size)
+ os.environ["LOCAL_RANK"] = str(rank % num_gpus)
+ os.environ["RANK"] = str(rank)
+ else:
+ rank = int(os.environ["RANK"])
+ world_size = int(os.environ["WORLD_SIZE"])
+
+ # torch.cuda.set_device(rank % num_gpus)
+
+ dist.init_process_group(
+ backend=backend,
+ world_size=world_size,
+ rank=rank,
+ )
+
+#################################################################################
+# Testing Utils #
+#################################################################################
+
+def save_video_grid(video, nrow=None):
+ b, t, h, w, c = video.shape
+
+ if nrow is None:
+ nrow = math.ceil(math.sqrt(b))
+ ncol = math.ceil(b / nrow)
+ padding = 1
+ video_grid = torch.zeros((t, (padding + h) * nrow + padding,
+ (padding + w) * ncol + padding, c), dtype=torch.uint8)
+
+ print(video_grid.shape)
+ for i in range(b):
+ r = i // ncol
+ c = i % ncol
+ start_r = (padding + h) * r
+ start_c = (padding + w) * c
+ video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i]
+
+ return video_grid
+
+def save_videos_grid_tav(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=8):
+ from einops import rearrange
+ import imageio
+ import torchvision
+
+ videos = rearrange(videos, "b c t h w -> t b c h w")
+ outputs = []
+ for x in videos:
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
+ if rescale:
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
+ x = (x * 255).numpy().astype(np.uint8)
+ outputs.append(x)
+
+ # os.makedirs(os.path.dirname(path), exist_ok=True)
+ imageio.mimsave(path, outputs, fps=fps)
+
+
+#################################################################################
+# MMCV Utils #
+#################################################################################
+
+
+def collect_env():
+ # Copyright (c) OpenMMLab. All rights reserved.
+ from mmcv.utils import collect_env as collect_base_env
+ from mmcv.utils import get_git_hash
+ """Collect the information of the running environments."""
+
+ env_info = collect_base_env()
+ env_info['MMClassification'] = get_git_hash()[:7]
+
+ for name, val in env_info.items():
+ print(f'{name}: {val}')
+
+ print(torch.cuda.get_arch_list())
+ print(torch.version.cuda)
+
+
+#################################################################################
+# Long video generation Utils #
+#################################################################################
+
+def mask_generation_before(mask_type, shape, dtype, device, dropout_prob=0.0, use_image_num=0):
+ b, f, c, h, w = shape
+ if mask_type.startswith('first'):
+ num = int(mask_type.split('first')[-1])
+ mask_f = torch.cat([torch.zeros(1, num, 1, 1, 1, dtype=dtype, device=device),
+ torch.ones(1, f-num, 1, 1, 1, dtype=dtype, device=device)], dim=1)
+ mask = mask_f.expand(b, -1, c, h, w)
+ elif mask_type.startswith('all'):
+ mask = torch.ones(b,f,c,h,w,dtype=dtype,device=device)
+ elif mask_type.startswith('onelast'):
+ num = int(mask_type.split('onelast')[-1])
+ mask_one = torch.zeros(1,1,1,1,1, dtype=dtype, device=device)
+ mask_mid = torch.ones(1,f-2*num,1,1,1,dtype=dtype, device=device)
+ mask_last = torch.zeros_like(mask_one)
+ mask = torch.cat([mask_one]*num + [mask_mid] + [mask_last]*num, dim=1)
+ mask = mask.expand(b, -1, c, h, w)
+ else:
+ raise ValueError(f"Invalid mask type: {mask_type}")
+ return mask