import socketio import requests import json import time import random import base64 import io import spaces import PIL from PIL import Image from io import BytesIO import gradio as gr from requests_toolbelt.multipart.encoder import MultipartEncoder from constant import * @spaces.GPU def foo(): pass def login(email, password): payload = {'password': password} if email: payload['email'] = email response = requests.post(f"{BASE_URL}/user/login", json=payload) try: response_data = response.json() except json.JSONDecodeError as e: log("ERROR", f"Error in login: {response}") raise e if 'error' in response_data and response_data['error']: raise Exception(response_data['error']) log("INFO", f"Logged successfully") user_uuid = response_data['user_uuid'] token = response_data['token'] return user_uuid, token def rodin_history(task_uuid, token): headers = { 'Authorization': f'Bearer {token}' } response = requests.post(f"{BASE_URL}/task/rodin_history", data={"uuid": task_uuid}, headers=headers) return response.json() def rodin_preprocess_image(generate_prompt, image, name, token): m = MultipartEncoder( fields={ 'generate_prompt': "true" if generate_prompt else "false", 'images': (name, image, 'image/jpeg') } ) headers = { 'Content-Type': m.content_type, 'Authorization': f'Bearer {token}' } response = requests.post(f"{BASE_URL}/task/rodin_mesh_image_process", data=m, headers=headers) return response def crop_image(image, type): if image == None: raise gr.Error("Please generate the object first") new_image_width = 360 * (11520 // 720) # 每隔720像素裁切一次,每次裁切宽度为360 new_image_height = 360 # 新图片的高度 new_image = Image.new('RGB', (new_image_width, new_image_height)) for i in range(11520 // 720): left = i * 720 + type[1] upper = type[0] right = left + 360 lower = upper + 360 cropped_image = image.crop((left, upper, right, lower)) new_image.paste(cropped_image, (i * 360, 0)) return new_image # Perform Rodin mesh operation def rodin_mesh(prompt, group_uuid, settings, images, name, token): images = [convert_base64_to_binary(img) for img in images] m = MultipartEncoder( fields={ 'prompt': prompt, 'group_uuid': group_uuid, 'settings': json.dumps(settings), # Convert settings dictionary to JSON string **{f'images': (name, image, 'image/jpeg') for i, image in enumerate(images)} } ) headers = { 'Content-Type': m.content_type, 'Authorization': f'Bearer {token}' } response = requests.post(f"{BASE_URL}/task/rodin_mesh", data=m, headers=headers) return response # Convert base64 to binary since the result from `rodin_preprocess_image` is encoded with base64 def convert_base64_to_binary(base64_string): if ',' in base64_string: base64_string = base64_string.split(',')[1] image_data = base64.b64decode(base64_string) image_buffer = io.BytesIO(image_data) return image_buffer def rodin_update(prompt, task_uuid, token, settings): headers = { 'Authorization': f'Bearer {token}' } response = requests.post(f"{BASE_URL}/task/rodin_update", data={"uuid": task_uuid, "prompt": prompt, "settings": settings}, headers=headers) return response def load_image(img_path): try: image = Image.open(img_path) except PIL.UnidentifiedImageError as e: raise gr.Error("Unsupported Image Format") # 按比例缩小图像到长度为1024 width, height = image.size if width > height: scale = 512 / width else: scale = 512 / height new_width = int(width * scale) new_height = int(height * scale) resized_image = image.resize((new_width, new_height)) # 将 PIL.Image 对象转换为字节流 byte_io = BytesIO() resized_image.save(byte_io, format='PNG') image_bytes = byte_io.getvalue() return image_bytes def log(level, info_text): print(f"[ {level} ] - {time.strftime('%Y%m%d_%H:%M:%S', time.localtime())} - {info_text}") class Generator: def __init__(self, user_id, password, token) -> None: # _, self.token = login(user_id, password) self.token = token self.user_id = user_id self.password = password self.task_uuid = None self.processed_image = None def preprocess(self, prompt, image_path, processed_image , task_uuid=""): if image_path == None: raise gr.Error("Please upload an image first") if processed_image and prompt and (not task_uuid): log("INFO", "Using cached image and prompt...") return prompt, processed_image log("INFO", "Preprocessing image...") success = False try_times = 0 while not success: if try_times > 3: raise gr.Error("Failed to preprocess image") try_times += 1 image_file = load_image(image_path) log("INFO", "Image loaded, processing...") try: if prompt and task_uuid: res = rodin_preprocess_image(generate_prompt=False, image=image_file, name=os.path.basename(image_path), token=self.token) else: res = rodin_preprocess_image(generate_prompt=True, image=image_file, name=os.path.basename(image_path), token=self.token) preprocess_response = res.json() log("INFO", f"Image preprocessed: {preprocess_response.get('statusCode')}") except Exception as e: log("ERROR", f"Error in image preprocessing: {res}") raise gr.Error("Error in image preprocessing, please try again.") if 'error' in preprocess_response: log("ERROR", f"Error in image preprocessing: {preprocess_response}") raise gr.Error("Error in image preprocessing, please try again.") elif preprocess_response.get("statusCode") == 400: if "InvalidFile.Content" in preprocess_response.get("message"): raise gr.Error("Unsupported Image Format") else: log("ERROR", f"Error in image preprocessing: {preprocess_response}") raise gr.Error("Busy connection, please try again later.") elif preprocess_response.get("statusCode") == 401: log("WARNING", "Token expired. Logging in again...") _, self.token = login(self.user_id, self.password) continue else: try: if not (prompt and task_uuid): prompt = preprocess_response.get('prompt', None) processed_image = "data:image/png;base64," + preprocess_response.get('processed_image', None) success = True except Exception as e: log("ERROR", f"Error in image preprocessing: {preprocess_response}") raise gr.Error("Busy connection, please try again later.") log("INFO", "Image preprocessed successfully") return prompt, processed_image def generate_mesh(self, prompt, processed_image, task_uuid=""): log("INFO", "Generating mesh...") if task_uuid == "": settings = {'view_weights': [1]} # Define weights as per your requirements, for multiple images, use multiple values, e,g [0.5, 0.5] images = [processed_image] # List of images, all the images should be processed first res = rodin_mesh(prompt=prompt, group_uuid=None, settings=settings, images=images, name="images.jpeg", token=self.token) try: mesh_response = res.json() progress_checker = JobStatusChecker(BASE_URL, mesh_response['job']['subscription_key']) progress_checker.start() except Exception as e: log("ERROR", f"Error in generating mesh: {e} and response: {res}") raise gr.Error("Error in generating mesh, please try again later.") task_uuid = mesh_response['uuid'] # The task_uuid should be same during whole generation process else: new_prompt = prompt settings = { "view_weights": [1], "seed": random.randint(0, 10000), # Customize your seed here "escore": 5.5, # Temprature } res = rodin_update(new_prompt, task_uuid, self.token, settings) try: update_response = res.json() subscription_key = update_response['job']['subscription_key'] checker = JobStatusChecker(BASE_URL, subscription_key) checker.start() except Exception as e: log("ERROR", f"Error in updating mesh: {e}") raise gr.Error("Error in generating mesh, please try again later.") try: history = rodin_history(task_uuid, self.token) preview_image = next(reversed(history.items()))[1]["preview_image"] except Exception as e: log("ERROR", f"Error in generating mesh: {history}") raise gr.Error("Busy connection, please try again later.") response = requests.get(preview_image, stream=True) if response.status_code == 200: image = Image.open(response.raw) else: log("ERROR", f"Error in generating mesh: {response}") raise RuntimeError response.close() return image, task_uuid, crop_image(image, DEFAULT) class JobStatusChecker: def __init__(self, base_url, subscription_key): self.base_url = base_url self.subscription_key = subscription_key self.sio = socketio.Client(logger=True, engineio_logger=True) @self.sio.event def connect(): print("Connected to the server.") @self.sio.event def disconnect(): print("Disconnected from server.") @self.sio.on('message', namespace='*') def message(*args, **kwargs): if len(args) > 2: data = args[2] if data.get('jobStatus') == 'Succeeded': print("Job Succeeded! Please find the SDF image in history") self.sio.disconnect() else: print("Received event with insufficient arguments.") def start(self): self.sio.connect(f"{self.base_url}/scheduler_socket?subscription={self.subscription_key}", namespaces=['/api/scheduler_socket'], transports='websocket') self.sio.wait()