KingNish commited on
Commit
085c378
1 Parent(s): c996b7d

faster (testing)

Browse files

made faster by
1. Model Loading and Device Placement: The model (pipe) is now loaded only once outside the generate function. This saves significant time on each generation.
2. Resolution Binning: Enabled by default in the options, this technique speeds up generation and reduces VRAM usage, especially for larger images.
3. Torch Compile (Experimental): Added the option to use torch.compile which might further improve performance on compatible hardware. This is highly dependent on your setup.
4. CPU Offloading (Experimental): Allows offloading parts of the model to CPU RAM, potentially enabling larger image generation or batch sizes if your GPU VRAM is limited.
5. Batch Generation: Added the ability to generate multiple images in a single pass (controlled by BATCH_SIZE environment variable or defaulting to 1). This significantly reduces overhead when generating more than one image.

Files changed (1) hide show
  1. app.py +59 -37
app.py CHANGED
@@ -10,27 +10,33 @@ import spaces
10
  import torch
11
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
12
 
13
- if not torch.cuda.is_available():
14
- DESCRIPTION += "\n<p>Running on CPU 🥶 This demo may not work on CPU.</p>"
15
-
16
- MAX_SEED = np.iinfo(np.int32).max
17
- CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
18
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
19
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
20
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
 
21
 
 
22
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- if torch.cuda.is_available():
25
- pipe = StableDiffusionXLPipeline.from_pretrained(
26
- "sd-community/sdxl-flash",
27
- torch_dtype=torch.float16,
28
- use_safetensors=True,
29
- add_watermarker=False
30
- )
31
- pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
32
- pipe.to("cuda")
33
-
34
 
35
  def save_image(img):
36
  unique_name = str(uuid.uuid4()) + ".png"
@@ -53,51 +59,60 @@ def generate(
53
  guidance_scale: float = 3,
54
  num_inference_steps: int = 30,
55
  randomize_seed: bool = False,
56
- use_resolution_binning: bool = True,
 
57
  progress=gr.Progress(track_tqdm=True),
58
  ):
59
- pipe.to(device)
60
  seed = int(randomize_seed_fn(seed, randomize_seed))
61
- generator = torch.Generator().manual_seed(seed)
62
 
 
63
  options = {
64
- "prompt":prompt,
65
- "negative_prompt":negative_prompt,
66
- "width":width,
67
- "height":height,
68
- "guidance_scale":guidance_scale,
69
- "num_inference_steps":num_inference_steps,
70
- "generator":generator,
71
- "use_resolution_binning":use_resolution_binning,
72
- "output_type":"pil",
73
-
74
  }
75
-
76
- images = pipe(**options).images
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  image_paths = [save_image(img) for img in images]
79
  return image_paths, seed
80
 
81
-
82
  examples = [
83
  "a cat eating a piece of cheese",
84
  "a ROBOT riding a BLUE horse on Mars, photorealistic, 4k",
85
  "Ironman VS Hulk, ultrarealistic",
86
  "Astronaut in a jungle, cold color palette, oil pastel, detailed, 8k",
87
- "An alien holding sign board contain word 'Flash', futuristic, neonpunk",
88
  "Kids going to school, Anime style"
89
  ]
90
 
91
  css = '''
92
- .gradio-container{max-width: 560px !important}
93
  h1{text-align:center}
94
  footer {
95
  visibility: hidden
96
  }
97
  '''
 
98
  with gr.Blocks(css=css) as demo:
99
- gr.Markdown("""# SDXL Flash
100
- ### First Image processing takes time then images generate faster.""")
101
  with gr.Group():
102
  with gr.Row():
103
  prompt = gr.Text(
@@ -108,8 +123,15 @@ with gr.Blocks(css=css) as demo:
108
  container=False,
109
  )
110
  run_button = gr.Button("Run", scale=0)
111
- result = gr.Gallery(label="Result", columns=1)
112
  with gr.Accordion("Advanced options", open=False):
 
 
 
 
 
 
 
113
  with gr.Row():
114
  use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True)
115
  negative_prompt = gr.Text(
@@ -164,7 +186,6 @@ with gr.Blocks(css=css) as demo:
164
  inputs=prompt,
165
  outputs=[result, seed],
166
  fn=generate,
167
- cache_examples=CACHE_EXAMPLES,
168
  )
169
 
170
  use_negative_prompt.change(
@@ -191,6 +212,7 @@ with gr.Blocks(css=css) as demo:
191
  guidance_scale,
192
  num_inference_steps,
193
  randomize_seed,
 
194
  ],
