File size: 5,044 Bytes
7f467ad
 
 
 
 
 
afda65a
7f467ad
 
e62c6f3
75ad521
7f467ad
 
acdf7f3
 
7f467ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
acdf7f3
 
 
afbf111
 
 
 
 
 
1c6e281
afbf111
 
 
 
 
 
1c6e281
 
 
 
 
 
acdf7f3
 
 
 
 
7f467ad
afbf111
acdf7f3
 
 
6a9c613
acdf7f3
 
436f541
 
 
 
 
db94543
acdf7f3
 
 
afbf111
 
 
 
 
1c6e281
7f467ad
1c6e281
 
7f467ad
acdf7f3
 
 
436f541
acdf7f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c6e281
 
 
7f467ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec63780
7f467ad
 
 
 
 
 
 
 
1bbf5ca
7f467ad
 
 
 
 
 
962c36e
108a5ec
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import gradio as gr
import numpy as np
import time
import requests
import json
import os
import random
import tempfile
import logging
import threading
import asyncio
from PIL import Image
from io import BytesIO
from requests.adapters import HTTPAdapter
from urllib3.util import Retry

odnapi = os.getenv("odnapi_url")
fetapi = os.getenv("fetapi_url")
auth_token = os.getenv("auth_token")

# Setup a logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def split_image(img):
    width, height = img.size
    width_cut = width // 2
    height_cut = height // 2
    return [
        img.crop((0, 0, width_cut, height_cut)),
        img.crop((width_cut, 0, width, height_cut)),
        img.crop((0, height_cut, width_cut, height)),
        img.crop((width_cut, height_cut, width, height))
    ]

def save_image(img, suffix='.png'):
    with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
        img.save(tmp, 'PNG')
        return tmp.name
    
async def niji_api(prompt, progress=gr.Progress(), max_retries=5, backoff_factor=0.1):
    iters = 1
    progress(iters/32, desc="Sending request to MidJourney Server")
    session = requests.Session()  # Using Session to reuse the underlying TCP connection
    retries = Retry(total=max_retries, backoff_factor=backoff_factor, status_forcelist=[429, 500, 502, 503, 504])
    adapter = HTTPAdapter(max_retries=retries)
    session.mount("http://", adapter)
    session.mount("https://", adapter)
    try:
        response = session.post(
            fetapi,
            headers={'Content-Type': 'application/json'},
            data=json.dumps({'msg': prompt}),
            timeout=5.0  # Here, the timeout duration is set to 5 seconds
        )
        response.raise_for_status()  # Check for HTTP errors.
    except requests.exceptions.RequestException as e:
        logger.error(f"Failed to make POST request")
        raise ValueError("Invalid Response") 
    data = response.json()
    message_id = data['messageId']
    prog = 0
    iters += 5
    progress(iters/48, desc="Waiting in the generate queue")

    def fetch_image(url):
        try:
            response = session.get(url, timeout=5.0)
            return Image.open(BytesIO(response.content))
        except requests.exceptions.RequestException as e:
            logger.error(f"Failed to fetch image")
            return None

    def download_and_split_image(url):
        try:
            img = fetch_image(url)
            images = split_image(img)
            return [save_image(i) for i in images]
        except Exception:
            pass

    while prog < 100:
        try:
            response = session.get(
                f'{odnapi}/message/{message_id}?expireMins=2', 
                headers={'Authorization': auth_token},
                timeout=5.0
            )
            response.raise_for_status()
        except requests.exceptions.RequestException as e:
            logger.warning(f"Failure in getting message response")
            continue
        data = response.json()
        prog = data.get('progress', 0)
        if progress_image_url := data.get('progressImageUrl'):
            iters = -100
            yield [(img, f"{prog}% done") for img in download_and_split_image(progress_image_url)]

        wait_time = random.uniform(1, 2) 
        await asyncio.sleep(wait_time)

        r = iters/48
        if r < 0.4:
          desc = "Waiting in the generate queue"
        elif r < 0.6:
          desc = "Still queueing"
        elif r < 0.8:
          desc = "Almost done"
        if iters > 0:
            progress(r, desc=desc)
            iters += random.uniform(1, 2)

    # Process the final image urls
    image_url = data['response']['imageUrl']
    yield [(img, f"image {idx+1}/4") for idx, img in enumerate(download_and_split_image(image_url))]

with gr.Blocks() as demo:
    gr.HTML('''
<div style="text-align: center; max-width: 650px; margin: 0 auto;">
  <div style="
        display: inline-flex;
        gap: 0.8rem;
        font-size: 1.75rem;
        justify-content: center;
        margin-bottom: 10px;
      ">
    <h1 style="font-weight: 900; align-items: center; margin-bottom: 7px; margin-top: 20px;">
      MidJourney / NijiJourney Playground 🎨
    </h1>
  </div>
  <div>
    <p style="align-items: center; margin-bottom: 7px;">
      Demo for the <a href="https://MidJourney.com/" target="_blank">MidJourney</a>, draw with heart and love.
  </div>
</div>
    ''')  
    with gr.Column(variant="panel"):
        with gr.Row():
            text = gr.Textbox(
                label="Enter your prompt",
                value="1girl,long hair,looking at viewer,kawaii,serafuku --s 250 --niji 5",
                max_lines=3,
                container=False,
            )
            btn = gr.Button("Generate image", scale=0)
        gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", height="4096")
    btn.click(niji_api, text, gallery)

demo.queue(concurrency_count=2)
demo.launch(debug=True)