import gradio as gr import numpy as np import time import requests import json import os import random import tempfile import logging import threading import asyncio from PIL import Image from io import BytesIO from requests.adapters import HTTPAdapter from urllib3.util import Retry odnapi = os.getenv("odnapi_url") fetapi = os.getenv("fetapi_url") auth_token = os.getenv("auth_token") # Setup a logger logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def split_image(img): width, height = img.size width_cut = width // 2 height_cut = height // 2 return [ img.crop((0, 0, width_cut, height_cut)), img.crop((width_cut, 0, width, height_cut)), img.crop((0, height_cut, width_cut, height)), img.crop((width_cut, height_cut, width, height)) ] def save_image(img, suffix='.png'): with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: img.save(tmp, 'PNG') return tmp.name async def niji_api(prompt, progress=gr.Progress(), max_retries=5, backoff_factor=0.1): iters = 1 progress(iters/32, desc="Sending request to MidJourney Server") session = requests.Session() # Using Session to reuse the underlying TCP connection retries = Retry(total=max_retries, backoff_factor=backoff_factor, status_forcelist=[429, 500, 502, 503, 504]) adapter = HTTPAdapter(max_retries=retries) session.mount("http://", adapter) session.mount("https://", adapter) try: response = session.post( fetapi, headers={'Content-Type': 'application/json'}, data=json.dumps({'msg': prompt}), timeout=5.0 # Here, the timeout duration is set to 5 seconds ) response.raise_for_status() # Check for HTTP errors. except requests.exceptions.RequestException as e: logger.error(f"Failed to make POST request") raise ValueError("Invalid Response") data = response.json() message_id = data['messageId'] prog = 0 iters += 5 progress(iters/48, desc="Waiting in the generate queue") def fetch_image(url): try: response = session.get(url, timeout=5.0) return Image.open(BytesIO(response.content)) except requests.exceptions.RequestException as e: logger.error(f"Failed to fetch image") return None def download_and_split_image(url): try: img = fetch_image(url) images = split_image(img) return [save_image(i) for i in images] except Exception: pass while prog < 100: try: response = session.get( f'{odnapi}/message/{message_id}?expireMins=2', headers={'Authorization': auth_token}, timeout=5.0 ) response.raise_for_status() except requests.exceptions.RequestException as e: logger.warning(f"Failure in getting message response") continue data = response.json() prog = data.get('progress', 0) if progress_image_url := data.get('progressImageUrl'): iters = -100 yield [(img, f"{prog}% done") for img in download_and_split_image(progress_image_url)] wait_time = random.uniform(1, 2) await asyncio.sleep(wait_time) r = iters/48 if r < 0.4: desc = "Waiting in the generate queue" elif r < 0.6: desc = "Still queueing" elif r < 0.8: desc = "Almost done" if iters > 0: progress(r, desc=desc) iters += random.uniform(1, 2) # Process the final image urls image_url = data['response']['imageUrl'] yield [(img, f"image {idx+1}/4") for idx, img in enumerate(download_and_split_image(image_url))] with gr.Blocks() as demo: gr.HTML('''

MidJourney / NijiJourney Playground 🎨

Demo for the MidJourney, draw with heart and love.

''') with gr.Column(variant="panel"): with gr.Row(): text = gr.Textbox( label="Enter your prompt", value="1girl,long hair,looking at viewer,kawaii,serafuku --s 250 --niji 5", max_lines=3, container=False, ) btn = gr.Button("Generate image", scale=0) gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", height="4096") btn.click(niji_api, text, gallery) demo.queue(concurrency_count=2) demo.launch(debug=True)