File size: 8,184 Bytes
628dd10
99caaea
 
 
 
 
2335e48
99caaea
 
83798fc
f873689
 
99caaea
2335e48
 
 
 
 
 
 
 
 
99caaea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a908ef9
 
99caaea
 
 
1a67055
99caaea
 
 
4a10f8f
f873689
4a10f8f
99caaea
 
 
 
 
 
 
 
 
ea8a830
99caaea
23aa474
ea8a830
23aa474
ea8a830
23aa474
 
 
ea8a830
23aa474
4918b4b
23aa474
ea8a830
4918b4b
ea8a830
23aa474
 
99caaea
 
 
f873689
99caaea
f873689
2a6b233
 
f873689
99caaea
4a10f8f
 
 
 
 
 
 
 
 
 
49bbcba
4a10f8f
a414223
ea8a830
a414223
ea8a830
a414223
49bbcba
 
 
 
 
 
 
 
ea8a830
49bbcba
ea8a830
49bbcba
 
ea8a830
49bbcba
 
ea8a830
49bbcba
ea8a830
49bbcba
 
a414223
49bbcba
4a10f8f
 
 
 
99caaea
 
 
 
 
 
4a10f8f
 
 
 
 
 
 
 
 
 
 
99caaea
ea8a830
99caaea
23aa474
ea8a830
23aa474
 
 
 
 
4918b4b
23aa474
ea8a830
4918b4b
23aa474
 
99caaea
 
 
4a10f8f
 
99caaea
ea8a830
4a10f8f
99caaea
ea8a830
99caaea
ea8a830
99caaea
 
ea8a830
4a10f8f
 
ea8a830
4a10f8f
 
ea8a830
99caaea
4a10f8f
 
99caaea
 
 
 
 
 
 
4a10f8f
 
 
f873689
4a10f8f
 
 
 
 
 
 
ea8a830
4a10f8f
 
99caaea
 
 
 
 
 
 
 
 
 
 
 
 
f682c4d
99caaea
 
 
 
f873689
 
 
99caaea
 
2a6b233
 
 
 
99caaea
 
 
 
 
 
4a10f8f
 
 
 
99caaea
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
import gradio as gr
from diffusers import StableDiffusionPipeline
import torch
import io
from PIL import Image
import os
from cryptography.fernet import Fernet
from google.cloud import storage
import pinecone
import json
import uuid
import pandas as pd

# decrypt Storage Cloud credentials
fernet = Fernet(os.environ['DECRYPTION_KEY'])

with open('cloud-storage.encrypted', 'rb') as fp:
    encrypted = fp.read()
    creds = json.loads(fernet.decrypt(encrypted).decode())
# then save creds to file
with open('cloud-storage.json', 'w', encoding='utf-8') as fp:
    fp.write(json.dumps(creds, indent=4))
# connect to Cloud Storage
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'cloud-storage.json'
storage_client = storage.Client()
bucket = storage_client.get_bucket('hf-diffusion-images')
    
# get api key for pinecone auth
PINECONE_KEY = os.environ['PINECONE_KEY']

index_id = "hf-diffusion"

# init connection to pinecone
pinecone.init(
    api_key=PINECONE_KEY,
    environment="us-west1-gcp"
)
if index_id not in pinecone.list_indexes():
    raise ValueError(f"Index '{index_id}' not found")

index = pinecone.Index(index_id)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using '{device}' device...")

# init all of the models and move them to a given GPU
pipe = StableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4", use_auth_token=os.environ['HF_AUTH']
)
pipe.to(device)

missing_im = Image.open('missing.png')
threshold = 0.85

def encode_text(text: str):
    text_inputs = pipe.tokenizer(
        text, return_tensors='pt'
    ).to(device)
    text_embeds = pipe.text_encoder(**text_inputs)
    text_embeds = text_embeds.pooler_output.cpu().tolist()[0]
    return text_embeds

def prompt_query(text: str):
    print(f"Running prompt_query('{text}')")
    embeds = encode_text(text)
    try:
        print("Try query pinecone")
        xc = index.query(embeds, top_k=30, include_metadata=True)
        print("query successful")
    except Exception as e:
        print(f"Error during query: {e}")
        # reinitialize connection
        print("Try reinitialize Pinecone connection")
        pinecone.init(api_key=PINECONE_KEY, environment='us-west1-gcp')
        index2 = pinecone.Index(index_id)
        try:
            print("Now try querying pinecone again")
            xc = index2.query(embeds, top_k=30, include_metadata=True)
            print("query successful")
        except Exception as e:
            raise ValueError(e)
    prompts = [
        match['metadata']['prompt'] for match in xc['matches']
    ]
    scores = [round(match['score'], 2) for match in xc['matches']]
    # deduplicate while preserving order
    df = pd.DataFrame({'Similarity': scores, 'Prompt': prompts})
    df = df.drop_duplicates(subset='Prompt', keep='first')
    df = df[df['Prompt'].str.len() > 7].head()
    return df

