import gradio as gr from diffusers import StableDiffusionPipeline import torch import io from PIL import Image import os from cryptography.fernet import Fernet from google.cloud import storage import pinecone import json import uuid import pandas as pd # decrypt Storage Cloud credentials fernet = Fernet(os.environ['DECRYPTION_KEY']) with open('cloud-storage.encrypted', 'rb') as fp: encrypted = fp.read() creds = json.loads(fernet.decrypt(encrypted).decode()) # then save creds to file with open('cloud-storage.json', 'w', encoding='utf-8') as fp: fp.write(json.dumps(creds, indent=4)) # connect to Cloud Storage os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'cloud-storage.json' storage_client = storage.Client() bucket = storage_client.get_bucket('hf-diffusion-images') # get api key for pinecone auth PINECONE_KEY = os.environ['PINECONE_KEY'] index_id = "hf-diffusion" # init connection to pinecone pinecone.init( api_key=PINECONE_KEY, environment="us-west1-gcp" ) if index_id not in pinecone.list_indexes(): raise ValueError(f"Index '{index_id}' not found") index = pinecone.Index(index_id) device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"Using '{device}' device...") # init all of the models and move them to a given GPU pipe = StableDiffusionPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", use_auth_token=os.environ['HF_AUTH'] ) pipe.to(device) missing_im = Image.open('missing.png') threshold = 0.85 def encode_text(text: str): text_inputs = pipe.tokenizer( text, return_tensors='pt' ).to(device) text_embeds = pipe.text_encoder(**text_inputs) text_embeds = text_embeds.pooler_output.cpu().tolist()[0] return text_embeds def prompt_query(text: str): print(f"Running prompt_query('{text}')") embeds = encode_text(text) try: print("Try query pinecone") xc = index.query(embeds, top_k=30, include_metadata=True) print("query successful") except Exception as e: print(f"Error during query: {e}") # reinitialize connection print("Try reinitialize Pinecone connection") pinecone.init(api_key=PINECONE_KEY, environment='us-west1-gcp') index2 = pinecone.Index(index_id) try: print("Now try querying pinecone again") xc = index2.query(embeds, top_k=30, include_metadata=True) print("query successful") except Exception as e: raise ValueError(e) prompts = [ match['metadata']['prompt'] for match in xc['matches'] ] scores = [round(match['score'], 2) for match in xc['matches']] # deduplicate while preserving order df = pd.DataFrame({'Similarity': scores, 'Prompt': prompts}) df = df.drop_duplicates(subset='Prompt', keep='first') df = df[df['Prompt'].str.len() > 7].head() return df def diffuse(text: str): # diffuse out = pipe(text) if any(out.nsfw_content_detected): return {} else: _id = str(uuid.uuid4()) # add image to Cloud Storage im = out.images[0] im.save(f'{_id}.png', format='png') added_gcp = False # push to storage try: print("try push to Cloud Storage") blob = bucket.blob(f'images/{_id}.png') print("try upload_from_filename") blob.upload_from_filename(f'{_id}.png') added_gcp = True # add embedding and metadata to Pinecone embeds = encode_text(text) meta = { 'prompt': text, 'image_url': f'images/{_id}.png' } try: print("now try upsert to pinecone") index.upsert([(_id, embeds, meta)]) print("upsert successful") except Exception as e: try: print("hit exception, now trying to reinit Pinecone connection") pinecone.init(api_key=PINECONE_KEY, environment='us-west1-gcp') index2 = pinecone.Index(index_id) print(f"reconnected to pinecone '{index_id}' index") index2.upsert([(_id, embeds, meta)]) print("upsert successful") except Exception as e: print(f"PINECONE_ERROR: {e}") except Exception as e: print(f"ERROR: New image not uploaded due to error with {'Pinecone' if added_gcp else 'Cloud Storage'}") # delete local file os.remove(f'{_id}.png') return out.images[0] def get_image(url: str): blob = bucket.blob(url).download_as_string() blob_bytes = io.BytesIO(blob) im = Image.open(blob_bytes) return im def test_image(_id, image): try: image.save('tmp.png') return True except OSError: # delete corrupted file from pinecone and cloud index.delete(ids=[_id]) bucket.blob(f"images/{_id}.png").delete() print(f"DELETED '{_id}'") return False def prompt_image(text: str): print(f"prompt_image('{text}')") embeds = encode_text(text) try: print("try query pinecone") xc = index.query(embeds, top_k=9, include_metadata=True) except Exception as e: print(f"Error during query: {e}") # reinitialize connection pinecone.init(api_key=PINECONE_KEY, environment='us-west1-gcp') index2 = pinecone.Index(index_id) try: print("try query pinecone after reinit") xc = index2.query(embeds, top_k=9, include_metadata=True) except Exception as e: raise ValueError(e) image_urls = [ match['metadata']['image_url'] for match in xc['matches'] ] scores = [match['score'] for match in xc['matches']] ids = [match['id'] for match in xc['matches']] images = [] print("Begin looping through (ids, image_urls)") for _id, image_url in zip(ids, image_urls): try: print("download_as_string from GCP") blob = bucket.blob(image_url).download_as_string() print("downloaded successfully") blob_bytes = io.BytesIO(blob) im = Image.open(blob_bytes) print("image opened successfully") if test_image(_id, im): images.append(im) print("image accessible") else: images.append(missing_im) print("image NOT accessible") except ValueError: print(f"ValueError: '{image_url}'") return images, scores # __APP FUNCTIONS__ def set_suggestion(text: str): return gr.TextArea.update(value=text[0]) def set_images(text: str): images, scores = prompt_image(text) match_found = False for score in scores: if score > threshold: match_found = True if match_found: print("MATCH FOUND") return gr.Gallery.update(value=images) else: print("NO MATCH FOUND") diffuse(text) print(f"diffusion for '{text}' complete") images, scores = prompt_image(text) return gr.Gallery.update(value=images) # __CREATE APP__ demo = gr.Blocks() with demo: gr.Markdown( """ # Dream Cacher """ ) with gr.Row(): with gr.Column(): prompt = gr.TextArea( value="A person surfing", placeholder="Enter a prompt to dream about", interactive=True ) search = gr.Button(value="Search!") suggestions = gr.Dataframe( values=[], headers=['Similarity', 'Prompt'] ) # event listener for change in prompt prompt.change( prompt_query, prompt, suggestions, show_progress=False ) # results column with gr.Column(): pics = gr.Gallery() pics.style(grid=3) # search event listening try: search.click(set_images, prompt, pics) except OSError: print("OSError") demo.launch()