fffiloni commited on
Commit
e02c605
1 Parent(s): 59fcb71

Upload 20 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ asset/gradio_example.png filter=lfs diff=lfs merge=lfs -text
37
+ data/example/1.png filter=lfs diff=lfs merge=lfs -text
38
+ data/example/2.png filter=lfs diff=lfs merge=lfs -text
39
+ data/example/3.png filter=lfs diff=lfs merge=lfs -text
40
+ data/example/4.png filter=lfs diff=lfs merge=lfs -text
41
+ data/example/5.png filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Ruixiang JIANG
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
LICENSE.md ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Ruixiang JIANG
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
asset/gradio_example.png ADDED

Git LFS Details

  • SHA256: 58215c200a2b2de8ef5629910947a195e64aac42767aa3725b6b0353882caa55
  • Pointer size: 132 Bytes
  • Size of remote file: 1.62 MB
data/example/1.png ADDED

Git LFS Details

  • SHA256: 5387fef7da6e30189251ab85b1aad1e63d92813bf23e642013ac31cf37380355
  • Pointer size: 132 Bytes
  • Size of remote file: 1.77 MB
data/example/2.png ADDED

Git LFS Details

  • SHA256: 858b63dd2e04066c7dc94d159434bf141e1292b7e302959415aad6ebb6e1c25b
  • Pointer size: 132 Bytes
  • Size of remote file: 1.6 MB
data/example/3.png ADDED

Git LFS Details

  • SHA256: d7e7970cfecb021639fe8b3ad31f5887e5aae66cc1ca878c8bd74789fa575eac
  • Pointer size: 132 Bytes
  • Size of remote file: 1.92 MB
data/example/4.png ADDED

Git LFS Details

  • SHA256: 3ee3290acb89ff35a6c75b899a0c70197030e2a707926e32e842ba22aa11d54a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.94 MB
data/example/5.png ADDED

Git LFS Details

  • SHA256: 2b3a723983a90b5ccb7f4ab6e1dc99b38cf284d195c57c7a0d41da30c68dfbbe
  • Pointer size: 132 Bytes
  • Size of remote file: 1.85 MB