def diffuse(text: str):
    # diffuse
    out = pipe(text)
    if any(out.nsfw_content_detected):
        return {}
    else:
        _id = str(uuid.uuid4())
        # add image to Cloud Storage
        im = out.images[0]
        im.save(f'{_id}.png', format='png')
        added_gcp = False
        # push to storage
        try:
            print("try push to Cloud Storage")
            blob = bucket.blob(f'images/{_id}.png')
            print("try upload_from_filename")
            blob.upload_from_filename(f'{_id}.png')
            added_gcp = True
            # add embedding and metadata to Pinecone
            embeds = encode_text(text)
            meta = {
                'prompt': text,
                'image_url': f'images/{_id}.png'
            }
            try:
                print("now try upsert to pinecone")
                index.upsert([(_id, embeds, meta)])
                print("upsert successful")
            except Exception as e:
                try:
                    print("hit exception, now trying to reinit Pinecone connection")
                    pinecone.init(api_key=PINECONE_KEY, environment='us-west1-gcp')
                    index2 = pinecone.Index(index_id)
                    print(f"reconnected to pinecone '{index_id}' index")
                    index2.upsert([(_id, embeds, meta)])
                    print("upsert successful")
                except Exception as e:
                    print(f"PINECONE_ERROR: {e}")
        except Exception as e:
            print(f"ERROR: New image not uploaded due to error with {'Pinecone' if added_gcp else 'Cloud Storage'}")
        # delete local file
        os.remove(f'{_id}.png')
        return out.images[0]

def get_image(url: str):
    blob = bucket.blob(url).download_as_string()
    blob_bytes = io.BytesIO(blob)
    im = Image.open(blob_bytes)
    return im

def test_image(_id, image):
    try:
        image.save('tmp.png')
        return True
    except OSError:
        # delete corrupted file from pinecone and cloud
        index.delete(ids=[_id])
        bucket.blob(f"images/{_id}.png").delete()
        print(f"DELETED '{_id}'")
        return False

def prompt_image(text: str):
    print(f"prompt_image('{text}')")
    embeds = encode_text(text)
    try:
        print("try query pinecone")
        xc = index.query(embeds, top_k=9, include_metadata=True)
    except Exception as e:
        print(f"Error during query: {e}")
        # reinitialize connection
        pinecone.init(api_key=PINECONE_KEY, environment='us-west1-gcp')
        index2 = pinecone.Index(index_id)
        try:
            print("try query pinecone after reinit")
            xc = index2.query(embeds, top_k=9, include_metadata=True)
        except Exception as e:
            raise ValueError(e)
    image_urls = [
        match['metadata']['image_url'] for match in xc['matches']
    ]
    scores = [match['score'] for match in xc['matches']]
    ids = [match['id'] for match in xc['matches']]
    images = []
    print("Begin looping through (ids, image_urls)")
    for _id, image_url in zip(ids, image_urls):
        try:
            print("download_as_string from GCP")
            blob = bucket.blob(image_url).download_as_string()
            print("downloaded successfully")
            blob_bytes = io.BytesIO(blob)
            im = Image.open(blob_bytes)
            print("image opened successfully")
            if test_image(_id, im):
                images.append(im)
                print("image accessible")
            else:
                images.append(missing_im)
                print("image NOT accessible")
        except ValueError:
            print(f"ValueError: '{image_url}'")
    return images, scores

# __APP FUNCTIONS__

def set_suggestion(text: str):
    return gr.TextArea.update(value=text[0])

def set_images(text: str):
    images, scores = prompt_image(text)
    match_found = False
    for score in scores:
        if score > threshold:
            match_found = True
    if match_found:
        print("MATCH FOUND")
        return gr.Gallery.update(value=images)
    else:
        print("NO MATCH FOUND")
        diffuse(text)
        print(f"diffusion for '{text}' complete")
        images, scores = prompt_image(text)
        return gr.Gallery.update(value=images)

# __CREATE APP__
demo = gr.Blocks()

with demo:
    gr.Markdown(
        """
        # Dream Cacher
        """
    )
    with gr.Row():
        with gr.Column():
            prompt = gr.TextArea(
                value="A person surfing",
                placeholder="Enter a prompt to dream about",
                interactive=True
            )
            search = gr.Button(value="Search!")
            suggestions = gr.Dataframe(
                values=[],
                headers=['Similarity', 'Prompt']
            )
            # event listener for change in prompt
            prompt.change(
                prompt_query, prompt, suggestions,
                show_progress=False
            )
            
        # results column
        with gr.Column():
            pics = gr.Gallery()
            pics.style(grid=3)
            # search event listening
            try:
                search.click(set_images, prompt, pics)
            except OSError:
                print("OSError")

demo.launch()