multimodalart HF staff commited on
Commit
12e6773
1 Parent(s): df2ad60

Swap to JAX

Browse files
Files changed (1) hide show
  1. app.py +79 -84
app.py CHANGED
@@ -1,53 +1,37 @@
1
  import gradio as gr
2
- import cv2
3
- import torch
4
- import os
5
- from imwatermark import WatermarkEncoder
6
- import numpy as np
7
- from PIL import Image
8
- import re
9
  from datasets import load_dataset
10
- from diffusers import DiffusionPipeline, EulerDiscreteScheduler
11
 
12
- from share_btn import community_icon_html, loading_icon_html, share_js
 
 
13
 
14
- REPO_ID = "stabilityai/stable-diffusion-2"
15
- device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
17
- wm = "SDV2"
18
- wm_encoder = WatermarkEncoder()
19
- wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
20
- def put_watermark(img, wm_encoder=None):
21
- if wm_encoder is not None:
22
- img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
23
- img = wm_encoder.encode(img, 'dwtDct')
24
- img = Image.fromarray(img[:, :, ::-1])
25
- return img
26
 
27
- repo_id = "stabilityai/stable-diffusion-2"
28
- scheduler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler", prediction_type="v_prediction")
29
- pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, revision="fp16", scheduler=scheduler)
30
- pipe = pipe.to(device)
31
- pipe.enable_xformers_memory_efficient_attention()
32
 
33
- #If you have duplicated this Space or is running locally, you can remove this snippet
34
- if "HUGGING_FACE_HUB_TOKEN" in os.environ:
35
- word_list_dataset = load_dataset("stabilityai/word-list", data_files="list.txt", use_auth_token=True)
36
- word_list = word_list_dataset["train"]['text']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- def infer(prompt, samples, steps, scale, seed):
39
- #If you have duplicated this Space or is running locally, you can remove this snippet
40
- if "HUGGING_FACE_HUB_TOKEN" in os.environ:
41
- for filter in word_list:
42
- if re.search(rf"\b{filter}\b", prompt):
43
- raise gr.Error("Unsafe content found. Please try again with different prompts.")
44
- generator = torch.Generator(device=device).manual_seed(seed)
45
- images = pipe(prompt, width=768, height=768, num_inference_steps=steps, guidance_scale=scale, num_images_per_prompt=samples, generator=generator).images
46
- images_watermarked = []
47
- for image in images:
48
- image = put_watermark(image, wm_encoder)
49
- images_watermarked.append(image)
50
- return images_watermarked
51
 
52
  css = """
53
  .gradio-container {
@@ -176,41 +160,42 @@ block = gr.Blocks(css=css)
176
  examples = [
177
  [
178
  'A high tech solarpunk utopia in the Amazon rainforest',
179
- 4,
180
- 25,
181
- 9,
182
- 1024,
183
  ],
184
  [
185
  'A pikachu fine dining with a view to the Eiffel Tower',
186
- 4,
187
- 25,
188
- 9,
189
- 1024,
190
  ],
191
  [
192
  'A mecha robot in a favela in expressionist style',
193
- 4,
194
- 25,
195
- 9,
196
- 1024,
197
  ],
198
  [
199
  'an insect robot preparing a delicious meal',
200
- 4,
201
- 25,
202
- 9,
203
- 1024,
204
  ],
205
  [
206
  "A small cabin on top of a snowy mountain in the style of Disney, artstation",
207
- 4,
208
- 25,
209
- 9,
210
- 1024,
211
  ],
212
  ]
213
 
 
214
  with block:
215
  gr.HTML(
216
  """
@@ -297,34 +282,44 @@ with block:
297
  label="Generated images", show_label=False, elem_id="gallery"
298
  ).style(grid=[2], height="auto")
299
 
300
-
301
-
302
- with gr.Accordion("Custom options", open=False):
303
- samples = gr.Slider(label="Images", minimum=1, maximum=4, value=4, step=1)
304
- steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=25, step=1)
305
- scale = gr.Slider(
306
- label="Guidance Scale", minimum=0, maximum=50, value=9, step=0.1
307
- )
308
- seed = gr.Slider(
309
- label="Seed",
310
- minimum=0,
311
- maximum=2147483647,
312
- step=1,
313
- randomize=True,
314
- )
315
-
316
- with gr.Group():
317
  with gr.Group(elem_id="share-btn-container"):
318
  community_icon = gr.HTML(community_icon_html)
319
  loading_icon = gr.HTML(loading_icon_html)
320
  share_button = gr.Button("Share to community", elem_id="share-btn")
321
 
