Spaces:
Paused
Paused
File size: 5,044 Bytes
7f467ad afda65a 7f467ad e62c6f3 75ad521 7f467ad acdf7f3 7f467ad acdf7f3 afbf111 1c6e281 afbf111 1c6e281 acdf7f3 7f467ad afbf111 acdf7f3 6a9c613 acdf7f3 436f541 db94543 acdf7f3 afbf111 1c6e281 7f467ad 1c6e281 7f467ad acdf7f3 436f541 acdf7f3 1c6e281 7f467ad ec63780 7f467ad 1bbf5ca 7f467ad 962c36e 108a5ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import gradio as gr
import numpy as np
import time
import requests
import json
import os
import random
import tempfile
import logging
import threading
import asyncio
from PIL import Image
from io import BytesIO
from requests.adapters import HTTPAdapter
from urllib3.util import Retry
odnapi = os.getenv("odnapi_url")
fetapi = os.getenv("fetapi_url")
auth_token = os.getenv("auth_token")
# Setup a logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def split_image(img):
width, height = img.size
width_cut = width // 2
height_cut = height // 2
return [
img.crop((0, 0, width_cut, height_cut)),
img.crop((width_cut, 0, width, height_cut)),
img.crop((0, height_cut, width_cut, height)),
img.crop((width_cut, height_cut, width, height))
]
def save_image(img, suffix='.png'):
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
img.save(tmp, 'PNG')
return tmp.name
async def niji_api(prompt, progress=gr.Progress(), max_retries=5, backoff_factor=0.1):
iters = 1
progress(iters/32, desc="Sending request to MidJourney Server")
session = requests.Session() # Using Session to reuse the underlying TCP connection
retries = Retry(total=max_retries, backoff_factor=backoff_factor, status_forcelist=[429, 500, 502, 503, 504])
adapter = HTTPAdapter(max_retries=retries)
session.mount("http://", adapter)
session.mount("https://", adapter)
try:
response = session.post(
fetapi,
headers={'Content-Type': 'application/json'},
data=json.dumps({'msg': prompt}),
timeout=5.0 # Here, the timeout duration is set to 5 seconds
)
response.raise_for_status() # Check for HTTP errors.
except requests.exceptions.RequestException as e:
logger.error(f"Failed to make POST request")
raise ValueError("Invalid Response")
data = response.json()
message_id = data['messageId']
prog = 0
iters += 5
progress(iters/48, desc="Waiting in the generate queue")
def fetch_image(url):
try:
response = session.get(url, timeout=5.0)
return Image.open(BytesIO(response.content))
except requests.exceptions.RequestException as e:
logger.error(f"Failed to fetch image")
return None
def download_and_split_image(url):
try:
img = fetch_image(url)
images = split_image(img)
return [save_image(i) for i in images]
except Exception:
pass
while prog < 100:
try:
response = session.get(
f'{odnapi}/message/{message_id}?expireMins=2',
headers={'Authorization': auth_token},
timeout=5.0
)
response.raise_for_status()
except requests.exceptions.RequestException as e:
logger.warning(f"Failure in getting message response")
continue
data = response.json()
prog = data.get('progress', 0)
if progress_image_url := data.get('progressImageUrl'):
iters = -100
yield [(img, f"{prog}% done") for img in download_and_split_image(progress_image_url)]
wait_time = random.uniform(1, 2)
await asyncio.sleep(wait_time)
r = iters/48
if r < 0.4:
desc = "Waiting in the generate queue"
elif r < 0.6:
desc = "Still queueing"
elif r < 0.8:
desc = "Almost done"
if iters > 0:
progress(r, desc=desc)
iters += random.uniform(1, 2)
# Process the final image urls
image_url = data['response']['imageUrl']
yield [(img, f"image {idx+1}/4") for idx, img in enumerate(download_and_split_image(image_url))]
with gr.Blocks() as demo:
gr.HTML('''
<div style="text-align: center; max-width: 650px; margin: 0 auto;">
<div style="
display: inline-flex;
gap: 0.8rem;
font-size: 1.75rem;
justify-content: center;
margin-bottom: 10px;
">
<h1 style="font-weight: 900; align-items: center; margin-bottom: 7px; margin-top: 20px;">
MidJourney / NijiJourney Playground 🎨
</h1>
</div>
<div>
<p style="align-items: center; margin-bottom: 7px;">
Demo for the <a href="https://MidJourney.com/" target="_blank">MidJourney</a>, draw with heart and love.
</div>
</div>
''')
with gr.Column(variant="panel"):
with gr.Row():
text = gr.Textbox(
label="Enter your prompt",
value="1girl,long hair,looking at viewer,kawaii,serafuku --s 250 --niji 5",
max_lines=3,
container=False,
)
btn = gr.Button("Generate image", scale=0)
gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", height="4096")
btn.click(niji_api, text, gallery)
demo.queue(concurrency_count=2)
demo.launch(debug=True) |