valhalla commited on
Commit
47c5fb9
1 Parent(s): bcfe90f

Create server.py

Browse files
Files changed (1) hide show
  1. server.py +76 -0
server.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import base64
4
+ from io import BytesIO
5
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
6
+
7
+ from fastapi import FastAPI
8
+ import numpy as np
9
+ from PIL import Image
10
+
11
+ import clip
12
+ from dalle.models import Dalle
13
+ from dalle.utils.utils import clip_score, download
14
+
15
+ print("Loading models...")
16
+ app = FastAPI()
17
+
18
+
19
+ url = "https://arena.kakaocdn.net/brainrepo/models/minDALL-E/57b008f02ceaa02b779c8b7463143315/1.3B.tar.gz"
20
+ root = os.path.expanduser("~/.cache/minDALLE")
21
+ filename = os.path.basename(url)
22
+ pathname = filename[: -len(".tar.gz")]
23
+ download_target = os.path.join(root, filename)
24
+ result_path = os.path.join(root, pathname)
25
+ if not os.path.exists(result_path):
26
+ result_path = download(url, root)
27
+
28
+
29
+ device = "cpu"
30
+ model = Dalle.from_pretrained(result_path) # This will automatically download the pretrained model.
31
+ model.to(device=device)
32
+ model_clip, preprocess_clip = clip.load("ViT-B/32", device=device)
33
+ model_clip.to(device=device)
34
+
35
+ print("Models loaded !")
36
+
37
+
38
+ @app.get("/")
39
+ def read_root():
40
+ return {"minDALL-E!"}
41
+
42
+
43
+ @app.get("/{generate}")
44
+ def generate(prompt):
45
+ images = sample(prompt)
46
+ images = [to_base64(image) for image in images]
47
+ return {"images": images}
48
+
49
+
50
+ def sample(prompt):
51
+ # Sampling
52
+ images = (
53
+ model.sampling(prompt=prompt, top_k=256, top_p=None, softmax_temperature=1.0, num_candidates=3, device=device)
54
+ .cpu()
55
+ .numpy()
56
+ )
57
+ images = np.transpose(images, (0, 2, 3, 1))
58
+
59
+ # CLIP Re-ranking
60
+ rank = clip_score(
61
+ prompt=prompt, images=images, model_clip=model_clip, preprocess_clip=preprocess_clip, device=device
62
+ )
63
+ images = images[rank]
64
+
65
+ pil_images = []
66
+ for i in range(len(images)):
67
+ im = Image.fromarray((images[i] * 255).astype(np.uint8))
68
+ pil_images.append(im)
69
+
70
+ return pil_images
71
+
72
+
73
+ def to_base64(pil_image):
74
+ buffered = BytesIO()
75
+ pil_image.save(buffered, format="JPEG")
76
+ return base64.b64encode(buffered.getvalue())