niji-playground / app.py
nyanko7's picture
Update app.py
db94543
import gradio as gr
import numpy as np
import time
import requests
import json
import os
import random
import tempfile
import logging
import threading
import asyncio
from PIL import Image
from io import BytesIO
from requests.adapters import HTTPAdapter
from urllib3.util import Retry
odnapi = os.getenv("odnapi_url")
fetapi = os.getenv("fetapi_url")
auth_token = os.getenv("auth_token")
# Setup a logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def split_image(img):
width, height = img.size
width_cut = width // 2
height_cut = height // 2
return [
img.crop((0, 0, width_cut, height_cut)),
img.crop((width_cut, 0, width, height_cut)),
img.crop((0, height_cut, width_cut, height)),
img.crop((width_cut, height_cut, width, height))
]
def save_image(img, suffix='.png'):
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
img.save(tmp, 'PNG')
return tmp.name
async def niji_api(prompt, progress=gr.Progress(), max_retries=5, backoff_factor=0.1):
iters = 1
progress(iters/32, desc="Sending request to MidJourney Server")
session = requests.Session() # Using Session to reuse the underlying TCP connection
retries = Retry(total=max_retries, backoff_factor=backoff_factor, status_forcelist=[429, 500, 502, 503, 504])
adapter = HTTPAdapter(max_retries=retries)
session.mount("http://", adapter)
session.mount("https://", adapter)
try:
response = session.post(
fetapi,
headers={'Content-Type': 'application/json'},
data=json.dumps({'msg': prompt}),
timeout=5.0 # Here, the timeout duration is set to 5 seconds
)
response.raise_for_status() # Check for HTTP errors.
except requests.exceptions.RequestException as e:
logger.error(f"Failed to make POST request")
raise ValueError("Invalid Response")
data = response.json()
message_id = data['messageId']
prog = 0
iters += 5
progress(iters/48, desc="Waiting in the generate queue")
def fetch_image(url):
try:
response = session.get(url, timeout=5.0)
return Image.open(BytesIO(response.content))
except requests.exceptions.RequestException as e:
logger.error(f"Failed to fetch image")
return None
def download_and_split_image(url):
try:
img = fetch_image(url)
images = split_image(img)
return [save_image(i) for i in images]
except Exception:
pass
while prog < 100:
try:
response = session.get(
f'{odnapi}/message/{message_id}?expireMins=2',
headers={'Authorization': auth_token},
timeout=5.0
)
response.raise_for_status()
except requests.exceptions.RequestException as e:
logger.warning(f"Failure in getting message response")
continue
data = response.json()
prog = data.get('progress', 0)
if progress_image_url := data.get('progressImageUrl'):
iters = -100
yield [(img, f"{prog}% done") for img in download_and_split_image(progress_image_url)]
wait_time = random.uniform(1, 2)
await asyncio.sleep(wait_time)
r = iters/48
if r < 0.4:
desc = "Waiting in the generate queue"
elif r < 0.6:
desc = "Still queueing"
elif r < 0.8:
desc = "Almost done"
if iters > 0:
progress(r, desc=desc)
iters += random.uniform(1, 2)
# Process the final image urls
image_url = data['response']['imageUrl']
yield [(img, f"image {idx+1}/4") for idx, img in enumerate(download_and_split_image(image_url))]
with gr.Blocks() as demo:
gr.HTML('''
<div style="text-align: center; max-width: 650px; margin: 0 auto;">
<div style="
display: inline-flex;
gap: 0.8rem;
font-size: 1.75rem;
justify-content: center;
margin-bottom: 10px;
">
<h1 style="font-weight: 900; align-items: center; margin-bottom: 7px; margin-top: 20px;">
MidJourney / NijiJourney Playground 🎨
</h1>
</div>
<div>
<p style="align-items: center; margin-bottom: 7px;">
Demo for the <a href="https://MidJourney.com/" target="_blank">MidJourney</a>, draw with heart and love.
</div>
</div>
''')
with gr.Column(variant="panel"):
with gr.Row():
text = gr.Textbox(
label="Enter your prompt",
value="1girl,long hair,looking at viewer,kawaii,serafuku --s 250 --niji 5",
max_lines=3,
container=False,
)
btn = gr.Button("Generate image", scale=0)
gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", height="4096")
btn.click(niji_api, text, gallery)
demo.queue(concurrency_count=2)
demo.launch(debug=True)