turboedit commited on
Commit
59f949f
โ€ข
1 Parent(s): 9304f1f

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. app.py +136 -0
  2. config.py +140 -0
  3. my_run.py +476 -0
  4. resize.py +18 -0
  5. utils.py +1356 -0
app.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import gradio as gr
4
+ import spaces
5
+ from PIL import Image
6
+ import torch
7
+
8
+ from my_run import run as run_model
9
+
10
+
11
+ DESCRIPTION = '''# Turbo Edit
12
+ '''
13
+
14
+ @spaces.GPU
15
+ def main_pipeline(
16
+ input_image: str,
17
+ src_prompt: str,
18
+ tgt_prompt: str,
19
+ seed: int,
20
+ w1: float,
21
+ # w2: float,
22
+ ):
23
+
24
+ w2 = 1.0
25
+ res_image = run_model(input_image, src_prompt, tgt_prompt, seed, w1, w2)
26
+
27
+ return res_image
28
+
29
+
30
+ with gr.Blocks(css='app/style.css') as demo:
31
+ gr.Markdown(DESCRIPTION)
32
+
33
+ gr.HTML(
34
+ '''<a href="https://huggingface.co/spaces/garibida/ReNoise-Inversion?duplicate=true">
35
+ <img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate the Space to run privately without waiting in queue''')
36
+
37
+ with gr.Row():
38
+ with gr.Column():
39
+ input_image = gr.Image(
40
+ label="Input image",
41
+ type="filepath",
42
+ height=512,
43
+ width=512
44
+ )
45
+ src_prompt = gr.Text(
46
+ label='Source Prompt',
47
+ max_lines=1,
48
+ placeholder='Source Prompt',
49
+ )
50
+ tgt_prompt = gr.Text(
51
+ label='Target Prompt',
52
+ max_lines=1,
53
+ placeholder='Target Prompt',
54
+ )
55
+ with gr.Accordion("Advanced Options", open=False):
56
+ seed = gr.Slider(
57
+ label='seed',
58
+ minimum=0,
59
+ maximum=16*1024,
60
+ value=7865,
61
+ step=1
62
+ )
63
+ w1 = gr.Slider(
64
+ label='w',
65
+ minimum=1.0,
66
+ maximum=3.0,
67
+ value=1.5,
68
+ step=0.05
69
+ )
70
+ # w2 = gr.Slider(
71
+ # label='w2',
72
+ # minimum=1.0,
73
+ # maximum=3.0,
74
+ # value=1.0,
75
+ # step=0.05
76
+ # )
77
+
78
+ run_button = gr.Button('Edit')
79
+ with gr.Column():
80
+ # result = gr.Gallery(label='Result')
81
+ result = gr.Image(
82
+ label="Result",
83
+ type="pil",
84
+ height=512,
85
+ width=512
86
+ )
87
+
88
+ examples = [
89
+ [
90
+ "demo_im/WhatsApp Image 2024-05-17 at 17.32.53.jpeg", #input_image
91
+ "a painting of a white cat sleeping on a lotus flower", #src_prompt
92
+ "a painting of a white cat sleeping on a lotus flower", #tgt_prompt
93
+ 4759, #seed
94
+ 1.0, #w1
95
+ # 1.1, #w2
96
+ ],
97
+ [
98
+ "demo_im/pexels-pixabay-458976.less.png", #input_image
99
+ "a squirrel standing in the grass", #src_prompt
100
+ "a squirrel standing in the grass", #tgt_prompt
101
+ 6128, #seed
102
+ 1.25, #w1
103
+ # 1.1, #w2
104
+ ],
105
+ ]
106
+
107
+ gr.Examples(examples=examples,
108
+ inputs=[
109
+ input_image,
110
+ src_prompt,
111
+ tgt_prompt,
112
+ seed,
113
+ w1,
114
+ # w2,
115
+ ],
116
+ outputs=[
117
+ result
118
+ ],
119
+ fn=main_pipeline,
120
+ cache_examples=True)
121
+
122
+
123
+ inputs = [
124
+ input_image,
125
+ src_prompt,
126
+ tgt_prompt,
127
+ seed,
128
+ w1,
129
+ # w2,
130
+ ]
131
+ outputs = [
132
+ result
133
+ ]
134
+ run_button.click(fn=main_pipeline, inputs=inputs, outputs=outputs)
135
+
136
+ demo.queue(max_size=50).launch(share=True, max_threads=100)
config.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ml_collections import config_dict
2
+ import yaml
3
+ from diffusers.schedulers import (
4
+ DDIMScheduler,
5
+ EulerAncestralDiscreteScheduler,
6
+ EulerDiscreteScheduler,
7
+ DDPMScheduler,
8
+ )
9
+ from utils import (
10
+ deterministic_ddim_step,
11
+ deterministic_ddpm_step,
12
+ deterministic_euler_step,
13
+ deterministic_non_ancestral_euler_step,
14
+ )
15
+
16
+ BREAKDOWNS = ["x_t_c_hat", "x_t_hat_c", "no_breakdown", "x_t_hat_c_with_zeros"]
17
+ SCHEDULERS = ["ddpm", "ddim", "euler", "euler_non_ancestral"]
18
+ MODELS = [
19
+ "stabilityai/sdxl-turbo",
20
+ "stabilityai/stable-diffusion-xl-base-1.0",
21
+ "CompVis/stable-diffusion-v1-4",
22
+ ]
23
+
24
+ def get_num_steps_actual(cfg):
25
+ return (
26
+ cfg.num_steps_inversion
27
+ - cfg.step_start
28
+ + (1 if cfg.clean_step_timestep > 0 else 0)
29
+ if cfg.timesteps is None
30
+ else len(cfg.timesteps) + (1 if cfg.clean_step_timestep > 0 else 0)
31
+ )
32
+
33
+
34
+ def get_config(args):
35
+ if args.config_from_file and args.config_from_file != "":
36
+ with open(args.config_from_file, "r") as f:
37
+ cfg = config_dict.ConfigDict(yaml.safe_load(f))
38
+
39
+ num_steps_actual = get_num_steps_actual(cfg)
40
+
41
+ else:
42
+ cfg = config_dict.ConfigDict()
43
+
44
+ cfg.seed = 2
45
+ cfg.self_r = 0.5
46
+ cfg.cross_r = 0.9
47
+ cfg.eta = 1
48
+ cfg.scheduler_type = SCHEDULERS[0]
49
+
50
+ cfg.num_steps_inversion = 50 # timesteps: 999, 799, 599, 399, 199
51
+ cfg.step_start = 20
52
+ cfg.timesteps = None
53
+ cfg.noise_timesteps = None
54
+ num_steps_actual = get_num_steps_actual(cfg)
55
+ cfg.ws1 = [2] * num_steps_actual
56
+ cfg.ws2 = [1] * num_steps_actual
57
+ cfg.real_cfg_scale = 0
58
+ cfg.real_cfg_scale_save = 0
59
+ cfg.breakdown = BREAKDOWNS[1]
60
+ cfg.noise_shift_delta = 1
61
+ cfg.max_norm_zs = [-1] * (num_steps_actual - 1) + [15.5]
62
+
63
+ cfg.clean_step_timestep = 0
64
+
65
+ cfg.model = MODELS[1]
66
+
67
+ if cfg.scheduler_type == "ddim":
68
+ cfg.scheduler_class = DDIMScheduler
69
+ cfg.step_function = deterministic_ddim_step
70
+ elif cfg.scheduler_type == "ddpm":
71
+ cfg.scheduler_class = DDPMScheduler
72
+ cfg.step_function = deterministic_ddpm_step
73
+ elif cfg.scheduler_type == "euler":
74
+ cfg.scheduler_class = EulerAncestralDiscreteScheduler
75
+ cfg.step_function = deterministic_euler_step
76
+ elif cfg.scheduler_type == "euler_non_ancestral":
77
+ cfg.scheduler_class = EulerDiscreteScheduler
78
+ cfg.step_function = deterministic_non_ancestral_euler_step
79
+ else:
80
+ raise ValueError(f"Unknown scheduler type: {cfg.scheduler_type}")
81
+
82
+ with cfg.ignore_type():
83
+ if isinstance(cfg.max_norm_zs, (int, float)):
84
+ cfg.max_norm_zs = [cfg.max_norm_zs] * num_steps_actual
85
+
86
+ if isinstance(cfg.ws1, (int, float)):
87
+ cfg.ws1 = [cfg.ws1] * num_steps_actual
88
+
89
+ if isinstance(cfg.ws2, (int, float)):
90
+ cfg.ws2 = [cfg.ws2] * num_steps_actual
91
+
92
+ if not hasattr(cfg, "update_eta"):
93
+ cfg.update_eta = False
94
+
95
+ if not hasattr(cfg, "save_timesteps"):
96
+ cfg.save_timesteps = None
97
+
98
+ if not hasattr(cfg, "scheduler_timesteps"):
99
+ cfg.scheduler_timesteps = None
100
+
101
+ assert (
102
+ cfg.scheduler_type == "ddpm" or cfg.timesteps is None
103
+ ), "timesteps must be None for ddim/euler"
104
+
105
+ assert (
106
+ len(cfg.max_norm_zs) == num_steps_actual
107
+ ), f"len(cfg.max_norm_zs) ({len(cfg.max_norm_zs)}) != num_steps_actual ({num_steps_actual})"
108
+
109
+ assert (
110
+ len(cfg.ws1) == num_steps_actual
111
+ ), f"len(cfg.ws1) ({len(cfg.ws1)}) != num_steps_actual ({num_steps_actual})"
112
+
113
+ assert (
114
+ len(cfg.ws2) == num_steps_actual
115
+ ), f"len(cfg.ws2) ({len(cfg.ws2)}) != num_steps_actual ({num_steps_actual})"
116
+
117
+ assert cfg.noise_timesteps is None or len(cfg.noise_timesteps) == (
118
+ num_steps_actual - (1 if cfg.clean_step_timestep > 0 else 0)
119
+ ), f"len(cfg.noise_timesteps) ({len(cfg.noise_timesteps)}) != num_steps_actual ({num_steps_actual})"
120
+
121
+ assert cfg.save_timesteps is None or len(cfg.save_timesteps) == (
122
+ num_steps_actual - (1 if cfg.clean_step_timestep > 0 else 0)
123
+ ), f"len(cfg.save_timesteps) ({len(cfg.save_timesteps)}) != num_steps_actual ({num_steps_actual})"
124
+
125
+ return cfg
126
+
127
+
128
+ def get_config_name(config, args):
129
+ if args.folder_name is not None and args.folder_name != "":
130
+ return args.folder_name
131
+ timesteps_str = (
132
+ f"step_start {config.step_start}"
133
+ if config.timesteps is None
134
+ else f"timesteps {config.timesteps}"
135
+ )
136
+ return f"""\
137
+ ws1 {config.ws1[0]} ws2 {config.ws2[0]} real_cfg_scale {config.real_cfg_scale} {timesteps_str} \
138
+ real_cfg_scale_save {config.real_cfg_scale_save} seed {config.seed} max_norm_zs {config.max_norm_zs[-1]} noise_shift_delta {config.noise_shift_delta} \
139
+ scheduler_type {config.scheduler_type} fp16 {args.fp16}\
140
+ """
my_run.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import AutoPipelineForImage2Image
2
+ from diffusers import DDPMScheduler
3
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img import retrieve_timesteps, retrieve_latents
4
+ from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput
5
+ import torch
6
+ from PIL import Image
7
+
8
+ num_steps_inversion = 5
9
+ strngth = 0.8
10
+ generator = None
11
+ device = "cuda"
12
+ image_path = "edit_dataset/01.jpg"
13
+ src_prompt = "butterfly perched on purple flower"
14
+ tgt_prompt = "dragonfly perched on purple flower"
15
+ ws1 = [1.5, 1.5, 1.5, 1.5]
16
+ ws2 = [1, 1, 1, 1]
17
+
18
+ def encode_image(image, pipe):
19
+ image = pipe.image_processor.preprocess(image)
20
+ image = image.to(device=device, dtype=pipeline.dtype)
21
+
22
+ if pipe.vae.config.force_upcast:
23
+ image = image.float()
24
+ pipe.vae.to(dtype=torch.float32)
25
+
26
+ if isinstance(generator, list):
27
+ init_latents = [
28
+ retrieve_latents(pipe.vae.encode(image[i : i + 1]), generator=generator[i])
29
+ for i in range(1)
30
+ ]
31
+ init_latents = torch.cat(init_latents, dim=0)
32
+ else:
33
+ init_latents = retrieve_latents(pipe.vae.encode(image), generator=generator)
34
+
35
+ if pipe.vae.config.force_upcast:
36
+ pipe.vae.to(pipeline.dtype)
37
+
38
+ init_latents = init_latents.to(pipeline.dtype)
39
+ init_latents = pipe.vae.config.scaling_factor * init_latents
40
+
41
+ return init_latents.to(dtype=torch.float16)
42
+
43
+ # def create_xts(scheduler, timesteps, x_0, noise_shift_delta=1, generator=None):
44
+ # noising_delta = noise_shift_delta * (timesteps[0] - timesteps[1])
45
+ # noise_timesteps = [timestep - int(noising_delta) for timestep in timesteps]
46
+ # noise_timesteps = noise_timesteps[:3]
47
+
48
+ # x_0_expanded = x_0.expand(len(noise_timesteps), -1, -1, -1)
49
+ # noise = torch.randn(x_0_expanded.size(), generator=generator, device="cpu", dtype=x_0.dtype).to(x_0.device)
50
+ # x_ts = scheduler.add_noise(x_0_expanded, noise, torch.IntTensor(noise_timesteps))
51
+ # x_ts = [t.unsqueeze(dim=0) for t in list(x_ts)]
52
+ # x_ts += [x_0]
53
+ # return x_ts
54
+
55
+ def deterministic_ddpm_step(
56
+ model_output: torch.FloatTensor,
57
+ timestep,
58
+ sample: torch.FloatTensor,
59
+ eta,
60
+ use_clipped_model_output,
61
+ generator,
62
+ variance_noise,
63
+ return_dict,
64
+ scheduler,
65
+ ):
66
+ """
67
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
68
+ process from the learned model outputs (most often the predicted noise).
69
+
70
+ Args:
71
+ model_output (`torch.FloatTensor`):
72
+ The direct output from learned diffusion model.
73
+ timestep (`float`):
74
+ The current discrete timestep in the diffusion chain.
75
+ sample (`torch.FloatTensor`):
76
+ A current instance of a sample created by the diffusion process.
77
+ generator (`torch.Generator`, *optional*):
78
+ A random number generator.
79
+ return_dict (`bool`, *optional*, defaults to `True`):
80
+ Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`.
81
+
82
+ Returns:
83
+ [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`:
84
+ If return_dict is `True`, [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] is returned, otherwise a
85
+ tuple is returned where the first element is the sample tensor.
86
+
87
+ """
88
+ t = timestep
89
+
90
+ prev_t = scheduler.previous_timestep(t)
91
+
92
+ if model_output.shape[1] == sample.shape[1] * 2 and scheduler.variance_type in [
93
+ "learned",
94
+ "learned_range",
95
+ ]:
96
+ model_output, predicted_variance = torch.split(
97
+ model_output, sample.shape[1], dim=1
98
+ )
99
+ else:
100
+ predicted_variance = None
101
+
102
+ # 1. compute alphas, betas
103
+ alpha_prod_t = scheduler.alphas_cumprod[t]
104
+ alpha_prod_t_prev = (
105
+ scheduler.alphas_cumprod[prev_t] if prev_t >= 0 else scheduler.one
106
+ )
107
+ beta_prod_t = 1 - alpha_prod_t
108
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
109
+ current_alpha_t = alpha_prod_t / alpha_prod_t_prev
110
+ current_beta_t = 1 - current_alpha_t
111
+
112
+ # 2. compute predicted original sample from predicted noise also called
113
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
114
+ if scheduler.config.prediction_type == "epsilon":
115
+ pred_original_sample = (
116
+ sample - beta_prod_t ** (0.5) * model_output
117
+ ) / alpha_prod_t ** (0.5)
118
+ elif scheduler.config.prediction_type == "sample":
119
+ pred_original_sample = model_output
120
+ elif scheduler.config.prediction_type == "v_prediction":
121
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (
122
+ beta_prod_t**0.5
123
+ ) * model_output
124
+ else:
125
+ raise ValueError(
126
+ f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, `sample` or"
127
+ " `v_prediction` for the DDPMScheduler."
128
+ )
129
+
130
+ # 3. Clip or threshold "predicted x_0"
131
+ if scheduler.config.thresholding:
132
+ pred_original_sample = scheduler._threshold_sample(pred_original_sample)
133
+ elif scheduler.config.clip_sample:
134
+ pred_original_sample = pred_original_sample.clamp(
135
+ -scheduler.config.clip_sample_range, scheduler.config.clip_sample_range
136
+ )
137
+
138
+ # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
139
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
140
+ pred_original_sample_coeff = (
141
+ alpha_prod_t_prev ** (0.5) * current_beta_t
142
+ ) / beta_prod_t
143
+ current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
144
+
145
+ # 5. Compute predicted previous sample ยต_t
146
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
147
+ pred_prev_sample = (
148
+ pred_original_sample_coeff * pred_original_sample
149
+ + current_sample_coeff * sample
150
+ )
151
+
152
+ return pred_prev_sample
153
+
154
+ def normalize(
155
+ z_t,
156
+ i,
157
+ max_norm_zs,
158
+ ):
159
+ max_norm = max_norm_zs[i]
160
+ if max_norm < 0:
161
+ return z_t, 1
162
+
163
+ norm = torch.norm(z_t)
164
+ if norm < max_norm:
165
+ return z_t, 1
166
+
167
+ coeff = max_norm / norm
168
+ z_t = z_t * coeff
169
+ return z_t, coeff
170
+
171
+ def step_save_latents(
172
+ self,
173
+ model_output: torch.FloatTensor,
174
+ timestep: int,
175
+ sample: torch.FloatTensor,
176
+ eta: float = 0.0,
177
+ use_clipped_model_output: bool = False,
178
+ generator=None,
179
+ variance_noise= None,
180
+ return_dict: bool = True,
181
+ ):
182
+
183
+ timestep_index = self._inner_index
184
+ next_timestep_index = timestep_index + 1
185
+ u_hat_t = deterministic_ddpm_step(
186
+ model_output=model_output,
187
+ timestep=timestep,
188
+ sample=sample,
189
+ eta=eta,
190
+ use_clipped_model_output=use_clipped_model_output,
191
+ generator=generator,
192
+ variance_noise=variance_noise,
193
+ return_dict=False,
194
+ scheduler=self,
195
+ )
196
+ x_t_minus_1 = self.x_ts[timestep_index]
197
+ self.x_ts_c_hat.append(u_hat_t)
198
+
199
+ z_t = x_t_minus_1 - u_hat_t
200
+ self.latents.append(z_t)
201
+
202
+ z_t, _ = normalize(z_t, timestep_index, [-1, -1, -1, 15.5])
203
+ x_t_minus_1_predicted = u_hat_t + z_t
204
+
205
+ if not return_dict:
206
+ return (x_t_minus_1_predicted,)
207
+
208
+ return DDIMSchedulerOutput(prev_sample=x_t_minus_1, pred_original_sample=None)
209
+
210
+ def step_use_latents(
211
+ self,
212
+ model_output: torch.FloatTensor,
213
+ timestep: int,
214
+ sample: torch.FloatTensor,
215
+ eta: float = 0.0,
216
+ use_clipped_model_output: bool = False,
217
+ generator=None,
218
+ variance_noise= None,
219
+ return_dict: bool = True,
220
+ ):
221
+ print(f'_inner_index: {self._inner_index}')
222
+ timestep_index = self._inner_index
223
+ next_timestep_index = timestep_index + 1
224
+ z_t = self.latents[timestep_index] # + 1 because latents[0] is X_T
225
+
226
+ _, normalize_coefficient = normalize(
227
+ z_t,
228
+ timestep_index,
229
+ [-1, -1, -1, 15.5],
230
+ )
231
+
232
+ if normalize_coefficient == 0:
233
+ eta = 0
234
+
235
+ # eta = normalize_coefficient
236
+
237
+ x_t_hat_c_hat = deterministic_ddpm_step(
238
+ model_output=model_output,
239
+ timestep=timestep,
240
+ sample=sample,
241
+ eta=eta,
242
+ use_clipped_model_output=use_clipped_model_output,
243
+ generator=generator,
244
+ variance_noise=variance_noise,
245
+ return_dict=False,
246
+ scheduler=self,
247
+ )
248
+
249
+ w1 = ws1[timestep_index]
250
+ w2 = ws2[timestep_index]
251
+
252
+ x_t_minus_1_exact = self.x_ts[timestep_index]
253
+ x_t_minus_1_exact = x_t_minus_1_exact.expand_as(x_t_hat_c_hat)
254
+
255
+ x_t_c_hat: torch.Tensor = self.x_ts_c_hat[timestep_index]
256
+
257
+ x_t_c = x_t_c_hat[0].expand_as(x_t_hat_c_hat)
258
+
259
+ zero_index_reconstruction = 0
260
+ edit_prompts_num = (model_output.size(0) - zero_index_reconstruction) // 2
261
+ x_t_hat_c_indices = (zero_index_reconstruction, edit_prompts_num + zero_index_reconstruction)
262
+ edit_images_indices = (
263
+ edit_prompts_num + zero_index_reconstruction,
264
+ model_output.size(0)
265
+ )
266
+ x_t_hat_c = torch.zeros_like(x_t_hat_c_hat)
267
+ x_t_hat_c[edit_images_indices[0] : edit_images_indices[1]] = x_t_hat_c_hat[
268
+ x_t_hat_c_indices[0] : x_t_hat_c_indices[1]
269
+ ]
270
+ v1 = x_t_hat_c_hat - x_t_hat_c
271
+ v2 = x_t_hat_c - normalize_coefficient * x_t_c
272
+
273
+ x_t_minus_1 = normalize_coefficient * x_t_minus_1_exact + w1 * v1 + w2 * v2
274
+
275
+ x_t_minus_1[x_t_hat_c_indices[0] : x_t_hat_c_indices[1]] = x_t_minus_1[
276
+ edit_images_indices[0] : edit_images_indices[1]
277
+ ] # update x_t_hat_c to be x_t_hat_c_hat
278
+
279
+
280
+ if not return_dict:
281
+ return (x_t_minus_1,)
282
+
283
+ return DDIMSchedulerOutput(
284
+ prev_sample=x_t_minus_1,
285
+ pred_original_sample=None,
286
+ )
287
+
288
+
289
+ class myDDPMScheduler(DDPMScheduler):
290
+ def step(
291
+ self,
292
+ model_output: torch.FloatTensor,
293
+ timestep: int,
294
+ sample: torch.FloatTensor,
295
+ eta: float = 0.0,
296
+ use_clipped_model_output: bool = False,
297
+ generator=None,
298
+ variance_noise= None,
299
+ return_dict: bool = True,
300
+ ):
301
+ print(f"timestep: {timestep}")
302
+
303
+ res_inv = step_save_latents(
304
+ self,
305
+ model_output[:1, :, :, :],
306
+ timestep,
307
+ sample[:1, :, :, :],
308
+ eta,
309
+ use_clipped_model_output,
310
+ generator,
311
+ variance_noise,
312
+ return_dict,
313
+ )
314
+
315
+ res_inf = step_use_latents(
316
+ self,
317
+ model_output[1:, :, :, :],
318
+ timestep,
319
+ sample[1:, :, :, :],
320
+ eta,
321
+ use_clipped_model_output,
322
+ generator,
323
+ variance_noise,
324
+ return_dict,
325
+ )
326
+
327
+ self._inner_index+=1
328
+
329
+ res = (torch.cat((res_inv[0], res_inf[0]), dim=0),)
330
+ return res
331
+
332
+
333
+ pipeline = AutoPipelineForImage2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", safety_checker = None)
334
+ pipeline = pipeline.to("cuda")
335
+ pipeline.scheduler = DDPMScheduler.from_pretrained( # type: ignore
336
+ 'stabilityai/sdxl-turbo',
337
+ subfolder="scheduler",
338
+ # cache_dir="/home/joberant/NLP_2223/giladd/test_dir/sdxl-turbo/models_cache",
339
+ )
340
+ # pipeline.scheduler = DDPMScheduler.from_config(pipeline.scheduler.config)
341
+
342
+ denoising_start = 0.2
343
+ timesteps, num_inference_steps = retrieve_timesteps(
344
+ pipeline.scheduler, num_steps_inversion, device, None
345
+ )
346
+ timesteps, num_inference_steps = pipeline.get_timesteps(
347
+ num_inference_steps=num_inference_steps,
348
+ device=device,
349
+ denoising_start=denoising_start,
350
+ strength=0,
351
+ )
352
+ timesteps = timesteps.type(torch.int64)
353
+ from functools import partial
354
+
355
+ timesteps = [torch.tensor(t) for t in timesteps.tolist()]
356
+ pipeline.__call__ = partial(
357
+ pipeline.__call__,
358
+ num_inference_steps=num_steps_inversion,
359
+ guidance_scale=0,
360
+ generator=generator,
361
+ denoising_start=denoising_start,
362
+ strength=0,
363
+ )
364
+
365
+ # timesteps, num_inference_steps = retrieve_timesteps(pipeline.scheduler, num_steps_inversion, device, None)
366
+ # timesteps, num_inference_steps = pipeline.get_timesteps(num_inference_steps=num_inference_steps, device=device, strength=strngth)
367
+
368
+
369
+ from utils import get_ddpm_inversion_scheduler, create_xts
370
+
371
+
372
+
373
+ from config import get_config, get_config_name
374
+ import argparse
375
+
376
+ # parser = argparse.ArgumentParser()
377
+ # parser.add_argument("--images_paths", type=str, default=None)
378
+ # parser.add_argument("--images_folder", type=str, default=None)
379
+ # parser.set_defaults(force_use_cpu=False)
380
+ # parser.add_argument("--force_use_cpu", action="store_true")
381
+ # parser.add_argument("--folder_name", type=str, default='test_measure_time')
382
+ # parser.add_argument("--config_from_file", type=str, default='run_configs/noise_shift_guidance_1_5.yaml')
383
+ # parser.set_defaults(save_intermediate_results=False)
384
+ # parser.add_argument("--save_intermediate_results", action="store_true")
385
+ # parser.add_argument("--batch_size", type=int, default=None)
386
+ # parser.set_defaults(skip_p_to_p=False)
387
+ # parser.add_argument("--skip_p_to_p", action="store_true", default=True)
388
+ # parser.set_defaults(only_p_to_p=False)
389
+ # parser.add_argument("--only_p_to_p", action="store_true")
390
+ # parser.set_defaults(fp16=False)
391
+ # parser.add_argument("--fp16", action="store_true", default=False)
392
+ # parser.add_argument("--prompts_file", type=str, default='dataset_measure_time/dataset.json')
393
+ # parser.add_argument("--images_in_prompts_file", type=str, default=None)
394
+ # parser.add_argument("--seed", type=int, default=2)
395
+ # parser.add_argument("--time_measure_n", type=int, default=1)
396
+
397
+ # args = parser.parse_args()
398
+ class Object(object):
399
+ pass
400
+
401
+ args = Object()
402
+ args.images_paths = None
403
+ args.images_folder = None
404
+ args.force_use_cpu = False
405
+ args.folder_name = 'test_measure_time'
406
+ args.config_from_file = 'run_configs/noise_shift_guidance_1_5.yaml'
407
+ args.save_intermediate_results = False
408
+ args.batch_size = None
409
+ args.skip_p_to_p = True
410
+ args.only_p_to_p = False
411
+ args.fp16 = False
412
+ args.prompts_file = 'dataset_measure_time/dataset.json'
413
+ args.images_in_prompts_file = None
414
+ args.seed = 986
415
+ args.time_measure_n = 1
416
+
417
+
418
+ assert (
419
+ args.batch_size is None or args.save_intermediate_results is False
420
+ ), "save_intermediate_results is not implemented for batch_size > 1"
421
+
422
+ config = get_config(args)
423
+
424
+
425
+
426
+
427
+
428
+ # latent = latents[0].expand(3, -1, -1, -1)
429
+ # prompt = [src_prompt, src_prompt, tgt_prompt]
430
+
431
+ # image = pipeline.__call__(image=latent, prompt=prompt, eta=1).images
432
+
433
+ # for i, im in enumerate(image):
434
+ # im.save(f"output_{i}.png")
435
+
436
+ def run(image_path, src_prompt, tgt_prompt, seed, w1, w2):
437
+ generator = torch.Generator().manual_seed(seed)
438
+ x_0_image = Image.open(image_path).convert("RGB").resize((512, 512), Image.LANCZOS)
439
+ x_0 = encode_image(x_0_image, pipeline)
440
+ # x_ts = create_xts(pipeline.scheduler, timesteps, x_0, noise_shift_delta=1, generator=generator)
441
+ x_ts = create_xts(1, None, 0, generator, pipeline.scheduler, timesteps, x_0, no_add_noise=False)
442
+ x_ts = [xt.to(dtype=torch.float16) for xt in x_ts]
443
+ latents = [x_ts[0]]
444
+ x_ts_c_hat = [None]
445
+ config.ws1 = [w1] * 4
446
+ config.ws2 = [w2] * 4
447
+ pipeline.scheduler = get_ddpm_inversion_scheduler(
448
+ pipeline.scheduler,
449
+ config.step_function,
450
+ config,
451
+ timesteps,
452
+ config.save_timesteps,
453
+ latents,
454
+ x_ts,
455
+ x_ts_c_hat,
456
+ args.save_intermediate_results,
457
+ pipeline,
458
+ x_0,
459
+ v1s_images := [],
460
+ v2s_images := [],
461
+ deltas_images := [],
462
+ v1_x0s := [],
463
+ v2_x0s := [],
464
+ deltas_x0s := [],
465
+ "res12",
466
+ image_name="im_name",
467
+ time_measure_n=args.time_measure_n,
468
+ )
469
+ latent = latents[0].expand(3, -1, -1, -1)
470
+ prompt = [src_prompt, src_prompt, tgt_prompt]
471
+ image = pipeline.__call__(image=latent, prompt=prompt, eta=1).images
472
+ return image[2]
473
+
474
+ if __name__ == "__main__":
475
+ res = run(image_path, src_prompt, tgt_prompt, args.seed, 1.5, 1.0)
476
+ res.save("output.png")
resize.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+
3
+ def resize_image(input_path, output_path, new_size):
4
+ # Open the image
5
+ image = Image.open(input_path)
6
+
7
+ # Resize the image
8
+ resized_image = image.resize(new_size)
9
+
10
+ # Save the resized image
11
+ resized_image.save(output_path)
12
+
13
+ # Example usage
14
+ input_path = "demo_im/pexels-pixabay-458976.png"
15
+ output_path = "demo_im/pexels-pixabay-458976.less.png"
16
+ new_size = (512, 512)
17
+
18
+ resize_image(input_path, output_path, new_size)
utils.py ADDED
@@ -0,0 +1,1356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ from typing import List, Optional, Union
3
+ import PIL
4
+ import PIL.Image
5
+ import torch
6
+ from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput
7
+ from diffusers.utils import make_image_grid
8
+ from PIL import Image, ImageDraw, ImageFont
9
+ import os
10
+ from diffusers.utils import (
11
+ logging,
12
+ USE_PEFT_BACKEND,
13
+ scale_lora_layers,
14
+ unscale_lora_layers,
15
+ )
16
+ from diffusers.loaders import (
17
+ StableDiffusionXLLoraLoaderMixin,
18
+ )
19
+ from diffusers.image_processor import VaeImageProcessor
20
+
21
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
22
+
23
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
24
+ from diffusers import DiffusionPipeline
25
+
26
+
27
+ VECTOR_DATA_FOLDER = "vector_data"
28
+ VECTOR_DATA_DICT = "vector_data"
29
+
30
+
31
+ def encode_image(image: PIL.Image, pipe: DiffusionPipeline):
32
+ pipe.image_processor: VaeImageProcessor = pipe.image_processor # type: ignore
33
+ image = pipe.image_processor.pil_to_numpy(image)
34
+ image = pipe.image_processor.numpy_to_pt(image)
35
+ image = image.to(pipe.device)
36
+ return (
37
+ pipe.vae.encode(
38
+ pipe.image_processor.preprocess(image),
39
+ ).latent_dist.mode()
40
+ * pipe.vae.config.scaling_factor
41
+ )
42
+
43
+
44
+ def decode_latents(latent, pipe):
45
+ latent_img = pipe.vae.decode(
46
+ latent / pipe.vae.config.scaling_factor, return_dict=False
47
+ )[0]
48
+ return pipe.image_processor.postprocess(latent_img, output_type="pil")
49
+
50
+
51
+ def get_device(argv, args=None):
52
+ import sys
53
+
54
+ def debugger_is_active():
55
+ return hasattr(sys, "gettrace") and sys.gettrace() is not None
56
+
57
+ if args:
58
+ return (
59
+ torch.device("cuda")
60
+ if (torch.cuda.is_available() and not debugger_is_active())
61
+ and not args.force_use_cpu
62
+ else torch.device("cpu")
63
+ )
64
+
65
+ return (
66
+ torch.device("cuda")
67
+ if (torch.cuda.is_available() and not debugger_is_active())
68
+ and not "cpu" in set(argv[1:])
69
+ else torch.device("cpu")
70
+ )
71
+
72
+
73
+ def deterministic_ddim_step(
74
+ model_output: torch.FloatTensor,
75
+ timestep: int,
76
+ sample: torch.FloatTensor,
77
+ eta: float = 0.0,
78
+ use_clipped_model_output: bool = False,
79
+ generator=None,
80
+ variance_noise: Optional[torch.FloatTensor] = None,
81
+ return_dict: bool = True,
82
+ scheduler=None,
83
+ ):
84
+
85
+ if scheduler.num_inference_steps is None:
86
+ raise ValueError(
87
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
88
+ )
89
+
90
+ prev_timestep = (
91
+ timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
92
+ )
93
+
94
+ # 2. compute alphas, betas
95
+ alpha_prod_t = scheduler.alphas_cumprod[timestep]
96
+ alpha_prod_t_prev = (
97
+ scheduler.alphas_cumprod[prev_timestep]
98
+ if prev_timestep >= 0
99
+ else scheduler.final_alpha_cumprod
100
+ )
101
+
102
+ beta_prod_t = 1 - alpha_prod_t
103
+
104
+ if scheduler.config.prediction_type == "epsilon":
105
+ pred_original_sample = (
106
+ sample - beta_prod_t ** (0.5) * model_output
107
+ ) / alpha_prod_t ** (0.5)
108
+ pred_epsilon = model_output
109
+ elif scheduler.config.prediction_type == "sample":
110
+ pred_original_sample = model_output
111
+ pred_epsilon = (
112
+ sample - alpha_prod_t ** (0.5) * pred_original_sample
113
+ ) / beta_prod_t ** (0.5)
114
+ elif scheduler.config.prediction_type == "v_prediction":
115
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (
116
+ beta_prod_t**0.5
117
+ ) * model_output
118
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
119
+ else:
120
+ raise ValueError(
121
+ f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, `sample`, or"
122
+ " `v_prediction`"
123
+ )
124
+
125
+ # 4. Clip or threshold "predicted x_0"
126
+ if scheduler.config.thresholding:
127
+ pred_original_sample = scheduler._threshold_sample(pred_original_sample)
128
+ elif scheduler.config.clip_sample:
129
+ pred_original_sample = pred_original_sample.clamp(
130
+ -scheduler.config.clip_sample_range,
131
+ scheduler.config.clip_sample_range,
132
+ )
133
+
134
+ # 5. compute variance: "sigma_t(ฮท)" -> see formula (16)
135
+ # ฯƒ_t = sqrt((1 โˆ’ ฮฑ_tโˆ’1)/(1 โˆ’ ฮฑ_t)) * sqrt(1 โˆ’ ฮฑ_t/ฮฑ_tโˆ’1)
136
+ variance = scheduler._get_variance(timestep, prev_timestep)
137
+ std_dev_t = eta * variance ** (0.5)
138
+
139
+ if use_clipped_model_output:
140
+ # the pred_epsilon is always re-derived from the clipped x_0 in Glide
141
+ pred_epsilon = (
142
+ sample - alpha_prod_t ** (0.5) * pred_original_sample
143
+ ) / beta_prod_t ** (0.5)
144
+
145
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
146
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (
147
+ 0.5
148
+ ) * pred_epsilon
149
+
150
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
151
+ prev_sample = (
152
+ alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
153
+ )
154
+ return prev_sample
155
+
156
+
157
+ def deterministic_euler_step(
158
+ model_output: torch.FloatTensor,
159
+ timestep: Union[float, torch.FloatTensor],
160
+ sample: torch.FloatTensor,
161
+ eta,
162
+ use_clipped_model_output,
163
+ generator,
164
+ variance_noise,
165
+ return_dict,
166
+ scheduler,
167
+ ):
168
+ """
169
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
170
+ process from the learned model outputs (most often the predicted noise).
171
+
172
+ Args:
173
+ model_output (`torch.FloatTensor`):
174
+ The direct output from learned diffusion model.
175
+ timestep (`float`):
176
+ The current discrete timestep in the diffusion chain.
177
+ sample (`torch.FloatTensor`):
178
+ A current instance of a sample created by the diffusion process.
179
+ generator (`torch.Generator`, *optional*):
180
+ A random number generator.
181
+ return_dict (`bool`):
182
+ Whether or not to return a
183
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
184
+
185
+ Returns:
186
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
187
+ If return_dict is `True`,
188
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
189
+ otherwise a tuple is returned where the first element is the sample tensor.
190
+
191
+ """
192
+
193
+ if (
194
+ isinstance(timestep, int)
195
+ or isinstance(timestep, torch.IntTensor)
196
+ or isinstance(timestep, torch.LongTensor)
197
+ ):
198
+ raise ValueError(
199
+ (
200
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
201
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
202
+ " one of the `scheduler.timesteps` as a timestep."
203
+ ),
204
+ )
205
+
206
+ if scheduler.step_index is None:
207
+ scheduler._init_step_index(timestep)
208
+
209
+ sigma = scheduler.sigmas[scheduler.step_index]
210
+
211
+ # Upcast to avoid precision issues when computing prev_sample
212
+ sample = sample.to(torch.float32)
213
+
214
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
215
+ if scheduler.config.prediction_type == "epsilon":
216
+ pred_original_sample = sample - sigma * model_output
217
+ elif scheduler.config.prediction_type == "v_prediction":
218
+ # * c_out + input * c_skip
219
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (
220
+ sample / (sigma**2 + 1)
221
+ )
222
+ elif scheduler.config.prediction_type == "sample":
223
+ raise NotImplementedError("prediction_type not implemented yet: sample")
224
+ else:
225
+ raise ValueError(
226
+ f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
227
+ )
228
+
229
+ sigma_from = scheduler.sigmas[scheduler.step_index]
230
+ sigma_to = scheduler.sigmas[scheduler.step_index + 1]
231
+ sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
232
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
233
+
234
+ # 2. Convert to an ODE derivative
235
+ derivative = (sample - pred_original_sample) / sigma
236
+
237
+ dt = sigma_down - sigma
238
+
239
+ prev_sample = sample + derivative * dt
240
+
241
+ # Cast sample back to model compatible dtype
242
+ prev_sample = prev_sample.to(model_output.dtype)
243
+
244
+ # upon completion increase step index by one
245
+ scheduler._step_index += 1
246
+
247
+ return prev_sample
248
+
249
+
250
+ def deterministic_non_ancestral_euler_step(
251
+ model_output: torch.FloatTensor,
252
+ timestep: Union[float, torch.FloatTensor],
253
+ sample: torch.FloatTensor,
254
+ eta: float = 0.0,
255
+ use_clipped_model_output: bool = False,
256
+ s_churn: float = 0.0,
257
+ s_tmin: float = 0.0,
258
+ s_tmax: float = float("inf"),
259
+ s_noise: float = 1.0,
260
+ generator: Optional[torch.Generator] = None,
261
+ variance_noise: Optional[torch.FloatTensor] = None,
262
+ return_dict: bool = True,
263
+ scheduler=None,
264
+ ):
265
+ """
266
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
267
+ process from the learned model outputs (most often the predicted noise).
268
+
269
+ Args:
270
+ model_output (`torch.FloatTensor`):
271
+ The direct output from learned diffusion model.
272
+ timestep (`float`):
273
+ The current discrete timestep in the diffusion chain.
274
+ sample (`torch.FloatTensor`):
275
+ A current instance of a sample created by the diffusion process.
276
+ s_churn (`float`):
277
+ s_tmin (`float`):
278
+ s_tmax (`float`):
279
+ s_noise (`float`, defaults to 1.0):
280
+ Scaling factor for noise added to the sample.
281
+ generator (`torch.Generator`, *optional*):
282
+ A random number generator.
283
+ return_dict (`bool`):
284
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
285
+ tuple.
286
+
287
+ Returns:
288
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
289
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
290
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
291
+ """
292
+
293
+ if (
294
+ isinstance(timestep, int)
295
+ or isinstance(timestep, torch.IntTensor)
296
+ or isinstance(timestep, torch.LongTensor)
297
+ ):
298
+ raise ValueError(
299
+ (
300
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
301
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
302
+ " one of the `scheduler.timesteps` as a timestep."
303
+ ),
304
+ )
305
+
306
+ if not scheduler.is_scale_input_called:
307
+ logger.warning(
308
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
309
+ "See `StableDiffusionPipeline` for a usage example."
310
+ )
311
+
312
+ if scheduler.step_index is None:
313
+ scheduler._init_step_index(timestep)
314
+
315
+ # Upcast to avoid precision issues when computing prev_sample
316
+ sample = sample.to(torch.float32)
317
+
318
+ sigma = scheduler.sigmas[scheduler.step_index]
319
+
320
+ gamma = (
321
+ min(s_churn / (len(scheduler.sigmas) - 1), 2**0.5 - 1)
322
+ if s_tmin <= sigma <= s_tmax
323
+ else 0.0
324
+ )
325
+
326
+ sigma_hat = sigma * (gamma + 1)
327
+
328
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
329
+ # NOTE: "original_sample" should not be an expected prediction_type but is left in for
330
+ # backwards compatibility
331
+ if (
332
+ scheduler.config.prediction_type == "original_sample"
333
+ or scheduler.config.prediction_type == "sample"
334
+ ):
335
+ pred_original_sample = model_output
336
+ elif scheduler.config.prediction_type == "epsilon":
337
+ pred_original_sample = sample - sigma_hat * model_output
338
+ elif scheduler.config.prediction_type == "v_prediction":
339
+ # denoised = model_output * c_out + input * c_skip
340
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (
341
+ sample / (sigma**2 + 1)
342
+ )
343
+ else:
344
+ raise ValueError(
345
+ f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
346
+ )
347
+
348
+ # 2. Convert to an ODE derivative
349
+ derivative = (sample - pred_original_sample) / sigma_hat
350
+
351
+ dt = scheduler.sigmas[scheduler.step_index + 1] - sigma_hat
352
+
353
+ prev_sample = sample + derivative * dt
354
+
355
+ # Cast sample back to model compatible dtype
356
+ prev_sample = prev_sample.to(model_output.dtype)
357
+
358
+ # upon completion increase step index by one
359
+ scheduler._step_index += 1
360
+
361
+ return prev_sample
362
+
363
+
364
+ def deterministic_ddpm_step(
365
+ model_output: torch.FloatTensor,
366
+ timestep: Union[float, torch.FloatTensor],
367
+ sample: torch.FloatTensor,
368
+ eta,
369
+ use_clipped_model_output,
370
+ generator,
371
+ variance_noise,
372
+ return_dict,
373
+ scheduler,
374
+ ):
375
+ """
376
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
377
+ process from the learned model outputs (most often the predicted noise).
378
+
379
+ Args:
380
+ model_output (`torch.FloatTensor`):
381
+ The direct output from learned diffusion model.
382
+ timestep (`float`):
383
+ The current discrete timestep in the diffusion chain.
384
+ sample (`torch.FloatTensor`):
385
+ A current instance of a sample created by the diffusion process.
386
+ generator (`torch.Generator`, *optional*):
387
+ A random number generator.
388
+ return_dict (`bool`, *optional*, defaults to `True`):
389
+ Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`.
390
+
391
+ Returns:
392
+ [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`:
393
+ If return_dict is `True`, [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] is returned, otherwise a
394
+ tuple is returned where the first element is the sample tensor.
395
+
396
+ """
397
+ t = timestep
398
+
399
+ prev_t = scheduler.previous_timestep(t)
400
+
401
+ if model_output.shape[1] == sample.shape[1] * 2 and scheduler.variance_type in [
402
+ "learned",
403
+ "learned_range",
404
+ ]:
405
+ model_output, predicted_variance = torch.split(
406
+ model_output, sample.shape[1], dim=1
407
+ )
408
+ else:
409
+ predicted_variance = None
410
+
411
+ # 1. compute alphas, betas
412
+ alpha_prod_t = scheduler.alphas_cumprod[t]
413
+ alpha_prod_t_prev = (
414
+ scheduler.alphas_cumprod[prev_t] if prev_t >= 0 else scheduler.one
415
+ )
416
+ beta_prod_t = 1 - alpha_prod_t
417
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
418
+ current_alpha_t = alpha_prod_t / alpha_prod_t_prev
419
+ current_beta_t = 1 - current_alpha_t
420
+
421
+ # 2. compute predicted original sample from predicted noise also called
422
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
423
+ if scheduler.config.prediction_type == "epsilon":
424
+ pred_original_sample = (
425
+ sample - beta_prod_t ** (0.5) * model_output
426
+ ) / alpha_prod_t ** (0.5)
427
+ elif scheduler.config.prediction_type == "sample":
428
+ pred_original_sample = model_output
429
+ elif scheduler.config.prediction_type == "v_prediction":
430
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (
431
+ beta_prod_t**0.5
432
+ ) * model_output
433
+ else:
434
+ raise ValueError(
435
+ f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, `sample` or"
436
+ " `v_prediction` for the DDPMScheduler."
437
+ )
438
+
439
+ # 3. Clip or threshold "predicted x_0"
440
+ if scheduler.config.thresholding:
441
+ pred_original_sample = scheduler._threshold_sample(pred_original_sample)
442
+ elif scheduler.config.clip_sample:
443
+ pred_original_sample = pred_original_sample.clamp(
444
+ -scheduler.config.clip_sample_range, scheduler.config.clip_sample_range
445
+ )
446
+
447
+ # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
448
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
449
+ pred_original_sample_coeff = (
450
+ alpha_prod_t_prev ** (0.5) * current_beta_t
451
+ ) / beta_prod_t
452
+ current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
453
+
454
+ # 5. Compute predicted previous sample ยต_t
455
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
456
+ pred_prev_sample = (
457
+ pred_original_sample_coeff * pred_original_sample
458
+ + current_sample_coeff * sample
459
+ )
460
+
461
+ return pred_prev_sample
462
+
463
+
464
+ def normalize(
465
+ z_t,
466
+ i,
467
+ max_norm_zs,
468
+ ):
469
+ max_norm = max_norm_zs[i]
470
+ if max_norm < 0:
471
+ return z_t, 1
472
+
473
+ norm = torch.norm(z_t)
474
+ if norm < max_norm:
475
+ return z_t, 1
476
+
477
+ coeff = max_norm / norm
478
+ z_t = z_t * coeff
479
+ return z_t, coeff
480
+
481
+
482
+ def find_index(timesteps, timestep):
483
+ for i, t in enumerate(timesteps):
484
+ if t == timestep:
485
+ return i
486
+ return -1
487
+
488
+ map_timpstep_to_index = {
489
+ torch.tensor(799): 0,
490
+ torch.tensor(599): 1,
491
+ torch.tensor(399): 2,
492
+ torch.tensor(199): 3,
493
+ torch.tensor(799, device='cuda:0'): 0,
494
+ torch.tensor(599, device='cuda:0'): 1,
495
+ torch.tensor(399, device='cuda:0'): 2,
496
+ torch.tensor(199, device='cuda:0'): 3,
497
+ }
498
+
499
+ def step_save_latents(
500
+ self,
501
+ model_output: torch.FloatTensor,
502
+ timestep: int,
503
+ sample: torch.FloatTensor,
504
+ eta: float = 0.0,
505
+ use_clipped_model_output: bool = False,
506
+ generator=None,
507
+ variance_noise: Optional[torch.FloatTensor] = None,
508
+ return_dict: bool = True,
509
+ ):
510
+ # print(self._save_timesteps)
511
+ # timestep_index = map_timpstep_to_index[timestep]
512
+ # timestep_index = ((self._save_timesteps == timestep).nonzero(as_tuple=True)[0]).item()
513
+ timestep_index = self._save_timesteps.index(timestep) if not self.clean_step_run else -1
514
+ next_timestep_index = timestep_index + 1 if not self.clean_step_run else -1
515
+ u_hat_t = self.step_function(
516
+ model_output=model_output,
517
+ timestep=timestep,
518
+ sample=sample,
519
+ eta=eta,
520
+ use_clipped_model_output=use_clipped_model_output,
521
+ generator=generator,
522
+ variance_noise=variance_noise,
523
+ return_dict=False,
524
+ scheduler=self,
525
+ )
526
+
527
+ x_t_minus_1 = self.x_ts[next_timestep_index]
528
+ self.x_ts_c_hat.append(u_hat_t)
529
+
530
+ z_t = x_t_minus_1 - u_hat_t
531
+ self.latents.append(z_t)
532
+
533
+ z_t, _ = normalize(z_t, timestep_index, self._config.max_norm_zs)
534
+
535
+ x_t_minus_1_predicted = u_hat_t + z_t
536
+
537
+ if not return_dict:
538
+ return (x_t_minus_1_predicted,)
539
+
540
+ return DDIMSchedulerOutput(prev_sample=x_t_minus_1, pred_original_sample=None)
541
+
542
+
543
+ def step_use_latents(
544
+ self,
545
+ model_output: torch.FloatTensor,
546
+ timestep: int,
547
+ sample: torch.FloatTensor,
548
+ eta: float = 0.0,
549
+ use_clipped_model_output: bool = False,
550
+ generator=None,
551
+ variance_noise: Optional[torch.FloatTensor] = None,
552
+ return_dict: bool = True,
553
+ ):
554
+ # timestep_index = ((self._save_timesteps == timestep).nonzero(as_tuple=True)[0]).item()
555
+ timestep_index = self._timesteps.index(timestep) if not self.clean_step_run else -1
556
+ next_timestep_index = (
557
+ timestep_index + 1 if not self.clean_step_run else -1
558
+ )
559
+ z_t = self.latents[next_timestep_index] # + 1 because latents[0] is X_T
560
+
561
+ _, normalize_coefficient = normalize(
562
+ z_t[0] if self._config.breakdown == "x_t_hat_c_with_zeros" else z_t,
563
+ timestep_index,
564
+ self._config.max_norm_zs,
565
+ )
566
+
567
+ if normalize_coefficient == 0:
568
+ eta = 0
569
+
570
+ # eta = normalize_coefficient
571
+
572
+ x_t_hat_c_hat = self.step_function(
573
+ model_output=model_output,
574
+ timestep=timestep,
575
+ sample=sample,
576
+ eta=eta,
577
+ use_clipped_model_output=use_clipped_model_output,
578
+ generator=generator,
579
+ variance_noise=variance_noise,
580
+ return_dict=False,
581
+ scheduler=self,
582
+ )
583
+
584
+ w1 = self._config.ws1[timestep_index]
585
+ w2 = self._config.ws2[timestep_index]
586
+
587
+ x_t_minus_1_exact = self.x_ts[next_timestep_index]
588
+ x_t_minus_1_exact = x_t_minus_1_exact.expand_as(x_t_hat_c_hat)
589
+
590
+ x_t_c_hat: torch.Tensor = self.x_ts_c_hat[next_timestep_index]
591
+ if self._config.breakdown == "x_t_c_hat":
592
+ raise NotImplementedError("breakdown x_t_c_hat not implemented yet")
593
+
594
+ # x_t_c_hat = x_t_c_hat.expand_as(x_t_hat_c_hat)
595
+ x_t_c = x_t_c_hat[0].expand_as(x_t_hat_c_hat)
596
+
597
+ # if self._config.breakdown == "x_t_c_hat":
598
+ # v1 = x_t_hat_c_hat - x_t_c_hat
599
+ # v2 = x_t_c_hat - x_t_c
600
+ if (
601
+ self._config.breakdown == "x_t_hat_c"
602
+ or self._config.breakdown == "x_t_hat_c_with_zeros"
603
+ ):
604
+ zero_index_reconstruction = 1 if not self.time_measure_n else 0
605
+ edit_prompts_num = (
606
+ (model_output.size(0) - zero_index_reconstruction) // 3
607
+ if self._config.breakdown == "x_t_hat_c_with_zeros" and not self.p_to_p
608
+ else (model_output.size(0) - zero_index_reconstruction) // 2
609
+ )
610
+ x_t_hat_c_indices = (zero_index_reconstruction, edit_prompts_num + zero_index_reconstruction)
611
+ edit_images_indices = (
612
+ edit_prompts_num + zero_index_reconstruction,
613
+ (
614
+ model_output.size(0)
615
+ if self._config.breakdown == "x_t_hat_c"
616
+ else zero_index_reconstruction + 2 * edit_prompts_num
617
+ ),
618
+ )
619
+ x_t_hat_c = torch.zeros_like(x_t_hat_c_hat)
620
+ x_t_hat_c[edit_images_indices[0] : edit_images_indices[1]] = x_t_hat_c_hat[
621
+ x_t_hat_c_indices[0] : x_t_hat_c_indices[1]
622
+ ]
623
+ v1 = x_t_hat_c_hat - x_t_hat_c
624
+ v2 = x_t_hat_c - normalize_coefficient * x_t_c
625
+ if self._config.breakdown == "x_t_hat_c_with_zeros" and not self.p_to_p:
626
+ path = os.path.join(
627
+ self.folder_name,
628
+ VECTOR_DATA_FOLDER,
629
+ self.image_name,
630
+ )
631
+ if not hasattr(self, VECTOR_DATA_DICT):
632
+ os.makedirs(path, exist_ok=True)
633
+ self.vector_data = dict()
634
+
635
+ x_t_0 = x_t_c_hat[1]
636
+ empty_prompt_indices = (1 + 2 * edit_prompts_num, 1 + 3 * edit_prompts_num)
637
+ x_t_hat_0 = x_t_hat_c_hat[empty_prompt_indices[0] : empty_prompt_indices[1]]
638
+
639
+ self.vector_data[timestep.item()] = dict()
640
+ self.vector_data[timestep.item()]["x_t_hat_c"] = x_t_hat_c[
641
+ edit_images_indices[0] : edit_images_indices[1]
642
+ ]
643
+ self.vector_data[timestep.item()]["x_t_hat_0"] = x_t_hat_0
644
+ self.vector_data[timestep.item()]["x_t_c"] = x_t_c[0].expand_as(x_t_hat_0)
645
+ self.vector_data[timestep.item()]["x_t_0"] = x_t_0.expand_as(x_t_hat_0)
646
+ self.vector_data[timestep.item()]["x_t_hat_c_hat"] = x_t_hat_c_hat[
647
+ edit_images_indices[0] : edit_images_indices[1]
648
+ ]
649
+ self.vector_data[timestep.item()]["x_t_minus_1_noisy"] = x_t_minus_1_exact[
650
+ 0
651
+ ].expand_as(x_t_hat_0)
652
+ self.vector_data[timestep.item()]["x_t_minus_1_clean"] = self.x_0s[
653
+ next_timestep_index
654
+ ].expand_as(x_t_hat_0)
655
+
656
+ else: # no breakdown
657
+ v1 = x_t_hat_c_hat - normalize_coefficient * x_t_c
658
+ v2 = 0
659
+
660
+ if self.save_intermediate_results and not self.p_to_p:
661
+ delta = v1 + v2
662
+ v1_plus_x0 = self.x_0s[next_timestep_index] + v1
663
+ v2_plus_x0 = self.x_0s[next_timestep_index] + v2
664
+ delta_plus_x0 = self.x_0s[next_timestep_index] + delta
665
+
666
+ v1_images = decode_latents(v1, self.pipe)
667
+ self.v1s_images.append(v1_images)
668
+ v2_images = (
669
+ decode_latents(v2, self.pipe)
670
+ if self._config.breakdown != "no_breakdown"
671
+ else [PIL.Image.new("RGB", (1, 1))]
672
+ )
673
+ self.v2s_images.append(v2_images)
674
+ delta_images = decode_latents(delta, self.pipe)
675
+ self.deltas_images.append(delta_images)
676
+ v1_plus_x0_images = decode_latents(v1_plus_x0, self.pipe)
677
+ self.v1_x0s.append(v1_plus_x0_images)
678
+ v2_plus_x0_images = (
679
+ decode_latents(v2_plus_x0, self.pipe)
680
+ if self._config.breakdown != "no_breakdown"
681
+ else [PIL.Image.new("RGB", (1, 1))]
682
+ )
683
+ self.v2_x0s.append(v2_plus_x0_images)
684
+ delta_plus_x0_images = decode_latents(delta_plus_x0, self.pipe)
685
+ self.deltas_x0s.append(delta_plus_x0_images)
686
+
687
+ # print(f"v1 norm: {torch.norm(v1, dim=0).mean()}")
688
+ # if self._config.breakdown != "no_breakdown":
689
+ # print(f"v2 norm: {torch.norm(v2, dim=0).mean()}")
690
+ # print(f"v sum norm: {torch.norm(v1 + v2, dim=0).mean()}")
691
+
692
+ x_t_minus_1 = normalize_coefficient * x_t_minus_1_exact + w1 * v1 + w2 * v2
693
+
694
+ if (
695
+ self._config.breakdown == "x_t_hat_c"
696
+ or self._config.breakdown == "x_t_hat_c_with_zeros"
697
+ ):
698
+ x_t_minus_1[x_t_hat_c_indices[0] : x_t_hat_c_indices[1]] = x_t_minus_1[
699
+ edit_images_indices[0] : edit_images_indices[1]
700
+ ] # update x_t_hat_c to be x_t_hat_c_hat
701
+ if self._config.breakdown == "x_t_hat_c_with_zeros" and not self.p_to_p:
702
+ x_t_minus_1[empty_prompt_indices[0] : empty_prompt_indices[1]] = (
703
+ x_t_minus_1[edit_images_indices[0] : edit_images_indices[1]]
704
+ )
705
+ self.vector_data[timestep.item()]["x_t_minus_1_edited"] = x_t_minus_1[
706
+ edit_images_indices[0] : edit_images_indices[1]
707
+ ]
708
+ if timestep == self._timesteps[-1]:
709
+ torch.save(
710
+ self.vector_data,
711
+ os.path.join(
712
+ path,
713
+ f"{VECTOR_DATA_DICT}.pt",
714
+ ),
715
+ )
716
+ # p_to_p_force_perfect_reconstruction
717
+ if not self.time_measure_n:
718
+ x_t_minus_1[0] = x_t_minus_1_exact[0]
719
+
720
+ if not return_dict:
721
+ return (x_t_minus_1,)
722
+
723
+ return DDIMSchedulerOutput(
724
+ prev_sample=x_t_minus_1,
725
+ pred_original_sample=None,
726
+ )
727
+
728
+
729
+
730
+ def get_ddpm_inversion_scheduler(
731
+ scheduler,
732
+ step_function,
733
+ config,
734
+ timesteps,
735
+ save_timesteps,
736
+ latents,
737
+ x_ts,
738
+ x_ts_c_hat,
739
+ save_intermediate_results,
740
+ pipe,
741
+ x_0,
742
+ v1s_images,
743
+ v2s_images,
744
+ deltas_images,
745
+ v1_x0s,
746
+ v2_x0s,
747
+ deltas_x0s,
748
+ folder_name,
749
+ image_name,
750
+ time_measure_n,
751
+ ):
752
+ def step(
753
+ model_output: torch.FloatTensor,
754
+ timestep: int,
755
+ sample: torch.FloatTensor,
756
+ eta: float = 0.0,
757
+ use_clipped_model_output: bool = False,
758
+ generator=None,
759
+ variance_noise: Optional[torch.FloatTensor] = None,
760
+ return_dict: bool = True,
761
+ ):
762
+ # if scheduler.is_save:
763
+ # start = timer()
764
+ res_inv = step_save_latents(
765
+ scheduler,
766
+ model_output[:1, :, :, :],
767
+ timestep,
768
+ sample[:1, :, :, :],
769
+ eta,
770
+ use_clipped_model_output,
771
+ generator,
772
+ variance_noise,
773
+ return_dict,
774
+ )
775
+ # end = timer()
776
+ # print(f"Run Time Inv: {end - start}")
777
+
778
+ res_inf = step_use_latents(
779
+ scheduler,
780
+ model_output[1:, :, :, :],
781
+ timestep,
782
+ sample[1:, :, :, :],
783
+ eta,
784
+ use_clipped_model_output,
785
+ generator,
786
+ variance_noise,
787
+ return_dict,
788
+ )
789
+ # res = res_inv
790
+ res = (torch.cat((res_inv[0], res_inf[0]), dim=0),)
791
+ return res
792
+ # return res
793
+
794
+ scheduler.step_function = step_function
795
+ scheduler.is_save = True
796
+ scheduler._timesteps = timesteps
797
+ scheduler._save_timesteps = save_timesteps if save_timesteps else timesteps
798
+ scheduler._config = config
799
+ scheduler.latents = latents
800
+ scheduler.x_ts = x_ts
801
+ scheduler.x_ts_c_hat = x_ts_c_hat
802
+ scheduler.step = step
803
+ scheduler.save_intermediate_results = save_intermediate_results
804
+ scheduler.pipe = pipe
805
+ scheduler.v1s_images = v1s_images
806
+ scheduler.v2s_images = v2s_images
807
+ scheduler.deltas_images = deltas_images
808
+ scheduler.v1_x0s = v1_x0s
809
+ scheduler.v2_x0s = v2_x0s
810
+ scheduler.deltas_x0s = deltas_x0s
811
+ scheduler.clean_step_run = False
812
+ scheduler.x_0s = create_xts(
813
+ config.noise_shift_delta,
814
+ config.noise_timesteps,
815
+ config.clean_step_timestep,
816
+ None,
817
+ pipe.scheduler,
818
+ timesteps,
819
+ x_0,
820
+ no_add_noise=True,
821
+ )
822
+ scheduler.folder_name = folder_name
823
+ scheduler.image_name = image_name
824
+ scheduler.p_to_p = False
825
+ scheduler.p_to_p_replace = False
826
+ scheduler.time_measure_n = time_measure_n
827
+ return scheduler
828
+
829
+
830
+ def create_grid(
831
+ images,
832
+ p_to_p_images,
833
+ prompts,
834
+ original_image_path,
835
+ ):
836
+ images_len = len(images) if len(images) > 0 else len(p_to_p_images)
837
+ images_size = images[0].size if len(images) > 0 else p_to_p_images[0].size
838
+ x_0 = Image.open(original_image_path).resize(images_size)
839
+
840
+ images_ = [x_0] + images + ([x_0] + p_to_p_images if p_to_p_images else [])
841
+
842
+ l1 = 1 if len(images) > 0 else 0
843
+ l2 = 1 if len(p_to_p_images) else 0
844
+ grid = make_image_grid(images_, rows=l1 + l2, cols=images_len + 1, resize=None)
845
+
846
+ width = images_size[0]
847
+ height = width // 5
848
+ font = ImageFont.truetype("font.ttf", width // 14)
849
+
850
+ grid1 = Image.new("RGB", size=(grid.size[0], grid.size[1] + height))
851
+ grid1.paste(grid, (0, 0))
852
+
853
+ draw = ImageDraw.Draw(grid1)
854
+
855
+ c_width = 0
856
+ for prompt in prompts:
857
+ if len(prompt) > 30:
858
+ prompt = prompt[:30] + "\n" + prompt[30:]
859
+ draw.text((c_width, width * 2), prompt, font=font, fill=(255, 255, 255))
860
+ c_width += width
861
+
862
+ return grid1
863
+
864
+
865
+ def save_intermediate_results(
866
+ v1s_images,
867
+ v2s_images,
868
+ deltas_images,
869
+ v1_x0s,
870
+ v2_x0s,
871
+ deltas_x0s,
872
+ folder_name,
873
+ original_prompt,
874
+ ):
875
+ from diffusers.utils import make_image_grid
876
+
877
+ path = f"{folder_name}/{original_prompt}_intermediate_results/"
878
+ os.makedirs(path, exist_ok=True)
879
+ make_image_grid(
880
+ list(itertools.chain(*v1s_images)),
881
+ rows=len(v1s_images),
882
+ cols=len(v1s_images[0]),
883
+ ).save(f"{path}v1s_images.png")
884
+ make_image_grid(
885
+ list(itertools.chain(*v2s_images)),
886
+ rows=len(v2s_images),
887
+ cols=len(v2s_images[0]),
888
+ ).save(f"{path}v2s_images.png")
889
+ make_image_grid(
890
+ list(itertools.chain(*deltas_images)),
891
+ rows=len(deltas_images),
892
+ cols=len(deltas_images[0]),
893
+ ).save(f"{path}deltas_images.png")
894
+ make_image_grid(
895
+ list(itertools.chain(*v1_x0s)),
896
+ rows=len(v1_x0s),
897
+ cols=len(v1_x0s[0]),
898
+ ).save(f"{path}v1_x0s.png")
899
+ make_image_grid(
900
+ list(itertools.chain(*v2_x0s)),
901
+ rows=len(v2_x0s),
902
+ cols=len(v2_x0s[0]),
903
+ ).save(f"{path}v2_x0s.png")
904
+ make_image_grid(
905
+ list(itertools.chain(*deltas_x0s)),
906
+ rows=len(deltas_x0s[0]),
907
+ cols=len(deltas_x0s),
908
+ ).save(f"{path}deltas_x0s.png")
909
+ for i, image in enumerate(list(itertools.chain(*deltas_x0s))):
910
+ image.save(f"{path}deltas_x0s_{i}.png")
911
+
912
+
913
+ # copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.py and removed the add_noise line
914
+ def prepare_latents_no_add_noise(
915
+ self,
916
+ image,
917
+ timestep,
918
+ batch_size,
919
+ num_images_per_prompt,
920
+ dtype,
921
+ device,
922
+ generator=None,
923
+ ):
924
+ from diffusers.utils import deprecate
925
+
926
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
927
+ raise ValueError(
928
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
929
+ )
930
+
931
+ image = image.to(device=device, dtype=dtype)
932
+
933
+ batch_size = batch_size * num_images_per_prompt
934
+
935
+ if image.shape[1] == 4:
936
+ init_latents = image
937
+
938
+ else:
939
+ if isinstance(generator, list) and len(generator) != batch_size:
940
+ raise ValueError(
941
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
942
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
943
+ )
944
+
945
+ elif isinstance(generator, list):
946
+ init_latents = [
947
+ self.retrieve_latents(
948
+ self.vae.encode(image[i : i + 1]), generator=generator[i]
949
+ )
950
+ for i in range(batch_size)
951
+ ]
952
+ init_latents = torch.cat(init_latents, dim=0)
953
+ else:
954
+ init_latents = self.retrieve_latents(
955
+ self.vae.encode(image), generator=generator
956
+ )
957
+
958
+ init_latents = self.vae.config.scaling_factor * init_latents
959
+
960
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
961
+ # expand init_latents for batch_size
962
+ deprecation_message = (
963
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
964
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
965
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
966
+ " your script to pass as many initial images as text prompts to suppress this warning."
967
+ )
968
+ deprecate(
969
+ "len(prompt) != len(image)",
970
+ "1.0.0",
971
+ deprecation_message,
972
+ standard_warn=False,
973
+ )
974
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
975
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
976
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
977
+ raise ValueError(
978
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
979
+ )
980
+ else:
981
+ init_latents = torch.cat([init_latents], dim=0)
982
+
983
+ # get latents
984
+ latents = init_latents
985
+
986
+ return latents
987
+
988
+
989
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
990
+ def encode_prompt_empty_prompt_zeros_sdxl(
991
+ self,
992
+ prompt: str,
993
+ prompt_2: Optional[str] = None,
994
+ device: Optional[torch.device] = None,
995
+ num_images_per_prompt: int = 1,
996
+ do_classifier_free_guidance: bool = True,
997
+ negative_prompt: Optional[str] = None,
998
+ negative_prompt_2: Optional[str] = None,
999
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1000
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1001
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1002
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1003
+ lora_scale: Optional[float] = None,
1004
+ clip_skip: Optional[int] = None,
1005
+ ):
1006
+ r"""
1007
+ Encodes the prompt into text encoder hidden states.
1008
+
1009
+ Args:
1010
+ prompt (`str` or `List[str]`, *optional*):
1011
+ prompt to be encoded
1012
+ prompt_2 (`str` or `List[str]`, *optional*):
1013
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
1014
+ used in both text-encoders
1015
+ device: (`torch.device`):
1016
+ torch device
1017
+ num_images_per_prompt (`int`):
1018
+ number of images that should be generated per prompt
1019
+ do_classifier_free_guidance (`bool`):
1020
+ whether to use classifier free guidance or not
1021
+ negative_prompt (`str` or `List[str]`, *optional*):
1022
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
1023
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
1024
+ less than `1`).
1025
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
1026
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
1027
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
1028
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1029
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1030
+ provided, text embeddings will be generated from `prompt` input argument.
1031
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1032
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1033
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1034
+ argument.
1035
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1036
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
1037
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
1038
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1039
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1040
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
1041
+ input argument.
1042
+ lora_scale (`float`, *optional*):
1043
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
1044
+ clip_skip (`int`, *optional*):
1045
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1046
+ the output of the pre-final layer will be used for computing the prompt embeddings.
1047
+ """
1048
+ device = device or self._execution_device
1049
+
1050
+ # set lora scale so that monkey patched LoRA
1051
+ # function of text encoder can correctly access it
1052
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
1053
+ self._lora_scale = lora_scale
1054
+
1055
+ # dynamically adjust the LoRA scale
1056
+ if self.text_encoder is not None:
1057
+ if not USE_PEFT_BACKEND:
1058
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
1059
+ else:
1060
+ scale_lora_layers(self.text_encoder, lora_scale)
1061
+
1062
+ if self.text_encoder_2 is not None:
1063
+ if not USE_PEFT_BACKEND:
1064
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
1065
+ else:
1066
+ scale_lora_layers(self.text_encoder_2, lora_scale)
1067
+
1068
+ prompt = [prompt] if isinstance(prompt, str) else prompt
1069
+
1070
+ if prompt is not None:
1071
+ batch_size = len(prompt)
1072
+ else:
1073
+ batch_size = prompt_embeds.shape[0]
1074
+
1075
+ # Define tokenizers and text encoders
1076
+ tokenizers = (
1077
+ [self.tokenizer, self.tokenizer_2]
1078
+ if self.tokenizer is not None
1079
+ else [self.tokenizer_2]
1080
+ )
1081
+ text_encoders = (
1082
+ [self.text_encoder, self.text_encoder_2]
1083
+ if self.text_encoder is not None
1084
+ else [self.text_encoder_2]
1085
+ )
1086
+
1087
+ if prompt_embeds is None:
1088
+ prompt_2 = prompt_2 or prompt
1089
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
1090
+
1091
+ # textual inversion: procecss multi-vector tokens if necessary
1092
+ prompt_embeds_list = []
1093
+ prompts = [prompt, prompt_2]
1094
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
1095
+
1096
+ text_inputs = tokenizer(
1097
+ prompt,
1098
+ padding="max_length",
1099
+ max_length=tokenizer.model_max_length,
1100
+ truncation=True,
1101
+ return_tensors="pt",
1102
+ )
1103
+
1104
+ text_input_ids = text_inputs.input_ids
1105
+ untruncated_ids = tokenizer(
1106
+ prompt, padding="longest", return_tensors="pt"
1107
+ ).input_ids
1108
+
1109
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[
1110
+ -1
1111
+ ] and not torch.equal(text_input_ids, untruncated_ids):
1112
+ removed_text = tokenizer.batch_decode(
1113
+ untruncated_ids[:, tokenizer.model_max_length - 1 : -1]
1114
+ )
1115
+ logger.warning(
1116
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
1117
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
1118
+ )
1119
+
1120
+ prompt_embeds = text_encoder(
1121
+ text_input_ids.to(device), output_hidden_states=True
1122
+ )
1123
+
1124
+ # We are only ALWAYS interested in the pooled output of the final text encoder
1125
+ pooled_prompt_embeds = prompt_embeds[0]
1126
+ if clip_skip is None:
1127
+ prompt_embeds = prompt_embeds.hidden_states[-2]
1128
+ else:
1129
+ # "2" because SDXL always indexes from the penultimate layer.
1130
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
1131
+
1132
+ if self.config.force_zeros_for_empty_prompt:
1133
+ prompt_embeds[[i for i in range(len(prompt)) if prompt[i] == ""]] = 0
1134
+ pooled_prompt_embeds[
1135
+ [i for i in range(len(prompt)) if prompt[i] == ""]
1136
+ ] = 0
1137
+
1138
+ prompt_embeds_list.append(prompt_embeds)
1139
+
1140
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
1141
+
1142
+ # get unconditional embeddings for classifier free guidance
1143
+ zero_out_negative_prompt = (
1144
+ negative_prompt is None and self.config.force_zeros_for_empty_prompt
1145
+ )
1146
+ if (
1147
+ do_classifier_free_guidance
1148
+ and negative_prompt_embeds is None
1149
+ and zero_out_negative_prompt
1150
+ ):
1151
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
1152
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
1153
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
1154
+ negative_prompt = negative_prompt or ""
1155
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
1156
+
1157
+ # normalize str to list
1158
+ negative_prompt = (
1159
+ batch_size * [negative_prompt]
1160
+ if isinstance(negative_prompt, str)
1161
+ else negative_prompt
1162
+ )
1163
+ negative_prompt_2 = (
1164
+ batch_size * [negative_prompt_2]
1165
+ if isinstance(negative_prompt_2, str)
1166
+ else negative_prompt_2
1167
+ )
1168
+
1169
+ uncond_tokens: List[str]
1170
+ if prompt is not None and type(prompt) is not type(negative_prompt):
1171
+ raise TypeError(
1172
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
1173
+ f" {type(prompt)}."
1174
+ )
1175
+ elif batch_size != len(negative_prompt):
1176
+ raise ValueError(
1177
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
1178
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
1179
+ " the batch size of `prompt`."
1180
+ )
1181
+ else:
1182
+ uncond_tokens = [negative_prompt, negative_prompt_2]
1183
+
1184
+ negative_prompt_embeds_list = []
1185
+ for negative_prompt, tokenizer, text_encoder in zip(
1186
+ uncond_tokens, tokenizers, text_encoders
1187
+ ):
1188
+
1189
+ max_length = prompt_embeds.shape[1]
1190
+ uncond_input = tokenizer(
1191
+ negative_prompt,
1192
+ padding="max_length",
1193
+ max_length=max_length,
1194
+ truncation=True,
1195
+ return_tensors="pt",
1196
+ )
1197
+
1198
+ negative_prompt_embeds = text_encoder(
1199
+ uncond_input.input_ids.to(device),
1200
+ output_hidden_states=True,
1201
+ )
1202
+ # We are only ALWAYS interested in the pooled output of the final text encoder
1203
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
1204
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
1205
+
1206
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
1207
+
1208
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
1209
+
1210
+ if self.text_encoder_2 is not None:
1211
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
1212
+ else:
1213
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
1214
+
1215
+ bs_embed, seq_len, _ = prompt_embeds.shape
1216
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
1217
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
1218
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
1219
+
1220
+ if do_classifier_free_guidance:
1221
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
1222
+ seq_len = negative_prompt_embeds.shape[1]
1223
+
1224
+ if self.text_encoder_2 is not None:
1225
+ negative_prompt_embeds = negative_prompt_embeds.to(
1226
+ dtype=self.text_encoder_2.dtype, device=device
1227
+ )
1228
+ else:
1229
+ negative_prompt_embeds = negative_prompt_embeds.to(
1230
+ dtype=self.unet.dtype, device=device
1231
+ )
1232
+
1233
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
1234
+ 1, num_images_per_prompt, 1
1235
+ )
1236
+ negative_prompt_embeds = negative_prompt_embeds.view(
1237
+ batch_size * num_images_per_prompt, seq_len, -1
1238
+ )
1239
+
1240
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
1241
+ bs_embed * num_images_per_prompt, -1
1242
+ )
1243
+ if do_classifier_free_guidance:
1244
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(
1245
+ 1, num_images_per_prompt
1246
+ ).view(bs_embed * num_images_per_prompt, -1)
1247
+
1248
+ if self.text_encoder is not None:
1249
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
1250
+ # Retrieve the original scale by scaling back the LoRA layers
1251
+ unscale_lora_layers(self.text_encoder, lora_scale)
1252
+
1253
+ if self.text_encoder_2 is not None:
1254
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
1255
+ # Retrieve the original scale by scaling back the LoRA layers
1256
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
1257
+
1258
+ return (
1259
+ prompt_embeds,
1260
+ negative_prompt_embeds,
1261
+ pooled_prompt_embeds,
1262
+ negative_pooled_prompt_embeds,
1263
+ )
1264
+
1265
+
1266
+ def create_xts(
1267
+ noise_shift_delta,
1268
+ noise_timesteps,
1269
+ clean_step_timestep,
1270
+ generator,
1271
+ scheduler,
1272
+ timesteps,
1273
+ x_0,
1274
+ no_add_noise=False,
1275
+ ):
1276
+ if noise_timesteps is None:
1277
+ noising_delta = noise_shift_delta * (timesteps[0] - timesteps[1])
1278
+ noise_timesteps = [timestep - int(noising_delta) for timestep in timesteps]
1279
+
1280
+ first_x_0_idx = len(noise_timesteps)
1281
+ for i in range(len(noise_timesteps)):
1282
+ if noise_timesteps[i] <= 0:
1283
+ first_x_0_idx = i
1284
+ break
1285
+
1286
+ noise_timesteps = noise_timesteps[:first_x_0_idx]
1287
+
1288
+ x_0_expanded = x_0.expand(len(noise_timesteps), -1, -1, -1)
1289
+ noise = (
1290
+ torch.randn(x_0_expanded.size(), generator=generator, device="cpu").to(
1291
+ x_0.device
1292
+ )
1293
+ if not no_add_noise
1294
+ else torch.zeros_like(x_0_expanded)
1295
+ )
1296
+ x_ts = scheduler.add_noise(
1297
+ x_0_expanded,
1298
+ noise,
1299
+ torch.IntTensor(noise_timesteps),
1300
+ )
1301
+ x_ts = [t.unsqueeze(dim=0) for t in list(x_ts)]
1302
+ x_ts += [x_0] * (len(timesteps) - first_x_0_idx)
1303
+ x_ts += [x_0]
1304
+ if clean_step_timestep > 0:
1305
+ x_ts += [x_0]
1306
+ return x_ts
1307
+
1308
+
1309
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
1310
+ def add_noise(
1311
+ self,
1312
+ original_samples: torch.FloatTensor,
1313
+ noise: torch.FloatTensor,
1314
+ image_timesteps: torch.IntTensor,
1315
+ noise_timesteps: torch.IntTensor,
1316
+ ) -> torch.FloatTensor:
1317
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
1318
+ # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
1319
+ # for the subsequent add_noise calls
1320
+ self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
1321
+ alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
1322
+ timesteps = timesteps.to(original_samples.device)
1323
+
1324
+ sqrt_alpha_prod = alphas_cumprod[image_timesteps] ** 0.5
1325
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
1326
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
1327
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
1328
+
1329
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[noise_timesteps]) ** 0.5
1330
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
1331
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
1332
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
1333
+
1334
+ noisy_samples = (
1335
+ sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
1336
+ )
1337
+ return noisy_samples
1338
+
1339
+
1340
+ def make_image_grid(
1341
+ images: List[PIL.Image.Image], rows: int, cols: int, resize: int = None, size=None
1342
+ ) -> PIL.Image.Image:
1343
+ """
1344
+ Prepares a single grid of images. Useful for visualization purposes.
1345
+ """
1346
+ assert len(images) == rows * cols
1347
+
1348
+ if resize is not None:
1349
+ images = [img.resize((resize, resize)) for img in images]
1350
+
1351
+ w, h = size
1352
+ grid = Image.new("RGB", size=(cols * w, rows * h))
1353
+
1354
+ for i, img in enumerate(images):
1355
+ grid.paste(img, box=(i % cols * w, i // cols * h))
1356
+ return grid