322
- ex = gr.Examples(examples=examples, fn=infer, inputs=[text, samples, steps, scale, seed], outputs=[gallery], cache_examples=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  ex.dataset.headers = [""]
324
 
325
- text.submit(infer, inputs=[text, samples, steps, scale, seed], outputs=[gallery])
326
- btn.click(infer, inputs=[text, samples, steps, scale, seed], outputs=[gallery])
327
 
 
 
 
 
 
 
 
 
 
 
328
  share_button.click(
329
  None,
330
  [],
@@ -334,7 +329,7 @@ with block:
334
  gr.HTML(
335
  """
336
  <div class="footer">
337
- <p>Model by <a href="https://huggingface.co/stabilityai" style="text-decoration: underline;" target="_blank">Stability AI</a> - Gradio Demo by 🤗 Hugging Face using the <a href="https://github.com/huggingface/diffusers" style="text-decoration: underline;" target="_blank">🧨 diffusers library</a>
338
  </p>
339
  </div>
340
  <div class="acknowledgments">
@@ -346,4 +341,4 @@ Despite how impressive being able to turn text into image is, beware to the fact
346
  """
347
  )
348
 
349
- block.queue(concurrency_count=1, max_size=50).launch(max_threads=150)
 
1
  import gradio as gr
 
 
 
 
 
 
 
2
  from datasets import load_dataset
3
+ from PIL import Image
4
 
5
+ import re
6
+ import os
7
+ import requests
8
 
 
 
9
 
10
+ from share_btn import community_icon_html, loading_icon_html, share_js
 
 
 
 
 
 
 
 
11
 
12
+ word_list_dataset = load_dataset("stabilityai/word-list", data_files="list.txt", use_auth_token=True)
13
+ word_list = word_list_dataset["train"]['text']
 
 
 
14
 
15
+ is_gpu_busy = False
16
+ def infer(prompt):
17
+ global is_gpu_busy
18
+ samples = 4
19
+ steps = 50
20
+ scale = 7.5
21
+ for filter in word_list:
22
+ if re.search(rf"\b{filter}\b", prompt):
23
+ raise gr.Error("Unsafe content found. Please try again with different prompts.")
24
+
25
+ images = []
26
+ url = os.getenv('JAX_BACKEND_URL')
27
+ payload = {'prompt': prompt}
28
+ images_request = requests.post(url, json = payload)
29
+ for image in images_request.json()["images"]:
30
+ image_b64 = (f"data:image/jpeg;base64,{image}")
31
+ images.append(image_b64)
32
+
33
+ return images
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  css = """
37
  .gradio-container {
 
160
  examples = [
161
  [
162
  'A high tech solarpunk utopia in the Amazon rainforest',
163
+ # 4,
164
+ # 45,
165
+ # 7.5,
166
+ # 1024,
167
  ],
168
  [
169
  'A pikachu fine dining with a view to the Eiffel Tower',
170
+ # 4,
171
+ # 45,
172
+ # 7,
173
+ # 1024,
174
  ],
175
  [
176
  'A mecha robot in a favela in expressionist style',
177
+ # 4,
178
+ # 45,
179
+ # 7,
180
+ # 1024,
181
  ],
182
  [
183
  'an insect robot preparing a delicious meal',
184
+ # 4,
185
+ # 45,
186
+ # 7,
187
+ # 1024,
188
  ],
189
  [
190
  "A small cabin on top of a snowy mountain in the style of Disney, artstation",
191
+ # 4,
192
+ # 45,
193
+ # 7,
194
+ # 1024,
195
  ],
196
  ]
197
 
198
+
199
  with block:
200
  gr.HTML(
201
  """
 
282
  label="Generated images", show_label=False, elem_id="gallery"
283
  ).style(grid=[2], height="auto")
284
 
285
+ with gr.Group(elem_id="container-advanced-btns"):
286
+ #advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  with gr.Group(elem_id="share-btn-container"):
288
  community_icon = gr.HTML(community_icon_html)
289
  loading_icon = gr.HTML(loading_icon_html)
290
  share_button = gr.Button("Share to community", elem_id="share-btn")
291
 
292
+ #with gr.Row(elem_id="advanced-options"):
293
+ # gr.Markdown("Advanced settings are temporarily unavailable")
294
+ # samples = gr.Slider(label="Images", minimum=1, maximum=4, value=4, step=1)
295
+ # steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=45, step=1)
296
+ # scale = gr.Slider(
297
+ # label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1
298
+ # )
299
+ # seed = gr.Slider(
300
+ # label="Seed",
301
+ # minimum=0,
302
+ # maximum=2147483647,
303
+ # step=1,
304
+ # randomize=True,
305
+ # )
306
+
307
+ ex = gr.Examples(examples=examples, fn=infer, inputs=text, outputs=[gallery, community_icon, loading_icon, share_button], cache_examples=False)
308
  ex.dataset.headers = [""]
309
 
310
+ text.submit(infer, inputs=text, outputs=[gallery], postprocess=False)
311
+ btn.click(infer, inputs=text, outputs=[gallery], postprocess=False)
312
 
313
+ #advanced_button.click(
314
+ # None,
315
+ # [],
316
+ # text,
317
+ # _js="""
318
+ # () => {
319
+ # const options = document.querySelector("body > gradio-app").querySelector("#advanced-options");
320
+ # options.style.display = ["none", ""].includes(options.style.display) ? "flex" : "none";
321
+ # }""",
322
+ #)
323
  share_button.click(
324
  None,
325
  [],
 
329
  gr.HTML(
330
  """
331
  <div class="footer">
332
+ <p>Model by <a href="https://huggingface.co/stabilityai" style="text-decoration: underline;" target="_blank">StabilityAI</a> - backend running JAX on TPUs due to generous support of <a href="https://sites.research.google/trc/about/" style="text-decoration: underline;" target="_blank">Google TRC program</a> - Gradio Demo by 🤗 Hugging Face
333
  </p>
334
  </div>
335
  <div class="acknowledgments">
 
341
  """
342
  )
343
 
344
+ block.queue(concurrency_count=24, max_size=40).launch(max_threads=150)