195
  outputs=[result, seed],
196
  api_name="run",
 
10
  import torch
11
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
12
 
13
+ # Use environment variables for flexibility
14
+ MODEL_ID = os.getenv("MODEL_ID", "sd-community/sdxl-flash")
 
 
 
15
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
16
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
17
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
18
+ BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # Allow generating multiple images at once
19
 
20
+ # Determine device and load model outside of function for efficiency
21
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
22
+ pipe = StableDiffusionXLPipeline.from_pretrained(
23
+ MODEL_ID,
24
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
25
+ use_safetensors=True,
26
+ add_watermarker=False,
27
+ variant="fp16" if torch.cuda.is_available() else None,
28
+ ).to(device)
29
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
30
+
31
+ # Torch compile for potential speedup (experimental)
32
+ if USE_TORCH_COMPILE:
33
+ pipe.compile()
34
+
35
+ # CPU offloading for larger RAM capacity (experimental)
36
+ if ENABLE_CPU_OFFLOAD:
37
+ pipe.enable_model_cpu_offload()
38
 
39
+ MAX_SEED = np.iinfo(np.int32).max
 
 
 
 
 
 
 
 
 
40
 
41
  def save_image(img):
42
  unique_name = str(uuid.uuid4()) + ".png"
 
59
  guidance_scale: float = 3,
60
  num_inference_steps: int = 30,
61
  randomize_seed: bool = False,
62
+ use_resolution_binning: bool = True,
63
+ num_images: int = 1, # Number of images to generate
64
  progress=gr.Progress(track_tqdm=True),
65
  ):
 
66
  seed = int(randomize_seed_fn(seed, randomize_seed))
67
+ generator = torch.Generator(device=device).manual_seed(seed)
68
 
69
+ # Improved options handling
70
  options = {
71
+ "prompt": [prompt] * num_images,
72
+ "negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None,
73
+ "width": width,
74
+ "height": height,
75
+ "guidance_scale": guidance_scale,
76
+ "num_inference_steps": num_inference_steps,
77
+ "generator": generator,
78
+ "output_type": "pil",
 
 
79
  }
80
+
81
+ # Use resolution binning for faster generation with less VRAM usage
82
+ if use_resolution_binning:
83
+ options["use_resolution_binning"] = True
84
+
85
+ # Generate images potentially in batches
86
+ images = []
87
+ for i in range(0, num_images, BATCH_SIZE):
88
+ batch_options = options.copy()
89
+ batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
90
+ if "negative_prompt" in batch_options:
91
+ batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
92
+ images.extend(pipe(**batch_options).images)
93
 
94
  image_paths = [save_image(img) for img in images]
95
  return image_paths, seed
96
 
 
97
  examples = [
98
  "a cat eating a piece of cheese",
99
  "a ROBOT riding a BLUE horse on Mars, photorealistic, 4k",
100
  "Ironman VS Hulk, ultrarealistic",
101
  "Astronaut in a jungle, cold color palette, oil pastel, detailed, 8k",
102
+ "An alien holding a sign board containing the word 'Flash', futuristic, neonpunk",
103
  "Kids going to school, Anime style"
104
  ]
105
 
106
  css = '''
107
+ .gradio-container{max-width: 700px !important}
108
  h1{text-align:center}
109
  footer {
110
  visibility: hidden
111
  }
112
  '''
113
+
114
  with gr.Blocks(css=css) as demo:
115
+ gr.Markdown("""# SDXL Flash""")
 
116
  with gr.Group():
117
  with gr.Row():
118
  prompt = gr.Text(
 
123
  container=False,
124
  )
125
  run_button = gr.Button("Run", scale=0)
126
+ result = gr.Gallery(label="Result", columns=2, show_label=False)
127
  with gr.Accordion("Advanced options", open=False):
128
+ num_images = gr.Slider(
129
+ label="Number of Images",
130
+ minimum=1,
131
+ maximum=4,
132
+ step=1,
133
+ value=1,
134
+ )
135
  with gr.Row():
136
  use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True)
137
  negative_prompt = gr.Text(
 
186
  inputs=prompt,
187
  outputs=[result, seed],
188
  fn=generate,
 
189
  )
190
 
191
  use_negative_prompt.change(
 
212
  guidance_scale,
213
  num_inference_steps,
214
  randomize_seed,
215
+ num_images
216
  ],
217
  outputs=[result, seed],
218
  api_name="run",