Spaces:
Paused
Paused
import gradio as gr | |
import numpy as np | |
import time | |
import requests | |
import json | |
import os | |
import random | |
import tempfile | |
import logging | |
import threading | |
import asyncio | |
from PIL import Image | |
from io import BytesIO | |
from requests.adapters import HTTPAdapter | |
from urllib3.util import Retry | |
odnapi = os.getenv("odnapi_url") | |
fetapi = os.getenv("fetapi_url") | |
auth_token = os.getenv("auth_token") | |
# Setup a logger | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
def split_image(img): | |
width, height = img.size | |
width_cut = width // 2 | |
height_cut = height // 2 | |
return [ | |
img.crop((0, 0, width_cut, height_cut)), | |
img.crop((width_cut, 0, width, height_cut)), | |
img.crop((0, height_cut, width_cut, height)), | |
img.crop((width_cut, height_cut, width, height)) | |
] | |
def save_image(img, suffix='.png'): | |
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: | |
img.save(tmp, 'PNG') | |
return tmp.name | |
async def niji_api(prompt, progress=gr.Progress(), max_retries=5, backoff_factor=0.1): | |
iters = 1 | |
progress(iters/32, desc="Sending request to MidJourney Server") | |
session = requests.Session() # Using Session to reuse the underlying TCP connection | |
retries = Retry(total=max_retries, backoff_factor=backoff_factor, status_forcelist=[429, 500, 502, 503, 504]) | |
adapter = HTTPAdapter(max_retries=retries) | |
session.mount("http://", adapter) | |
session.mount("https://", adapter) | |
try: | |
response = session.post( | |
fetapi, | |
headers={'Content-Type': 'application/json'}, | |
data=json.dumps({'msg': prompt}), | |
timeout=5.0 # Here, the timeout duration is set to 5 seconds | |
) | |
response.raise_for_status() # Check for HTTP errors. | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Failed to make POST request") | |
raise ValueError("Invalid Response") | |
data = response.json() | |
message_id = data['messageId'] | |
prog = 0 | |
iters += 5 | |
progress(iters/48, desc="Waiting in the generate queue") | |
def fetch_image(url): | |
try: | |
response = session.get(url, timeout=5.0) | |
return Image.open(BytesIO(response.content)) | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Failed to fetch image") | |
return None | |
def download_and_split_image(url): | |
try: | |
img = fetch_image(url) | |
images = split_image(img) | |
return [save_image(i) for i in images] | |
except Exception: | |
pass | |
while prog < 100: | |
try: | |
response = session.get( | |
f'{odnapi}/message/{message_id}?expireMins=2', | |
headers={'Authorization': auth_token}, | |
timeout=5.0 | |
) | |
response.raise_for_status() | |
except requests.exceptions.RequestException as e: | |
logger.warning(f"Failure in getting message response") | |
continue | |
data = response.json() | |
prog = data.get('progress', 0) | |
if progress_image_url := data.get('progressImageUrl'): | |
iters = -100 | |
yield [(img, f"{prog}% done") for img in download_and_split_image(progress_image_url)] | |
wait_time = random.uniform(1, 2) | |
await asyncio.sleep(wait_time) | |
r = iters/48 | |
if r < 0.4: | |
desc = "Waiting in the generate queue" | |
elif r < 0.6: | |
desc = "Still queueing" | |
elif r < 0.8: | |
desc = "Almost done" | |
if iters > 0: | |
progress(r, desc=desc) | |
iters += random.uniform(1, 2) | |
# Process the final image urls | |
image_url = data['response']['imageUrl'] | |
yield [(img, f"image {idx+1}/4") for idx, img in enumerate(download_and_split_image(image_url))] | |
with gr.Blocks() as demo: | |
gr.HTML(''' | |
<div style="text-align: center; max-width: 650px; margin: 0 auto;"> | |
<div style=" | |
display: inline-flex; | |
gap: 0.8rem; | |
font-size: 1.75rem; | |
justify-content: center; | |
margin-bottom: 10px; | |
"> | |
<h1 style="font-weight: 900; align-items: center; margin-bottom: 7px; margin-top: 20px;"> | |
MidJourney / NijiJourney Playground 🎨 | |
</h1> | |
</div> | |
<div> | |
<p style="align-items: center; margin-bottom: 7px;"> | |
Demo for the <a href="https://MidJourney.com/" target="_blank">MidJourney</a>, draw with heart and love. | |
</div> | |
</div> | |
''') | |
with gr.Column(variant="panel"): | |
with gr.Row(): | |
text = gr.Textbox( | |
label="Enter your prompt", | |
value="1girl,long hair,looking at viewer,kawaii,serafuku --s 250 --niji 5", | |
max_lines=3, | |
container=False, | |
) | |
btn = gr.Button("Generate image", scale=0) | |
gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", height="4096") | |
btn.click(niji_api, text, gallery) | |
demo.queue(concurrency_count=2) | |
demo.launch(debug=True) |