camenduru commited on
Commit
2f760fe
1 Parent(s): 001612d

thanks to Vchitect ❤

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +201 -0
  2. README.md +98 -0
  3. base/__pycache__/download.cpython-311.pyc +0 -0
  4. base/app.py +181 -0
  5. base/app.sh +1 -0
  6. base/configs/sample.yaml +28 -0
  7. base/download.py +18 -0
  8. base/gradio_cached_examples/14/Video result/35727d2ebeb816c94d68/laviea_confused_grizzly_bear_in_cal-1000-50-7-.mp4 +0 -0
  9. base/gradio_cached_examples/14/Video result/42b8a418d77480fcc8fc/laviea_panda_taking_a_selfie_2k_h-400-50-7-.mp4 +0 -0
  10. base/gradio_cached_examples/14/Video result/4b54d32b7e8f2a3cd333/lavieA_steam_train_moving_on_a_moun-230-50-7-.mp4 +0 -0
  11. base/gradio_cached_examples/14/Video result/625fb799abdbcb60fe2f/laviea_corgi_walking_in_the_park_at-400-50-7-.mp4 +0 -0
  12. base/gradio_cached_examples/14/Video result/75b535bd2f78c28d2789/laviea_teddy_bear_walking_on_the_st-100-50-7-.mp4 +0 -0
  13. base/gradio_cached_examples/14/Video result/767a0718deb1983b3d43/laviean_epic_tornado_attacking_abov-230-50-7-.mp4 +0 -0
  14. base/gradio_cached_examples/14/Video result/a15ccccc7c42e18cd062/laviea_jar_filled_with_fire_4K_vid-400-50-7-.mp4 +0 -0
  15. base/gradio_cached_examples/14/Video result/beedae14fa3a8f24e4ec/laviea_shark_swimming_in_clear_Carr-400-50-7-.mp4 +0 -0
  16. base/gradio_cached_examples/14/Video result/c2e7acb8ce5cb0a52899/laviea_teddy_bear_walking_in_the_pa-400-50-7-.mp4 +0 -0
  17. base/gradio_cached_examples/14/Video result/cecd7ff29690b876a418/laviejungle_river_at_sunset_ultra_-400-50-7-.mp4 +0 -0
  18. base/gradio_cached_examples/14/Video result/e67b3c12db1c38afd2c4/laviea_polar_bear_playing_drum_kit_-400-50-7-.mp4 +0 -0
  19. base/gradio_cached_examples/14/Video result/feeee8981f36b962bfe6/laviea_cut_teddy_bear_reading_a_boo-700-50-7-.mp4 +0 -0
  20. base/gradio_cached_examples/14/log.csv +13 -0
  21. base/huggingface-t2v/.DS_Store +0 -0
  22. base/huggingface-t2v/__init__.py +0 -0
  23. base/huggingface-t2v/requirements.txt +0 -0
  24. base/models/__init__.py +33 -0
  25. base/models/__pycache__/__init__.cpython-311.pyc +0 -0
  26. base/models/__pycache__/attention.cpython-311.pyc +0 -0
  27. base/models/__pycache__/resnet.cpython-311.pyc +0 -0
  28. base/models/__pycache__/unet.cpython-311.pyc +0 -0
  29. base/models/__pycache__/unet_blocks.cpython-311.pyc +0 -0
  30. base/models/attention.py +707 -0
  31. base/models/clip.py +120 -0
  32. base/models/resnet.py +212 -0
  33. base/models/temporal_attention.py +388 -0
  34. base/models/transformer_3d.py +367 -0
  35. base/models/unet.py +617 -0
  36. base/models/unet_blocks.py +648 -0
  37. base/models/utils.py +215 -0
  38. base/pipelines/__pycache__/pipeline_videogen.cpython-311.pyc +0 -0
  39. base/pipelines/pipeline_videogen.py +677 -0
  40. base/pipelines/sample.py +88 -0
  41. base/pipelines/sample.sh +2 -0
  42. base/text_to_video/__init__.py +45 -0
  43. base/text_to_video/__pycache__/__init__.cpython-311.pyc +0 -0
  44. base/try.py +5 -0
  45. environment.yml +27 -0
  46. interpolation/configs/sample.yaml +36 -0
  47. interpolation/datasets/__init__.py +1 -0
  48. interpolation/datasets/video_transforms.py +109 -0
  49. interpolation/diffusion/__init__.py +47 -0
  50. interpolation/diffusion/diffusion_utils.py +88 -0
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ---
3
+ title: LaVie
4
+ emoji: 😊
5
+ colorFrom: pink
6
+ colorTo: pink
7
+ sdk: gradio
8
+ sdk_version: 4.3.0
9
+ app_file: base/app.py
10
+ pinned: false
11
+ ---
12
+
13
+ # LaVie: High-Quality Video Generation with Cascaded Latent Diffusion Models
14
+
15
+ This repository is the official PyTorch implementation of [LaVie](https://arxiv.org/abs/2309.15103).
16
+
17
+ **LaVie** is a Text-to-Video (T2V) generation framework, and main part of video generation system [Vchitect](http://vchitect.intern-ai.org.cn/).
18
+
19
+ [![arXiv](https://img.shields.io/badge/arXiv-2307.04725-b31b1b.svg)](https://arxiv.org/abs/2309.15103)
20
+ [![Project Page](https://img.shields.io/badge/Project-Website-green)](https://vchitect.github.io/LaVie-project/)
21
+ <!--
22
+ [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)]()
23
+ [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-yellow)]()
24
+ -->
25
+
26
+ <img src="lavie.gif" width="800">
27
+
28
+ ## Installation
29
+ ```
30
+ conda env create -f environment.yml
31
+ conda activate lavie
32
+ ```
33
+
34
+ ## Download Pre-Trained models
35
+ Download [pre-trained models](https://huggingface.co/YaohuiW/LaVie/tree/main), [stable diffusion 1.4](https://huggingface.co/CompVis/stable-diffusion-v1-4/tree/main), [stable-diffusion-x4-upscaler](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler/tree/main) to `./pretrained_models`. You should be able to see the following:
36
+ ```
37
+ ├── pretrained_models
38
+ │ ├── lavie_base.pt
39
+ │ ├── lavie_interpolation.pt
40
+ │ ├── lavie_vsr.pt
41
+ │ ├── stable-diffusion-v1-4
42
+ │ │ ├── ...
43
+ └── └── stable-diffusion-x4-upscaler
44
+ ├── ...
45
+ ```
46
+
47
+ ## Inference
48
+ The inference contains **Base T2V**, **Video Interpolation** and **Video Super-Resolution** three steps. We provide several options to generate videos:
49
+ * **Step1**: 320 x 512 resolution, 16 frames
50
+ * **Step1+Step2**: 320 x 512 resolution, 61 frames
51
+ * **Step1+Step3**: 1280 x 2048 resolution, 16 frames
52
+ * **Step1+Step2+Step3**: 1280 x 2048 resolution, 61 frames
53
+
54
+ Feel free to try different options:)
55
+
56
+
57
+ ### Step1. Base T2V
58
+ Run following command to generate videos from base T2V model.
59
+ ```
60
+ cd base
61
+ python pipelines/sample.py --config configs/sample.yaml
62
+ ```
63
+ Edit `text_prompt` in `configs/sample.yaml` to change prompt, results will be saved under `./res/base`.
64
+
65
+ ### Step2 (optional). Video Interpolation
66
+ Run following command to conduct video interpolation.
67
+ ```
68
+ cd interpolation
69
+ python sample.py --config configs/sample.yaml
70
+ ```
71
+ The default input video path is `./res/base`, results will be saved under `./res/interpolation`. In `configs/sample.yaml`, you could modify default `input_folder` with `YOUR_INPUT_FOLDER` in `configs/sample.yaml`. Input videos should be named as `prompt1.mp4`, `prompt2.mp4`, ... and put under `YOUR_INPUT_FOLDER`. Launching the code will process all the input videos in `input_folder`.
72
+
73
+
74
+ ### Step3 (optional). Video Super-Resolution
75
+ Run following command to conduct video super-resolution.
76
+ ```
77
+ cd vsr
78
+ python sample.py --config configs/sample.yaml
79
+ ```
80
+ The default input video path is `./res/base` and results will be saved under `./res/vsr`. You could modify default `input_path` with `YOUR_INPUT_FOLDER` in `configs/sample.yaml`. Smiliar to Step2, input videos should be named as `prompt1.mp4`, `prompt2.mp4`, ... and put under `YOUR_INPUT_FOLDER`. Launching the code will process all the input videos in `input_folder`.
81
+
82
+
83
+ ## BibTex
84
+ ```bibtex
85
+ @article{wang2023lavie,
86
+ title={LAVIE: High-Quality Video Generation with Cascaded Latent Diffusion Models},
87
+ author={Wang, Yaohui and Chen, Xinyuan and Ma, Xin and Zhou, Shangchen and Huang, Ziqi and Wang, Yi and Yang, Ceyuan and He, Yinan and Yu, Jiashuo and Yang, Peiqing and others},
88
+ journal={arXiv preprint arXiv:2309.15103},
89
+ year={2023}
90
+ }
91
+ ```
92
+
93
+ ## Acknowledgements
94
+ The code is buit upon [diffusers](https://github.com/huggingface/diffusers) and [Stable Diffusion](https://github.com/CompVis/stable-diffusion), we thank all the contributors for open-sourcing.
95
+
96
+
97
+ ## License
98
+ The code is licensed under Apache-2.0, model weights are fully open for academic research and also allow **free** commercial usage. To apply for a commercial license, please fill in the [application form]().
base/__pycache__/download.cpython-311.pyc ADDED
Binary file (815 Bytes). View file
 
base/app.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from text_to_video import model_t2v_fun,setup_seed
3
+ from omegaconf import OmegaConf
4
+ import torch
5
+ import imageio
6
+ import os
7
+ import cv2
8
+ import pandas as pd
9
+ import torchvision
10
+ import random
11
+ from huggingface_hub import snapshot_download
12
+
13
+ config_path = "./base/configs/sample.yaml"
14
+ args = OmegaConf.load("./base/configs/sample.yaml")
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ ### download models
17
+ # snapshot_download('Vchitect/LaVie',cache_dir='./pretrained_models')
18
+ # snapshot_download('CompVis/stable-diffusion-v1-4',cache_dir='./pretrained_models')
19
+
20
+ # ------- get model ---------------
21
+ model_t2V = model_t2v_fun(args)
22
+ model_t2V.to(device)
23
+ if device == "cuda":
24
+ model_t2V.enable_xformers_memory_efficient_attention()
25
+
26
+ # model_t2V.enable_xformers_memory_efficient_attention()
27
+ css = """
28
+ h1 {
29
+ text-align: center;
30
+ }
31
+ #component-0 {
32
+ max-width: 730px;
33
+ margin: auto;
34
+ }
35
+ """
36
+
37
+ def infer(prompt, seed_inp, ddim_steps,cfg):
38
+ if seed_inp!=-1:
39
+ setup_seed(seed_inp)
40
+ else:
41
+ seed_inp = random.choice(range(10000000))
42
+ setup_seed(seed_inp)
43
+ videos = model_t2V(prompt, video_length=16, height = 320, width= 512, num_inference_steps=ddim_steps, guidance_scale=cfg).video
44
+ print(videos[0].shape)
45
+ if not os.path.exists(args.output_folder):
46
+ os.mkdir(args.output_folder)
47
+ torchvision.io.write_video(args.output_folder + prompt[0:30].replace(' ', '_') + '-'+str(seed_inp)+'-'+str(ddim_steps)+'-'+str(cfg)+ '-.mp4', videos[0], fps=8)
48
+ # imageio.mimwrite(args.output_folder + prompt.replace(' ', '_') + '.mp4', videos[0], fps=8)
49
+ # video = cv2.VideoCapture(args.output_folder + prompt.replace(' ', '_') + '.mp4')
50
+ # video = imageio.get_reader(args.output_folder + prompt.replace(' ', '_') + '.mp4', 'ffmpeg')
51
+
52
+
53
+ # video = model_t2V(prompt, seed_inp, ddim_steps)
54
+
55
+ return args.output_folder + prompt[0:30].replace(' ', '_') + '-'+str(seed_inp)+'-'+str(ddim_steps)+'-'+str(cfg)+ '-.mp4'
56
+
57
+ print(1)
58
+
59
+ # def clean():
60
+ # return gr.Image.update(value=None, visible=False), gr.Video.update(value=None)
61
+ def clean():
62
+ return gr.Video.update(value=None)
63
+
64
+ title = """
65
+ <div style="text-align: center; max-width: 700px; margin: 0 auto;">
66
+ <div
67
+ style="
68
+ display: inline-flex;
69
+ align-items: center;
70
+ gap: 0.8rem;
71
+ font-size: 1.75rem;
72
+ "
73
+ >
74
+ <h1 style="font-weight: 900; margin-bottom: 7px; margin-top: 5px;">
75
+ Intern·Vchitect (Text-to-Video)
76
+ </h1>
77
+ </div>
78
+ <p style="margin-bottom: 10px; font-size: 94%">
79
+ Apply Intern·Vchitect to generate a video
80
+ </p>
81
+ </div>
82
+ """
83
+
84
+ # print(1)
85
+ with gr.Blocks(css='style.css') as demo:
86
+ gr.Markdown("<font color=red size=10><center>LaVie: Text-to-Video generation</center></font>")
87
+ with gr.Column():
88
+ with gr.Row(elem_id="col-container"):
89
+ # inputs = [prompt, seed_inp, ddim_steps]
90
+ # outputs = [video_out]
91
+ with gr.Column():
92
+
93
+ prompt = gr.Textbox(value="a corgi walking in the park at sunrise, oil painting style", label="Prompt", placeholder="enter prompt", show_label=True, elem_id="prompt-in", min_width=200, lines=2)
94
+
95
+ ddim_steps = gr.Slider(label='Steps', minimum=50, maximum=300, value=50, step=1)
96
+ seed_inp = gr.Slider(value=-1,label="seed (for random generation, use -1)",show_label=True,minimum=-1,maximum=2147483647)
97
+ cfg = gr.Number(label="guidance_scale",value=7.5)
98
+ # seed_inp = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, value=400, elem_id="seed-in")
99
+
100
+ # with gr.Row():
101
+ # # control_task = gr.Dropdown(label="Task", choices=["Text-2-video", "Image-2-video"], value="Text-2-video", multiselect=False, elem_id="controltask-in")
102
+ # ddim_steps = gr.Slider(label='Steps', minimum=50, maximum=300, value=250, step=1)
103
+ # seed_inp = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, value=123456, elem_id="seed-in")
104
+
105
+ # ddim_steps = gr.Slider(label='Steps', minimum=50, maximum=300, value=250, step=1)
106
+ # ex = gr.Examples(
107
+ # examples = [['a corgi walking in the park at sunrise, oil painting style',400,50,7],
108
+ # ['a cut teddy bear reading a book in the park, oil painting style, high quality',700,50,7],
109
+ # ['an epic tornado attacking above a glowing city at night, the tornado is made of smoke, highly detailed',230,50,7],
110
+ # ['a jar filled with fire, 4K video, 3D rendered, well-rendered',400,50,7],
111
+ # ['a teddy bear walking in the park, oil painting style, high quality',400,50,7],
112
+ # ['a teddy bear walking on the street, 2k, high quality',100,50,7],
113
+ # ['a panda taking a selfie, 2k, high quality',400,50,7],
114
+ # ['a polar bear playing drum kit in NYC Times Square, 4k, high resolution',400,50,7],
115
+ # ['jungle river at sunset, ultra quality',400,50,7],
116
+ # ['a shark swimming in clear Carribean ocean, 2k, high quality',400,50,7],
117
+ # ['A steam train moving on a mountainside by Vincent van Gogh',230,50,7],
118
+ # ['a confused grizzly bear in calculus class',1000,50,7]],
119
+ # fn = infer,
120
+ # inputs=[prompt, seed_inp, ddim_steps,cfg],
121
+ # # outputs=[video_out],
122
+ # cache_examples=False,
123
+ # examples_per_page = 6
124
+ # )
125
+ # ex.dataset.headers = [""]
126
+
127
+ with gr.Column():
128
+ submit_btn = gr.Button("Generate video")
129
+ clean_btn = gr.Button("Clean video")
130
+ # submit_btn = gr.Button("Generate video", size='sm')
131
+ # video_out = gr.Video(label="Video result", elem_id="video-output", height=320, width=512)
132
+ video_out = gr.Video(label="Video result", elem_id="video-output")
133
+ # with gr.Row():
134
+ # video_out = gr.Video(label="Video result", elem_id="video-output", height=320, width=512)
135
+ # submit_btn = gr.Button("Generate video", size='sm')
136
+
137
+
138
+ # video_out = gr.Video(label="Video result", elem_id="video-output", height=320, width=512)
139
+ inputs = [prompt, seed_inp, ddim_steps,cfg]
140
+ outputs = [video_out]
141
+ # gr.Examples(
142
+ # value = [['An astronaut riding a horse',123,50],
143
+ # ['a panda eating bamboo on a rock',123,50],
144
+ # ['Spiderman is surfing',123,50]],
145
+ # label = "example of sampling",
146
+ # show_label = True,
147
+ # headers = ['prompt','seed','steps'],
148
+ # datatype = ['str','number','number'],
149
+ # row_count=4,
150
+ # col_count=(3,"fixed")
151
+ # )
152
+ ex = gr.Examples(
153
+ examples = [['a corgi walking in the park at sunrise, oil painting style',400,50,7],
154
+ ['a cut teddy bear reading a book in the park, oil painting style, high quality',700,50,7],
155
+ ['an epic tornado attacking above a glowing city at night, the tornado is made of smoke, highly detailed',230,50,7],
156
+ ['a jar filled with fire, 4K video, 3D rendered, well-rendered',400,50,7],
157
+ ['a teddy bear walking in the park, oil painting style, high quality',400,50,7],
158
+ ['a teddy bear walking on the street, 2k, high quality',100,50,7],
159
+ ['a panda taking a selfie, 2k, high quality',400,50,7],
160
+ ['a polar bear playing drum kit in NYC Times Square, 4k, high resolution',400,50,7],
161
+ ['jungle river at sunset, ultra quality',400,50,7],
162
+ ['a shark swimming in clear Carribean ocean, 2k, high quality',400,50,7],
163
+ ['A steam train moving on a mountainside by Vincent van Gogh',230,50,7],
164
+ ['a confused grizzly bear in calculus class',1000,50,7]],
165
+ fn = infer,
166
+ inputs=[prompt, seed_inp, ddim_steps,cfg],
167
+ outputs=[video_out],
168
+ cache_examples=False,
169
+ )
170
+ ex.dataset.headers = [""]
171
+
172
+ # control_task.change(change_task_options, inputs=[control_task], outputs=[canny_opt, hough_opt, normal_opt], queue=False)
173
+ # submit_btn.click(clean, inputs=[], outputs=[video_out], queue=False)
174
+ clean_btn.click(clean, inputs=[], outputs=[video_out], queue=False)
175
+ submit_btn.click(infer, inputs, outputs)
176
+ # share_button.click(None, [], [], _js=share_js)
177
+
178
+ print(2)
179
+ demo.queue(max_size=12).launch()
180
+
181
+
base/app.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ srun -p aigc-video --gres=gpu:1 -n1 -N1 python app.py
base/configs/sample.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # path:
2
+ output_folder: "outputs"
3
+ pretrained_path: "pretrained_models"
4
+
5
+ # model config:
6
+ model: UNet
7
+ video_length: 16
8
+ image_size: [320, 512]
9
+
10
+ # beta schedule
11
+ beta_start: 0.0001
12
+ beta_end: 0.02
13
+ beta_schedule: "linear"
14
+
15
+ # model speedup
16
+ use_compile: False
17
+ use_fp16: True
18
+
19
+ # sample config:
20
+ seed: 3
21
+ run_time: 0
22
+ guidance_scale: 7.0
23
+ sample_method: 'ddpm'
24
+ num_sampling_steps: 250
25
+ text_prompt: [
26
+ 'a teddy bear walking on the street, high quality, 2k',
27
+
28
+ ]
base/download.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # All rights reserved.
2
+
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import os
8
+
9
+
10
+ def find_model(model_name):
11
+ """
12
+ Finds a pre-trained model, downloading it if necessary. Alternatively, loads a model from a local path.
13
+ """
14
+ checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage)
15
+ if "ema" in checkpoint: # supports checkpoints from train.py
16
+ print('Ema existing!')
17
+ checkpoint = checkpoint["ema"]
18
+ return checkpoint
base/gradio_cached_examples/14/Video result/35727d2ebeb816c94d68/laviea_confused_grizzly_bear_in_cal-1000-50-7-.mp4 ADDED
Binary file (332 kB). View file
 
base/gradio_cached_examples/14/Video result/42b8a418d77480fcc8fc/laviea_panda_taking_a_selfie_2k_h-400-50-7-.mp4 ADDED
Binary file (288 kB). View file
 
base/gradio_cached_examples/14/Video result/4b54d32b7e8f2a3cd333/lavieA_steam_train_moving_on_a_moun-230-50-7-.mp4 ADDED
Binary file (316 kB). View file
 
base/gradio_cached_examples/14/Video result/625fb799abdbcb60fe2f/laviea_corgi_walking_in_the_park_at-400-50-7-.mp4 ADDED
Binary file (258 kB). View file
 
base/gradio_cached_examples/14/Video result/75b535bd2f78c28d2789/laviea_teddy_bear_walking_on_the_st-100-50-7-.mp4 ADDED
Binary file (352 kB). View file
 
base/gradio_cached_examples/14/Video result/767a0718deb1983b3d43/laviean_epic_tornado_attacking_abov-230-50-7-.mp4 ADDED
Binary file (247 kB). View file
 
base/gradio_cached_examples/14/Video result/a15ccccc7c42e18cd062/laviea_jar_filled_with_fire_4K_vid-400-50-7-.mp4 ADDED
Binary file (267 kB). View file
 
base/gradio_cached_examples/14/Video result/beedae14fa3a8f24e4ec/laviea_shark_swimming_in_clear_Carr-400-50-7-.mp4 ADDED
Binary file (307 kB). View file
 
base/gradio_cached_examples/14/Video result/c2e7acb8ce5cb0a52899/laviea_teddy_bear_walking_in_the_pa-400-50-7-.mp4 ADDED
Binary file (338 kB). View file
 
base/gradio_cached_examples/14/Video result/cecd7ff29690b876a418/laviejungle_river_at_sunset_ultra_-400-50-7-.mp4 ADDED
Binary file (261 kB). View file
 
base/gradio_cached_examples/14/Video result/e67b3c12db1c38afd2c4/laviea_polar_bear_playing_drum_kit_-400-50-7-.mp4 ADDED
Binary file (372 kB). View file
 
base/gradio_cached_examples/14/Video result/feeee8981f36b962bfe6/laviea_cut_teddy_bear_reading_a_boo-700-50-7-.mp4 ADDED
Binary file (381 kB). View file
 
base/gradio_cached_examples/14/log.csv ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Video result,flag,username,timestamp
2
+ "{""video"":{""path"":""gradio_cached_examples/14/Video result/625fb799abdbcb60fe2f/laviea_corgi_walking_in_the_park_at-400-50-7-.mp4"",""url"":null,""size"":null,""orig_name"":""laviea_corgi_walking_in_the_park_at-400-50-7-.mp4"",""mime_type"":null},""subtitles"":null}",,,2023-11-27 14:01:13.139609
3
+ "{""video"":{""path"":""gradio_cached_examples/14/Video result/feeee8981f36b962bfe6/laviea_cut_teddy_bear_reading_a_boo-700-50-7-.mp4"",""url"":null,""size"":null,""orig_name"":""laviea_cut_teddy_bear_reading_a_boo-700-50-7-.mp4"",""mime_type"":null},""subtitles"":null}",,,2023-11-27 14:01:23.543257
4
+ "{""video"":{""path"":""gradio_cached_examples/14/Video result/767a0718deb1983b3d43/laviean_epic_tornado_attacking_abov-230-50-7-.mp4"",""url"":null,""size"":null,""orig_name"":""laviean_epic_tornado_attacking_abov-230-50-7-.mp4"",""mime_type"":null},""subtitles"":null}",,,2023-11-27 14:01:33.942899
5
+ "{""video"":{""path"":""gradio_cached_examples/14/Video result/a15ccccc7c42e18cd062/laviea_jar_filled_with_fire_4K_vid-400-50-7-.mp4"",""url"":null,""size"":null,""orig_name"":""laviea_jar_filled_with_fire,_4K_vid-400-50-7-.mp4"",""mime_type"":null},""subtitles"":null}",,,2023-11-27 14:01:44.348969
6
+ "{""video"":{""path"":""gradio_cached_examples/14/Video result/c2e7acb8ce5cb0a52899/laviea_teddy_bear_walking_in_the_pa-400-50-7-.mp4"",""url"":null,""size"":null,""orig_name"":""laviea_teddy_bear_walking_in_the_pa-400-50-7-.mp4"",""mime_type"":null},""subtitles"":null}",,,2023-11-27 14:01:54.765554
7
+ "{""video"":{""path"":""gradio_cached_examples/14/Video result/75b535bd2f78c28d2789/laviea_teddy_bear_walking_on_the_st-100-50-7-.mp4"",""url"":null,""size"":null,""orig_name"":""laviea_teddy_bear_walking_on_the_st-100-50-7-.mp4"",""mime_type"":null},""subtitles"":null}",,,2023-11-27 14:02:05.255612
8
+ "{""video"":{""path"":""gradio_cached_examples/14/Video result/42b8a418d77480fcc8fc/laviea_panda_taking_a_selfie_2k_h-400-50-7-.mp4"",""url"":null,""size"":null,""orig_name"":""laviea_panda_taking_a_selfie,_2k,_h-400-50-7-.mp4"",""mime_type"":null},""subtitles"":null}",,,2023-11-27 14:02:15.694357
9
+ "{""video"":{""path"":""gradio_cached_examples/14/Video result/e67b3c12db1c38afd2c4/laviea_polar_bear_playing_drum_kit_-400-50-7-.mp4"",""url"":null,""size"":null,""orig_name"":""laviea_polar_bear_playing_drum_kit_-400-50-7-.mp4"",""mime_type"":null},""subtitles"":null}",,,2023-11-27 14:02:26.121546
10
+ "{""video"":{""path"":""gradio_cached_examples/14/Video result/cecd7ff29690b876a418/laviejungle_river_at_sunset_ultra_-400-50-7-.mp4"",""url"":null,""size"":null,""orig_name"":""laviejungle_river_at_sunset,_ultra_-400-50-7-.mp4"",""mime_type"":null},""subtitles"":null}",,,2023-11-27 14:02:36.540682
11
+ "{""video"":{""path"":""gradio_cached_examples/14/Video result/beedae14fa3a8f24e4ec/laviea_shark_swimming_in_clear_Carr-400-50-7-.mp4"",""url"":null,""size"":null,""orig_name"":""laviea_shark_swimming_in_clear_Carr-400-50-7-.mp4"",""mime_type"":null},""subtitles"":null}",,,2023-11-27 14:02:46.992686
12
+ "{""video"":{""path"":""gradio_cached_examples/14/Video result/4b54d32b7e8f2a3cd333/lavieA_steam_train_moving_on_a_moun-230-50-7-.mp4"",""url"":null,""size"":null,""orig_name"":""lavieA_steam_train_moving_on_a_moun-230-50-7-.mp4"",""mime_type"":null},""subtitles"":null}",,,2023-11-27 14:02:57.458758
13
+ "{""video"":{""path"":""gradio_cached_examples/14/Video result/35727d2ebeb816c94d68/laviea_confused_grizzly_bear_in_cal-1000-50-7-.mp4"",""url"":null,""size"":null,""orig_name"":""laviea_confused_grizzly_bear_in_cal-1000-50-7-.mp4"",""mime_type"":null},""subtitles"":null}",,,2023-11-27 14:03:07.878403
base/huggingface-t2v/.DS_Store ADDED
Binary file (6.15 kB). View file
 
base/huggingface-t2v/__init__.py ADDED
File without changes
base/huggingface-t2v/requirements.txt ADDED
File without changes
base/models/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append(os.path.split(sys.path[0])[0])
4
+
5
+ from .unet import UNet3DConditionModel
6
+ from torch.optim.lr_scheduler import LambdaLR
7
+
8
+ def customized_lr_scheduler(optimizer, warmup_steps=5000): # 5000 from u-vit
9
+ from torch.optim.lr_scheduler import LambdaLR
10
+ def fn(step):
11
+ if warmup_steps > 0:
12
+ return min(step / warmup_steps, 1)
13
+ else:
14
+ return 1
15
+ return LambdaLR(optimizer, fn)
16
+
17
+
18
+ def get_lr_scheduler(optimizer, name, **kwargs):
19
+ if name == 'warmup':
20
+ return customized_lr_scheduler(optimizer, **kwargs)
21
+ elif name == 'cosine':
22
+ from torch.optim.lr_scheduler import CosineAnnealingLR
23
+ return CosineAnnealingLR(optimizer, **kwargs)
24
+ else:
25
+ raise NotImplementedError(name)
26
+
27
+ def get_models(args, sd_path):
28
+
29
+ if 'UNet' in args.model:
30
+ return UNet3DConditionModel.from_pretrained_2d(sd_path, subfolder="unet")
31
+ else:
32
+ raise '{} Model Not Supported!'.format(args.model)
33
+
base/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.9 kB). View file
 
base/models/__pycache__/attention.cpython-311.pyc ADDED
Binary file (33.7 kB). View file
 
base/models/__pycache__/resnet.cpython-311.pyc ADDED
Binary file (9.76 kB). View file
 
base/models/__pycache__/unet.cpython-311.pyc ADDED
Binary file (27.3 kB). View file
 
base/models/__pycache__/unet_blocks.cpython-311.pyc ADDED
Binary file (20.3 kB). View file
 
base/models/attention.py ADDED
@@ -0,0 +1,707 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2
+ import os
3
+ import sys
4
+ sys.path.append(os.path.split(sys.path[0])[0])
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Optional
8
+
9
+ import math
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from torch import nn
13
+
14
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
15
+ from diffusers.utils import BaseOutput
16
+ from diffusers.utils.import_utils import is_xformers_available
17
+ from diffusers.models.attention import FeedForward, AdaLayerNorm
18
+ from rotary_embedding_torch import RotaryEmbedding
19
+ from typing import Callable, Optional
20
+ from einops import rearrange, repeat
21
+
22
+ try:
23
+ from diffusers.models.modeling_utils import ModelMixin
24
+ except:
25
+ from diffusers.modeling_utils import ModelMixin # 0.11.1
26
+
27
+
28
+ @dataclass
29
+ class Transformer3DModelOutput(BaseOutput):
30
+ sample: torch.FloatTensor
31
+
32
+
33
+ if is_xformers_available():
34
+ import xformers
35
+ import xformers.ops
36
+ else:
37
+ xformers = None
38
+
39
+ def exists(x):
40
+ return x is not None
41
+
42
+
43
+ class CrossAttention(nn.Module):
44
+ r"""
45
+ copy from diffuser 0.11.1
46
+ A cross attention layer.
47
+ Parameters:
48
+ query_dim (`int`): The number of channels in the query.
49
+ cross_attention_dim (`int`, *optional*):
50
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
51
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
52
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
53
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
54
+ bias (`bool`, *optional*, defaults to False):
55
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ query_dim: int,
61
+ cross_attention_dim: Optional[int] = None,
62
+ heads: int = 8,
63
+ dim_head: int = 64,
64
+ dropout: float = 0.0,
65
+ bias=False,
66
+ upcast_attention: bool = False,
67
+ upcast_softmax: bool = False,
68
+ added_kv_proj_dim: Optional[int] = None,
69
+ norm_num_groups: Optional[int] = None,
70
+ use_relative_position: bool = False,
71
+ ):
72
+ super().__init__()
73
+ inner_dim = dim_head * heads
74
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
75
+ self.upcast_attention = upcast_attention
76
+ self.upcast_softmax = upcast_softmax
77
+
78
+ self.scale = dim_head**-0.5
79
+
80
+ self.heads = heads
81
+ self.dim_head = dim_head
82
+ # for slice_size > 0 the attention score computation
83
+ # is split across the batch axis to save memory
84
+ # You can set slice_size with `set_attention_slice`
85
+ self.sliceable_head_dim = heads
86
+ self._slice_size = None
87
+ self._use_memory_efficient_attention_xformers = False
88
+ self.added_kv_proj_dim = added_kv_proj_dim
89
+
90
+ if norm_num_groups is not None:
91
+ self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
92
+ else:
93
+ self.group_norm = None
94
+
95
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
96
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
97
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
98
+
99
+ if self.added_kv_proj_dim is not None:
100
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
101
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
102
+
103
+ self.to_out = nn.ModuleList([])
104
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
105
+ self.to_out.append(nn.Dropout(dropout))
106
+
107
+ self.use_relative_position = use_relative_position
108
+ if self.use_relative_position:
109
+ self.rotary_emb = RotaryEmbedding(min(32, dim_head))
110
+
111
+
112
+ def reshape_heads_to_batch_dim(self, tensor):
113
+ batch_size, seq_len, dim = tensor.shape
114
+ head_size = self.heads
115
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
116
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
117
+ return tensor
118
+
119
+ def reshape_batch_dim_to_heads(self, tensor):
120
+ batch_size, seq_len, dim = tensor.shape
121
+ head_size = self.heads
122
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
123
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
124
+ return tensor
125
+
126
+ def reshape_for_scores(self, tensor):
127
+ # split heads and dims
128
+ # tensor should be [b (h w)] f (d nd)
129
+ batch_size, seq_len, dim = tensor.shape
130
+ head_size = self.heads
131
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
132
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
133
+ return tensor
134
+
135
+ def same_batch_dim_to_heads(self, tensor):
136
+ batch_size, head_size, seq_len, dim = tensor.shape # [b (h w)] nd f d
137
+ tensor = tensor.reshape(batch_size, seq_len, dim * head_size)
138
+ return tensor
139
+
140
+ def set_attention_slice(self, slice_size):
141
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
142
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
143
+
144
+ self._slice_size = slice_size
145
+
146
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, use_image_num=None):
147
+ batch_size, sequence_length, _ = hidden_states.shape
148
+
149
+ encoder_hidden_states = encoder_hidden_states
150
+
151
+ if self.group_norm is not None:
152
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
153
+
154
+ query = self.to_q(hidden_states) # [b (h w)] f (nd * d)
155
+
156
+ # print('before reshpape query shape', query.shape)
157
+ dim = query.shape[-1]
158
+ if not self.use_relative_position:
159
+ query = self.reshape_heads_to_batch_dim(query) # [b (h w) nd] f d
160
+ # print('after reshape query shape', query.shape)
161
+
162
+ if self.added_kv_proj_dim is not None:
163
+ key = self.to_k(hidden_states)
164
+ value = self.to_v(hidden_states)
165
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
166
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
167
+
168
+ key = self.reshape_heads_to_batch_dim(key)
169
+ value = self.reshape_heads_to_batch_dim(value)
170
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
171
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
172
+
173
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
174
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
175
+ else:
176
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
177
+ key = self.to_k(encoder_hidden_states)
178
+ value = self.to_v(encoder_hidden_states)
179
+
180
+ if not self.use_relative_position:
181
+ key = self.reshape_heads_to_batch_dim(key)
182
+ value = self.reshape_heads_to_batch_dim(value)
183
+
184
+ if attention_mask is not None:
185
+ if attention_mask.shape[-1] != query.shape[1]:
186
+ target_length = query.shape[1]
187
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
188
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
189
+
190
+ # attention, what we cannot get enough of
191
+ if self._use_memory_efficient_attention_xformers:
192
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
193
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
194
+ hidden_states = hidden_states.to(query.dtype)
195
+ else:
196
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
197
+ hidden_states = self._attention(query, key, value, attention_mask)
198
+ else:
199
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
200
+
201
+ # linear proj
202
+ hidden_states = self.to_out[0](hidden_states)
203
+
204
+ # dropout
205
+ hidden_states = self.to_out[1](hidden_states)
206
+ return hidden_states
207
+
208
+
209
+ def _attention(self, query, key, value, attention_mask=None):
210
+ if self.upcast_attention:
211
+ query = query.float()
212
+ key = key.float()
213
+
214
+ attention_scores = torch.baddbmm(
215
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
216
+ query,
217
+ key.transpose(-1, -2),
218
+ beta=0,
219
+ alpha=self.scale,
220
+ )
221
+
222
+ if attention_mask is not None:
223
+ attention_scores = attention_scores + attention_mask
224
+
225
+ if self.upcast_softmax:
226
+ attention_scores = attention_scores.float()
227
+
228
+ attention_probs = attention_scores.softmax(dim=-1)
229
+
230
+ # cast back to the original dtype
231
+ attention_probs = attention_probs.to(value.dtype)
232
+
233
+ # compute attention output
234
+ hidden_states = torch.bmm(attention_probs, value)
235
+
236
+ # reshape hidden_states
237
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
238
+
239
+ return hidden_states
240
+
241
+ def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
242
+ batch_size_attention = query.shape[0]
243
+ hidden_states = torch.zeros(
244
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
245
+ )
246
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
247
+ for i in range(hidden_states.shape[0] // slice_size):
248
+ start_idx = i * slice_size
249
+ end_idx = (i + 1) * slice_size
250
+
251
+ query_slice = query[start_idx:end_idx]
252
+ key_slice = key[start_idx:end_idx]
253
+
254
+ if self.upcast_attention:
255
+ query_slice = query_slice.float()
256
+ key_slice = key_slice.float()
257
+
258
+ attn_slice = torch.baddbmm(
259
+ torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
260
+ query_slice,
261
+ key_slice.transpose(-1, -2),
262
+ beta=0,
263
+ alpha=self.scale,
264
+ )
265
+
266
+ if attention_mask is not None:
267
+ attn_slice = attn_slice + attention_mask[start_idx:end_idx]
268
+
269
+ if self.upcast_softmax:
270
+ attn_slice = attn_slice.float()
271
+
272
+ attn_slice = attn_slice.softmax(dim=-1)
273
+
274
+ # cast back to the original dtype
275
+ attn_slice = attn_slice.to(value.dtype)
276
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
277
+
278
+ hidden_states[start_idx:end_idx] = attn_slice
279
+
280
+ # reshape hidden_states
281
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
282
+ return hidden_states
283
+
284
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
285
+ # TODO attention_mask
286
+ query = query.contiguous()
287
+ key = key.contiguous()
288
+ value = value.contiguous()
289
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
290
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
291
+ return hidden_states
292
+
293
+
294
+ class Transformer3DModel(ModelMixin, ConfigMixin):
295
+ @register_to_config
296
+ def __init__(
297
+ self,
298
+ num_attention_heads: int = 16,
299
+ attention_head_dim: int = 88,
300
+ in_channels: Optional[int] = None,
301
+ num_layers: int = 1,
302
+ dropout: float = 0.0,
303
+ norm_num_groups: int = 32,
304
+ cross_attention_dim: Optional[int] = None,
305
+ attention_bias: bool = False,
306
+ activation_fn: str = "geglu",
307
+ num_embeds_ada_norm: Optional[int] = None,
308
+ use_linear_projection: bool = False,
309
+ only_cross_attention: bool = False,
310
+ upcast_attention: bool = False,
311
+ use_first_frame: bool = False,
312
+ use_relative_position: bool = False,
313
+ rotary_emb: bool = None,
314
+ ):
315
+ super().__init__()
316
+ self.use_linear_projection = use_linear_projection
317
+ self.num_attention_heads = num_attention_heads
318
+ self.attention_head_dim = attention_head_dim
319
+ inner_dim = num_attention_heads * attention_head_dim
320
+
321
+ # Define input layers
322
+ self.in_channels = in_channels
323
+
324
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
325
+ if use_linear_projection:
326
+ self.proj_in = nn.Linear(in_channels, inner_dim)
327
+ else:
328
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
329
+
330
+ # Define transformers blocks
331
+ self.transformer_blocks = nn.ModuleList(
332
+ [
333
+ BasicTransformerBlock(
334
+ inner_dim,
335
+ num_attention_heads,
336
+ attention_head_dim,
337
+ dropout=dropout,
338
+ cross_attention_dim=cross_attention_dim,
339
+ activation_fn=activation_fn,
340
+ num_embeds_ada_norm=num_embeds_ada_norm,
341
+ attention_bias=attention_bias,
342
+ only_cross_attention=only_cross_attention,
343
+ upcast_attention=upcast_attention,
344
+ use_first_frame=use_first_frame,
345
+ use_relative_position=use_relative_position,
346
+ rotary_emb=rotary_emb,
347
+ )
348
+ for d in range(num_layers)
349
+ ]
350
+ )
351
+
352
+ # 4. Define output layers
353
+ if use_linear_projection:
354
+ self.proj_out = nn.Linear(in_channels, inner_dim)
355
+ else:
356
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
357
+
358
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, use_image_num=None, return_dict: bool = True):
359
+ # Input
360
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
361
+
362
+ video_length = hidden_states.shape[2]
363
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous()
364
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length).contiguous()
365
+
366
+ batch, channel, height, weight = hidden_states.shape
367
+ residual = hidden_states
368
+
369
+ hidden_states = self.norm(hidden_states)
370
+ if not self.use_linear_projection:
371
+ hidden_states = self.proj_in(hidden_states)
372
+ inner_dim = hidden_states.shape[1]
373
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
374
+ else:
375
+ inner_dim = hidden_states.shape[1]
376
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
377
+ hidden_states = self.proj_in(hidden_states)
378
+
379
+ # Blocks
380
+ for block in self.transformer_blocks:
381
+ hidden_states = block(
382
+ hidden_states,
383
+ encoder_hidden_states=encoder_hidden_states,
384
+ timestep=timestep,
385
+ video_length=video_length,
386
+ use_image_num=use_image_num,
387
+ )
388
+
389
+ # Output
390
+ if not self.use_linear_projection:
391
+ hidden_states = (
392
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
393
+ )
394
+ hidden_states = self.proj_out(hidden_states)
395
+ else:
396
+ hidden_states = self.proj_out(hidden_states)
397
+ hidden_states = (
398
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
399
+ )
400
+
401
+ output = hidden_states + residual
402
+
403
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length + use_image_num).contiguous()
404
+ if not return_dict:
405
+ return (output,)
406
+
407
+ return Transformer3DModelOutput(sample=output)
408
+
409
+
410
+ class BasicTransformerBlock(nn.Module):
411
+ def __init__(
412
+ self,
413
+ dim: int,
414
+ num_attention_heads: int,
415
+ attention_head_dim: int,
416
+ dropout=0.0,
417
+ cross_attention_dim: Optional[int] = None,
418
+ activation_fn: str = "geglu",
419
+ num_embeds_ada_norm: Optional[int] = None,
420
+ attention_bias: bool = False,
421
+ only_cross_attention: bool = False,
422
+ upcast_attention: bool = False,
423
+ use_first_frame: bool = False,
424
+ use_relative_position: bool = False,
425
+ rotary_emb: bool = False,
426
+ ):
427
+ super().__init__()
428
+ self.only_cross_attention = only_cross_attention
429
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
430
+ self.use_first_frame = use_first_frame
431
+
432
+ # Spatial-Attn
433
+ self.attn1 = CrossAttention(
434
+ query_dim=dim,
435
+ heads=num_attention_heads,
436
+ dim_head=attention_head_dim,
437
+ dropout=dropout,
438
+ bias=attention_bias,
439
+ cross_attention_dim=None,
440
+ upcast_attention=upcast_attention,
441
+ )
442
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
443
+
444
+ # Text Cross-Attn
445
+ if cross_attention_dim is not None:
446
+ self.attn2 = CrossAttention(
447
+ query_dim=dim,
448
+ cross_attention_dim=cross_attention_dim,
449
+ heads=num_attention_heads,
450
+ dim_head=attention_head_dim,
451
+ dropout=dropout,
452
+ bias=attention_bias,
453
+ upcast_attention=upcast_attention,
454
+ )
455
+ else:
456
+ self.attn2 = None
457
+
458
+ if cross_attention_dim is not None:
459
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
460
+ else:
461
+ self.norm2 = None
462
+
463
+ # Temp
464
+ self.attn_temp = TemporalAttention(
465
+ query_dim=dim,
466
+ heads=num_attention_heads,
467
+ dim_head=attention_head_dim,
468
+ dropout=dropout,
469
+ bias=attention_bias,
470
+ cross_attention_dim=None,
471
+ upcast_attention=upcast_attention,
472
+ rotary_emb=rotary_emb,
473
+ )
474
+ self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
475
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
476
+
477
+
478
+ # Feed-forward
479
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
480
+ self.norm3 = nn.LayerNorm(dim)
481
+
482
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, op=None):
483
+
484
+ if not is_xformers_available():
485
+ print("Here is how to install it")
486
+ raise ModuleNotFoundError(
487
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
488
+ " xformers",
489
+ name="xformers",
490
+ )
491
+ elif not torch.cuda.is_available():
492
+ raise ValueError(
493
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
494
+ " available for GPU "
495
+ )
496
+ else:
497
+ try:
498
+ # Make sure we can run the memory efficient attention
499
+ _ = xformers.ops.memory_efficient_attention(
500
+ torch.randn((1, 2, 40), device="cuda"),
501
+ torch.randn((1, 2, 40), device="cuda"),
502
+ torch.randn((1, 2, 40), device="cuda"),
503
+ )
504
+ except Exception as e:
505
+ raise e
506
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
507
+ if self.attn2 is not None:
508
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
509
+ # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
510
+
511
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None, use_image_num=None):
512
+ # SparseCausal-Attention
513
+ norm_hidden_states = (
514
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
515
+ )
516
+
517
+ if self.only_cross_attention:
518
+ hidden_states = (
519
+ self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
520
+ )
521
+ else:
522
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, use_image_num=use_image_num) + hidden_states
523
+
524
+ if self.attn2 is not None:
525
+ # Cross-Attention
526
+ norm_hidden_states = (
527
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
528
+ )
529
+ hidden_states = (
530
+ self.attn2(
531
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
532
+ )
533
+ + hidden_states
534
+ )
535
+
536
+ # Temporal Attention
537
+ if self.training:
538
+ d = hidden_states.shape[1]
539
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length + use_image_num).contiguous()
540
+ hidden_states_video = hidden_states[:, :video_length, :]
541
+ hidden_states_image = hidden_states[:, video_length:, :]
542
+ norm_hidden_states_video = (
543
+ self.norm_temp(hidden_states_video, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states_video)
544
+ )
545
+ hidden_states_video = self.attn_temp(norm_hidden_states_video) + hidden_states_video
546
+ hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
547
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d).contiguous()
548
+ else:
549
+ d = hidden_states.shape[1]
550
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length + use_image_num).contiguous()
551
+ norm_hidden_states = (
552
+ self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
553
+ )
554
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
555
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d).contiguous()
556
+
557
+ # Feed-forward
558
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
559
+
560
+ return hidden_states
561
+
562
+ class TemporalAttention(CrossAttention):
563
+ def __init__(self,
564
+ query_dim: int,
565
+ cross_attention_dim: Optional[int] = None,
566
+ heads: int = 8,
567
+ dim_head: int = 64,
568
+ dropout: float = 0.0,
569
+ bias=False,
570
+ upcast_attention: bool = False,
571
+ upcast_softmax: bool = False,
572
+ added_kv_proj_dim: Optional[int] = None,
573
+ norm_num_groups: Optional[int] = None,
574
+ rotary_emb=None):
575
+ super().__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax, added_kv_proj_dim, norm_num_groups)
576
+ # relative time positional embeddings
577
+ self.time_rel_pos_bias = RelativePositionBias(heads=heads, max_distance=32) # realistically will not be able to generate that many frames of video... yet
578
+ self.rotary_emb = rotary_emb
579
+
580
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
581
+ time_rel_pos_bias = self.time_rel_pos_bias(hidden_states.shape[1], device=hidden_states.device)
582
+ batch_size, sequence_length, _ = hidden_states.shape
583
+
584
+ encoder_hidden_states = encoder_hidden_states
585
+
586
+ if self.group_norm is not None:
587
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
588
+
589
+ query = self.to_q(hidden_states) # [b (h w)] f (nd * d)
590
+ dim = query.shape[-1]
591
+
592
+ if self.added_kv_proj_dim is not None:
593
+ key = self.to_k(hidden_states)
594
+ value = self.to_v(hidden_states)
595
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
596
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
597
+
598
+ key = self.reshape_heads_to_batch_dim(key)
599
+ value = self.reshape_heads_to_batch_dim(value)
600
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
601
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
602
+
603
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
604
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
605
+ else:
606
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
607
+ key = self.to_k(encoder_hidden_states)
608
+ value = self.to_v(encoder_hidden_states)
609
+
610
+ if attention_mask is not None:
611
+ if attention_mask.shape[-1] != query.shape[1]:
612
+ target_length = query.shape[1]
613
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
614
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
615
+
616
+ # attention, what we cannot get enough of
617
+ if self._use_memory_efficient_attention_xformers:
618
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
619
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
620
+ hidden_states = hidden_states.to(query.dtype)
621
+ else:
622
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
623
+ hidden_states = self._attention(query, key, value, attention_mask, time_rel_pos_bias)
624
+ else:
625
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
626
+
627
+ # linear proj
628
+ hidden_states = self.to_out[0](hidden_states)
629
+
630
+ # dropout
631
+ hidden_states = self.to_out[1](hidden_states)
632
+ return hidden_states
633
+
634
+ def _attention(self, query, key, value, attention_mask=None, time_rel_pos_bias=None):
635
+ if self.upcast_attention:
636
+ query = query.float()
637
+ key = key.float()
638
+
639
+ # reshape for adding time positional bais
640
+ query = self.scale * rearrange(query, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads
641
+ key = rearrange(key, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads
642
+ value = rearrange(value, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads
643
+
644
+ if exists(self.rotary_emb):
645
+ query = self.rotary_emb.rotate_queries_or_keys(query)
646
+ key = self.rotary_emb.rotate_queries_or_keys(key)
647
+
648
+ attention_scores = torch.einsum('... h i d, ... h j d -> ... h i j', query, key)
649
+
650
+ attention_scores = attention_scores + time_rel_pos_bias
651
+
652
+ if attention_mask is not None:
653
+ # add attention mask
654
+ attention_scores = attention_scores + attention_mask
655
+
656
+ attention_scores = attention_scores - attention_scores.amax(dim = -1, keepdim = True).detach()
657
+
658
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
659
+ # print(attention_probs[0][0])
660
+
661
+ # cast back to the original dtype
662
+ attention_probs = attention_probs.to(value.dtype)
663
+
664
+ # compute attention output
665
+ hidden_states = torch.einsum('... h i j, ... h j d -> ... h i d', attention_probs, value)
666
+ hidden_states = rearrange(hidden_states, 'b h f d -> b f (h d)')
667
+ return hidden_states
668
+
669
+ class RelativePositionBias(nn.Module):
670
+ def __init__(
671
+ self,
672
+ heads=8,
673
+ num_buckets=32,
674
+ max_distance=128,
675
+ ):
676
+ super().__init__()
677
+ self.num_buckets = num_buckets
678
+ self.max_distance = max_distance
679
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
680
+
681
+ @staticmethod
682
+ def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
683
+ ret = 0
684
+ n = -relative_position
685
+
686
+ num_buckets //= 2
687
+ ret += (n < 0).long() * num_buckets
688
+ n = torch.abs(n)
689
+
690
+ max_exact = num_buckets // 2
691
+ is_small = n < max_exact
692
+
693
+ val_if_large = max_exact + (
694
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
695
+ ).long()
696
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
697
+
698
+ ret += torch.where(is_small, n, val_if_large)
699
+ return ret
700
+
701
+ def forward(self, n, device):
702
+ q_pos = torch.arange(n, dtype = torch.long, device = device)
703
+ k_pos = torch.arange(n, dtype = torch.long, device = device)
704
+ rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
705
+ rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
706
+ values = self.relative_attention_bias(rp_bucket)
707
+ return rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames
base/models/clip.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy
2
+ import torch.nn as nn
3
+ from transformers import CLIPTokenizer, CLIPTextModel
4
+
5
+ import transformers
6
+ transformers.logging.set_verbosity_error()
7
+
8
+ """
9
+ Will encounter following warning:
10
+ - This IS expected if you are initializing CLIPTextModel from the checkpoint of a model trained on another task
11
+ or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
12
+ - This IS NOT expected if you are initializing CLIPTextModel from the checkpoint of a model
13
+ that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
14
+
15
+ https://github.com/CompVis/stable-diffusion/issues/97
16
+ according to this issue, this warning is safe.
17
+
18
+ This is expected since the vision backbone of the CLIP model is not needed to run Stable Diffusion.
19
+ You can safely ignore the warning, it is not an error.
20
+
21
+ This clip usage is from U-ViT and same with Stable Diffusion.
22
+ """
23
+
24
+ class AbstractEncoder(nn.Module):
25
+ def __init__(self):
26
+ super().__init__()
27
+
28
+ def encode(self, *args, **kwargs):
29
+ raise NotImplementedError
30
+
31
+
32
+ class FrozenCLIPEmbedder(AbstractEncoder):
33
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
34
+ # def __init__(self, version="openai/clip-vit-huge-patch14", device="cuda", max_length=77):
35
+ def __init__(self, path, device="cuda", max_length=77):
36
+ super().__init__()
37
+ self.tokenizer = CLIPTokenizer.from_pretrained(path, subfolder="tokenizer")
38
+ self.transformer = CLIPTextModel.from_pretrained(path, subfolder='text_encoder')
39
+ self.device = device
40
+ self.max_length = max_length
41
+ self.freeze()
42
+
43
+ def freeze(self):
44
+ self.transformer = self.transformer.eval()
45
+ for param in self.parameters():
46
+ param.requires_grad = False
47
+
48
+ def forward(self, text):
49
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
50
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
51
+ tokens = batch_encoding["input_ids"].to(self.device)
52
+ outputs = self.transformer(input_ids=tokens)
53
+
54
+ z = outputs.last_hidden_state
55
+ return z
56
+
57
+ def encode(self, text):
58
+ return self(text)
59
+
60
+
61
+ class TextEmbedder(nn.Module):
62
+ """
63
+ Embeds text prompt into vector representations. Also handles text dropout for classifier-free guidance.
64
+ """
65
+ def __init__(self, path, dropout_prob=0.1):
66
+ super().__init__()
67
+ self.text_encodder = FrozenCLIPEmbedder(path=path)
68
+ self.dropout_prob = dropout_prob
69
+
70
+ def token_drop(self, text_prompts, force_drop_ids=None):
71
+ """
72
+ Drops text to enable classifier-free guidance.
73
+ """
74
+ if force_drop_ids is None:
75
+ drop_ids = numpy.random.uniform(0, 1, len(text_prompts)) < self.dropout_prob
76
+ else:
77
+ # TODO
78
+ drop_ids = force_drop_ids == 1
79
+ labels = list(numpy.where(drop_ids, "", text_prompts))
80
+ # print(labels)
81
+ return labels
82
+
83
+ def forward(self, text_prompts, train, force_drop_ids=None):
84
+ use_dropout = self.dropout_prob > 0
85
+ if (train and use_dropout) or (force_drop_ids is not None):
86
+ text_prompts = self.token_drop(text_prompts, force_drop_ids)
87
+ embeddings = self.text_encodder(text_prompts)
88
+ return embeddings
89
+
90
+
91
+ if __name__ == '__main__':
92
+
93
+ r"""
94
+ Returns:
95
+
96
+ Examples from CLIPTextModel:
97
+
98
+ ```python
99
+ >>> from transformers import AutoTokenizer, CLIPTextModel
100
+
101
+ >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
102
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
103
+
104
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
105
+
106
+ >>> outputs = model(**inputs)
107
+ >>> last_hidden_state = outputs.last_hidden_state
108
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
109
+ ```"""
110
+
111
+ import torch
112
+
113
+ device = "cuda" if torch.cuda.is_available() else "cpu"
114
+
115
+ text_encoder = TextEmbedder(path='/mnt/petrelfs/maxin/work/pretrained/stable-diffusion-2-1-base',
116
+ dropout_prob=0.00001).to(device)
117
+
118
+ text_prompt = [["a photo of a cat", "a photo of a cat"], ["a photo of a dog", "a photo of a cat"], ['a photo of a dog human', "a photo of a cat"]]
119
+ output = text_encoder(text_prompts=text_prompt, train=False)
120
+ print(output.shape)
base/models/resnet.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2
+ import os
3
+ import sys
4
+ sys.path.append(os.path.split(sys.path[0])[0])
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from einops import rearrange
11
+
12
+
13
+ class InflatedConv3d(nn.Conv2d):
14
+ def forward(self, x):
15
+ video_length = x.shape[2]
16
+
17
+ x = rearrange(x, "b c f h w -> (b f) c h w")
18
+ x = super().forward(x)
19
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
20
+
21
+ return x
22
+
23
+
24
+ class Upsample3D(nn.Module):
25
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
26
+ super().__init__()
27
+ self.channels = channels
28
+ self.out_channels = out_channels or channels
29
+ self.use_conv = use_conv
30
+ self.use_conv_transpose = use_conv_transpose
31
+ self.name = name
32
+
33
+ conv = None
34
+ if use_conv_transpose:
35
+ raise NotImplementedError
36
+ elif use_conv:
37
+ conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
38
+
39
+ if name == "conv":
40
+ self.conv = conv
41
+ else:
42
+ self.Conv2d_0 = conv
43
+
44
+ def forward(self, hidden_states, output_size=None):
45
+ assert hidden_states.shape[1] == self.channels
46
+
47
+ if self.use_conv_transpose:
48
+ raise NotImplementedError
49
+
50
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
51
+ dtype = hidden_states.dtype
52
+ if dtype == torch.bfloat16:
53
+ hidden_states = hidden_states.to(torch.float32)
54
+
55
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
56
+ if hidden_states.shape[0] >= 64:
57
+ hidden_states = hidden_states.contiguous()
58
+
59
+ # if `output_size` is passed we force the interpolation output
60
+ # size and do not make use of `scale_factor=2`
61
+ if output_size is None:
62
+ hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
63
+ else:
64
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
65
+
66
+ # If the input is bfloat16, we cast back to bfloat16
67
+ if dtype == torch.bfloat16:
68
+ hidden_states = hidden_states.to(dtype)
69
+
70
+ if self.use_conv:
71
+ if self.name == "conv":
72
+ hidden_states = self.conv(hidden_states)
73
+ else:
74
+ hidden_states = self.Conv2d_0(hidden_states)
75
+
76
+ return hidden_states
77
+
78
+
79
+ class Downsample3D(nn.Module):
80
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
81
+ super().__init__()
82
+ self.channels = channels
83
+ self.out_channels = out_channels or channels
84
+ self.use_conv = use_conv
85
+ self.padding = padding
86
+ stride = 2
87
+ self.name = name
88
+
89
+ if use_conv:
90
+ conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
91
+ else:
92
+ raise NotImplementedError
93
+
94
+ if name == "conv":
95
+ self.Conv2d_0 = conv
96
+ self.conv = conv
97
+ elif name == "Conv2d_0":
98
+ self.conv = conv
99
+ else:
100
+ self.conv = conv
101
+
102
+ def forward(self, hidden_states):
103
+ assert hidden_states.shape[1] == self.channels
104
+ if self.use_conv and self.padding == 0:
105
+ raise NotImplementedError
106
+
107
+ assert hidden_states.shape[1] == self.channels
108
+ hidden_states = self.conv(hidden_states)
109
+
110
+ return hidden_states
111
+
112
+
113
+ class ResnetBlock3D(nn.Module):
114
+ def __init__(
115
+ self,
116
+ *,
117
+ in_channels,
118
+ out_channels=None,
119
+ conv_shortcut=False,
120
+ dropout=0.0,
121
+ temb_channels=512,
122
+ groups=32,
123
+ groups_out=None,
124
+ pre_norm=True,
125
+ eps=1e-6,
126
+ non_linearity="swish",
127
+ time_embedding_norm="default",
128
+ output_scale_factor=1.0,
129
+ use_in_shortcut=None,
130
+ ):
131
+ super().__init__()
132
+ self.pre_norm = pre_norm
133
+ self.pre_norm = True
134
+ self.in_channels = in_channels
135
+ out_channels = in_channels if out_channels is None else out_channels
136
+ self.out_channels = out_channels
137
+ self.use_conv_shortcut = conv_shortcut
138
+ self.time_embedding_norm = time_embedding_norm
139
+ self.output_scale_factor = output_scale_factor
140
+
141
+ if groups_out is None:
142
+ groups_out = groups
143
+
144
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
145
+
146
+ self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
147
+
148
+ if temb_channels is not None:
149
+ if self.time_embedding_norm == "default":
150
+ time_emb_proj_out_channels = out_channels
151
+ elif self.time_embedding_norm == "scale_shift":
152
+ time_emb_proj_out_channels = out_channels * 2
153
+ else:
154
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
155
+
156
+ self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
157
+ else:
158
+ self.time_emb_proj = None
159
+
160
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
161
+ self.dropout = torch.nn.Dropout(dropout)
162
+ self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
163
+
164
+ if non_linearity == "swish":
165
+ self.nonlinearity = lambda x: F.silu(x)
166
+ elif non_linearity == "mish":
167
+ self.nonlinearity = Mish()
168
+ elif non_linearity == "silu":
169
+ self.nonlinearity = nn.SiLU()
170
+
171
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
172
+
173
+ self.conv_shortcut = None
174
+ if self.use_in_shortcut:
175
+ self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
176
+
177
+ def forward(self, input_tensor, temb):
178
+ hidden_states = input_tensor
179
+
180
+ hidden_states = self.norm1(hidden_states)
181
+ hidden_states = self.nonlinearity(hidden_states)
182
+
183
+ hidden_states = self.conv1(hidden_states)
184
+
185
+ if temb is not None:
186
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
187
+
188
+ if temb is not None and self.time_embedding_norm == "default":
189
+ hidden_states = hidden_states + temb
190
+
191
+ hidden_states = self.norm2(hidden_states)
192
+
193
+ if temb is not None and self.time_embedding_norm == "scale_shift":
194
+ scale, shift = torch.chunk(temb, 2, dim=1)
195
+ hidden_states = hidden_states * (1 + scale) + shift
196
+
197
+ hidden_states = self.nonlinearity(hidden_states)
198
+
199
+ hidden_states = self.dropout(hidden_states)
200
+ hidden_states = self.conv2(hidden_states)
201
+
202
+ if self.conv_shortcut is not None:
203
+ input_tensor = self.conv_shortcut(input_tensor)
204
+
205
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
206
+
207
+ return output_tensor
208
+
209
+
210
+ class Mish(torch.nn.Module):
211
+ def forward(self, hidden_states):
212
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
base/models/temporal_attention.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from typing import Optional
4
+ from rotary_embedding_torch import RotaryEmbedding
5
+ from dataclasses import dataclass
6
+ from diffusers.utils import BaseOutput
7
+ from diffusers.utils.import_utils import is_xformers_available
8
+ import torch.nn.functional as F
9
+ from einops import rearrange, repeat
10
+ import math
11
+
12
+ @dataclass
13
+ class Transformer3DModelOutput(BaseOutput):
14
+ sample: torch.FloatTensor
15
+
16
+
17
+ if is_xformers_available():
18
+ import xformers
19
+ import xformers.ops
20
+ else:
21
+ xformers = None
22
+
23
+ def exists(x):
24
+ return x is not None
25
+
26
+ class CrossAttention(nn.Module):
27
+ r"""
28
+ copy from diffuser 0.11.1
29
+ A cross attention layer.
30
+ Parameters:
31
+ query_dim (`int`): The number of channels in the query.
32
+ cross_attention_dim (`int`, *optional*):
33
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
34
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
35
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
36
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
37
+ bias (`bool`, *optional*, defaults to False):
38
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ query_dim: int,
44
+ cross_attention_dim: Optional[int] = None,
45
+ heads: int = 8,
46
+ dim_head: int = 64,
47
+ dropout: float = 0.0,
48
+ bias=False,
49
+ upcast_attention: bool = False,
50
+ upcast_softmax: bool = False,
51
+ added_kv_proj_dim: Optional[int] = None,
52
+ norm_num_groups: Optional[int] = None,
53
+ use_relative_position: bool = False,
54
+ ):
55
+ super().__init__()
56
+ # print('num head', heads)
57
+ inner_dim = dim_head * heads
58
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
59
+ self.upcast_attention = upcast_attention
60
+ self.upcast_softmax = upcast_softmax
61
+
62
+ self.scale = dim_head**-0.5
63
+
64
+ self.heads = heads
65
+ self.dim_head = dim_head
66
+ # for slice_size > 0 the attention score computation
67
+ # is split across the batch axis to save memory
68
+ # You can set slice_size with `set_attention_slice`
69
+ self.sliceable_head_dim = heads
70
+ self._slice_size = None
71
+ self._use_memory_efficient_attention_xformers = False # No use xformers for temporal attention
72
+ self.added_kv_proj_dim = added_kv_proj_dim
73
+
74
+ if norm_num_groups is not None:
75
+ self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
76
+ else:
77
+ self.group_norm = None
78
+
79
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
80
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
81
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
82
+
83
+ if self.added_kv_proj_dim is not None:
84
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
85
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
86
+
87
+ self.to_out = nn.ModuleList([])
88
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
89
+ self.to_out.append(nn.Dropout(dropout))
90
+
91
+ def reshape_heads_to_batch_dim(self, tensor):
92
+ batch_size, seq_len, dim = tensor.shape
93
+ head_size = self.heads
94
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
95
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
96
+ return tensor
97
+
98
+ def reshape_batch_dim_to_heads(self, tensor):
99
+ batch_size, seq_len, dim = tensor.shape
100
+ head_size = self.heads
101
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
102
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
103
+ return tensor
104
+
105
+ def reshape_for_scores(self, tensor):
106
+ # split heads and dims
107
+ # tensor should be [b (h w)] f (d nd)
108
+ batch_size, seq_len, dim = tensor.shape
109
+ head_size = self.heads
110
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
111
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
112
+ return tensor
113
+
114
+ def same_batch_dim_to_heads(self, tensor):
115
+ batch_size, head_size, seq_len, dim = tensor.shape # [b (h w)] nd f d
116
+ tensor = tensor.reshape(batch_size, seq_len, dim * head_size)
117
+ return tensor
118
+
119
+ def set_attention_slice(self, slice_size):
120
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
121
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
122
+
123
+ self._slice_size = slice_size
124
+
125
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, use_image_num=None):
126
+ batch_size, sequence_length, _ = hidden_states.shape
127
+
128
+ encoder_hidden_states = encoder_hidden_states
129
+
130
+ if self.group_norm is not None:
131
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
132
+
133
+ query = self.to_q(hidden_states) # [b (h w)] f (nd * d)
134
+
135
+ # print('before reshpape query shape', query.shape)
136
+ dim = query.shape[-1]
137
+ query = self.reshape_heads_to_batch_dim(query) # [b (h w) nd] f d
138
+ # print('after reshape query shape', query.shape)
139
+
140
+ if self.added_kv_proj_dim is not None:
141
+ key = self.to_k(hidden_states)
142
+ value = self.to_v(hidden_states)
143
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
144
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
145
+
146
+ key = self.reshape_heads_to_batch_dim(key)
147
+ value = self.reshape_heads_to_batch_dim(value)
148
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
149
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
150
+
151
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
152
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
153
+ else:
154
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
155
+ key = self.to_k(encoder_hidden_states)
156
+ value = self.to_v(encoder_hidden_states)
157
+
158
+ key = self.reshape_heads_to_batch_dim(key)
159
+ value = self.reshape_heads_to_batch_dim(value)
160
+
161
+ if attention_mask is not None:
162
+ if attention_mask.shape[-1] != query.shape[1]:
163
+ target_length = query.shape[1]
164
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
165
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
166
+
167
+ hidden_states = self._attention(query, key, value, attention_mask)
168
+
169
+ # linear proj
170
+ hidden_states = self.to_out[0](hidden_states)
171
+
172
+ # dropout
173
+ hidden_states = self.to_out[1](hidden_states)
174
+ return hidden_states
175
+
176
+
177
+ def _attention(self, query, key, value, attention_mask=None):
178
+ if self.upcast_attention:
179
+ query = query.float()
180
+ key = key.float()
181
+
182
+ attention_scores = torch.baddbmm(
183
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
184
+ query,
185
+ key.transpose(-1, -2),
186
+ beta=0,
187
+ alpha=self.scale,
188
+ )
189
+
190
+ if attention_mask is not None:
191
+ attention_scores = attention_scores + attention_mask
192
+
193
+ if self.upcast_softmax:
194
+ attention_scores = attention_scores.float()
195
+
196
+ attention_probs = attention_scores.softmax(dim=-1)
197
+ attention_probs = attention_probs.to(value.dtype)
198
+ # compute attention output
199
+ hidden_states = torch.bmm(attention_probs, value)
200
+ # reshape hidden_states
201
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
202
+ return hidden_states
203
+
204
+ def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
205
+ batch_size_attention = query.shape[0]
206
+ hidden_states = torch.zeros(
207
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
208
+ )
209
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
210
+ for i in range(hidden_states.shape[0] // slice_size):
211
+ start_idx = i * slice_size
212
+ end_idx = (i + 1) * slice_size
213
+
214
+ query_slice = query[start_idx:end_idx]
215
+ key_slice = key[start_idx:end_idx]
216
+
217
+ if self.upcast_attention:
218
+ query_slice = query_slice.float()
219
+ key_slice = key_slice.float()
220
+
221
+ attn_slice = torch.baddbmm(
222
+ torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
223
+ query_slice,
224
+ key_slice.transpose(-1, -2),
225
+ beta=0,
226
+ alpha=self.scale,
227
+ )
228
+
229
+ if attention_mask is not None:
230
+ attn_slice = attn_slice + attention_mask[start_idx:end_idx]
231
+
232
+ if self.upcast_softmax:
233
+ attn_slice = attn_slice.float()
234
+
235
+ attn_slice = attn_slice.softmax(dim=-1)
236
+
237
+ # cast back to the original dtype
238
+ attn_slice = attn_slice.to(value.dtype)
239
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
240
+
241
+ hidden_states[start_idx:end_idx] = attn_slice
242
+
243
+ # reshape hidden_states
244
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
245
+ return hidden_states
246
+
247
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
248
+ # TODO attention_mask
249
+ query = query.contiguous()
250
+ key = key.contiguous()
251
+ value = value.contiguous()
252
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
253
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
254
+ return hidden_states
255
+
256
+ class TemporalAttention(CrossAttention):
257
+ def __init__(self,
258
+ query_dim: int,
259
+ cross_attention_dim: Optional[int] = None,
260
+ heads: int = 8,
261
+ dim_head: int = 64,
262
+ dropout: float = 0.0,
263
+ bias=False,
264
+ upcast_attention: bool = False,
265
+ upcast_softmax: bool = False,
266
+ added_kv_proj_dim: Optional[int] = None,
267
+ norm_num_groups: Optional[int] = None,
268
+ rotary_emb=None):
269
+ super().__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax, added_kv_proj_dim, norm_num_groups)
270
+ # relative time positional embeddings
271
+ self.time_rel_pos_bias = RelativePositionBias(heads=heads, max_distance=32) # realistically will not be able to generate that many frames of video... yet
272
+ self.rotary_emb = rotary_emb
273
+
274
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
275
+ time_rel_pos_bias = self.time_rel_pos_bias(hidden_states.shape[1], device=hidden_states.device)
276
+ batch_size, sequence_length, _ = hidden_states.shape
277
+
278
+ encoder_hidden_states = encoder_hidden_states
279
+
280
+ if self.group_norm is not None:
281
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
282
+
283
+ query = self.to_q(hidden_states) # [b (h w)] f (nd * d)
284
+ dim = query.shape[-1]
285
+
286
+ if self.added_kv_proj_dim is not None:
287
+ key = self.to_k(hidden_states)
288
+ value = self.to_v(hidden_states)
289
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
290
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
291
+
292
+ key = self.reshape_heads_to_batch_dim(key)
293
+ value = self.reshape_heads_to_batch_dim(value)
294
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
295
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
296
+
297
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
298
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
299
+ else:
300
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
301
+ key = self.to_k(encoder_hidden_states)
302
+ value = self.to_v(encoder_hidden_states)
303
+
304
+ if attention_mask is not None:
305
+ if attention_mask.shape[-1] != query.shape[1]:
306
+ target_length = query.shape[1]
307
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
308
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
309
+
310
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
311
+ hidden_states = self._attention(query, key, value, attention_mask, time_rel_pos_bias)
312
+ else:
313
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
314
+
315
+ # linear proj
316
+ hidden_states = self.to_out[0](hidden_states)
317
+
318
+ # dropout
319
+ hidden_states = self.to_out[1](hidden_states)
320
+ return hidden_states
321
+
322
+
323
+ def _attention(self, query, key, value, attention_mask=None, time_rel_pos_bias=None):
324
+ if self.upcast_attention:
325
+ query = query.float()
326
+ key = key.float()
327
+
328
+ query = self.scale * rearrange(query, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads
329
+ key = rearrange(key, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads
330
+ value = rearrange(value, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads
331
+ if exists(self.rotary_emb):
332
+ query = self.rotary_emb.rotate_queries_or_keys(query)
333
+ key = self.rotary_emb.rotate_queries_or_keys(key)
334
+
335
+ attention_scores = torch.einsum('... h i d, ... h j d -> ... h i j', query, key)
336
+ attention_scores = attention_scores + time_rel_pos_bias
337
+
338
+ if attention_mask is not None:
339
+ # add attention mask
340
+ attention_scores = attention_scores + attention_mask
341
+
342
+ attention_scores = attention_scores - attention_scores.amax(dim = -1, keepdim = True).detach()
343
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
344
+
345
+ attention_probs = attention_probs.to(value.dtype)
346
+ hidden_states = torch.einsum('... h i j, ... h j d -> ... h i d', attention_probs, value)
347
+ hidden_states = rearrange(hidden_states, 'b h f d -> b f (h d)')
348
+ return hidden_states
349
+
350
+ class RelativePositionBias(nn.Module):
351
+ def __init__(
352
+ self,
353
+ heads=8,
354
+ num_buckets=32,
355
+ max_distance=128,
356
+ ):
357
+ super().__init__()
358
+ self.num_buckets = num_buckets
359
+ self.max_distance = max_distance
360
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
361
+
362
+ @staticmethod
363
+ def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
364
+ ret = 0
365
+ n = -relative_position
366
+
367
+ num_buckets //= 2
368
+ ret += (n < 0).long() * num_buckets
369
+ n = torch.abs(n)
370
+
371
+ max_exact = num_buckets // 2
372
+ is_small = n < max_exact
373
+
374
+ val_if_large = max_exact + (
375
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
376
+ ).long()
377
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
378
+
379
+ ret += torch.where(is_small, n, val_if_large)
380
+ return ret
381
+
382
+ def forward(self, n, device):
383
+ q_pos = torch.arange(n, dtype = torch.long, device = device)
384
+ k_pos = torch.arange(n, dtype = torch.long, device = device)
385
+ rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
386
+ rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
387
+ values = self.relative_attention_bias(rp_bucket)
388
+ return rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames
base/models/transformer_3d.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ from dataclasses import dataclass
14
+ from typing import Any, Dict, Optional
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
21
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
22
+ from diffusers.utils import BaseOutput, deprecate
23
+ from diffusers.models.embeddings import PatchEmbed
24
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
25
+ from diffusers.models.modeling_utils import ModelMixin
26
+ from einops import rearrange, repeat
27
+
28
+ try:
29
+ from attention import BasicTransformerBlock
30
+ except:
31
+ from .attention import BasicTransformerBlock
32
+
33
+ @dataclass
34
+ class Transformer3DModelOutput(BaseOutput):
35
+ """
36
+ The output of [`Transformer2DModel`].
37
+
38
+ Args:
39
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
40
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
41
+ distributions for the unnoised latent pixels.
42
+ """
43
+
44
+ sample: torch.FloatTensor
45
+
46
+
47
+ class Transformer3DModel(ModelMixin, ConfigMixin):
48
+ """
49
+ A 2D Transformer model for image-like data.
50
+
51
+ Parameters:
52
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
53
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
54
+ in_channels (`int`, *optional*):
55
+ The number of channels in the input and output (specify if the input is **continuous**).
56
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
57
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
58
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
59
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
60
+ This is fixed during training since it is used to learn a number of position embeddings.
61
+ num_vector_embeds (`int`, *optional*):
62
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
63
+ Includes the class for the masked latent pixel.
64
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
65
+ num_embeds_ada_norm ( `int`, *optional*):
66
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
67
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
68
+ added to the hidden states.
69
+
70
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
71
+ attention_bias (`bool`, *optional*):
72
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
73
+ """
74
+
75
+ @register_to_config
76
+ def __init__(
77
+ self,
78
+ num_attention_heads: int = 16,
79
+ attention_head_dim: int = 88,
80
+ in_channels: Optional[int] = None,
81
+ out_channels: Optional[int] = None,
82
+ num_layers: int = 1,
83
+ dropout: float = 0.0,
84
+ norm_num_groups: int = 32,
85
+ cross_attention_dim: Optional[int] = None,
86
+ attention_bias: bool = False,
87
+ sample_size: Optional[int] = None,
88
+ num_vector_embeds: Optional[int] = None,
89
+ patch_size: Optional[int] = None,
90
+ activation_fn: str = "geglu",
91
+ num_embeds_ada_norm: Optional[int] = None,
92
+ use_linear_projection: bool = False,
93
+ only_cross_attention: bool = False,
94
+ upcast_attention: bool = False,
95
+ norm_type: str = "layer_norm",
96
+ norm_elementwise_affine: bool = True,
97
+ rotary_emb=None,
98
+ ):
99
+ super().__init__()
100
+ self.use_linear_projection = use_linear_projection
101
+ self.num_attention_heads = num_attention_heads
102
+ self.attention_head_dim = attention_head_dim
103
+ inner_dim = num_attention_heads * attention_head_dim
104
+
105
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
106
+ # Define whether input is continuous or discrete depending on configuration
107
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
108
+ self.is_input_vectorized = num_vector_embeds is not None
109
+ self.is_input_patches = in_channels is not None and patch_size is not None
110
+
111
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
112
+ deprecation_message = (
113
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
114
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
115
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
116
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
117
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
118
+ )
119
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
120
+ norm_type = "ada_norm"
121
+
122
+ if self.is_input_continuous and self.is_input_vectorized:
123
+ raise ValueError(
124
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
125
+ " sure that either `in_channels` or `num_vector_embeds` is None."
126
+ )
127
+ elif self.is_input_vectorized and self.is_input_patches:
128
+ raise ValueError(
129
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
130
+ " sure that either `num_vector_embeds` or `num_patches` is None."
131
+ )
132
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
133
+ raise ValueError(
134
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
135
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
136
+ )
137
+
138
+ # 2. Define input layers
139
+ if self.is_input_continuous:
140
+ self.in_channels = in_channels
141
+
142
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
143
+ if use_linear_projection:
144
+ self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
145
+ else:
146
+ self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
147
+ elif self.is_input_vectorized:
148
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
149
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
150
+
151
+ self.height = sample_size
152
+ self.width = sample_size
153
+ self.num_vector_embeds = num_vector_embeds
154
+ self.num_latent_pixels = self.height * self.width
155
+
156
+ self.latent_image_embedding = ImagePositionalEmbeddings(
157
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
158
+ )
159
+ elif self.is_input_patches:
160
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
161
+
162
+ self.height = sample_size
163
+ self.width = sample_size
164
+
165
+ self.patch_size = patch_size
166
+ self.pos_embed = PatchEmbed(
167
+ height=sample_size,
168
+ width=sample_size,
169
+ patch_size=patch_size,
170
+ in_channels=in_channels,
171
+ embed_dim=inner_dim,
172
+ )
173
+
174
+ # 3. Define transformers blocks
175
+ self.transformer_blocks = nn.ModuleList(
176
+ [
177
+ BasicTransformerBlock(
178
+ inner_dim,
179
+ num_attention_heads,
180
+ attention_head_dim,
181
+ dropout=dropout,
182
+ cross_attention_dim=cross_attention_dim,
183
+ activation_fn=activation_fn,
184
+ num_embeds_ada_norm=num_embeds_ada_norm,
185
+ attention_bias=attention_bias,
186
+ only_cross_attention=only_cross_attention,
187
+ upcast_attention=upcast_attention,
188
+ norm_type=norm_type,
189
+ norm_elementwise_affine=norm_elementwise_affine,
190
+ rotary_emb=rotary_emb,
191
+ )
192
+ for d in range(num_layers)
193
+ ]
194
+ )
195
+
196
+ # 4. Define output layers
197
+ self.out_channels = in_channels if out_channels is None else out_channels
198
+ if self.is_input_continuous:
199
+ # TODO: should use out_channels for continuous projections
200
+ if use_linear_projection:
201
+ self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
202
+ else:
203
+ self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
204
+ elif self.is_input_vectorized:
205
+ self.norm_out = nn.LayerNorm(inner_dim)
206
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
207
+ elif self.is_input_patches:
208
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
209
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
210
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
211
+
212
+ def forward(
213
+ self,
214
+ hidden_states: torch.Tensor,
215
+ encoder_hidden_states: Optional[torch.Tensor] = None,
216
+ timestep: Optional[torch.LongTensor] = None,
217
+ class_labels: Optional[torch.LongTensor] = None,
218
+ cross_attention_kwargs: Dict[str, Any] = None,
219
+ attention_mask: Optional[torch.Tensor] = None,
220
+ encoder_attention_mask: Optional[torch.Tensor] = None,
221
+ return_dict: bool = True,
222
+ use_image_num=None,
223
+ ):
224
+ """
225
+ The [`Transformer2DModel`] forward method.
226
+
227
+ Args:
228
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
229
+ Input `hidden_states`.
230
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
231
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
232
+ self-attention.
233
+ timestep ( `torch.LongTensor`, *optional*):
234
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
235
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
236
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
237
+ `AdaLayerZeroNorm`.
238
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
239
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
240
+
241
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
242
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
243
+
244
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
245
+ above. This bias will be added to the cross-attention scores.
246
+ return_dict (`bool`, *optional*, defaults to `True`):
247
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
248
+ tuple.
249
+
250
+ Returns:
251
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
252
+ `tuple` where the first element is the sample tensor.
253
+ """
254
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
255
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
256
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
257
+ # expects mask of shape:
258
+ # [batch, key_tokens]
259
+ # adds singleton query_tokens dimension:
260
+ # [batch, 1, key_tokens]
261
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
262
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
263
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
264
+ if attention_mask is not None and attention_mask.ndim == 2:
265
+ # assume that mask is expressed as:
266
+ # (1 = keep, 0 = discard)
267
+ # convert mask into a bias that can be added to attention scores:
268
+ # (keep = +0, discard = -10000.0)
269
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
270
+ attention_mask = attention_mask.unsqueeze(1)
271
+
272
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
273
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
274
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
275
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
276
+
277
+ # 1. Input
278
+ if self.is_input_continuous: # True
279
+
280
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
281
+ if self.training:
282
+ video_length = hidden_states.shape[2] - use_image_num
283
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous()
284
+ encoder_hidden_states_length = encoder_hidden_states.shape[1]
285
+ encoder_hidden_states_video = encoder_hidden_states[:, :encoder_hidden_states_length - use_image_num, ...]
286
+ encoder_hidden_states_video = repeat(encoder_hidden_states_video, 'b m n c -> b (m f) n c', f=video_length).contiguous()
287
+ encoder_hidden_states_image = encoder_hidden_states[:, encoder_hidden_states_length - use_image_num:, ...]
288
+ encoder_hidden_states = torch.cat([encoder_hidden_states_video, encoder_hidden_states_image], dim=1)
289
+ encoder_hidden_states = rearrange(encoder_hidden_states, 'b m n c -> (b m) n c').contiguous()
290
+ else:
291
+ video_length = hidden_states.shape[2]
292
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous()
293
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length).contiguous()
294
+
295
+ batch, _, height, width = hidden_states.shape
296
+ residual = hidden_states
297
+
298
+ hidden_states = self.norm(hidden_states)
299
+ if not self.use_linear_projection:
300
+ hidden_states = self.proj_in(hidden_states)
301
+ inner_dim = hidden_states.shape[1]
302
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
303
+ else:
304
+ inner_dim = hidden_states.shape[1]
305
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
306
+ hidden_states = self.proj_in(hidden_states)
307
+ elif self.is_input_vectorized:
308
+ hidden_states = self.latent_image_embedding(hidden_states)
309
+ elif self.is_input_patches:
310
+ hidden_states = self.pos_embed(hidden_states)
311
+
312
+ # 2. Blocks
313
+ for block in self.transformer_blocks:
314
+ hidden_states = block(
315
+ hidden_states,
316
+ attention_mask=attention_mask,
317
+ encoder_hidden_states=encoder_hidden_states,
318
+ encoder_attention_mask=encoder_attention_mask,
319
+ timestep=timestep,
320
+ cross_attention_kwargs=cross_attention_kwargs,
321
+ class_labels=class_labels,
322
+ video_length=video_length,
323
+ use_image_num=use_image_num,
324
+ )
325
+
326
+ # 3. Output
327
+ if self.is_input_continuous:
328
+ if not self.use_linear_projection:
329
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
330
+ hidden_states = self.proj_out(hidden_states)
331
+ else:
332
+ hidden_states = self.proj_out(hidden_states)
333
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
334
+
335
+ output = hidden_states + residual
336
+ elif self.is_input_vectorized:
337
+ hidden_states = self.norm_out(hidden_states)
338
+ logits = self.out(hidden_states)
339
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
340
+ logits = logits.permute(0, 2, 1)
341
+
342
+ # log(p(x_0))
343
+ output = F.log_softmax(logits.double(), dim=1).float()
344
+ elif self.is_input_patches:
345
+ # TODO: cleanup!
346
+ conditioning = self.transformer_blocks[0].norm1.emb(
347
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
348
+ )
349
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
350
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
351
+ hidden_states = self.proj_out_2(hidden_states)
352
+
353
+ # unpatchify
354
+ height = width = int(hidden_states.shape[1] ** 0.5)
355
+ hidden_states = hidden_states.reshape(
356
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
357
+ )
358
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
359
+ output = hidden_states.reshape(
360
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
361
+ )
362
+
363
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length + use_image_num).contiguous()
364
+ if not return_dict:
365
+ return (output,)
366
+
367
+ return Transformer3DModelOutput(sample=output)
base/models/unet.py ADDED
@@ -0,0 +1,617 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import os
7
+ import sys
8
+ sys.path.append(os.path.split(sys.path[0])[0])
9
+
10
+ import math
11
+ import json
12
+ import torch
13
+ import einops
14
+ import torch.nn as nn
15
+ import torch.utils.checkpoint
16
+
17
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
18
+ from diffusers.utils import BaseOutput, logging
19
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
20
+
21
+ try:
22
+ from diffusers.models.modeling_utils import ModelMixin
23
+ except:
24
+ from diffusers.modeling_utils import ModelMixin # 0.11.1
25
+
26
+ try:
27
+ from .unet_blocks import (
28
+ CrossAttnDownBlock3D,
29
+ CrossAttnUpBlock3D,
30
+ DownBlock3D,
31
+ UNetMidBlock3DCrossAttn,
32
+ UpBlock3D,
33
+ get_down_block,
34
+ get_up_block,
35
+ )
36
+ from .resnet import InflatedConv3d
37
+ except:
38
+ from unet_blocks import (
39
+ CrossAttnDownBlock3D,
40
+ CrossAttnUpBlock3D,
41
+ DownBlock3D,
42
+ UNetMidBlock3DCrossAttn,
43
+ UpBlock3D,
44
+ get_down_block,
45
+ get_up_block,
46
+ )
47
+ from resnet import InflatedConv3d
48
+
49
+ from rotary_embedding_torch import RotaryEmbedding
50
+
51
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
52
+
53
+ class RelativePositionBias(nn.Module):
54
+ def __init__(
55
+ self,
56
+ heads=8,
57
+ num_buckets=32,
58
+ max_distance=128,
59
+ ):
60
+ super().__init__()
61
+ self.num_buckets = num_buckets
62
+ self.max_distance = max_distance
63
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
64
+
65
+ @staticmethod
66
+ def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
67
+ ret = 0
68
+ n = -relative_position
69
+
70
+ num_buckets //= 2
71
+ ret += (n < 0).long() * num_buckets
72
+ n = torch.abs(n)
73
+
74
+ max_exact = num_buckets // 2
75
+ is_small = n < max_exact
76
+
77
+ val_if_large = max_exact + (
78
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
79
+ ).long()
80
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
81
+
82
+ ret += torch.where(is_small, n, val_if_large)
83
+ return ret
84
+
85
+ def forward(self, n, device):
86
+ q_pos = torch.arange(n, dtype = torch.long, device = device)
87
+ k_pos = torch.arange(n, dtype = torch.long, device = device)
88
+ rel_pos = einops.rearrange(k_pos, 'j -> 1 j') - einops.rearrange(q_pos, 'i -> i 1')
89
+ rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
90
+ values = self.relative_attention_bias(rp_bucket)
91
+ return einops.rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames
92
+
93
+ @dataclass
94
+ class UNet3DConditionOutput(BaseOutput):
95
+ sample: torch.FloatTensor
96
+
97
+
98
+ class UNet3DConditionModel(ModelMixin, ConfigMixin):
99
+ _supports_gradient_checkpointing = True
100
+
101
+ @register_to_config
102
+ def __init__(
103
+ self,
104
+ sample_size: Optional[int] = None, # 64
105
+ in_channels: int = 4,
106
+ out_channels: int = 4,
107
+ center_input_sample: bool = False,
108
+ flip_sin_to_cos: bool = True,
109
+ freq_shift: int = 0,
110
+ down_block_types: Tuple[str] = (
111
+ "CrossAttnDownBlock3D",
112
+ "CrossAttnDownBlock3D",
113
+ "CrossAttnDownBlock3D",
114
+ "DownBlock3D",
115
+ ),
116
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
117
+ up_block_types: Tuple[str] = (
118
+ "UpBlock3D",
119
+ "CrossAttnUpBlock3D",
120
+ "CrossAttnUpBlock3D",
121
+ "CrossAttnUpBlock3D"
122
+ ),
123
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
124
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
125
+ layers_per_block: int = 2,
126
+ downsample_padding: int = 1,
127
+ mid_block_scale_factor: float = 1,
128
+ act_fn: str = "silu",
129
+ norm_num_groups: int = 32,
130
+ norm_eps: float = 1e-5,
131
+ cross_attention_dim: int = 1280,
132
+ attention_head_dim: Union[int, Tuple[int]] = 8,
133
+ dual_cross_attention: bool = False,
134
+ use_linear_projection: bool = False,
135
+ class_embed_type: Optional[str] = None,
136
+ num_class_embeds: Optional[int] = None,
137
+ upcast_attention: bool = False,
138
+ resnet_time_scale_shift: str = "default",
139
+ use_first_frame: bool = False,
140
+ use_relative_position: bool = False,
141
+ ):
142
+ super().__init__()
143
+
144
+ # print(use_first_frame)
145
+
146
+ self.sample_size = sample_size
147
+ time_embed_dim = block_out_channels[0] * 4
148
+
149
+ # input
150
+ self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
151
+
152
+ # time
153
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
154
+ timestep_input_dim = block_out_channels[0]
155
+
156
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
157
+
158
+ # class embedding
159
+ if class_embed_type is None and num_class_embeds is not None:
160
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
161
+ elif class_embed_type == "timestep":
162
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
163
+ elif class_embed_type == "identity":
164
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
165
+ else:
166
+ self.class_embedding = None
167
+
168
+ self.down_blocks = nn.ModuleList([])
169
+ self.mid_block = None
170
+ self.up_blocks = nn.ModuleList([])
171
+
172
+ # print(only_cross_attention)
173
+ # print(type(only_cross_attention))
174
+ # exit()
175
+ if isinstance(only_cross_attention, bool):
176
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
177
+ # print(only_cross_attention)
178
+ # exit()
179
+
180
+ if isinstance(attention_head_dim, int):
181
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
182
+ # print(attention_head_dim)
183
+ # exit()
184
+
185
+ rotary_emb = RotaryEmbedding(32)
186
+
187
+ # down
188
+ output_channel = block_out_channels[0]
189
+ for i, down_block_type in enumerate(down_block_types):
190
+ input_channel = output_channel
191
+ output_channel = block_out_channels[i]
192
+ is_final_block = i == len(block_out_channels) - 1
193
+
194
+ down_block = get_down_block(
195
+ down_block_type,
196
+ num_layers=layers_per_block,
197
+ in_channels=input_channel,
198
+ out_channels=output_channel,
199
+ temb_channels=time_embed_dim,
200
+ add_downsample=not is_final_block,
201
+ resnet_eps=norm_eps,
202
+ resnet_act_fn=act_fn,
203
+ resnet_groups=norm_num_groups,
204
+ cross_attention_dim=cross_attention_dim,
205
+ attn_num_head_channels=attention_head_dim[i],
206
+ downsample_padding=downsample_padding,
207
+ dual_cross_attention=dual_cross_attention,
208
+ use_linear_projection=use_linear_projection,
209
+ only_cross_attention=only_cross_attention[i],
210
+ upcast_attention=upcast_attention,
211
+ resnet_time_scale_shift=resnet_time_scale_shift,
212
+ use_first_frame=use_first_frame,
213
+ use_relative_position=use_relative_position,
214
+ rotary_emb=rotary_emb,
215
+ )
216
+ self.down_blocks.append(down_block)
217
+
218
+ # mid
219
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
220
+ self.mid_block = UNetMidBlock3DCrossAttn(
221
+ in_channels=block_out_channels[-1],
222
+ temb_channels=time_embed_dim,
223
+ resnet_eps=norm_eps,
224
+ resnet_act_fn=act_fn,
225
+ output_scale_factor=mid_block_scale_factor,
226
+ resnet_time_scale_shift=resnet_time_scale_shift,
227
+ cross_attention_dim=cross_attention_dim,
228
+ attn_num_head_channels=attention_head_dim[-1],
229
+ resnet_groups=norm_num_groups,
230
+ dual_cross_attention=dual_cross_attention,
231
+ use_linear_projection=use_linear_projection,
232
+ upcast_attention=upcast_attention,
233
+ use_first_frame=use_first_frame,
234
+ use_relative_position=use_relative_position,
235
+ rotary_emb=rotary_emb,
236
+ )
237
+ else:
238
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
239
+
240
+ # count how many layers upsample the videos
241
+ self.num_upsamplers = 0
242
+
243
+ # up
244
+ reversed_block_out_channels = list(reversed(block_out_channels))
245
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
246
+ only_cross_attention = list(reversed(only_cross_attention))
247
+ output_channel = reversed_block_out_channels[0]
248
+ for i, up_block_type in enumerate(up_block_types):
249
+ is_final_block = i == len(block_out_channels) - 1
250
+
251
+ prev_output_channel = output_channel
252
+ output_channel = reversed_block_out_channels[i]
253
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
254
+
255
+ # add upsample block for all BUT final layer
256
+ if not is_final_block:
257
+ add_upsample = True
258
+ self.num_upsamplers += 1
259
+ else:
260
+ add_upsample = False
261
+
262
+ up_block = get_up_block(
263
+ up_block_type,
264
+ num_layers=layers_per_block + 1,
265
+ in_channels=input_channel,
266
+ out_channels=output_channel,
267
+ prev_output_channel=prev_output_channel,
268
+ temb_channels=time_embed_dim,
269
+ add_upsample=add_upsample,
270
+ resnet_eps=norm_eps,
271
+ resnet_act_fn=act_fn,
272
+ resnet_groups=norm_num_groups,
273
+ cross_attention_dim=cross_attention_dim,
274
+ attn_num_head_channels=reversed_attention_head_dim[i],
275
+ dual_cross_attention=dual_cross_attention,
276
+ use_linear_projection=use_linear_projection,
277
+ only_cross_attention=only_cross_attention[i],
278
+ upcast_attention=upcast_attention,
279
+ resnet_time_scale_shift=resnet_time_scale_shift,
280
+ use_first_frame=use_first_frame,
281
+ use_relative_position=use_relative_position,
282
+ rotary_emb=rotary_emb,
283
+ )
284
+ self.up_blocks.append(up_block)
285
+ prev_output_channel = output_channel
286
+
287
+ # out
288
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
289
+ self.conv_act = nn.SiLU()
290
+ self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
291
+
292
+ # relative time positional embeddings
293
+ self.use_relative_position = use_relative_position
294
+ if self.use_relative_position:
295
+ self.time_rel_pos_bias = RelativePositionBias(heads=8, max_distance=32) # realistically will not be able to generate that many frames of video... yet
296
+
297
+ def set_attention_slice(self, slice_size):
298
+ r"""
299
+ Enable sliced attention computation.
300
+
301
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
302
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
303
+
304
+ Args:
305
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
306
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
307
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
308
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
309
+ must be a multiple of `slice_size`.
310
+ """
311
+ sliceable_head_dims = []
312
+
313
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
314
+ if hasattr(module, "set_attention_slice"):
315
+ sliceable_head_dims.append(module.sliceable_head_dim)
316
+
317
+ for child in module.children():
318
+ fn_recursive_retrieve_slicable_dims(child)
319
+
320
+ # retrieve number of attention layers
321
+ for module in self.children():
322
+ fn_recursive_retrieve_slicable_dims(module)
323
+
324
+ num_slicable_layers = len(sliceable_head_dims)
325
+
326
+ if slice_size == "auto":
327
+ # half the attention head size is usually a good trade-off between
328
+ # speed and memory
329
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
330
+ elif slice_size == "max":
331
+ # make smallest slice possible
332
+ slice_size = num_slicable_layers * [1]
333
+
334
+ slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
335
+
336
+ if len(slice_size) != len(sliceable_head_dims):
337
+ raise ValueError(
338
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
339
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
340
+ )
341
+
342
+ for i in range(len(slice_size)):
343
+ size = slice_size[i]
344
+ dim = sliceable_head_dims[i]
345
+ if size is not None and size > dim:
346
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
347
+
348
+ # Recursively walk through all the children.
349
+ # Any children which exposes the set_attention_slice method
350
+ # gets the message
351
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
352
+ if hasattr(module, "set_attention_slice"):
353
+ module.set_attention_slice(slice_size.pop())
354
+
355
+ for child in module.children():
356
+ fn_recursive_set_attention_slice(child, slice_size)
357
+
358
+ reversed_slice_size = list(reversed(slice_size))
359
+ for module in self.children():
360
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
361
+
362
+ def _set_gradient_checkpointing(self, module, value=False):
363
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
364
+ module.gradient_checkpointing = value
365
+
366
+ def forward(
367
+ self,
368
+ sample: torch.FloatTensor,
369
+ timestep: Union[torch.Tensor, float, int],
370
+ encoder_hidden_states: torch.Tensor = None,
371
+ class_labels: Optional[torch.Tensor] = None,
372
+ attention_mask: Optional[torch.Tensor] = None,
373
+ use_image_num: int = 0,
374
+ return_dict: bool = True,
375
+ ) -> Union[UNet3DConditionOutput, Tuple]:
376
+ r"""
377
+ Args:
378
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
379
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
380
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
381
+ return_dict (`bool`, *optional*, defaults to `True`):
382
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
383
+
384
+ Returns:
385
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
386
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
387
+ returning a tuple, the first element is the sample tensor.
388
+ """
389
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
390
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
391
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
392
+ # on the fly if necessary.
393
+ default_overall_up_factor = 2**self.num_upsamplers
394
+
395
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
396
+ forward_upsample_size = False
397
+ upsample_size = None
398
+
399
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
400
+ logger.info("Forward upsample size to force interpolation output size.")
401
+ forward_upsample_size = True
402
+
403
+ # prepare attention_mask
404
+ if attention_mask is not None:
405
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
406
+ attention_mask = attention_mask.unsqueeze(1)
407
+
408
+ # center input if necessary
409
+ if self.config.center_input_sample:
410
+ sample = 2 * sample - 1.0
411
+
412
+ # time
413
+ timesteps = timestep
414
+ if not torch.is_tensor(timesteps):
415
+ # This would be a good case for the `match` statement (Python 3.10+)
416
+ is_mps = sample.device.type == "mps"
417
+ if isinstance(timestep, float):
418
+ dtype = torch.float32 if is_mps else torch.float64
419
+ else:
420
+ dtype = torch.int32 if is_mps else torch.int64
421
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
422
+ elif len(timesteps.shape) == 0:
423
+ timesteps = timesteps[None].to(sample.device)
424
+
425
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
426
+ timesteps = timesteps.expand(sample.shape[0])
427
+
428
+ t_emb = self.time_proj(timesteps)
429
+
430
+ # timesteps does not contain any weights and will always return f32 tensors
431
+ # but time_embedding might actually be running in fp16. so we need to cast here.
432
+ # there might be better ways to encapsulate this.
433
+ t_emb = t_emb.to(dtype=self.dtype)
434
+ emb = self.time_embedding(t_emb)
435
+
436
+ if self.class_embedding is not None:
437
+ if class_labels is None:
438
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
439
+
440
+ if self.config.class_embed_type == "timestep":
441
+ class_labels = self.time_proj(class_labels)
442
+
443
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
444
+ # print(emb.shape) # torch.Size([3, 1280])
445
+ # print(class_emb.shape) # torch.Size([3, 1280])
446
+ emb = emb + class_emb
447
+
448
+ if self.use_relative_position:
449
+ frame_rel_pos_bias = self.time_rel_pos_bias(sample.shape[2], device=sample.device)
450
+ else:
451
+ frame_rel_pos_bias = None
452
+
453
+ # pre-process
454
+ sample = self.conv_in(sample)
455
+
456
+ # down
457
+ down_block_res_samples = (sample,)
458
+ for downsample_block in self.down_blocks:
459
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
460
+ sample, res_samples = downsample_block(
461
+ hidden_states=sample,
462
+ temb=emb,
463
+ encoder_hidden_states=encoder_hidden_states,
464
+ attention_mask=attention_mask,
465
+ use_image_num=use_image_num,
466
+ )
467
+ else:
468
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
469
+
470
+ down_block_res_samples += res_samples
471
+
472
+ # mid
473
+ sample = self.mid_block(
474
+ sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, use_image_num=use_image_num,
475
+ )
476
+
477
+ # up
478
+ for i, upsample_block in enumerate(self.up_blocks):
479
+ is_final_block = i == len(self.up_blocks) - 1
480
+
481
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
482
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
483
+
484
+ # if we have not reached the final block and need to forward the
485
+ # upsample size, we do it here
486
+ if not is_final_block and forward_upsample_size:
487
+ upsample_size = down_block_res_samples[-1].shape[2:]
488
+
489
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
490
+ sample = upsample_block(
491
+ hidden_states=sample,
492
+ temb=emb,
493
+ res_hidden_states_tuple=res_samples,
494
+ encoder_hidden_states=encoder_hidden_states,
495
+ upsample_size=upsample_size,
496
+ attention_mask=attention_mask,
497
+ use_image_num=use_image_num,
498
+ )
499
+ else:
500
+ sample = upsample_block(
501
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
502
+ )
503
+ # post-process
504
+ sample = self.conv_norm_out(sample)
505
+ sample = self.conv_act(sample)
506
+ sample = self.conv_out(sample)
507
+ # print(sample.shape)
508
+
509
+ if not return_dict:
510
+ return (sample,)
511
+ sample = UNet3DConditionOutput(sample=sample)
512
+ return sample
513
+
514
+ def forward_with_cfg(self,
515
+ x,
516
+ t,
517
+ encoder_hidden_states = None,
518
+ class_labels: Optional[torch.Tensor] = None,
519
+ cfg_scale=4.0,
520
+ use_fp16=False):
521
+ """
522
+ Forward, but also batches the unconditional forward pass for classifier-free guidance.
523
+ """
524
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
525
+ half = x[: len(x) // 2]
526
+ combined = torch.cat([half, half], dim=0)
527
+ if use_fp16:
528
+ combined = combined.to(dtype=torch.float16)
529
+ model_out = self.forward(combined, t, encoder_hidden_states, class_labels).sample
530
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
531
+ # three channels by default. The standard approach to cfg applies it to all channels.
532
+ # This can be done by uncommenting the following line and commenting-out the line following that.
533
+ eps, rest = model_out[:, :4], model_out[:, 4:]
534
+ # eps, rest = model_out[:, :3], model_out[:, 3:] # b c f h w
535
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
536
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
537
+ eps = torch.cat([half_eps, half_eps], dim=0)
538
+ return torch.cat([eps, rest], dim=1)
539
+
540
+ @classmethod
541
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None):
542
+ if subfolder is not None:
543
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
544
+
545
+
546
+ config_file = os.path.join(pretrained_model_path, 'config.json')
547
+ if not os.path.isfile(config_file):
548
+ raise RuntimeError(f"{config_file} does not exist")
549
+ with open(config_file, "r") as f:
550
+ config = json.load(f)
551
+ config["_class_name"] = cls.__name__
552
+ config["down_block_types"] = [
553
+ "CrossAttnDownBlock3D",
554
+ "CrossAttnDownBlock3D",
555
+ "CrossAttnDownBlock3D",
556
+ "DownBlock3D"
557
+ ]
558
+ config["up_block_types"] = [
559
+ "UpBlock3D",
560
+ "CrossAttnUpBlock3D",
561
+ "CrossAttnUpBlock3D",
562
+ "CrossAttnUpBlock3D"
563
+ ]
564
+
565
+ config["use_first_frame"] = False
566
+
567
+ from diffusers.utils import WEIGHTS_NAME # diffusion_pytorch_model.bin
568
+
569
+
570
+ model = cls.from_config(config)
571
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
572
+ # if not os.path.isfile(model_file):
573
+ # raise RuntimeError(f"{model_file} does not exist")
574
+ # state_dict = torch.load(model_file, map_location="cpu")
575
+ # for k, v in model.state_dict().items():
576
+ # # print(k)
577
+ # if '_temp' in k:
578
+ # state_dict.update({k: v})
579
+ # if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross
580
+ # k = k.replace('attn_fcross', 'attn1')
581
+ # state_dict.update({k: state_dict[k]})
582
+ # if 'norm_fcross' in k:
583
+ # k = k.replace('norm_fcross', 'norm1')
584
+ # state_dict.update({k: state_dict[k]})
585
+
586
+ # model.load_state_dict(state_dict)
587
+
588
+ return model
589
+
590
+ if __name__ == '__main__':
591
+ import torch
592
+ # from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
593
+
594
+ device = "cuda" if torch.cuda.is_available() else "cpu"
595
+
596
+ pretrained_model_path = "/mnt/petrelfs/maxin/work/pretrained/stable-diffusion-v1-4/" # p cluster
597
+ unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet").to(device)
598
+ # unet.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
599
+ unet.enable_xformers_memory_efficient_attention()
600
+ unet.enable_gradient_checkpointing()
601
+
602
+ unet.train()
603
+
604
+ use_image_num = 5
605
+ noisy_latents = torch.randn((2, 4, 16 + use_image_num, 32, 32)).to(device)
606
+ bsz = noisy_latents.shape[0]
607
+ timesteps = torch.randint(0, 1000, (bsz,)).to(device)
608
+ timesteps = timesteps.long()
609
+ encoder_hidden_states = torch.randn((bsz, 1 + use_image_num, 77, 768)).to(device)
610
+ # class_labels = torch.randn((bsz, )).to(device)
611
+
612
+
613
+ model_pred = unet(sample=noisy_latents, timestep=timesteps,
614
+ encoder_hidden_states=encoder_hidden_states,
615
+ class_labels=None,
616
+ use_image_num=use_image_num).sample
617
+ print(model_pred.shape)
base/models/unet_blocks.py ADDED
@@ -0,0 +1,648 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2
+ import os
3
+ import sys
4
+ sys.path.append(os.path.split(sys.path[0])[0])
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+ try:
10
+ from .attention import Transformer3DModel
11
+ from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
12
+ except:
13
+ from attention import Transformer3DModel
14
+ from resnet import Downsample3D, ResnetBlock3D, Upsample3D
15
+
16
+
17
+ def get_down_block(
18
+ down_block_type,
19
+ num_layers,
20
+ in_channels,
21
+ out_channels,
22
+ temb_channels,
23
+ add_downsample,
24
+ resnet_eps,
25
+ resnet_act_fn,
26
+ attn_num_head_channels,
27
+ resnet_groups=None,
28
+ cross_attention_dim=None,
29
+ downsample_padding=None,
30
+ dual_cross_attention=False,
31
+ use_linear_projection=False,
32
+ only_cross_attention=False,
33
+ upcast_attention=False,
34
+ resnet_time_scale_shift="default",
35
+ use_first_frame=False,
36
+ use_relative_position=False,
37
+ rotary_emb=False,
38
+ ):
39
+ # print(down_block_type)
40
+ # print(use_first_frame)
41
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
42
+ if down_block_type == "DownBlock3D":
43
+ return DownBlock3D(
44
+ num_layers=num_layers,
45
+ in_channels=in_channels,
46
+ out_channels=out_channels,
47
+ temb_channels=temb_channels,
48
+ add_downsample=add_downsample,
49
+ resnet_eps=resnet_eps,
50
+ resnet_act_fn=resnet_act_fn,
51
+ resnet_groups=resnet_groups,
52
+ downsample_padding=downsample_padding,
53
+ resnet_time_scale_shift=resnet_time_scale_shift,
54
+ )
55
+ elif down_block_type == "CrossAttnDownBlock3D":
56
+ if cross_attention_dim is None:
57
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
58
+ return CrossAttnDownBlock3D(
59
+ num_layers=num_layers,
60
+ in_channels=in_channels,
61
+ out_channels=out_channels,
62
+ temb_channels=temb_channels,
63
+ add_downsample=add_downsample,
64
+ resnet_eps=resnet_eps,
65
+ resnet_act_fn=resnet_act_fn,
66
+ resnet_groups=resnet_groups,
67
+ downsample_padding=downsample_padding,
68
+ cross_attention_dim=cross_attention_dim,
69
+ attn_num_head_channels=attn_num_head_channels,
70
+ dual_cross_attention=dual_cross_attention,
71
+ use_linear_projection=use_linear_projection,
72
+ only_cross_attention=only_cross_attention,
73
+ upcast_attention=upcast_attention,
74
+ resnet_time_scale_shift=resnet_time_scale_shift,
75
+ use_first_frame=use_first_frame,
76
+ use_relative_position=use_relative_position,
77
+ rotary_emb=rotary_emb,
78
+ )
79
+ raise ValueError(f"{down_block_type} does not exist.")
80
+
81
+
82
+ def get_up_block(
83
+ up_block_type,
84
+ num_layers,
85
+ in_channels,
86
+ out_channels,
87
+ prev_output_channel,
88
+ temb_channels,
89
+ add_upsample,
90
+ resnet_eps,
91
+ resnet_act_fn,
92
+ attn_num_head_channels,
93
+ resnet_groups=None,
94
+ cross_attention_dim=None,
95
+ dual_cross_attention=False,
96
+ use_linear_projection=False,
97
+ only_cross_attention=False,
98
+ upcast_attention=False,
99
+ resnet_time_scale_shift="default",
100
+ use_first_frame=False,
101
+ use_relative_position=False,
102
+ rotary_emb=False,
103
+ ):
104
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
105
+ if up_block_type == "UpBlock3D":
106
+ return UpBlock3D(
107
+ num_layers=num_layers,
108
+ in_channels=in_channels,
109
+ out_channels=out_channels,
110
+ prev_output_channel=prev_output_channel,
111
+ temb_channels=temb_channels,
112
+ add_upsample=add_upsample,
113
+ resnet_eps=resnet_eps,
114
+ resnet_act_fn=resnet_act_fn,
115
+ resnet_groups=resnet_groups,
116
+ resnet_time_scale_shift=resnet_time_scale_shift,
117
+ )
118
+ elif up_block_type == "CrossAttnUpBlock3D":
119
+ if cross_attention_dim is None:
120
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
121
+ return CrossAttnUpBlock3D(
122
+ num_layers=num_layers,
123
+ in_channels=in_channels,
124
+ out_channels=out_channels,
125
+ prev_output_channel=prev_output_channel,
126
+ temb_channels=temb_channels,
127
+ add_upsample=add_upsample,
128
+ resnet_eps=resnet_eps,
129
+ resnet_act_fn=resnet_act_fn,
130
+ resnet_groups=resnet_groups,
131
+ cross_attention_dim=cross_attention_dim,
132
+ attn_num_head_channels=attn_num_head_channels,
133
+ dual_cross_attention=dual_cross_attention,
134
+ use_linear_projection=use_linear_projection,
135
+ only_cross_attention=only_cross_attention,
136
+ upcast_attention=upcast_attention,
137
+ resnet_time_scale_shift=resnet_time_scale_shift,
138
+ use_first_frame=use_first_frame,
139
+ use_relative_position=use_relative_position,
140
+ rotary_emb=rotary_emb,
141
+ )
142
+ raise ValueError(f"{up_block_type} does not exist.")
143
+
144
+
145
+ class UNetMidBlock3DCrossAttn(nn.Module):
146
+ def __init__(
147
+ self,
148
+ in_channels: int,
149
+ temb_channels: int,
150
+ dropout: float = 0.0,
151
+ num_layers: int = 1,
152
+ resnet_eps: float = 1e-6,
153
+ resnet_time_scale_shift: str = "default",
154
+ resnet_act_fn: str = "swish",
155
+ resnet_groups: int = 32,
156
+ resnet_pre_norm: bool = True,
157
+ attn_num_head_channels=1,
158
+ output_scale_factor=1.0,
159
+ cross_attention_dim=1280,
160
+ dual_cross_attention=False,
161
+ use_linear_projection=False,
162
+ upcast_attention=False,
163
+ use_first_frame=False,
164
+ use_relative_position=False,
165
+ rotary_emb=False,
166
+ ):
167
+ super().__init__()
168
+
169
+ self.has_cross_attention = True
170
+ self.attn_num_head_channels = attn_num_head_channels
171
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
172
+
173
+ # there is always at least one resnet
174
+ resnets = [
175
+ ResnetBlock3D(
176
+ in_channels=in_channels,
177
+ out_channels=in_channels,
178
+ temb_channels=temb_channels,
179
+ eps=resnet_eps,
180
+ groups=resnet_groups,
181
+ dropout=dropout,
182
+ time_embedding_norm=resnet_time_scale_shift,
183
+ non_linearity=resnet_act_fn,
184
+ output_scale_factor=output_scale_factor,
185
+ pre_norm=resnet_pre_norm,
186
+ )
187
+ ]
188
+ attentions = []
189
+
190
+ for _ in range(num_layers):
191
+ if dual_cross_attention:
192
+ raise NotImplementedError
193
+ attentions.append(
194
+ Transformer3DModel(
195
+ attn_num_head_channels,
196
+ in_channels // attn_num_head_channels,
197
+ in_channels=in_channels,
198
+ num_layers=1,
199
+ cross_attention_dim=cross_attention_dim,
200
+ norm_num_groups=resnet_groups,
201
+ use_linear_projection=use_linear_projection,
202
+ upcast_attention=upcast_attention,
203
+ use_first_frame=use_first_frame,
204
+ use_relative_position=use_relative_position,
205
+ rotary_emb=rotary_emb,
206
+ )
207
+ )
208
+ resnets.append(
209
+ ResnetBlock3D(
210
+ in_channels=in_channels,
211
+ out_channels=in_channels,
212
+ temb_channels=temb_channels,
213
+ eps=resnet_eps,
214
+ groups=resnet_groups,
215
+ dropout=dropout,
216
+ time_embedding_norm=resnet_time_scale_shift,
217
+ non_linearity=resnet_act_fn,
218
+ output_scale_factor=output_scale_factor,
219
+ pre_norm=resnet_pre_norm,
220
+ )
221
+ )
222
+
223
+ self.attentions = nn.ModuleList(attentions)
224
+ self.resnets = nn.ModuleList(resnets)
225
+
226
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, use_image_num=None):
227
+ hidden_states = self.resnets[0](hidden_states, temb)
228
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
229
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, use_image_num=use_image_num).sample
230
+ hidden_states = resnet(hidden_states, temb)
231
+
232
+ return hidden_states
233
+
234
+
235
+ class CrossAttnDownBlock3D(nn.Module):
236
+ def __init__(
237
+ self,
238
+ in_channels: int,
239
+ out_channels: int,
240
+ temb_channels: int,
241
+ dropout: float = 0.0,
242
+ num_layers: int = 1,
243
+ resnet_eps: float = 1e-6,
244
+ resnet_time_scale_shift: str = "default",
245
+ resnet_act_fn: str = "swish",
246
+ resnet_groups: int = 32,
247
+ resnet_pre_norm: bool = True,
248
+ attn_num_head_channels=1,
249
+ cross_attention_dim=1280,
250
+ output_scale_factor=1.0,
251
+ downsample_padding=1,
252
+ add_downsample=True,
253
+ dual_cross_attention=False,
254
+ use_linear_projection=False,
255
+ only_cross_attention=False,
256
+ upcast_attention=False,
257
+ use_first_frame=False,
258
+ use_relative_position=False,
259
+ rotary_emb=False,
260
+ ):
261
+ super().__init__()
262
+ resnets = []
263
+ attentions = []
264
+
265
+ # print(use_first_frame)
266
+
267
+ self.has_cross_attention = True
268
+ self.attn_num_head_channels = attn_num_head_channels
269
+
270
+ for i in range(num_layers):
271
+ in_channels = in_channels if i == 0 else out_channels
272
+ resnets.append(
273
+ ResnetBlock3D(
274
+ in_channels=in_channels,
275
+ out_channels=out_channels,
276
+ temb_channels=temb_channels,
277
+ eps=resnet_eps,
278
+ groups=resnet_groups,
279
+ dropout=dropout,
280
+ time_embedding_norm=resnet_time_scale_shift,
281
+ non_linearity=resnet_act_fn,
282
+ output_scale_factor=output_scale_factor,
283
+ pre_norm=resnet_pre_norm,
284
+ )
285
+ )
286
+ if dual_cross_attention:
287
+ raise NotImplementedError
288
+ attentions.append(
289
+ Transformer3DModel(
290
+ attn_num_head_channels,
291
+ out_channels // attn_num_head_channels,
292
+ in_channels=out_channels,
293
+ num_layers=1,
294
+ cross_attention_dim=cross_attention_dim,
295
+ norm_num_groups=resnet_groups,
296
+ use_linear_projection=use_linear_projection,
297
+ only_cross_attention=only_cross_attention,
298
+ upcast_attention=upcast_attention,
299
+ use_first_frame=use_first_frame,
300
+ use_relative_position=use_relative_position,
301
+ rotary_emb=rotary_emb,
302
+ )
303
+ )
304
+ self.attentions = nn.ModuleList(attentions)
305
+ self.resnets = nn.ModuleList(resnets)
306
+
307
+ if add_downsample:
308
+ self.downsamplers = nn.ModuleList(
309
+ [
310
+ Downsample3D(
311
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
312
+ )
313
+ ]
314
+ )
315
+ else:
316
+ self.downsamplers = None
317
+
318
+ self.gradient_checkpointing = False
319
+
320
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, use_image_num=None):
321
+ output_states = ()
322
+
323
+ for resnet, attn in zip(self.resnets, self.attentions):
324
+ if self.training and self.gradient_checkpointing:
325
+
326
+ def create_custom_forward(module, return_dict=None):
327
+ def custom_forward(*inputs):
328
+ if return_dict is not None:
329
+ return module(*inputs, return_dict=return_dict)
330
+ else:
331
+ return module(*inputs)
332
+
333
+ return custom_forward
334
+
335
+ def create_custom_forward_attn(module, return_dict=None, use_image_num=None):
336
+ def custom_forward(*inputs):
337
+ if return_dict is not None:
338
+ return module(*inputs, return_dict=return_dict, use_image_num=use_image_num)
339
+ else:
340
+ return module(*inputs, use_image_num=use_image_num)
341
+
342
+ return custom_forward
343
+
344
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
345
+ hidden_states = torch.utils.checkpoint.checkpoint(
346
+ create_custom_forward_attn(attn, return_dict=False, use_image_num=use_image_num),
347
+ hidden_states,
348
+ encoder_hidden_states,
349
+ )[0]
350
+ else:
351
+ hidden_states = resnet(hidden_states, temb)
352
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, use_image_num=use_image_num).sample
353
+
354
+ output_states += (hidden_states,)
355
+
356
+ if self.downsamplers is not None:
357
+ for downsampler in self.downsamplers:
358
+ hidden_states = downsampler(hidden_states)
359
+
360
+ output_states += (hidden_states,)
361
+
362
+ return hidden_states, output_states
363
+
364
+
365
+ class DownBlock3D(nn.Module):
366
+ def __init__(
367
+ self,
368
+ in_channels: int,
369
+ out_channels: int,
370
+ temb_channels: int,
371
+ dropout: float = 0.0,
372
+ num_layers: int = 1,
373
+ resnet_eps: float = 1e-6,
374
+ resnet_time_scale_shift: str = "default",
375
+ resnet_act_fn: str = "swish",
376
+ resnet_groups: int = 32,
377
+ resnet_pre_norm: bool = True,
378
+ output_scale_factor=1.0,
379
+ add_downsample=True,
380
+ downsample_padding=1,
381
+ ):
382
+ super().__init__()
383
+ resnets = []
384
+
385
+ for i in range(num_layers):
386
+ in_channels = in_channels if i == 0 else out_channels
387
+ resnets.append(
388
+ ResnetBlock3D(
389
+ in_channels=in_channels,
390
+ out_channels=out_channels,
391
+ temb_channels=temb_channels,
392
+ eps=resnet_eps,
393
+ groups=resnet_groups,
394
+ dropout=dropout,
395
+ time_embedding_norm=resnet_time_scale_shift,
396
+ non_linearity=resnet_act_fn,
397
+ output_scale_factor=output_scale_factor,
398
+ pre_norm=resnet_pre_norm,
399
+ )
400
+ )
401
+
402
+ self.resnets = nn.ModuleList(resnets)
403
+
404
+ if add_downsample:
405
+ self.downsamplers = nn.ModuleList(
406
+ [
407
+ Downsample3D(
408
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
409
+ )
410
+ ]
411
+ )
412
+ else:
413
+ self.downsamplers = None
414
+
415
+ self.gradient_checkpointing = False
416
+
417
+ def forward(self, hidden_states, temb=None):
418
+ output_states = ()
419
+
420
+ for resnet in self.resnets:
421
+ if self.training and self.gradient_checkpointing:
422
+
423
+ def create_custom_forward(module):
424
+ def custom_forward(*inputs):
425
+ return module(*inputs)
426
+
427
+ return custom_forward
428
+
429
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
430
+ else:
431
+ hidden_states = resnet(hidden_states, temb)
432
+
433
+ output_states += (hidden_states,)
434
+
435
+ if self.downsamplers is not None:
436
+ for downsampler in self.downsamplers:
437
+ hidden_states = downsampler(hidden_states)
438
+
439
+ output_states += (hidden_states,)
440
+
441
+ return hidden_states, output_states
442
+
443
+
444
+ class CrossAttnUpBlock3D(nn.Module):
445
+ def __init__(
446
+ self,
447
+ in_channels: int,
448
+ out_channels: int,
449
+ prev_output_channel: int,
450
+ temb_channels: int,
451
+ dropout: float = 0.0,
452
+ num_layers: int = 1,
453
+ resnet_eps: float = 1e-6,
454
+ resnet_time_scale_shift: str = "default",
455
+ resnet_act_fn: str = "swish",
456
+ resnet_groups: int = 32,
457
+ resnet_pre_norm: bool = True,
458
+ attn_num_head_channels=1,
459
+ cross_attention_dim=1280,
460
+ output_scale_factor=1.0,
461
+ add_upsample=True,
462
+ dual_cross_attention=False,
463
+ use_linear_projection=False,
464
+ only_cross_attention=False,
465
+ upcast_attention=False,
466
+ use_first_frame=False,
467
+ use_relative_position=False,
468
+ rotary_emb=False
469
+ ):
470
+ super().__init__()
471
+ resnets = []
472
+ attentions = []
473
+
474
+ self.has_cross_attention = True
475
+ self.attn_num_head_channels = attn_num_head_channels
476
+
477
+ for i in range(num_layers):
478
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
479
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
480
+
481
+ resnets.append(
482
+ ResnetBlock3D(
483
+ in_channels=resnet_in_channels + res_skip_channels,
484
+ out_channels=out_channels,
485
+ temb_channels=temb_channels,
486
+ eps=resnet_eps,
487
+ groups=resnet_groups,
488
+ dropout=dropout,
489
+ time_embedding_norm=resnet_time_scale_shift,
490
+ non_linearity=resnet_act_fn,
491
+ output_scale_factor=output_scale_factor,
492
+ pre_norm=resnet_pre_norm,
493
+ )
494
+ )
495
+ if dual_cross_attention:
496
+ raise NotImplementedError
497
+ attentions.append(
498
+ Transformer3DModel(
499
+ attn_num_head_channels,
500
+ out_channels // attn_num_head_channels,
501
+ in_channels=out_channels,
502
+ num_layers=1,
503
+ cross_attention_dim=cross_attention_dim,
504
+ norm_num_groups=resnet_groups,
505
+ use_linear_projection=use_linear_projection,
506
+ only_cross_attention=only_cross_attention,
507
+ upcast_attention=upcast_attention,
508
+ use_first_frame=use_first_frame,
509
+ use_relative_position=use_relative_position,
510
+ rotary_emb=rotary_emb,
511
+ )
512
+ )
513
+
514
+ self.attentions = nn.ModuleList(attentions)
515
+ self.resnets = nn.ModuleList(resnets)
516
+
517
+ if add_upsample:
518
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
519
+ else:
520
+ self.upsamplers = None
521
+
522
+ self.gradient_checkpointing = False
523
+
524
+ def forward(
525
+ self,
526
+ hidden_states,
527
+ res_hidden_states_tuple,
528
+ temb=None,
529
+ encoder_hidden_states=None,
530
+ upsample_size=None,
531
+ attention_mask=None,
532
+ use_image_num=None,
533
+ ):
534
+ for resnet, attn in zip(self.resnets, self.attentions):
535
+ # pop res hidden states
536
+ res_hidden_states = res_hidden_states_tuple[-1]
537
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
538
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
539
+
540
+ if self.training and self.gradient_checkpointing:
541
+
542
+ def create_custom_forward(module, return_dict=None):
543
+ def custom_forward(*inputs):
544
+ if return_dict is not None:
545
+ return module(*inputs, return_dict=return_dict)
546
+ else:
547
+ return module(*inputs)
548
+
549
+ return custom_forward
550
+
551
+ def create_custom_forward_attn(module, return_dict=None, use_image_num=None):
552
+ def custom_forward(*inputs):
553
+ if return_dict is not None:
554
+ return module(*inputs, return_dict=return_dict, use_image_num=use_image_num)
555
+ else:
556
+ return module(*inputs, use_image_num=use_image_num)
557
+
558
+ return custom_forward
559
+
560
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
561
+ hidden_states = torch.utils.checkpoint.checkpoint(
562
+ create_custom_forward_attn(attn, return_dict=False, use_image_num=use_image_num),
563
+ hidden_states,
564
+ encoder_hidden_states,
565
+ )[0]
566
+ else:
567
+ hidden_states = resnet(hidden_states, temb)
568
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, use_image_num=use_image_num).sample
569
+
570
+ if self.upsamplers is not None:
571
+ for upsampler in self.upsamplers:
572
+ hidden_states = upsampler(hidden_states, upsample_size)
573
+
574
+ return hidden_states
575
+
576
+
577
+ class UpBlock3D(nn.Module):
578
+ def __init__(
579
+ self,
580
+ in_channels: int,
581
+ prev_output_channel: int,
582
+ out_channels: int,
583
+ temb_channels: int,
584
+ dropout: float = 0.0,
585
+ num_layers: int = 1,
586
+ resnet_eps: float = 1e-6,
587
+ resnet_time_scale_shift: str = "default",
588
+ resnet_act_fn: str = "swish",
589
+ resnet_groups: int = 32,
590
+ resnet_pre_norm: bool = True,
591
+ output_scale_factor=1.0,
592
+ add_upsample=True,
593
+ ):
594
+ super().__init__()
595
+ resnets = []
596
+
597
+ for i in range(num_layers):
598
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
599
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
600
+
601
+ resnets.append(
602
+ ResnetBlock3D(
603
+ in_channels=resnet_in_channels + res_skip_channels,
604
+ out_channels=out_channels,
605
+ temb_channels=temb_channels,
606
+ eps=resnet_eps,
607
+ groups=resnet_groups,
608
+ dropout=dropout,
609
+ time_embedding_norm=resnet_time_scale_shift,
610
+ non_linearity=resnet_act_fn,
611
+ output_scale_factor=output_scale_factor,
612
+ pre_norm=resnet_pre_norm,
613
+ )
614
+ )
615
+
616
+ self.resnets = nn.ModuleList(resnets)
617
+
618
+ if add_upsample:
619
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
620
+ else:
621
+ self.upsamplers = None
622
+
623
+ self.gradient_checkpointing = False
624
+
625
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
626
+ for resnet in self.resnets:
627
+ # pop res hidden states
628
+ res_hidden_states = res_hidden_states_tuple[-1]
629
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
630
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
631
+
632
+ if self.training and self.gradient_checkpointing:
633
+
634
+ def create_custom_forward(module):
635
+ def custom_forward(*inputs):
636
+ return module(*inputs)
637
+
638
+ return custom_forward
639
+
640
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
641
+ else:
642
+ hidden_states = resnet(hidden_states, temb)
643
+
644
+ if self.upsamplers is not None:
645
+ for upsampler in self.upsamplers:
646
+ hidden_states = upsampler(hidden_states, upsample_size)
647
+
648
+ return hidden_states
base/models/utils.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from
2
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
+ # and
4
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
+ # and
6
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
+ #
8
+ # thanks!
9
+
10
+
11
+ import os
12
+ import math
13
+ import torch
14
+
15
+ import numpy as np
16
+ import torch.nn as nn
17
+
18
+ from einops import repeat
19
+
20
+
21
+ #################################################################################
22
+ # Unet Utils #
23
+ #################################################################################
24
+
25
+ def checkpoint(func, inputs, params, flag):
26
+ """
27
+ Evaluate a function without caching intermediate activations, allowing for
28
+ reduced memory at the expense of extra compute in the backward pass.
29
+ :param func: the function to evaluate.
30
+ :param inputs: the argument sequence to pass to `func`.
31
+ :param params: a sequence of parameters `func` depends on but does not
32
+ explicitly take as arguments.
33
+ :param flag: if False, disable gradient checkpointing.
34
+ """
35
+ if flag:
36
+ args = tuple(inputs) + tuple(params)
37
+ return CheckpointFunction.apply(func, len(inputs), *args)
38
+ else:
39
+ return func(*inputs)
40
+
41
+
42
+ class CheckpointFunction(torch.autograd.Function):
43
+ @staticmethod
44
+ def forward(ctx, run_function, length, *args):
45
+ ctx.run_function = run_function
46
+ ctx.input_tensors = list(args[:length])
47
+ ctx.input_params = list(args[length:])
48
+
49
+ with torch.no_grad():
50
+ output_tensors = ctx.run_function(*ctx.input_tensors)
51
+ return output_tensors
52
+
53
+ @staticmethod
54
+ def backward(ctx, *output_grads):
55
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
56
+ with torch.enable_grad():
57
+ # Fixes a bug where the first op in run_function modifies the
58
+ # Tensor storage in place, which is not allowed for detach()'d
59
+ # Tensors.
60
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
61
+ output_tensors = ctx.run_function(*shallow_copies)
62
+ input_grads = torch.autograd.grad(
63
+ output_tensors,
64
+ ctx.input_tensors + ctx.input_params,
65
+ output_grads,
66
+ allow_unused=True,
67
+ )
68
+ del ctx.input_tensors
69
+ del ctx.input_params
70
+ del output_tensors
71
+ return (None, None) + input_grads
72
+
73
+
74
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
75
+ """
76
+ Create sinusoidal timestep embeddings.
77
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
78
+ These may be fractional.
79
+ :param dim: the dimension of the output.
80
+ :param max_period: controls the minimum frequency of the embeddings.
81
+ :return: an [N x dim] Tensor of positional embeddings.
82
+ """
83
+ if not repeat_only:
84
+ half = dim // 2
85
+ freqs = torch.exp(
86
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
87
+ ).to(device=timesteps.device)
88
+ args = timesteps[:, None].float() * freqs[None]
89
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
90
+ if dim % 2:
91
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
92
+ else:
93
+ embedding = repeat(timesteps, 'b -> b d', d=dim).contiguous()
94
+ return embedding
95
+
96
+
97
+ def zero_module(module):
98
+ """
99
+ Zero out the parameters of a module and return it.
100
+ """
101
+ for p in module.parameters():
102
+ p.detach().zero_()
103
+ return module
104
+
105
+
106
+ def scale_module(module, scale):
107
+ """
108
+ Scale the parameters of a module and return it.
109
+ """
110
+ for p in module.parameters():
111
+ p.detach().mul_(scale)
112
+ return module
113
+
114
+
115
+ def mean_flat(tensor):
116
+ """
117
+ Take the mean over all non-batch dimensions.
118
+ """
119
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
120
+
121
+
122
+ def normalization(channels):
123
+ """
124
+ Make a standard normalization layer.
125
+ :param channels: number of input channels.
126
+ :return: an nn.Module for normalization.
127
+ """
128
+ return GroupNorm32(32, channels)
129
+
130
+
131
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
132
+ class SiLU(nn.Module):
133
+ def forward(self, x):
134
+ return x * torch.sigmoid(x)
135
+
136
+
137
+ class GroupNorm32(nn.GroupNorm):
138
+ def forward(self, x):
139
+ return super().forward(x.float()).type(x.dtype)
140
+
141
+ def conv_nd(dims, *args, **kwargs):
142
+ """
143
+ Create a 1D, 2D, or 3D convolution module.
144
+ """
145
+ if dims == 1:
146
+ return nn.Conv1d(*args, **kwargs)
147
+ elif dims == 2:
148
+ return nn.Conv2d(*args, **kwargs)
149
+ elif dims == 3:
150
+ return nn.Conv3d(*args, **kwargs)
151
+ raise ValueError(f"unsupported dimensions: {dims}")
152
+
153
+
154
+ def linear(*args, **kwargs):
155
+ """
156
+ Create a linear module.
157
+ """
158
+ return nn.Linear(*args, **kwargs)
159
+
160
+
161
+ def avg_pool_nd(dims, *args, **kwargs):
162
+ """
163
+ Create a 1D, 2D, or 3D average pooling module.
164
+ """
165
+ if dims == 1:
166
+ return nn.AvgPool1d(*args, **kwargs)
167
+ elif dims == 2:
168
+ return nn.AvgPool2d(*args, **kwargs)
169
+ elif dims == 3:
170
+ return nn.AvgPool3d(*args, **kwargs)
171
+ raise ValueError(f"unsupported dimensions: {dims}")
172
+
173
+
174
+ # class HybridConditioner(nn.Module):
175
+
176
+ # def __init__(self, c_concat_config, c_crossattn_config):
177
+ # super().__init__()
178
+ # self.concat_conditioner = instantiate_from_config(c_concat_config)
179
+ # self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
180
+
181
+ # def forward(self, c_concat, c_crossattn):
182
+ # c_concat = self.concat_conditioner(c_concat)
183
+ # c_crossattn = self.crossattn_conditioner(c_crossattn)
184
+ # return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
185
+
186
+
187
+ def noise_like(shape, device, repeat=False):
188
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
189
+ noise = lambda: torch.randn(shape, device=device)
190
+ return repeat_noise() if repeat else noise()
191
+
192
+ def count_flops_attn(model, _x, y):
193
+ """
194
+ A counter for the `thop` package to count the operations in an
195
+ attention operation.
196
+ Meant to be used like:
197
+ macs, params = thop.profile(
198
+ model,
199
+ inputs=(inputs, timestamps),
200
+ custom_ops={QKVAttention: QKVAttention.count_flops},
201
+ )
202
+ """
203
+ b, c, *spatial = y[0].shape
204
+ num_spatial = int(np.prod(spatial))
205
+ # We perform two matmuls with the same number of ops.
206
+ # The first computes the weight matrix, the second computes
207
+ # the combination of the value vectors.
208
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
209
+ model.total_ops += torch.DoubleTensor([matmul_ops])
210
+
211
+ def count_params(model, verbose=False):
212
+ total_params = sum(p.numel() for p in model.parameters())
213
+ if verbose:
214
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
215
+ return total_params
base/pipelines/__pycache__/pipeline_videogen.cpython-311.pyc ADDED
Binary file (34.9 kB). View file
 
