Spaces:
Runtime error
Runtime error
File size: 3,150 Bytes
4d72a29 0ea4415 4d72a29 b442155 0ea4415 4d72a29 091b9da 4d72a29 091b9da 4d72a29 0ea4415 4d72a29 0ea4415 4d72a29 0ea4415 4d72a29 0ea4415 b442155 4d72a29 0ea4415 a200d93 405665c b442155 4d72a29 b442155 0ea4415 749bafa 0ea4415 4d72a29 0ea4415 4d72a29 0ea4415 749bafa c55fdff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import base64
import os
import time
from io import BytesIO
from multiprocessing import Process
import streamlit as st
from PIL import Image
import requests
def start_server():
os.system("uvicorn server:app --port 8080 --host 0.0.0.0 --workers 2")
def load_models():
if not is_port_in_use(8080):
with st.spinner(text="Loading models, please wait..."):
proc = Process(target=start_server, args=(), daemon=True)
proc.start()
while not is_port_in_use(8080):
time.sleep(1)
st.success("Model server started.")
else:
st.success("Model server already running...")
st.session_state["models_loaded"] = True
def is_port_in_use(port):
import socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(("0.0.0.0", port)) == 0
def generate(prompt):
correct_request = f"http://0.0.0.0:8080/correct?prompt={prompt}"
response = requests.get(correct_request)
images = response.json()["images"]
images = [Image.open(BytesIO(base64.b64decode(img))) for img in images]
return images
if "models_loaded" not in st.session_state:
st.session_state["models_loaded"] = False
st.header("minDALL-E")
#st.subheader("Generate images from text")
st.write("Generate images from text: Interactive demo for [minDALL-E](https://github.com/kakaobrain/minDALL-E)")
if not st.session_state["models_loaded"]:
load_models()
prompt = st.text_input("What do you want to see?")
DEBUG = False
# UI code taken from https://huggingface.co/spaces/flax-community/dalle-mini/blob/main/app/streamlit/app.py
if prompt != "":
container = st.empty()
container.markdown(
f"""
<style> p {{ margin:0 }} div {{ margin:0 }} </style>
<div data-stale="false" class="element-container css-1e5imcs e1tzin5v1">
<div class="stAlert">
<div role="alert" data-baseweb="notification" class="st-ae st-af st-ag st-ah st-ai st-aj st-ak st-g3 st-am st-b8 st-ao st-ap st-aq st-ar st-as st-at st-au st-av st-aw st-ax st-ay st-az st-b9 st-b1 st-b2 st-b3 st-b4 st-b5 st-b6">
<div class="st-b7">
<div class="css-whx05o e13vu3m50">
<div data-testid="stMarkdownContainer" class="css-1ekf893 e16nr0p30">
<img src="https://raw.githubusercontent.com/borisdayma/dalle-mini/main/app/streamlit/img/loading.gif" width="30"/>
Generating predictions for: <b>{prompt}</b>
</div>
</div>
</div>
</div>
</div>
</div>
""",
unsafe_allow_html=True,
)
print(f"Getting selections: {prompt}")
selected = generate(prompt)
margin = 0.1 # for better position of zoom in arrow
n_columns = 3
cols = st.columns([1] + [margin, 1] * (n_columns - 1))
for i, img in enumerate(selected):
cols[(i % n_columns) * 2].image(img)
container.markdown(f"**{prompt}**")
st.button("Again!", key="again_button")
st.write(f"<b><i>UI credits: <a href='https://huggingface.co/spaces/flax-community/dalle-mini'>DALL-E mini Space</a></i></b>", unsafe_allow_html=True)
|