data/example/annotation.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "image_path": "data/example/1.png",
4
+ "source_prompt": "",
5
+ "target_prompt": "A B&W pencil sketch, detailed cross-hatching"
6
+ },
7
+ {
8
+ "image_path": "data/example/2.png",
9
+ "source_prompt": "",
10
+ "target_prompt": "American comic, western style"
11
+ },
12
+ {
13
+ "image_path": "data/example/3.png",
14
+ "source_prompt": "",
15
+ "target_prompt": "Starry Night style painting by Van Gogh"
16
+ },
17
+ {
18
+ "image_path": "data/example/4.png",
19
+ "source_prompt": "",
20
+ "target_prompt": "Cubism painting, detailed."
21
+ },
22
+ {
23
+ "image_path": "data/example/5.png",
24
+ "source_prompt": "",
25
+ "target_prompt": "painting by Edvard Munch, The Scream"
26
+ }
27
+ ]
data/example/log.csv ADDED
File without changes
environment.yml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: gaussian_splatting
2
+ channels:
3
+ - pytorch
4
+ - conda-forge
5
+ - defaults
6
+ dependencies:
7
+ - cudatoolkit=11.6
8
+ - plyfile=0.8.1
9
+ - python=3.7.13
10
+ - pip=22.3.1
11
+ - pytorch=1.12.1
12
+ - torchaudio=0.12.1
13
+ - torchvision=0.13.1
14
+ - tqdm
15
+ - pip:
16
+ - submodules/diff-gaussian-rasterization
17
+ - submodules/simple-knn
example_config.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exp_name: example
2
+ batch_size: 1
3
+ num_steps: 50
4
+ start_step: 0
5
+ out_path: out/
6
+ seed: 10
7
+ share_attn_layers: [0, 1, 2, 3, 4, 5, 6, 7, 8]
8
+ share_resnet_layers: [0,1,2,3]
9
+
10
+ share_attn: true
11
+ share_cross_attn: true
12
+ share_query: true
13
+ share_key: true
14
+ share_value: false
15
+ use_adain: true
16
+ use_content_anchor: true
17
+ disentangle: true
18
+ resnet_mode: hidden
19
+
20
+ annotation: /data/example/annotation.json
21
+ style_cfg_scale: 7.5
22
+ tau_attn: 1
23
+ tau_feat: 1
injection_main.py ADDED
@@ -0,0 +1,739 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ import argparse, os
3
+
4
+
5
+ import torch
6
+ import requests
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from PIL import Image
10
+ from io import BytesIO
11
+ from tqdm.auto import tqdm
12
+ from matplotlib import pyplot as plt
13
+ from torchvision import transforms as tfms
14
+ from diffusers import (
15
+ StableDiffusionPipeline,
16
+ DDIMScheduler,
17
+ DiffusionPipeline,
18
+ StableDiffusionXLPipeline,
19
+ )
20
+ from diffusers.image_processor import VaeImageProcessor
21
+ import torch
22
+ import torch.nn as nn
23
+ import torchvision
24
+ import torchvision.transforms as transforms
25
+ from torchvision.utils import save_image
26
+ import argparse
27
+ import PIL.Image as Image
28
+ from torchvision.utils import make_grid
29
+ import numpy
30
+ from diffusers.schedulers import DDIMScheduler
31
+ import torch.nn.functional as F
32
+ from models import attn_injection
33
+ from omegaconf import OmegaConf
34
+ from typing import List, Tuple
35
+
36
+ import omegaconf
37
+ import utils.exp_utils
38
+ import json
39
+
40
+ device = torch.device("cuda")
41
+
42
+
43
+ def _get_text_embeddings(prompt: str, tokenizer, text_encoder, device):
44
+ # Tokenize text and get embeddings
45
+ text_inputs = tokenizer(
46
+ prompt,
47
+ padding="max_length",
48
+ max_length=tokenizer.model_max_length,
49
+ truncation=True,
50
+ return_tensors="pt",
51
+ )
52
+ text_input_ids = text_inputs.input_ids
53
+
54
+ with torch.no_grad():
55
+ prompt_embeds = text_encoder(
56
+ text_input_ids.to(device),
57
+ output_hidden_states=True,
58
+ )
59
+
60
+ pooled_prompt_embeds = prompt_embeds[0]
61
+ prompt_embeds = prompt_embeds.hidden_states[-2]
62
+ if prompt == "":
63
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
64
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
65
+ return negative_prompt_embeds, negative_pooled_prompt_embeds
66
+ return prompt_embeds, pooled_prompt_embeds
67
+
68
+
69
+ def _encode_text_sdxl(model: StableDiffusionXLPipeline, prompt: str):
70
+ device = model._execution_device
71
+ (
72
+ prompt_embeds,
73
+ pooled_prompt_embeds,
74
+ ) = _get_text_embeddings(prompt, model.tokenizer, model.text_encoder, device)
75
+ (
76
+ prompt_embeds_2,
77
+ pooled_prompt_embeds_2,
78
+ ) = _get_text_embeddings(prompt, model.tokenizer_2, model.text_encoder_2, device)
79
+ prompt_embeds = torch.cat((prompt_embeds, prompt_embeds_2), dim=-1)
80
+ text_encoder_projection_dim = model.text_encoder_2.config.projection_dim
81
+ add_time_ids = model._get_add_time_ids(
82
+ (1024, 1024), (0, 0), (1024, 1024), torch.float16, text_encoder_projection_dim
83
+ ).to(device)
84
+ # repeat the time ids for each prompt
85
+ add_time_ids = add_time_ids.repeat(len(prompt), 1)
86
+ added_cond_kwargs = {
87
+ "text_embeds": pooled_prompt_embeds_2,
88
+ "time_ids": add_time_ids,
89
+ }
90
+ return added_cond_kwargs, prompt_embeds
91
+
92
+
93
+ def _encode_text_sdxl_with_negative(
94
+ model: StableDiffusionXLPipeline, prompt: List[str]
95
+ ):
96
+
97
+ B = len(prompt)
98
+ added_cond_kwargs, prompt_embeds = _encode_text_sdxl(model, prompt)
99
+ added_cond_kwargs_uncond, prompt_embeds_uncond = _encode_text_sdxl(
100
+ model, ["" for _ in range(B)]
101
+ )
102
+ prompt_embeds = torch.cat(
103
+ (
104
+ prompt_embeds_uncond,
105
+ prompt_embeds,
106
+ )
107
+ )
108
+ added_cond_kwargs = {
109
+ "text_embeds": torch.cat(
110
+ (added_cond_kwargs_uncond["text_embeds"], added_cond_kwargs["text_embeds"])
111
+ ),
112
+ "time_ids": torch.cat(
113
+ (added_cond_kwargs_uncond["time_ids"], added_cond_kwargs["time_ids"])
114
+ ),
115
+ }
116
+ return added_cond_kwargs, prompt_embeds
117
+
118
+
119
+ # Sample function (regular DDIM)
120
+ @torch.no_grad()
121
+ def sample(
122
+ pipe,
123
+ prompt,
124
+ start_step=0,
125
+ start_latents=None,
126
+ intermediate_latents=None,
127
+ guidance_scale=3.5,
128
+ num_inference_steps=30,
129
+ num_images_per_prompt=1,
130
+ do_classifier_free_guidance=True,
131
+ negative_prompt="",
132
+ device=device,
133
+ ):
134
+ negative_prompt = [""] * len(prompt)
135
+ # Encode prompt
136
+ if isinstance(pipe, StableDiffusionPipeline):
137
+ text_embeddings = pipe._encode_prompt(
138
+ prompt,
139
+ device,
140
+ num_images_per_prompt,
141
+ do_classifier_free_guidance,
142
+ negative_prompt,
143
+ )
144
+ added_cond_kwargs = None
145
+ elif isinstance(pipe, StableDiffusionXLPipeline):
146
+ added_cond_kwargs, text_embeddings = _encode_text_sdxl_with_negative(
147
+ pipe, prompt
148
+ )
149
+
150
+ # Set num inference steps
151
+ pipe.scheduler.set_timesteps(num_inference_steps, device=device)
152
+
153
+ # Create a random starting point if we don't have one already
154
+ if start_latents is None:
155
+ start_latents = torch.randn(1, 4, 64, 64, device=device)
156
+ start_latents *= pipe.scheduler.init_noise_sigma
157
+
158
+ latents = start_latents.clone()
159
+
160
+ latents = latents.repeat(len(prompt), 1, 1, 1)
161
+ # assume that the first latent is used for reconstruction
162
+ for i in tqdm(range(start_step, num_inference_steps)):
163
+ latents[0] = intermediate_latents[(-i + 1)]
164
+ t = pipe.scheduler.timesteps[i]
165
+
166
+ # Expand the latents if we are doing classifier free guidance
167
+ latent_model_input = (
168
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
169
+ )
170
+ latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)
171
+
172
+ # Predict the noise residual
173
+ noise_pred = pipe.unet(
174
+ latent_model_input,
175
+ t,
176
+ encoder_hidden_states=text_embeddings,
177
+ added_cond_kwargs=added_cond_kwargs,
178
+ ).sample
179
+
180
+ # Perform guidance
181
+ if do_classifier_free_guidance:
182
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
183
+ noise_pred = noise_pred_uncond + guidance_scale * (
184
+ noise_pred_text - noise_pred_uncond
185
+ )
186
+ latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample
187
+
188
+ # Post-processing
189
+ images = pipe.decode_latents(latents)
190
+ images = pipe.numpy_to_pil(images)
191
+
192
+ return images
193
+
194
+
195
+ # Sample function (regular DDIM), but disentangle the content and style
196
+ @torch.no_grad()
197
+ def sample_disentangled(
198
+ pipe,
199
+ prompt,
200
+ start_step=0,
201
+ start_latents=None,
202
+ intermediate_latents=None,
203
+ guidance_scale=3.5,
204
+ num_inference_steps=30,
205
+ num_images_per_prompt=1,
206
+ do_classifier_free_guidance=True,
207
+ use_content_anchor=True,
208
+ negative_prompt="",
209
+ device=device,
210
+ ):
211
+ negative_prompt = [""] * len(prompt)
212
+ vae_decoder = VaeImageProcessor(vae_scale_factor=pipe.vae.config.scaling_factor)
213
+ # Encode prompt
214
+ if isinstance(pipe, StableDiffusionPipeline):
215
+ text_embeddings = pipe._encode_prompt(
216
+ prompt,
217
+ device,
218
+ num_images_per_prompt,
219
+ do_classifier_free_guidance,
220
+ negative_prompt,
221
+ )
222
+ added_cond_kwargs = None
223
+ elif isinstance(pipe, StableDiffusionXLPipeline):
224
+ added_cond_kwargs, text_embeddings = _encode_text_sdxl_with_negative(
225
+ pipe, prompt
226
+ )
227
+
228
+ # Set num inference steps
229
+ pipe.scheduler.set_timesteps(num_inference_steps, device=device)
230
+ # save
231
+
232
+ latent_shape = (
233
+ (1, 4, 64, 64) if isinstance(pipe, StableDiffusionPipeline) else (1, 4, 64, 64)
234
+ )
235
+ generative_latent = torch.randn(latent_shape, device=device)
236
+ generative_latent *= pipe.scheduler.init_noise_sigma
237
+
238
+ latents = start_latents.clone()
239
+
240
+ latents = latents.repeat(len(prompt), 1, 1, 1)
241
+ # randomly initalize the 1st lantent for generation
242
+
243
+ latents[1] = generative_latent
244
+ # assume that the first latent is used for reconstruction
245
+ for i in tqdm(range(start_step, num_inference_steps), desc="Stylizing"):
246
+
247
+ if use_content_anchor:
248
+ latents[0] = intermediate_latents[(-i + 1)]
249
+ t = pipe.scheduler.timesteps[i]
250
+
251
+ # Expand the latents if we are doing classifier free guidance
252
+ latent_model_input = (
253
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
254
+ )
255
+ latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)
256
+
257
+ # Predict the noise residual
258
+ noise_pred = pipe.unet(
259
+ latent_model_input,
260
+ t,
261
+ encoder_hidden_states=text_embeddings,
262
+ added_cond_kwargs=added_cond_kwargs,
263
+ ).sample
264
+
265
+ # Perform guidance
266
+ if do_classifier_free_guidance:
267
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
268
+ noise_pred = noise_pred_uncond + guidance_scale * (
269
+ noise_pred_text - noise_pred_uncond
270
+ )
271
+
272
+ latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample
273
+
274
+ # Post-processing
275
+ # images = vae_decoder.postprocess(latents)
276
+ pipe.vae.to(dtype=torch.float32)
277
+ latents = latents.to(next(iter(pipe.vae.post_quant_conv.parameters())).dtype)
278
+ latents = 1 / pipe.vae.config.scaling_factor * latents
279
+ images = pipe.vae.decode(latents, return_dict=False)[0]
280
+ images = (images / 2 + 0.5).clamp(0, 1)
281
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
282
+ images = images.cpu().permute(0, 2, 3, 1).float().numpy()
283
+ images = pipe.numpy_to_pil(images)
284
+ if isinstance(pipe, StableDiffusionXLPipeline):
285
+ pipe.vae.to(dtype=torch.float16)
286
+
287
+ return images
288
+
289
+
290
+ ## Inversion
291
+ @torch.no_grad()
292
+ def invert(
293
+ pipe,
294
+ start_latents,
295
+ prompt,
296
+ guidance_scale=3.5,
297
+ num_inference_steps=50,
298
+ num_images_per_prompt=1,
299
+ do_classifier_free_guidance=True,
300
+ negative_prompt="",
301
+ device=device,
302
+ ):
303
+
304
+ # Encode prompt
305
+ if isinstance(pipe, StableDiffusionPipeline):
306
+ text_embeddings = pipe._encode_prompt(
307
+ prompt,
308
+ device,
309
+ num_images_per_prompt,
310
+ do_classifier_free_guidance,
311
+ negative_prompt,
312
+ )
313
+ added_cond_kwargs = None
314
+ latents = start_latents.clone().detach()
315
+ elif isinstance(pipe, StableDiffusionXLPipeline):
316
+ added_cond_kwargs, text_embeddings = _encode_text_sdxl_with_negative(
317
+ pipe, [prompt]
318
+ ) # Latents are now the specified start latents
319
+ latents = start_latents.clone().detach().half()
320
+
321
+ # We'll keep a list of the inverted latents as the process goes on
322
+ intermediate_latents = []
323
+
324
+ # Set num inference steps
325
+ pipe.scheduler.set_timesteps(num_inference_steps, device=device)
326
+
327
+ # Reversed timesteps <<<<<<<<<<<<<<<<<<<<
328
+ timesteps = reversed(pipe.scheduler.timesteps)
329
+
330
+ for i in tqdm(
331
+ range(1, num_inference_steps),
332
+ total=num_inference_steps - 1,
333
+ desc="DDIM Inversion",
334
+ ):
335
+
336
+ # We'll skip the final iteration
337
+ if i >= num_inference_steps - 1:
338
+ continue
339
+
340
+ t = timesteps[i]
341
+
342
+ # Expand the latents if we are doing classifier free guidance
343
+ latent_model_input = (
344
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
345
+ )
346
+ latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)
347
+
348
+ # Predict the noise residual
349
+ noise_pred = pipe.unet(
350
+ latent_model_input,
351
+ t,
352
+ encoder_hidden_states=text_embeddings,
353
+ added_cond_kwargs=added_cond_kwargs,
354
+ ).sample
355
+
356
+ # Perform guidance
357
+ if do_classifier_free_guidance:
358
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
359
+ noise_pred = noise_pred_uncond + guidance_scale * (
360
+ noise_pred_text - noise_pred_uncond
361
+ )
362
+
363
+ current_t = max(0, t.item() - (1000 // num_inference_steps)) # t
364
+ next_t = t # min(999, t.item() + (1000//num_inference_steps)) # t+1
365
+ alpha_t = pipe.scheduler.alphas_cumprod[current_t]
366
+ alpha_t_next = pipe.scheduler.alphas_cumprod[next_t]
367
+
368
+ # Inverted update step (re-arranging the update step to get x(t) (new latents) as a function of x(t-1) (current latents)
369
+ latents = (latents - (1 - alpha_t).sqrt() * noise_pred) * (
370
+ alpha_t_next.sqrt() / alpha_t.sqrt()
371
+ ) + (1 - alpha_t_next).sqrt() * noise_pred
372
+
373
+ # Store
374
+ intermediate_latents.append(latents)
375
+
376
+ return torch.cat(intermediate_latents)
377
+
378
+
379
+ def style_image_with_inversion(
380
+ pipe,
381
+ input_image,
382
+ input_image_prompt,
383
+ style_prompt,
384
+ num_steps=100,
385
+ start_step=30,
386
+ guidance_scale=3.5,
387
+ disentangle=False,
388
+ share_attn=False,
389
+ share_cross_attn=False,
390
+ share_resnet_layers=[0, 1],
391
+ share_attn_layers=[],
392
+ c2s_layers=[0, 1],
393
+ share_key=True,
394
+ share_query=True,
395
+ share_value=False,
396
+ use_adain=True,
397
+ use_content_anchor=True,
398
+ output_dir: str = None,
399
+ resnet_mode: str = None,
400
+ return_intermediate=False,
401
+ intermediate_latents=None,
402
+ ):
403
+ with torch.no_grad():
404
+ pipe.vae.to(dtype=torch.float32)
405
+ latent = pipe.vae.encode(input_image.to(device) * 2 - 1)
406
+ # latent = pipe.vae.encode(input_image.to(device))
407
+ l = pipe.vae.config.scaling_factor * latent.latent_dist.sample()
408
+ if isinstance(pipe, StableDiffusionXLPipeline):
409
+ pipe.vae.to(dtype=torch.float16)
410
+ if intermediate_latents is None:
411
+ inverted_latents = invert(
412
+ pipe, l, input_image_prompt, num_inference_steps=num_steps
413
+ )
414
+ else:
415
+ inverted_latents = intermediate_latents
416
+
417
+ attn_injection.register_attention_processors(
418
+ pipe,
419
+ base_dir=output_dir,
420
+ resnet_mode=resnet_mode,
421
+ attn_mode="artist" if disentangle else "pnp",
422
+ disentangle=disentangle,
423
+ share_resblock=True,
424
+ share_attn=share_attn,
425
+ share_cross_attn=share_cross_attn,
426
+ share_resnet_layers=share_resnet_layers,
427
+ share_attn_layers=share_attn_layers,
428
+ share_key=share_key,
429
+ share_query=share_query,
430
+ share_value=share_value,
431
+ use_adain=use_adain,
432
+ c2s_layers=c2s_layers,
433
+ )
434
+
435
+ if disentangle:
436
+ final_im = sample_disentangled(
437
+ pipe,
438
+ style_prompt,
439
+ start_latents=inverted_latents[-(start_step + 1)][None],
440
+ intermediate_latents=inverted_latents,
441
+ start_step=start_step,
442
+ num_inference_steps=num_steps,
443
+ guidance_scale=guidance_scale,
444
+ use_content_anchor=use_content_anchor,
445
+ )
446
+ else:
447
+ final_im = sample(
448
+ pipe,
449
+ style_prompt,
450
+ start_latents=inverted_latents[-(start_step + 1)][None],
451
+ intermediate_latents=inverted_latents,
452
+ start_step=start_step,
453
+ num_inference_steps=num_steps,
454
+ guidance_scale=guidance_scale,
455
+ )
456
+
457
+ # unset the attention processors
458
+ attn_injection.unset_attention_processors(
459
+ pipe,
460
+ unset_share_attn=True,
461
+ unset_share_resblock=True,
462
+ )
463
+ if return_intermediate:
464
+ return final_im, inverted_latents
465
+ return final_im
466
+
467
+
468
+ if __name__ == "__main__":
469
+
470
+ # Load a pipeline
471
+ pipe = StableDiffusionPipeline.from_pretrained(
472
+ "stabilityai/stable-diffusion-2-1-base"
473
+ ).to(device)
474
+
475
+ # pipe = DiffusionPipeline.from_pretrained(
476
+ # # "playgroundai/playground-v2-1024px-aesthetic",
477
+ # torch_dtype=torch.float16,
478
+ # use_safetensors=True,
479
+ # add_watermarker=False,
480
+ # variant="fp16",
481
+ # )
482
+ # pipe.to("cuda")
483
+
484
+ # Set up a DDIM scheduler
485
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
486
+
487
+ parser = argparse.ArgumentParser(description="Stable Diffusion with OmegaConf")
488
+ parser.add_argument(
489
+ "--config", type=str, default="config.yaml", help="Path to the config file"
490
+ )
491
+ parser.add_argument(
492
+ "--mode",
493
+ type=str,
494
+ default="dataset",
495
+ choices=["dataset", "cli", "app"],
496
+ help="Path to the config file",
497
+ )
498
+ parser.add_argument(
499
+ "--image_dir", type=str, default="test.png", help="Path to the image"
500
+ )
501
+ parser.add_argument(
502
+ "--prompt",
503
+ type=str,
504
+ default="an impressionist painting",
505
+ help="Stylization prompt",
506
+ )
507
+ # mode = "single_control_content"
508
+ args = parser.parse_args()
509
+ config_dir = args.config
510
+ mode = args.mode
511
+ # mode = "dataset"
512
+ out_name = ["content_delegation", "style_delegation", "style_out"]
513
+
514
+ if mode == "dataset":
515
+ cfg = OmegaConf.load(config_dir)
516
+
517
+ base_output_path = cfg.out_path
518
+ if not os.path.exists(cfg.out_path):
519
+ os.makedirs(cfg.out_path)
520
+ base_output_path = os.path.join(base_output_path, cfg.exp_name)
521
+
522
+ experiment_output_path = utils.exp_utils.make_unique_experiment_path(
523
+ base_output_path
524
+ )
525
+
526
+ # Save the experiment configuration
527
+ config_file_path = os.path.join(experiment_output_path, "config.yaml")
528
+ omegaconf.OmegaConf.save(cfg, config_file_path)
529
+
530
+ # Seed all
531
+
532
+ annotation = json.load(open(cfg.annotation))
533
+ with open(os.path.join(experiment_output_path, "annotation.json"), "w") as f:
534
+ json.dump(annotation, f)
535
+ for i, entry in enumerate(annotation):
536
+ utils.exp_utils.seed_all(cfg.seed)
537
+ image_path = entry["image_path"]
538
+ src_prompt = entry["source_prompt"]
539
+ tgt_prompt = entry["target_prompt"]
540
+ resolution = 512 if isinstance(pipe, StableDiffusionXLPipeline) else 512
541
+ input_image = utils.exp_utils.get_processed_image(
542
+ image_path, device, resolution
543
+ )
544
+
545
+ prompt_in = [
546
+ src_prompt, # reconstruction
547
+ tgt_prompt, # uncontrolled style
548
+ "", # controlled style
549
+ ]
550
+
551
+ imgs = style_image_with_inversion(
552
+ pipe,
553
+ input_image,
554
+ src_prompt,
555
+ style_prompt=prompt_in,
556
+ num_steps=cfg.num_steps,
557
+ start_step=cfg.start_step,
558
+ guidance_scale=cfg.style_cfg_scale,
559
+ disentangle=cfg.disentangle,
560
+ resnet_mode=cfg.resnet_mode,
561
+ share_attn=cfg.share_attn,
562
+ share_cross_attn=cfg.share_cross_attn,
563
+ share_resnet_layers=cfg.share_resnet_layers,
564
+ share_attn_layers=cfg.share_attn_layers,
565
+ share_key=cfg.share_key,
566
+ share_query=cfg.share_query,
567
+ share_value=cfg.share_value,
568
+ use_content_anchor=cfg.use_content_anchor,
569
+ use_adain=cfg.use_adain,
570
+ output_dir=experiment_output_path,
571
+ )
572
+
573
+ for j, img in enumerate(imgs):
574
+ img.save(f"{experiment_output_path}/out_{i}_{out_name[j]}.png")
575
+ print(
576
+ f"Image saved as {experiment_output_path}/out_{i}_{out_name[j]}.png"
577
+ )
578
+ elif mode == "cli":
579
+ cfg = OmegaConf.load(config_dir)
580
+ utils.exp_utils.seed_all(cfg.seed)
581
+ image = utils.exp_utils.get_processed_image(args.image_dir, device, 512)
582
+ tgt_prompt = args.prompt
583
+ src_prompt = ""
584
+ prompt_in = [
585
+ "", # reconstruction
586
+ tgt_prompt, # uncontrolled style
587
+ "", # controlled style
588
+ ]
589
+ out_dir = "./out"
590
+ os.makedirs(out_dir, exist_ok=True)
591
+ imgs = style_image_with_inversion(
592
+ pipe,
593
+ image,
594
+ src_prompt,
595
+ style_prompt=prompt_in,
596
+ num_steps=cfg.num_steps,
597
+ start_step=cfg.start_step,
598
+ guidance_scale=cfg.style_cfg_scale,
599
+ disentangle=cfg.disentangle,
600
+ resnet_mode=cfg.resnet_mode,
601
+ share_attn=cfg.share_attn,
602
+ share_cross_attn=cfg.share_cross_attn,
603
+ share_resnet_layers=cfg.share_resnet_layers,
604
+ share_attn_layers=cfg.share_attn_layers,
605
+ share_key=cfg.share_key,
606
+ share_query=cfg.share_query,
607
+ share_value=cfg.share_value,
608
+ use_content_anchor=cfg.use_content_anchor,
609
+ use_adain=cfg.use_adain,
610
+ output_dir=out_dir,
611
+ )
612
+ image_base_name = os.path.basename(args.image_dir).split(".")[0]
613
+ for j, img in enumerate(imgs):
614
+ img.save(f"{out_dir}/{image_base_name}_out_{out_name[j]}.png")
615
+ print(f"Image saved as {out_dir}/{image_base_name}_out_{out_name[j]}.png")
616
+ elif mode == "app":
617
+ # gradio
618
+ import gradio as gr
619
+
620
+ def style_transfer_app(
621
+ prompt,
622
+ image,
623
+ cfg_scale=7.5,
624
+ num_content_layers=4,
625
+ num_style_layers=9,
626
+ seed=0,
627
+ progress=gr.Progress(track_tqdm=True),
628
+ ):
629
+ utils.exp_utils.seed_all(seed)
630
+ image = utils.exp_utils.process_image(image, device, 512)
631
+
632
+ tgt_prompt = prompt
633
+ src_prompt = ""
634
+ prompt_in = [
635
+ "", # reconstruction
636
+ tgt_prompt, # uncontrolled style
637
+ "", # controlled style
638
+ ]
639
+
640
+ share_resnet_layers = (
641
+ list(range(num_content_layers)) if num_content_layers != 0 else None
642
+ )
643
+ share_attn_layers = (
644
+ list(range(num_style_layers)) if num_style_layers != 0 else None
645
+ )
646
+ imgs = style_image_with_inversion(
647
+ pipe,
648
+ image,
649
+ src_prompt,
650
+ style_prompt=prompt_in,
651
+ num_steps=50,
652
+ start_step=0,
653
+ guidance_scale=cfg_scale,
654
+ disentangle=True,
655
+ resnet_mode="hidden",
656
+ share_attn=True,
657
+ share_cross_attn=True,
658
+ share_resnet_layers=share_resnet_layers,
659
+ share_attn_layers=share_attn_layers,
660
+ share_key=True,
661
+ share_query=True,
662
+ share_value=False,
663
+ use_content_anchor=True,
664
+ use_adain=True,
665
+ output_dir="./",
666
+ )
667
+
668
+ return imgs[2]
669
+
670
+ # load examples
671
+ examples = []
672
+ annotation = json.load(open("data/example/annotation.json"))
673
+ for entry in annotation:
674
+ image = utils.exp_utils.get_processed_image(
675
+ entry["image_path"], device, 512
676
+ )
677
+ image = transforms.ToPILImage()(image[0])
678
+
679
+ examples.append([entry["target_prompt"], image, None, None, None])
680
+
681
+ text_input = gr.Textbox(
682
+ value="An impressionist painting",
683
+ label="Text Prompt",
684
+ info="Describe the style you want to apply to the image, do not include the description of the image content itself",
685
+ lines=2,
686
+ placeholder="Enter a text prompt",
687
+ )
688
+ image_input = gr.Image(
689
+ height="80%",
690
+ width="80%",
691
+ label="Content image (will be resized to 512x512)",
692
+ interactive=True,
693
+ )
694
+ cfg_slider = gr.Slider(
695
+ 0,
696
+ 15,
697
+ value=7.5,
698
+ label="Classifier Free Guidance (CFG) Scale",
699
+ info="higher values give more style, 7.5 should be good for most cases",
700
+ )
701
+ content_slider = gr.Slider(
702
+ 0,
703
+ 9,
704
+ value=4,
705
+ step=1,
706
+ label="Number of content control layer",
707
+ info="higher values make it more similar to original image. Default to control first 4 layers",
708
+ )
709
+ style_slider = gr.Slider(
710
+ 0,
711
+ 9,
712
+ value=9,
713
+ step=1,
714
+ label="Number of style control layer",
715
+ info="higher values make it more similar to target style. Default to control first 9 layers, usually not necessary to change.",
716
+ )
717
+ seed_slider = gr.Slider(
718
+ 0,
719
+ 100,
720
+ value=0,
721
+ step=1,
722
+ label="Seed",
723
+ info="Random seed for the model",
724
+ )
725
+ app = gr.Interface(
726
+ fn=style_transfer_app,
727
+ inputs=[
728
+ text_input,
729
+ image_input,
730
+ cfg_slider,
731
+ content_slider,
732
+ style_slider,
733
+ seed_slider,
734
+ ],
735
+ outputs=["image"],
736
+ title="Artist Interactive Demo",
737
+ examples=examples,
738
+ )
739
+ app.launch()
lpipsPyTorch/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from .modules.lpips import LPIPS
4
+
5
+
6
+ def lpips(x: torch.Tensor,
7
+ y: torch.Tensor,
8
+ net_type: str = 'alex',
9
+ version: str = '0.1'):
10
+ r"""Function that measures
11
+ Learned Perceptual Image Patch Similarity (LPIPS).
12
+
13
+ Arguments:
14
+ x, y (torch.Tensor): the input tensors to compare.
15
+ net_type (str): the network type to compare the features:
16
+ 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
17
+ version (str): the version of LPIPS. Default: 0.1.
18
+ """
19
+ device = x.device
20
+ criterion = LPIPS(net_type, version).to(device)
21
+ return criterion(x, y)
lpipsPyTorch/modules/lpips.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .networks import get_network, LinLayers
5
+ from .utils import get_state_dict
6
+
7
+
8
+ class LPIPS(nn.Module):
9
+ r"""Creates a criterion that measures
10
+ Learned Perceptual Image Patch Similarity (LPIPS).
11
+
12
+ Arguments:
13
+ net_type (str): the network type to compare the features:
14
+ 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
15
+ version (str): the version of LPIPS. Default: 0.1.
16
+ """
17
+ def __init__(self, net_type: str = 'alex', version: str = '0.1'):
18
+
19
+ assert version in ['0.1'], 'v0.1 is only supported now'
20
+
21
+ super(LPIPS, self).__init__()
22
+
23
+ # pretrained network
24
+ self.net = get_network(net_type)
25
+
26
+ # linear layers
27
+ self.lin = LinLayers(self.net.n_channels_list)
28
+ self.lin.load_state_dict(get_state_dict(net_type, version))
29
+
30
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
31
+ feat_x, feat_y = self.net(x), self.net(y)
32
+
33
+ diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
34
+ res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
35
+
36
+ return torch.sum(torch.cat(res, 0), 0, True)
lpipsPyTorch/modules/networks.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence
2
+
3
+ from itertools import chain
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torchvision import models
8
+
9
+ from .utils import normalize_activation
10
+
11
+
12
+ def get_network(net_type: str):
13
+ if net_type == 'alex':
14
+ return AlexNet()
15
+ elif net_type == 'squeeze':
16
+ return SqueezeNet()
17
+ elif net_type == 'vgg':
18
+ return VGG16()
19
+ else:
20
+ raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
21
+
22
+
23
+ class LinLayers(nn.ModuleList):
24
+ def __init__(self, n_channels_list: Sequence[int]):
25
+ super(LinLayers, self).__init__([
26
+ nn.Sequential(
27
+ nn.Identity(),
28
+ nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
29
+ ) for nc in n_channels_list
30
+ ])
31
+
32
+ for param in self.parameters():
33
+ param.requires_grad = False
34
+
35
+
36
+ class BaseNet(nn.Module):
37
+ def __init__(self):
38
+ super(BaseNet, self).__init__()
39
+
40
+ # register buffer
41
+ self.register_buffer(
42
+ 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
43
+ self.register_buffer(
44
+ 'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
45
+
46
+ def set_requires_grad(self, state: bool):
47
+ for param in chain(self.parameters(), self.buffers()):
48
+ param.requires_grad = state
49
+
50
+ def z_score(self, x: torch.Tensor):
51
+ return (x - self.mean) / self.std
52
+
53
+ def forward(self, x: torch.Tensor):
54
+ x = self.z_score(x)
55
+
56
+ output = []
57
+ for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
58
+ x = layer(x)
59
+ if i in self.target_layers:
60
+ output.append(normalize_activation(x))
61
+ if len(output) == len(self.target_layers):
62
+ break
63
+ return output
64
+
65
+
66
+ class SqueezeNet(BaseNet):
67
+ def __init__(self):
68
+ super(SqueezeNet, self).__init__()
69
+
70
+ self.layers = models.squeezenet1_1(True).features
71
+ self.target_layers = [2, 5, 8, 10, 11, 12, 13]
72
+ self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
73
+
74
+ self.set_requires_grad(False)
75
+
76
+
77
+ class AlexNet(BaseNet):
78
+ def __init__(self):
79
+ super(AlexNet, self).__init__()
80
+
81
+ self.layers = models.alexnet(True).features
82
+ self.target_layers = [2, 5, 8, 10, 12]
83
+ self.n_channels_list = [64, 192, 384, 256, 256]
84
+
85
+ self.set_requires_grad(False)
86
+
87
+
88
+ class VGG16(BaseNet):
89
+ def __init__(self):
90
+ super(VGG16, self).__init__()
91
+
92
+ self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features
93
+ self.target_layers = [4, 9, 16, 23, 30]
94
+ self.n_channels_list = [64, 128, 256, 512, 512]
95
+
96
+ self.set_requires_grad(False)
lpipsPyTorch/modules/utils.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import torch
4
+
5
+
6
+ def normalize_activation(x, eps=1e-10):
7
+ norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
8
+ return x / (norm_factor + eps)
9
+
10
+
11
+ def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
12
+ # build url
13
+ url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
14
+ + f'master/lpips/weights/v{version}/{net_type}.pth'
15
+
16
+ # download
17
+ old_state_dict = torch.hub.load_state_dict_from_url(
18
+ url, progress=True,
19
+ map_location=None if torch.cuda.is_available() else torch.device('cpu')
20
+ )
21
+
22
+ # rename keys
23
+ new_state_dict = OrderedDict()
24
+ for key, val in old_state_dict.items():
25
+ new_key = key
26
+ new_key = new_key.replace('lin', '')
27
+ new_key = new_key.replace('model.', '')
28
+ new_state_dict[new_key] = val
29
+
30
+ return new_state_dict
models/attn_injection.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding : utf-8 -*-
2
+ # @FileName : attn_injection.py
3
+ # @Author : Ruixiang JIANG (Songrise)
4
+ # @Time : Mar 20, 2024
5
+ # @Github : https://github.com/songrise
6
+ # @Description: implement attention dump and attention injection for CPSD
7
+
8
+ from __future__ import annotations
9
+
10
+ from dataclasses import dataclass
11
+ from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.nn import functional as nnf
15
+ from diffusers.models import attention_processor
16
+ import einops
17
+ from diffusers.models import unet_2d_condition, attention, transformer_2d, resnet
18
+ from diffusers.models.unets import unet_2d_blocks
19
+
20
+ # from diffusers.models.unet_2d import CrossAttnUpBlock2D
21
+ from typing import Optional, List
22
+
23
+ T = torch.Tensor
24
+ import os
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class StyleAlignedArgs:
29
+ share_group_norm: bool = True
30
+ share_layer_norm: bool = (True,)
31
+ share_attention: bool = True
32
+ adain_queries: bool = True
33
+ adain_keys: bool = True
34
+ adain_values: bool = False
35
+ full_attention_share: bool = False
36
+ shared_score_scale: float = 1.0
37
+ shared_score_shift: float = 0.0
38
+ only_self_level: float = 0.0
39
+
40
+
41
+ def expand_first(
42
+ feat: T,
43
+ scale=1.0,
44
+ ) -> T:
45
+ b = feat.shape[0]
46
+ feat_style = torch.stack((feat[0], feat[b // 2])).unsqueeze(1)
47
+ if scale == 1:
48
+ feat_style = feat_style.expand(2, b // 2, *feat.shape[1:])
49
+ else:
50
+ feat_style = feat_style.repeat(1, b // 2, 1, 1, 1)
51
+ feat_style = torch.cat([feat_style[:, :1], scale * feat_style[:, 1:]], dim=1)
52
+ return feat_style.reshape(*feat.shape)
53
+
54
+
55
+ def concat_first(feat: T, dim=2, scale=1.0) -> T:
56
+ feat_style = expand_first(feat, scale=scale)
57
+ return torch.cat((feat, feat_style), dim=dim)
58
+
59
+
60
+ def calc_mean_std(feat, eps: float = 1e-5) -> tuple[T, T]:
61
+ feat_std = (feat.var(dim=-2, keepdims=True) + eps).sqrt()
62
+ feat_mean = feat.mean(dim=-2, keepdims=True)
63
+ return feat_mean, feat_std
64
+
65
+
66
+ def adain(feat: T) -> T:
67
+ feat_mean, feat_std = calc_mean_std(feat)
68
+ feat_style_mean = expand_first(feat_mean)
69
+ feat_style_std = expand_first(feat_std)
70
+ feat = (feat - feat_mean) / feat_std
71
+ feat = feat * feat_style_std + feat_style_mean
72
+ return feat
73
+
74
+
75
+ def my_adain(feat: T) -> T:
76
+ batch_size = feat.shape[0] // 2
77
+ feat_mean, feat_std = calc_mean_std(feat)
78
+ feat_uncond_content, feat_cond_content = feat[0], feat[batch_size]
79
+
80
+ feat_style_mean = torch.stack((feat_mean[1], feat_mean[batch_size + 1])).unsqueeze(
81
+ 1
82
+ )
83
+ feat_style_mean = feat_style_mean.expand(2, batch_size, *feat_mean.shape[1:])
84
+ feat_style_mean = feat_style_mean.reshape(*feat_mean.shape) # (6, D)
85
+
86
+ feat_style_std = torch.stack((feat_std[1], feat_std[batch_size + 1])).unsqueeze(1)
87
+ feat_style_std = feat_style_std.expand(2, batch_size, *feat_std.shape[1:])
88
+ feat_style_std = feat_style_std.reshape(*feat_std.shape)
89
+
90
+ feat = (feat - feat_mean) / feat_std
91
+ feat = feat * feat_style_std + feat_style_mean
92
+ feat[0] = feat_uncond_content
93
+ feat[batch_size] = feat_cond_content
94
+ return feat
95
+
96
+
97
+ class DefaultAttentionProcessor(nn.Module):
98
+
99
+ def __init__(self):
100
+ super().__init__()
101
+ # self.processor = attention_processor.AttnProcessor2_0()
102
+ self.processor = attention_processor.AttnProcessor() # for torch 1.11.0
103
+
104
+ def __call__(
105
+ self,
106
+ attn: attention_processor.Attention,
107
+ hidden_states,
108
+ encoder_hidden_states=None,
109
+ attention_mask=None,
110
+ **kwargs,
111
+ ):
112
+ return self.processor(
113
+ attn, hidden_states, encoder_hidden_states, attention_mask
114
+ )
115
+
116
+
117
+ class ArtistAttentionProcessor(DefaultAttentionProcessor):
118
+ def __init__(
119
+ self,
120
+ inject_query: bool = True,
121
+ inject_key: bool = True,
122
+ inject_value: bool = True,
123
+ use_adain: bool = False,
124
+ name: str = None,
125
+ use_content_to_style_injection=False,
126
+ ):
127
+ super().__init__()
128
+
129
+ self.inject_query = inject_query
130
+ self.inject_key = inject_key
131
+ self.inject_value = inject_value
132
+ self.share_enabled = True
133
+ self.use_adain = use_adain
134
+
135
+ self.__custom_name = name
136
+ self.content_to_style_injection = use_content_to_style_injection
137
+
138
+ def __call__(
139
+ self,
140
+ attn: Attention,
141
+ hidden_states: torch.FloatTensor,
142
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
143
+ attention_mask: Optional[torch.FloatTensor] = None,
144
+ temb: Optional[torch.FloatTensor] = None,
145
+ scale: float = 1.0,
146
+ ) -> torch.Tensor:
147
+ #######Code from original attention impl
148
+ residual = hidden_states
149
+
150
+ # args = () if USE_PEFT_BACKEND else (scale,)
151
+ args = ()
152
+
153
+ if attn.spatial_norm is not None:
154
+ hidden_states = attn.spatial_norm(hidden_states, temb)
155
+
156
+ input_ndim = hidden_states.ndim
157
+
158
+ if input_ndim == 4:
159
+ batch_size, channel, height, width = hidden_states.shape
160
+ hidden_states = hidden_states.view(
161
+ batch_size, channel, height * width
162
+ ).transpose(1, 2)
163
+
164
+ batch_size, sequence_length, _ = (
165
+ hidden_states.shape
166
+ if encoder_hidden_states is None
167
+ else encoder_hidden_states.shape
168
+ )
169
+ attention_mask = attn.prepare_attention_mask(
170
+ attention_mask, sequence_length, batch_size
171
+ )
172
+
173
+ if attn.group_norm is not None:
174
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
175
+ 1, 2
176
+ )
177
+
178
+ query = attn.to_q(hidden_states, *args)
179
+
180
+ if encoder_hidden_states is None:
181
+ encoder_hidden_states = hidden_states
182
+ elif attn.norm_cross:
183
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
184
+ encoder_hidden_states
185
+ )
186
+
187
+ key = attn.to_k(encoder_hidden_states, *args)
188
+ value = attn.to_v(encoder_hidden_states, *args)
189
+ ######## inject begins here, here we assume the style image is always the 2nd instance in batch
190
+ batch_size = query.shape[0] // 2 # divide 2 since CFG is used
191
+ if self.share_enabled and batch_size > 1: # when == 1, no need to inject,
192
+ ref_q_uncond, ref_q_cond = query[1, ...].unsqueeze(0), query[
193
+ batch_size + 1, ...
194
+ ].unsqueeze(0)
195
+ ref_k_uncond, ref_k_cond = key[1, ...].unsqueeze(0), key[
196
+ batch_size + 1, ...
197
+ ].unsqueeze(0)
198
+
199
+ ref_v_uncond, ref_v_cond = value[1, ...].unsqueeze(0), value[
200
+ batch_size + 1, ...
201
+ ].unsqueeze(0)
202
+ if self.inject_query:
203
+ if self.use_adain:
204
+ query = my_adain(query)
205
+
206
+ if self.content_to_style_injection:
207
+ content_v_uncond = value[0, ...].unsqueeze(0)
208
+ content_v_cond = value[batch_size, ...].unsqueeze(0)
209
+ query[1] = content_v_uncond
210
+ query[batch_size + 1] = content_v_cond
211
+ else:
212
+ query[2] = ref_q_uncond
213
+ query[batch_size + 2] = ref_q_cond
214
+ if self.inject_key:
215
+ if self.use_adain:
216
+ key = my_adain(key)
217
+ else:
218
+ key[2] = ref_k_uncond
219
+ key[batch_size + 2] = ref_k_cond
220
+
221
+ if self.inject_value:
222
+ if self.use_adain:
223
+ value = my_adain(value)
224
+ else:
225
+ value[2] = ref_v_uncond
226
+ value[batch_size + 2] = ref_v_cond
227
+
228
+ query = attn.head_to_batch_dim(query)
229
+ key = attn.head_to_batch_dim(key)
230
+ value = attn.head_to_batch_dim(value)
231
+
232
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
233
+
234
+ # inject here, swap the attention map
235
+ hidden_states = torch.bmm(attention_probs, value)
236
+ hidden_states = attn.batch_to_head_dim(hidden_states)
237
+
238
+ # linear proj
239
+ hidden_states = attn.to_out[0](hidden_states, *args)
240
+ # dropout
241
+ hidden_states = attn.to_out[1](hidden_states)
242
+
243
+ if input_ndim == 4:
244
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
245
+ batch_size, channel, height, width
246
+ )
247
+
248
+ if attn.residual_connection:
249
+ hidden_states = hidden_states + residual
250
+
251
+ hidden_states = hidden_states / attn.rescale_output_factor
252
+
253
+ return hidden_states
254
+
255
+
256
+ class ArtistResBlockWrapper(nn.Module):
257
+
258
+ def __init__(
259
+ self, block: resnet.ResnetBlock2D, injection_method: str, name: str = None
260
+ ):
261
+ super().__init__()
262
+ self.block = block
263
+ self.output_scale_factor = self.block.output_scale_factor
264
+ self.injection_method = injection_method
265
+ self.name = name
266
+
267
+ def forward(
268
+ self,
269
+ input_tensor: torch.FloatTensor,
270
+ temb: torch.FloatTensor,
271
+ scale: float = 1.0,
272
+ ):
273
+ if self.injection_method == "hidden":
274
+ feat = self.block(
275
+ input_tensor, temb, scale
276
+ ) # when disentangle, feat should be [recon, uncontrolled style, controlled style]
277
+ batch_size = feat.shape[0] // 2
278
+ if batch_size == 1:
279
+ return feat
280
+
281
+ # the features of the reconstruction
282
+ recon_feat_uncond, recon_feat_cond = feat[0, ...].unsqueeze(0), feat[
283
+ batch_size, ...
284
+ ].unsqueeze(0)
285
+ # residual
286
+ input_tensor = self.block.conv_shortcut(input_tensor)
287
+ input_content_uncond, input_content_cond = input_tensor[0, ...].unsqueeze(
288
+ 0
289
+ ), input_tensor[batch_size, ...].unsqueeze(0)
290
+ # since feat = (input + h) / scale
291
+ recon_feat_uncond, recon_feat_cond = (
292
+ recon_feat_uncond * self.output_scale_factor,
293
+ recon_feat_cond * self.output_scale_factor,
294
+ )
295
+ h_content_uncond, h_content_cond = (
296
+ recon_feat_uncond - input_content_uncond,
297
+ recon_feat_cond - input_content_cond,
298
+ )
299
+ # only share the h, the residual is not shared
300
+ h_shared = torch.cat(
301
+ ([h_content_uncond] * batch_size) + ([h_content_cond] * batch_size),
302
+ dim=0,
303
+ )
304
+
305
+ output_feat_shared = (input_tensor + h_shared) / self.output_scale_factor
306
+ # do not inject the feat for the 2nd instance, which is uncontrolled style
307
+ output_feat_shared[1] = feat[1]
308
+ output_feat_shared[batch_size + 1] = feat[batch_size + 1]
309
+ # uncomment to not inject content to controlled style
310
+ # output_feat_shared[2] = feat[2]
311
+ # output_feat_shared[batch_size + 2] = feat[batch_size + 2]
312
+ return output_feat_shared
313
+ else:
314
+ raise NotImplementedError(f"Unknown injection method {self.injection_method}")
315
+
316
+
317
+ class SharedResBlockWrapper(nn.Module):
318
+ def __init__(self, block: resnet.ResnetBlock2D):
319
+ super().__init__()
320
+ self.block = block
321
+ self.output_scale_factor = self.block.output_scale_factor
322
+ self.share_enabled = True
323
+
324
+ def forward(
325
+ self,
326
+ input_tensor: torch.FloatTensor,
327
+ temb: torch.FloatTensor,
328
+ scale: float = 1.0,
329
+ ):
330
+ if self.share_enabled:
331
+ feat = self.block(input_tensor, temb, scale)
332
+ batch_size = feat.shape[0] // 2
333
+ if batch_size == 1:
334
+ return feat
335
+
336
+ # the features of the reconstruction
337
+ feat_uncond, feat_cond = feat[0, ...].unsqueeze(0), feat[
338
+ batch_size, ...
339
+ ].unsqueeze(0)
340
+ # residual
341
+ input_tensor = self.block.conv_shortcut(input_tensor)
342
+ input_content_uncond, input_content_cond = input_tensor[0, ...].unsqueeze(
343
+ 0
344
+ ), input_tensor[batch_size, ...].unsqueeze(0)
345
+ # since feat = (input + h) / scale
346
+ feat_uncond, feat_cond = (
347
+ feat_uncond * self.output_scale_factor,
348
+ feat_cond * self.output_scale_factor,
349
+ )
350
+ h_content_uncond, h_content_cond = (
351
+ feat_uncond - input_content_uncond,
352
+ feat_cond - input_content_cond,
353
+ )
354
+ # only share the h, the residual is not shared
355
+ h_shared = torch.cat(
356
+ ([h_content_uncond] * batch_size) + ([h_content_cond] * batch_size),
357
+ dim=0,
358
+ )
359
+ output_shared = (input_tensor + h_shared) / self.output_scale_factor
360
+ return output_shared
361
+ else:
362
+ return self.block(input_tensor, temb, scale)
363
+
364
+
365
+
366
+
367
+ def register_attention_processors(
368
+ pipe,
369
+ base_dir: str = None,
370
+ disentangle: bool = False,
371
+ attn_mode: str = "artist",
372
+ resnet_mode: str = "hidden",
373
+ share_resblock: bool = True,
374
+ share_attn: bool = True,
375
+ share_cross_attn: bool = False,
376
+ share_attn_layers: Optional[int] = None,
377
+ share_resnet_layers: Optional[int] = None,
378
+ c2s_layers: Optional[int] = [0, 1],
379
+ share_query: bool = True,
380
+ share_key: bool = True,
381
+ share_value: bool = True,
382
+ use_adain: bool = False,
383
+ ):
384
+ unet: unet_2d_condition.UNet2DConditionModel = pipe.unet
385
+ if isinstance(pipe, StableDiffusionPipeline):
386
+ up_blocks: List[unet_2d_blocks.CrossAttnUpBlock2D] = unet.up_blocks[
387
+ 1:
388
+ ] # skip the first block, which is UpBlock2D
389
+ elif isinstance(pipe, StableDiffusionXLPipeline):
390
+ up_blocks: List[unet_2d_blocks.CrossAttnUpBlock2D] = unet.up_blocks[:-1]
391
+ layer_idx_attn = 0
392
+ layer_idx_resnet = 0
393
+ for block in up_blocks:
394
+ # each block should have 3 transformer layer
395
+ # transformer_layer : transformer_2d.Transformer2DModel
396
+ if share_resblock:
397
+ if share_resnet_layers is not None:
398
+ resnet_wrappers = []
399
+ resnets = block.resnets
400
+ for resnet_block in resnets:
401
+ if layer_idx_resnet not in share_resnet_layers:
402
+ resnet_wrappers.append(
403
+ resnet_block
404
+ ) # use original implementation
405
+ else:
406
+ if disentangle:
407
+ resnet_wrappers.append(
408
+ ArtistResBlockWrapper(
409
+ resnet_block,
410
+ injection_method=resnet_mode,
411
+ name=f"layer_{layer_idx_resnet}",
412
+ )
413
+ )
414
+ print(
415
+ f"Disentangle resnet {resnet_mode} set for layer {layer_idx_resnet}"
416
+ )
417
+ else:
418
+ resnet_wrappers.append(SharedResBlockWrapper(resnet_block))
419
+ print(
420
+ f"Share resnet feature set for layer {layer_idx_resnet}"
421
+ )
422
+
423
+ layer_idx_resnet += 1
424
+ block.resnets = nn.ModuleList(
425
+ resnet_wrappers
426
+ ) # actually apply the change
427
+ if share_attn:
428
+ for transformer_layer in block.attentions:
429
+ transformer_block: attention.BasicTransformerBlock = (
430
+ transformer_layer.transformer_blocks[0]
431
+ )
432
+ self_attn: attention_processor.Attention = transformer_block.attn1
433
+ # cross attn does not inject
434
+ cross_attn: attention_processor.Attention = transformer_block.attn2
435
+
436
+ if attn_mode == "artist":
437
+ if (
438
+ share_attn_layers is not None
439
+ and layer_idx_attn in share_attn_layers
440
+ ):
441
+ if layer_idx_attn in c2s_layers:
442
+ content_to_style = True
443
+ else:
444
+ content_to_style = False
445
+ pnp_inject_processor = ArtistAttentionProcessor(
446
+ inject_query=share_query,
447
+ inject_key=share_key,
448
+ inject_value=share_value,
449
+ use_adain=use_adain,
450
+ name=f"layer_{layer_idx_attn}_self",
451
+ use_content_to_style_injection=content_to_style,
452
+ )
453
+ self_attn.set_processor(pnp_inject_processor)
454
+ print(
455
+ f"Disentangled Pnp inject processor set for self-attention in layer {layer_idx_attn} with c2s={content_to_style}"
456
+ )
457
+ if share_cross_attn:
458
+ cross_attn_processor = ArtistAttentionProcessor(
459
+ inject_query=False,
460
+ inject_key=True,
461
+ inject_value=True,
462
+ use_adain=False,
463
+ name=f"layer_{layer_idx_attn}_cross",
464
+ )
465
+ cross_attn.set_processor(cross_attn_processor)
466
+ print(
467
+ f"Disentangled Pnp inject processor set for cross-attention in layer {layer_idx_attn}"
468
+ )
469
+ layer_idx_attn += 1
470
+
471
+
472
+ def unset_attention_processors(
473
+ pipe,
474
+ unset_share_attn: bool = False,
475
+ unset_share_resblock: bool = False,
476
+ ):
477
+ unet: unet_2d_condition.UNet2DConditionMode = pipe.unet
478
+ if isinstance(pipe, StableDiffusionPipeline):
479
+ up_blocks: List[unet_2d_blocks.CrossAttnUpBlock2D] = unet.up_blocks[
480
+ 1:
481
+ ] # skip the first block, which is UpBlock2D
482
+ elif isinstance(pipe, StableDiffusionXLPipeline):
483
+ up_blocks: List[unet_2d_blocks.CrossAttnUpBlock2D] = unet.up_blocks[:-1]
484
+ block_idx = 1
485
+ layer_idx = 0
486
+ for block in up_blocks:
487
+ if unset_share_resblock:
488
+ resnet_origs = []
489
+ resnets = block.resnets
490
+ for resnet_block in resnets:
491
+ if isinstance(resnet_block, SharedResBlockWrapper) or isinstance(
492
+ resnet_block, ArtistResBlockWrapper
493
+ ):
494
+ resnet_origs.append(resnet_block.block)
495
+ else:
496
+ resnet_origs.append(resnet_block)
497
+ block.resnets = nn.ModuleList(resnet_origs)
498
+ if unset_share_attn:
499
+ for transformer_layer in block.attentions:
500
+ layer_idx += 1
501
+ transformer_block: attention.BasicTransformerBlock = (
502
+ transformer_layer.transformer_blocks[0]
503
+ )
504
+ self_attn: attention_processor.Attention = transformer_block.attn1
505
+ cross_attn: attention_processor.Attention = transformer_block.attn2
506
+ self_attn.set_processor(DefaultAttentionProcessor())
507
+ cross_attn.set_processor(DefaultAttentionProcessor())
508
+ block_idx += 1
509
+ layer_idx = 0
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ clip==1.0
2
+ diffusers==0.26.3
3
+ einops==0.8.0
4
+ gradio==4.39.0
5
+ matplotlib==3.5.2
6
+ numpy==1.22.4
7
+ omegaconf==2.3.0
8
+ Pillow==9.1.1
9
+ Pillow==10.4.0
10
+ Requests==2.32.3
11
+ torch==1.11.0+cu113
12
+ torchvision==0.12.0+cu113
13
+ tqdm==4.61.2
utils/exp_utils.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ import os
3
+ import PIL.Image as Image
4
+ import torch
5
+ import numpy as np
6
+ from torchvision import transforms
7
+ import torch.nn.functional as F
8
+ import torchvision
9
+
10
+
11
+ def make_unique_experiment_path(base_dir: str) -> str:
12
+ """
13
+ Create a unique directory in the base directory, named as the least unused number.
14
+ return: path to the unique directory
15
+ """
16
+ if not os.path.exists(base_dir):
17
+ os.makedirs(base_dir)
18
+
19
+ # List all existing directories
20
+ existing_dirs = [
21
+ d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))
22
+ ]
23
+
24
+ # Convert directory names to integers, filter out non-numeric names
25
+ existing_numbers = sorted([int(d) for d in existing_dirs if d.isdigit()])
26
+
27
+ # Find the least unused number
28
+ experiment_id = 1
29
+ for number in existing_numbers:
30
+ if number != experiment_id:
31
+ break
32
+ experiment_id += 1
33
+
34
+ # Create the new directory
35
+ experiment_output_path = os.path.join(base_dir, str(experiment_id))
36
+ os.makedirs(experiment_output_path)
37
+
38
+ return experiment_output_path
39
+
40
+
41
+ def get_processed_image(image_dir: str, device, resolution) -> torch.Tensor:
42
+ src_img = Image.open(image_dir)
43
+ src_img = transforms.ToTensor()(src_img).unsqueeze(0).to(device)
44
+
45
+ h, w = src_img.shape[-2:]
46
+ src_img_512 = torchvision.transforms.functional.pad(
47
+ src_img, ((resolution - w) // 2,), fill=0, padding_mode="constant"
48
+ )
49
+ input_image = F.interpolate(
50
+ src_img, (resolution, resolution), mode="bilinear", align_corners=False
51
+ )
52
+ # drop alpha channel if it exists
53
+ if input_image.shape[1] == 4:
54
+ input_image = input_image[:, :3]
55
+
56
+ return input_image
57
+
58
+
59
+ def process_image(image, device, resolution) -> torch.Tensor:
60
+ if isinstance(image, np.ndarray):
61
+ image = Image.fromarray(image)
62
+ src_img = image
63
+ src_img = transforms.ToTensor()(src_img).unsqueeze(0).to(device)
64
+
65
+ h, w = src_img.shape[-2:]
66
+ src_img_512 = torchvision.transforms.functional.pad(
67
+ src_img, ((resolution - w) // 2,), fill=0, padding_mode="constant"
68
+ )
69
+ input_image = F.interpolate(
70
+ src_img, (resolution, resolution), mode="bilinear", align_corners=False
71
+ )
72
+ # drop alpha channel if it exists
73
+ if input_image.shape[1] == 4:
74
+ input_image = input_image[:, :3]
75
+
76
+ return input_image
77
+
78
+
79
+ def seed_all(seed: int):
80
+ torch.manual_seed(seed)
81
+ np.random.seed(seed)
82
+ torch.cuda.manual_seed(seed)
83
+ torch.cuda.manual_seed_all(seed)
84
+ torch.backends.cudnn.deterministic = True
85
+ torch.backends.cudnn.benchmark = False
86
+
87
+ g_cpu = torch.Generator(device="cpu")
88
+ g_cpu.manual_seed(42)
89
+
90
+
91
+ def dump_tensor(tensor, filename):
92
+ with open(filename) as f:
93
+ torch.save(tensor, f)