base/pipelines/pipeline_videogen.py ADDED
@@ -0,0 +1,677 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+
14
+ import inspect
15
+ from typing import Any, Callable, Dict, List, Optional, Union
16
+ import einops
17
+ import torch
18
+ from packaging import version
19
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
20
+
21
+ from diffusers.configuration_utils import FrozenDict
22
+ from diffusers.models import AutoencoderKL
23
+ from diffusers.schedulers import KarrasDiffusionSchedulers
24
+ from diffusers.utils import (
25
+ deprecate,
26
+ is_accelerate_available,
27
+ is_accelerate_version,
28
+ logging,
29
+ #randn_tensor,
30
+ replace_example_docstring,
31
+ BaseOutput,
32
+ )
33
+
34
+ try:
35
+ from diffusers.utils import randn_tensor
36
+ except:
37
+ from diffusers.utils.torch_utils import randn_tensor
38
+
39
+
40
+ from diffusers.pipeline_utils import DiffusionPipeline
41
+ from dataclasses import dataclass
42
+
43
+ import os, sys
44
+ sys.path.append(os.path.split(sys.path[0])[0])
45
+ from models.unet import UNet3DConditionModel
46
+
47
+ import numpy as np
48
+
49
+ @dataclass
50
+ class StableDiffusionPipelineOutput(BaseOutput):
51
+ video: torch.Tensor
52
+
53
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
54
+
55
+ EXAMPLE_DOC_STRING = """
56
+ Examples:
57
+ ```py
58
+ >>> import torch
59
+ >>> from diffusers import StableDiffusionPipeline
60
+
61
+ >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
62
+ >>> pipe = pipe.to("cuda")
63
+
64
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
65
+ >>> image = pipe(prompt).images[0]
66
+ ```
67
+ """
68
+
69
+
70
+ class VideoGenPipeline(DiffusionPipeline):
71
+ r"""
72
+ Pipeline for text-to-image generation using Stable Diffusion.
73
+
74
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
75
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
76
+
77
+ Args:
78
+ vae ([`AutoencoderKL`]):
79
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
80
+ text_encoder ([`CLIPTextModel`]):
81
+ Frozen text-encoder. Stable Diffusion uses the text portion of
82
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
83
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
84
+ tokenizer (`CLIPTokenizer`):
85
+ Tokenizer of class
86
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
87
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
88
+ scheduler ([`SchedulerMixin`]):
89
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
90
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
91
+ safety_checker ([`StableDiffusionSafetyChecker`]):
92
+ Classification module that estimates whether generated images could be considered offensive or harmful.
93
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
94
+ feature_extractor ([`CLIPFeatureExtractor`]):
95
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
96
+ """
97
+ _optional_components = ["safety_checker", "feature_extractor"]
98
+
99
+ def __init__(
100
+ self,
101
+ vae: AutoencoderKL,
102
+ text_encoder: CLIPTextModel,
103
+ tokenizer: CLIPTokenizer,
104
+ unet: UNet3DConditionModel,
105
+ scheduler: KarrasDiffusionSchedulers,
106
+ ):
107
+ super().__init__()
108
+
109
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
110
+ deprecation_message = (
111
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
112
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
113
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
114
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
115
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
116
+ " file"
117
+ )
118
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
119
+ new_config = dict(scheduler.config)
120
+ new_config["steps_offset"] = 1
121
+ scheduler._internal_dict = FrozenDict(new_config)
122
+
123
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
124
+ deprecation_message = (
125
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
126
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
127
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
128
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
129
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
130
+ )
131
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
132
+ new_config = dict(scheduler.config)
133
+ new_config["clip_sample"] = False
134
+ scheduler._internal_dict = FrozenDict(new_config)
135
+
136
+
137
+
138
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
139
+ version.parse(unet.config._diffusers_version).base_version
140
+ ) < version.parse("0.9.0.dev0")
141
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
142
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
143
+ deprecation_message = (
144
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
145
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
146
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
147
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
148
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
149
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
150
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
151
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
152
+ " the `unet/config.json` file"
153
+ )
154
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
155
+ new_config = dict(unet.config)
156
+ new_config["sample_size"] = 64
157
+ unet._internal_dict = FrozenDict(new_config)
158
+
159
+ self.register_modules(
160
+ vae=vae,
161
+ text_encoder=text_encoder,
162
+ tokenizer=tokenizer,
163
+ unet=unet,
164
+ scheduler=scheduler,
165
+ )
166
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
167
+ # self.register_to_config(requires_safety_checker=requires_safety_checker)
168
+
169
+ def enable_vae_slicing(self):
170
+ r"""
171
+ Enable sliced VAE decoding.
172
+
173
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
174
+ steps. This is useful to save some memory and allow larger batch sizes.
175
+ """
176
+ self.vae.enable_slicing()
177
+
178
+ def disable_vae_slicing(self):
179
+ r"""
180
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
181
+ computing decoding in one step.
182
+ """
183
+ self.vae.disable_slicing()
184
+
185
+ def enable_vae_tiling(self):
186
+ r"""
187
+ Enable tiled VAE decoding.
188
+
189
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
190
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
191
+ """
192
+ self.vae.enable_tiling()
193
+
194
+ def disable_vae_tiling(self):
195
+ r"""
196
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
197
+ computing decoding in one step.
198
+ """
199
+ self.vae.disable_tiling()
200
+
201
+ def enable_sequential_cpu_offload(self, gpu_id=0):
202
+ r"""
203
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
204
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
205
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
206
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
207
+ `enable_model_cpu_offload`, but performance is lower.
208
+ """
209
+ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
210
+ from accelerate import cpu_offload
211
+ else:
212
+ raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
213
+
214
+ device = torch.device(f"cuda:{gpu_id}")
215
+
216
+ if self.device.type != "cpu":
217
+ self.to("cpu", silence_dtype_warnings=True)
218
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
219
+
220
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
221
+ cpu_offload(cpu_offloaded_model, device)
222
+
223
+ # if self.safety_checker is not None:
224
+ # cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
225
+
226
+ def enable_model_cpu_offload(self, gpu_id=0):
227
+ r"""
228
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
229
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
230
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
231
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
232
+ """
233
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
234
+ from accelerate import cpu_offload_with_hook
235
+ else:
236
+ raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
237
+
238
+ device = torch.device(f"cuda:{gpu_id}")
239
+
240
+ if self.device.type != "cpu":
241
+ self.to("cpu", silence_dtype_warnings=True)
242
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
243
+
244
+ hook = None
245
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
246
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
247
+
248
+ self.final_offload_hook = hook
249
+
250
+ @property
251
+ def _execution_device(self):
252
+ r"""
253
+ Returns the device on which the pipeline's models will be executed. After calling
254
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
255
+ hooks.
256
+ """
257
+ if not hasattr(self.unet, "_hf_hook"):
258
+ return self.device
259
+ for module in self.unet.modules():
260
+ if (
261
+ hasattr(module, "_hf_hook")
262
+ and hasattr(module._hf_hook, "execution_device")
263
+ and module._hf_hook.execution_device is not None
264
+ ):
265
+ return torch.device(module._hf_hook.execution_device)
266
+ return self.device
267
+
268
+ def _encode_prompt(
269
+ self,
270
+ prompt,
271
+ device,
272
+ num_images_per_prompt,
273
+ do_classifier_free_guidance,
274
+ negative_prompt=None,
275
+ prompt_embeds: Optional[torch.FloatTensor] = None,
276
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
277
+ ):
278
+ r"""
279
+ Encodes the prompt into text encoder hidden states.
280
+
281
+ Args:
282
+ prompt (`str` or `List[str]`, *optional*):
283
+ prompt to be encoded
284
+ device: (`torch.device`):
285
+ torch device
286
+ num_images_per_prompt (`int`):
287
+ number of images that should be generated per prompt
288
+ do_classifier_free_guidance (`bool`):
289
+ whether to use classifier free guidance or not
290
+ negative_prompt (`str` or `List[str]`, *optional*):
291
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
292
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
293
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
294
+ prompt_embeds (`torch.FloatTensor`, *optional*):
295
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
296
+ provided, text embeddings will be generated from `prompt` input argument.
297
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
298
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
299
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
300
+ argument.
301
+ """
302
+ if prompt is not None and isinstance(prompt, str):
303
+ batch_size = 1
304
+ elif prompt is not None and isinstance(prompt, list):
305
+ batch_size = len(prompt)
306
+ else:
307
+ batch_size = prompt_embeds.shape[0]
308
+
309
+ if prompt_embeds is None:
310
+ text_inputs = self.tokenizer(
311
+ prompt,
312
+ padding="max_length",
313
+ max_length=self.tokenizer.model_max_length,
314
+ truncation=True,
315
+ return_tensors="pt",
316
+ )
317
+ text_input_ids = text_inputs.input_ids
318
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
319
+
320
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
321
+ text_input_ids, untruncated_ids
322
+ ):
323
+ removed_text = self.tokenizer.batch_decode(
324
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
325
+ )
326
+ logger.warning(
327
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
328
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
329
+ )
330
+
331
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
332
+ attention_mask = text_inputs.attention_mask.to(device)
333
+ else:
334
+ attention_mask = None
335
+
336
+ prompt_embeds = self.text_encoder(
337
+ text_input_ids.to(device),
338
+ attention_mask=attention_mask,
339
+ )
340
+ prompt_embeds = prompt_embeds[0]
341
+
342
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
343
+
344
+ bs_embed, seq_len, _ = prompt_embeds.shape
345
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
346
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
347
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
348
+
349
+ # get unconditional embeddings for classifier free guidance
350
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
351
+ uncond_tokens: List[str]
352
+ if negative_prompt is None:
353
+ uncond_tokens = [""] * batch_size
354
+ elif type(prompt) is not type(negative_prompt):
355
+ raise TypeError(
356
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
357
+ f" {type(prompt)}."
358
+ )
359
+ elif isinstance(negative_prompt, str):
360
+ uncond_tokens = [negative_prompt]
361
+ elif batch_size != len(negative_prompt):
362
+ raise ValueError(
363
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
364
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
365
+ " the batch size of `prompt`."
366
+ )
367
+ else:
368
+ uncond_tokens = negative_prompt
369
+
370
+ max_length = prompt_embeds.shape[1]
371
+ uncond_input = self.tokenizer(
372
+ uncond_tokens,
373
+ padding="max_length",
374
+ max_length=max_length,
375
+ truncation=True,
376
+ return_tensors="pt",
377
+ )
378
+
379
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
380
+ attention_mask = uncond_input.attention_mask.to(device)
381
+ else:
382
+ attention_mask = None
383
+
384
+ negative_prompt_embeds = self.text_encoder(
385
+ uncond_input.input_ids.to(device),
386
+ attention_mask=attention_mask,
387
+ )
388
+ negative_prompt_embeds = negative_prompt_embeds[0]
389
+
390
+ if do_classifier_free_guidance:
391
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
392
+ seq_len = negative_prompt_embeds.shape[1]
393
+
394
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
395
+
396
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
397
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
398
+
399
+ # For classifier free guidance, we need to do two forward passes.
400
+ # Here we concatenate the unconditional and text embeddings into a single batch
401
+ # to avoid doing two forward passes
402
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
403
+
404
+ return prompt_embeds
405
+
406
+ def decode_latents(self, latents):
407
+ video_length = latents.shape[2]
408
+ latents = 1 / 0.18215 * latents
409
+ latents = einops.rearrange(latents, "b c f h w -> (b f) c h w")
410
+ video = self.vae.decode(latents).sample
411
+ video = einops.rearrange(video, "(b f) c h w -> b f h w c", f=video_length)
412
+ video = ((video / 2 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().contiguous()
413
+ return video
414
+
415
+ def prepare_extra_step_kwargs(self, generator, eta):
416
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
417
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
418
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
419
+ # and should be between [0, 1]
420
+
421
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
422
+ extra_step_kwargs = {}
423
+ if accepts_eta:
424
+ extra_step_kwargs["eta"] = eta
425
+
426
+ # check if the scheduler accepts generator
427
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
428
+ if accepts_generator:
429
+ extra_step_kwargs["generator"] = generator
430
+ return extra_step_kwargs
431
+
432
+ def check_inputs(
433
+ self,
434
+ prompt,
435
+ height,
436
+ width,
437
+ callback_steps,
438
+ negative_prompt=None,
439
+ prompt_embeds=None,
440
+ negative_prompt_embeds=None,
441
+ ):
442
+ if height % 8 != 0 or width % 8 != 0:
443
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
444
+
445
+ if (callback_steps is None) or (
446
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
447
+ ):
448
+ raise ValueError(
449
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
450
+ f" {type(callback_steps)}."
451
+ )
452
+
453
+ if prompt is not None and prompt_embeds is not None:
454
+ raise ValueError(
455
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
456
+ " only forward one of the two."
457
+ )
458
+ elif prompt is None and prompt_embeds is None:
459
+ raise ValueError(
460
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
461
+ )
462
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
463
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
464
+
465
+ if negative_prompt is not None and negative_prompt_embeds is not None:
466
+ raise ValueError(
467
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
468
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
469
+ )
470
+
471
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
472
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
473
+ raise ValueError(
474
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
475
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
476
+ f" {negative_prompt_embeds.shape}."
477
+ )
478
+
479
+ def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
480
+ shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
481
+ if isinstance(generator, list) and len(generator) != batch_size:
482
+ raise ValueError(
483
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
484
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
485
+ )
486
+
487
+ if latents is None:
488
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
489
+ else:
490
+ latents = latents.to(device)
491
+
492
+ # scale the initial noise by the standard deviation required by the scheduler
493
+ latents = latents * self.scheduler.init_noise_sigma
494
+ return latents
495
+
496
+ @torch.no_grad()
497
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
498
+ def __call__(
499
+ self,
500
+ prompt: Union[str, List[str]] = None,
501
+ height: Optional[int] = None,
502
+ width: Optional[int] = None,
503
+ video_length: int = 16,
504
+ num_inference_steps: int = 50,
505
+ guidance_scale: float = 7.5,
506
+ negative_prompt: Optional[Union[str, List[str]]] = None,
507
+ num_images_per_prompt: Optional[int] = 1,
508
+ eta: float = 0.0,
509
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
510
+ latents: Optional[torch.FloatTensor] = None,
511
+ prompt_embeds: Optional[torch.FloatTensor] = None,
512
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
513
+ output_type: Optional[str] = "pil",
514
+ return_dict: bool = True,
515
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
516
+ callback_steps: int = 1,
517
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
518
+ ):
519
+ r"""
520
+ Function invoked when calling the pipeline for generation.
521
+
522
+ Args:
523
+ prompt (`str` or `List[str]`, *optional*):
524
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
525
+ instead.
526
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
527
+ The height in pixels of the generated image.
528
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
529
+ The width in pixels of the generated image.
530
+ num_inference_steps (`int`, *optional*, defaults to 50):
531
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
532
+ expense of slower inference.
533
+ guidance_scale (`float`, *optional*, defaults to 7.5):
534
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
535
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
536
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
537
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
538
+ usually at the expense of lower image quality.
539
+ negative_prompt (`str` or `List[str]`, *optional*):
540
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
541
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
542
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
543
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
544
+ The number of images to generate per prompt.
545
+ eta (`float`, *optional*, defaults to 0.0):
546
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
547
+ [`schedulers.DDIMScheduler`], will be ignored for others.
548
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
549
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
550
+ to make generation deterministic.
551
+ latents (`torch.FloatTensor`, *optional*):
552
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
553
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
554
+ tensor will ge generated by sampling using the supplied random `generator`.
555
+ prompt_embeds (`torch.FloatTensor`, *optional*):
556
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
557
+ provided, text embeddings will be generated from `prompt` input argument.
558
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
559
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
560
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
561
+ argument.
562
+ output_type (`str`, *optional*, defaults to `"pil"`):
563
+ The output format of the generate image. Choose between
564
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
565
+ return_dict (`bool`, *optional*, defaults to `True`):
566
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
567
+ plain tuple.
568
+ callback (`Callable`, *optional*):
569
+ A function that will be called every `callback_steps` steps during inference. The function will be
570
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
571
+ callback_steps (`int`, *optional*, defaults to 1):
572
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
573
+ called at every step.
574
+ cross_attention_kwargs (`dict`, *optional*):
575
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
576
+ `self.processor` in
577
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
578
+
579
+ Examples:
580
+
581
+ Returns:
582
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
583
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
584
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
585
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
586
+ (nsfw) content, according to the `safety_checker`.
587
+ """
588
+ # 0. Default height and width to unet
589
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
590
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
591
+
592
+ # 1. Check inputs. Raise error if not correct
593
+ self.check_inputs(
594
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
595
+ )
596
+
597
+ # 2. Define call parameters
598
+ if prompt is not None and isinstance(prompt, str):
599
+ batch_size = 1
600
+ elif prompt is not None and isinstance(prompt, list):
601
+ batch_size = len(prompt)
602
+ else:
603
+ batch_size = prompt_embeds.shape[0]
604
+
605
+ device = self._execution_device
606
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
607
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
608
+ # corresponds to doing no classifier free guidance.
609
+ do_classifier_free_guidance = guidance_scale > 1.0
610
+
611
+ # 3. Encode input prompt
612
+ prompt_embeds = self._encode_prompt(
613
+ prompt,
614
+ device,
615
+ num_images_per_prompt,
616
+ do_classifier_free_guidance,
617
+ negative_prompt,
618
+ prompt_embeds=prompt_embeds,
619
+ negative_prompt_embeds=negative_prompt_embeds,
620
+ )
621
+
622
+ # 4. Prepare timesteps
623
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
624
+ timesteps = self.scheduler.timesteps
625
+
626
+ # 5. Prepare latent variables
627
+ num_channels_latents = self.unet.config.in_channels
628
+ latents = self.prepare_latents(
629
+ batch_size * num_images_per_prompt,
630
+ num_channels_latents,
631
+ video_length,
632
+ height,
633
+ width,
634
+ prompt_embeds.dtype,
635
+ device,
636
+ generator,
637
+ latents,
638
+ )
639
+
640
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
641
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
642
+
643
+ # 7. Denoising loop
644
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
645
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
646
+ for i, t in enumerate(timesteps):
647
+ # expand the latents if we are doing classifier free guidance
648
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
649
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
650
+
651
+ # predict the noise residual
652
+ noise_pred = self.unet(
653
+ latent_model_input,
654
+ t,
655
+ encoder_hidden_states=prompt_embeds,
656
+ # cross_attention_kwargs=cross_attention_kwargs,
657
+ ).sample
658
+
659
+ # perform guidance
660
+ if do_classifier_free_guidance:
661
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
662
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
663
+
664
+ # compute the previous noisy sample x_t -> x_t-1
665
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
666
+
667
+ # call the callback, if provided
668
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
669
+ progress_bar.update()
670
+ if callback is not None and i % callback_steps == 0:
671
+ callback(i, t, latents)
672
+
673
+
674
+ # 8. Post-processing
675
+ video = self.decode_latents(latents)
676
+
677
+ return StableDiffusionPipelineOutput(video=video)
base/pipelines/sample.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import argparse
4
+ import torchvision
5
+
6
+ from pipeline_videogen import VideoGenPipeline
7
+
8
+ from download import find_model
9
+ from diffusers.schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, EulerDiscreteScheduler
10
+ from diffusers.models import AutoencoderKL
11
+ from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
12
+ from omegaconf import OmegaConf
13
+
14
+ import os, sys
15
+ sys.path.append(os.path.split(sys.path[0])[0])
16
+ from models import get_models
17
+ import imageio
18
+
19
+ def main(args):
20
+ #torch.manual_seed(args.seed)
21
+ torch.set_grad_enabled(False)
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+
24
+ sd_path = args.pretrained_path + "/stable-diffusion-v1-4"
25
+ unet = get_models(args, sd_path).to(device, dtype=torch.float16)
26
+ state_dict = find_model(args.pretrained_path + "/lavie_base.pt")
27
+ unet.load_state_dict(state_dict)
28
+
29
+ vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae", torch_dtype=torch.float16).to(device)
30
+ tokenizer_one = CLIPTokenizer.from_pretrained(sd_path, subfolder="tokenizer")
31
+ text_encoder_one = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device) # huge
32
+
33
+ # set eval mode
34
+ unet.eval()
35
+ vae.eval()
36
+ text_encoder_one.eval()
37
+
38
+ if args.sample_method == 'ddim':
39
+ scheduler = DDIMScheduler.from_pretrained(sd_path,
40
+ subfolder="scheduler",
41
+ beta_start=args.beta_start,
42
+ beta_end=args.beta_end,
43
+ beta_schedule=args.beta_schedule)
44
+ elif args.sample_method == 'eulerdiscrete':
45
+ scheduler = EulerDiscreteScheduler.from_pretrained(sd_path,
46
+ subfolder="scheduler",
47
+ beta_start=args.beta_start,
48
+ beta_end=args.beta_end,
49
+ beta_schedule=args.beta_schedule)
50
+ elif args.sample_method == 'ddpm':
51
+ scheduler = DDPMScheduler.from_pretrained(sd_path,
52
+ subfolder="scheduler",
53
+ beta_start=args.beta_start,
54
+ beta_end=args.beta_end,
55
+ beta_schedule=args.beta_schedule)
56
+ else:
57
+ raise NotImplementedError
58
+
59
+ videogen_pipeline = VideoGenPipeline(vae=vae,
60
+ text_encoder=text_encoder_one,
61
+ tokenizer=tokenizer_one,
62
+ scheduler=scheduler,
63
+ unet=unet).to(device)
64
+ videogen_pipeline.enable_xformers_memory_efficient_attention()
65
+
66
+ if not os.path.exists(args.output_folder):
67
+ os.makedirs(args.output_folder)
68
+
69
+ video_grids = []
70
+ for prompt in args.text_prompt:
71
+ print('Processing the ({}) prompt'.format(prompt))
72
+ videos = videogen_pipeline(prompt,
73
+ video_length=args.video_length,
74
+ height=args.image_size[0],
75
+ width=args.image_size[1],
76
+ num_inference_steps=args.num_sampling_steps,
77
+ guidance_scale=args.guidance_scale).video
78
+ imageio.mimwrite(args.output_folder + prompt.replace(' ', '_') + '.mp4', videos[0], fps=8, quality=9) # highest quality is 10, lowest is 0
79
+
80
+ print('save path {}'.format(args.output_folder))
81
+
82
+ if __name__ == "__main__":
83
+ parser = argparse.ArgumentParser()
84
+ parser.add_argument("--config", type=str, default="")
85
+ args = parser.parse_args()
86
+
87
+ main(OmegaConf.load(args.config))
88
+
base/pipelines/sample.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ export CUDA_VISIBLE_DEVICES=6
2
+ python pipelines/sample.py --config configs/sample.yaml
base/text_to_video/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import argparse
4
+ import torchvision
5
+
6
+ from pipelines.pipeline_videogen import VideoGenPipeline
7
+
8
+ from download import find_model
9
+ from diffusers.schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, EulerDiscreteScheduler
10
+ from diffusers.models import AutoencoderKL
11
+ from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
12
+ from omegaconf import OmegaConf
13
+
14
+ import os, sys
15
+ sys.path.append(os.path.split(sys.path[0])[0])
16
+ from models import get_models
17
+ import imageio
18
+
19
+ config_path = "./base/configs/sample.yaml"
20
+ args = OmegaConf.load("./base/configs/sample.yaml")
21
+
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+
24
+ def model_t2v_fun(args):
25
+ # sd_path = args.pretrained_path + "/stable-diffusion-v1-4"
26
+ sd_path = args.pretrained_path
27
+ unet = get_models(args, sd_path).to(device, dtype=torch.float16)
28
+ state_dict = find_model("./pretrained_models/lavie_base.pt")
29
+ # state_dict = find_model("./pretrained_models/lavie_base.pt")
30
+ unet.load_state_dict(state_dict)
31
+
32
+ vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae", torch_dtype=torch.float16).to(device)
33
+ tokenizer_one = CLIPTokenizer.from_pretrained(sd_path, subfolder="tokenizer")
34
+ text_encoder_one = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device) # huge
35
+ unet.eval()
36
+ vae.eval()
37
+ text_encoder_one.eval()
38
+ scheduler = DDIMScheduler.from_pretrained(sd_path, subfolder="scheduler", beta_start=args.beta_start, beta_end=args.beta_end, beta_schedule=args.beta_schedule)
39
+ return VideoGenPipeline(vae=vae, text_encoder=text_encoder_one, tokenizer=tokenizer_one, scheduler=scheduler, unet=unet)
40
+
41
+ def setup_seed(seed):
42
+ torch.manual_seed(seed)
43
+ torch.cuda.manual_seed_all(seed)
44
+
45
+
base/text_to_video/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (3.38 kB). View file
 
