minDALLE / app.py
valhalla's picture
Update app.py
0b22a40
raw
history blame
3.16 kB
import base64
import os
import time
from io import BytesIO
from multiprocessing import Process
import streamlit as st
from PIL import Image
import requests
def start_server():
os.system("uvicorn server:app --port 8080 --host 0.0.0.0 --workers 2")
def load_models():
if not is_port_in_use(8080):
with st.spinner(text="Loading models, please wait..."):
proc = Process(target=start_server, args=(), daemon=True)
proc.start()
while not is_port_in_use(8080):
time.sleep(1)
st.success("Model server started.")
else:
st.success("Model server already running...")
st.session_state["models_loaded"] = True
def is_port_in_use(port):
import socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(("0.0.0.0", port)) == 0
def generate(prompt):
correct_request = f"http://0.0.0.0:8080/correct?prompt={prompt}"
response = requests.get(correct_request)
images = response.json()["images"]
images = [Image.open(BytesIO(base64.b64decode(img))) for img in images]
return images
if "models_loaded" not in st.session_state:
st.session_state["models_loaded"] = False
st.header("minDALL-E")
#st.subheader("Generate images from text")
st.write("Generate images from text: Interactive demo for [minDALL-E](https://github.com/kakaobrain/minDALL-E)")
if not st.session_state["models_loaded"]:
load_models()
prompt = st.text_input("What do you want to see?")
DEBUG = False
# UI code taken from https://huggingface.co/spaces/flax-community/dalle-mini/blob/main/app/streamlit/app.py
if prompt != "":
container = st.empty()
container.markdown(
f"""
<style> p {{ margin:0 }} div {{ margin:0 }} </style>
<div data-stale="false" class="element-container css-1e5imcs e1tzin5v1">
<div class="stAlert">
<div role="alert" data-baseweb="notification" class="st-ae st-af st-ag st-ah st-ai st-aj st-ak st-g3 st-am st-b8 st-ao st-ap st-aq st-ar st-as st-at st-au st-av st-aw st-ax st-ay st-az st-b9 st-b1 st-b2 st-b3 st-b4 st-b5 st-b6">
<div class="st-b7">
<div class="css-whx05o e13vu3m50">
<div data-testid="stMarkdownContainer" class="css-1ekf893 e16nr0p30">
<img src="https://raw.githubusercontent.com/borisdayma/dalle-mini/main/app/streamlit/img/loading.gif" width="30"/>
Generating predictions for: <b>{prompt}</b>
</div>
</div>
</div>
</div>
</div>
</div>
""",
unsafe_allow_html=True,
)
print(f"Getting selections: {prompt}")
selected = generate(prompt)
margin = 0.1 # for better position of zoom in arrow
n_columns = 3
cols = st.columns([1] + [margin, 1] * (n_columns - 1))
for i, img in enumerate(selected):
cols[(i % n_columns) * 2].image(img)
container.markdown(f"**{prompt}**")
st.button("Again!", key="again_button")
container.markdown(f"<b><i>UI credits: <a href='https://huggingface.co/spaces/flax-community/dalle-mini'>DALL-E mini Space</a></i></b>", unsafe_allow_html=True)