base/try.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ with gr.Blocks() as demo:
4
+ prompt = gr.Textbox(label="Prompt", placeholder="enter prompt", show_label=True, elem_id="prompt-in")
5
+ demo.launch(server_name="0.0.0.0")
environment.yml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: lavie
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ dependencies:
6
+ - python=3.11.3
7
+ - pytorch=2.0.1
8
+ - pytorch-cuda=11.7
9
+ - torchvision=0.15.2
10
+ - pip:
11
+ - accelerate==0.19.0
12
+ - av==10.0.0
13
+ - decord==0.6.0
14
+ - diffusers[torch]==0.16.0
15
+ - einops==0.6.1
16
+ - ffmpeg==1.4
17
+ - imageio==2.31.1
18
+ - imageio-ffmpeg==0.4.9
19
+ - pandas==2.0.1
20
+ - timm==0.6.13
21
+ - tqdm==4.65.0
22
+ - transformers==4.28.1
23
+ - xformers==0.0.20
24
+ - omegaconf==2.3.0
25
+ - natsort==8.4.0
26
+ - rotary_embedding_torch
27
+ - gradio==4.3.0
interpolation/configs/sample.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ args:
2
+ input_folder: "../res/base/"
3
+ pretrained_path: "../pretrained_models"
4
+ output_folder: "../res/interpolation/"
5
+ seed_list:
6
+ - 3418
7
+
8
+ fps_list:
9
+ - 24
10
+
11
+ # model config:
12
+ model: TSR
13
+ num_frames: 61
14
+ image_size: [320, 512]
15
+ num_sampling_steps: 50
16
+ vae: mse
17
+ use_timecross_transformer: False
18
+ frame_interval: 1
19
+
20
+ # sample config:
21
+ seed: 0
22
+ cfg_scale: 4.0
23
+ run_time: 12
24
+ use_compile: False
25
+ enable_xformers_memory_efficient_attention: True
26
+ num_sample: 1
27
+
28
+ additional_prompt: ", 4k."
29
+ negative_prompt: "None"
30
+ do_classifier_free_guidance: True
31
+ use_ddim_sample_loop: True
32
+
33
+ researve_frame: 3
34
+ mask_type: "tsr"
35
+ use_concat: True
36
+ copy_no_mask: True
interpolation/datasets/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from datasets import video_transforms
interpolation/datasets/video_transforms.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import numbers
4
+ from torchvision.transforms import RandomCrop, RandomResizedCrop
5
+
6
+ def _is_tensor_video_clip(clip):
7
+ if not torch.is_tensor(clip):
8
+ raise TypeError("clip should be Tensor. Got %s" % type(clip))
9
+
10
+ if not clip.ndimension() == 4:
11
+ raise ValueError("clip should be 4D. Got %dD" % clip.dim())
12
+
13
+ return True
14
+
15
+
16
+ def to_tensor(clip):
17
+ """
18
+ Convert tensor data type from uint8 to float, divide value by 255.0 and
19
+ permute the dimensions of clip tensor
20
+ Args:
21
+ clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
22
+ Return:
23
+ clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
24
+ """
25
+ _is_tensor_video_clip(clip)
26
+ if not clip.dtype == torch.uint8:
27
+ raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
28
+ # return clip.float().permute(3, 0, 1, 2) / 255.0
29
+ return clip.float() / 255.0
30
+
31
+
32
+ def resize(clip, target_size, interpolation_mode):
33
+ if len(target_size) != 2:
34
+ raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
35
+ return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
36
+
37
+
38
+ class ToTensorVideo:
39
+ """
40
+ Convert tensor data type from uint8 to float, divide value by 255.0 and
41
+ permute the dimensions of clip tensor
42
+ """
43
+
44
+ def __init__(self):
45
+ pass
46
+
47
+ def __call__(self, clip):
48
+ """
49
+ Args:
50
+ clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
51
+ Return:
52
+ clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
53
+ """
54
+ return to_tensor(clip)
55
+
56
+ def __repr__(self) -> str:
57
+ return self.__class__.__name__
58
+
59
+
60
+ class ResizeVideo:
61
+ '''
62
+ Resize to the specified size
63
+ '''
64
+ def __init__(
65
+ self,
66
+ size,
67
+ interpolation_mode="bilinear",
68
+ ):
69
+ if isinstance(size, tuple):
70
+ if len(size) != 2:
71
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
72
+ self.size = size
73
+ else:
74
+ self.size = (size, size)
75
+
76
+ self.interpolation_mode = interpolation_mode
77
+
78
+
79
+ def __call__(self, clip):
80
+ """
81
+ Args:
82
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
83
+ Returns:
84
+ torch.tensor: scale resized video clip.
85
+ size is (T, C, h, w)
86
+ """
87
+ clip_resize = resize(clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
88
+ return clip_resize
89
+
90
+ def __repr__(self) -> str:
91
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
92
+
93
+
94
+ class TemporalRandomCrop(object):
95
+ """Temporally crop the given frame indices at a random location.
96
+
97
+ Args:
98
+ size (int): Desired length of frames will be seen in the model.
99
+ """
100
+
101
+ def __init__(self, size):
102
+ self.size = size
103
+
104
+ def __call__(self, total_frames):
105
+ rand_end = max(0, total_frames - self.size - 1)
106
+ begin_index = random.randint(0, rand_end)
107
+ end_index = min(begin_index + self.size, total_frames)
108
+ return begin_index, end_index
109
+
interpolation/diffusion/__init__.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ from . import gaussian_diffusion as gd
7
+ from .respace import SpacedDiffusion, space_timesteps
8
+
9
+
10
+ def create_diffusion(
11
+ timestep_respacing,
12
+ noise_schedule="linear",
13
+ use_kl=False,
14
+ sigma_small=False,
15
+ predict_xstart=False,
16
+ # learn_sigma=True,
17
+ learn_sigma=False, # for unet
18
+ rescale_learned_sigmas=False,
19
+ diffusion_steps=1000
20
+ ):
21
+ betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
22
+ if use_kl:
23
+ loss_type = gd.LossType.RESCALED_KL
24
+ elif rescale_learned_sigmas:
25
+ loss_type = gd.LossType.RESCALED_MSE
26
+ else:
27
+ loss_type = gd.LossType.MSE
28
+ if timestep_respacing is None or timestep_respacing == "":
29
+ timestep_respacing = [diffusion_steps]
30
+ return SpacedDiffusion(
31
+ use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
32
+ betas=betas,
33
+ model_mean_type=(
34
+ gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
35
+ ),
36
+ model_var_type=(
37
+ (
38
+ gd.ModelVarType.FIXED_LARGE
39
+ if not sigma_small
40
+ else gd.ModelVarType.FIXED_SMALL
41
+ )
42
+ if not learn_sigma
43
+ else gd.ModelVarType.LEARNED_RANGE
44
+ ),
45
+ loss_type=loss_type
46
+ # rescale_timesteps=rescale_timesteps,
47
+ )
interpolation/diffusion/diffusion_utils.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ import torch as th
7
+ import numpy as np
8
+
9
+
10
+ def normal_kl(mean1, logvar1, mean2, logvar2):
11
+ """
12
+ Compute the KL divergence between two gaussians.
13
+ Shapes are automatically broadcasted, so batches can be compared to
14
+ scalars, among other use cases.
15
+ """
16
+ tensor = None
17
+ for obj in (mean1, logvar1, mean2, logvar2):
18
+ if isinstance(obj, th.Tensor):
19
+ tensor = obj
20
+ break
21
+ assert tensor is not None, "at least one argument must be a Tensor"
22
+
23
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
24
+ # Tensors, but it does not work for th.exp().
25
+ logvar1, logvar2 = [
26
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
27
+ for x in (logvar1, logvar2)
28
+ ]
29
+
30
+ return 0.5 * (
31
+ -1.0
32
+ + logvar2
33
+ - logvar1
34
+ + th.exp(logvar1 - logvar2)
35
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
36
+ )
37
+
38
+
39
+ def approx_standard_normal_cdf(x):
40
+ """
41
+ A fast approximation of the cumulative distribution function of the
42
+ standard normal.
43
+ """
44
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
45
+
46
+
47
+ def continuous_gaussian_log_likelihood(x, *, means, log_scales):
48
+ """
49
+ Compute the log-likelihood of a continuous Gaussian distribution.
50
+ :param x: the targets
51
+ :param means: the Gaussian mean Tensor.
52
+ :param log_scales: the Gaussian log stddev Tensor.
53
+ :return: a tensor like x of log probabilities (in nats).
54
+ """
55
+ centered_x = x - means
56
+ inv_stdv = th.exp(-log_scales)
57
+ normalized_x = centered_x * inv_stdv
58
+ log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
59
+ return log_probs
60
+
61
+
62
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
63
+ """
64
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
65
+ given image.
66
+ :param x: the target images. It is assumed that this was uint8 values,
67
+ rescaled to the range [-1, 1].
68
+ :param means: the Gaussian mean Tensor.
69
+ :param log_scales: the Gaussian log stddev Tensor.
70
+ :return: a tensor like x of log probabilities (in nats).
71
+ """
72
+ assert x.shape == means.shape == log_scales.shape
73
+ centered_x = x - means
74
+ inv_stdv = th.exp(-log_scales)
75
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
76
+ cdf_plus = approx_standard_normal_cdf(plus_in)
77
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
78
+ cdf_min = approx_standard_normal_cdf(min_in)
79
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
80
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
81
+ cdf_delta = cdf_plus - cdf_min
82
+ log_probs = th.where(
83
+ x < -0.999,
84
+ log_cdf_plus,
85
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
86
+ )
87
+ assert log_probs.shape == x.shape
88
+ return log_probs