import socketio import requests import json import time import random import base64 import io from PIL import Image from io import BytesIO from requests_toolbelt.multipart.encoder import MultipartEncoder from constant import * def login(email, password): payload = {'password': password} if email: payload['email'] = email response = requests.post(f"{BASE_URL}/user/login", json=payload) response_data = response.json() if 'error' in response_data and response_data['error']: raise Exception(response_data['error']) print("Login successful") user_uuid = response_data['user_uuid'] token = response_data['token'] return user_uuid, token def rodin_history(task_uuid, token): headers = { 'Authorization': f'Bearer {token}' } response = requests.post(f"{BASE_URL}/task/rodin_history", data={"uuid": task_uuid}, headers=headers) return response.json() def rodin_preprocess_image(generate_prompt, image, name, token): m = MultipartEncoder( fields={ 'generate_prompt': "true" if generate_prompt else "false", 'images': (name, image, 'image/jpeg') } ) headers = { 'Content-Type': m.content_type, 'Authorization': f'Bearer {token}' } response = requests.post(f"{BASE_URL}/task/rodin_mesh_image_process", data=m, headers=headers) return response.json() def crop_image(image, type): new_image_width = 360 * (11520 // 720) # 每隔720像素裁切一次,每次裁切宽度为360 new_image_height = 360 # 新图片的高度 new_image = Image.new('RGB', (new_image_width, new_image_height)) for i in range(11520 // 720): left = i * 720 + type[1] upper = type[0] right = left + 360 lower = upper + 360 cropped_image = image.crop((left, upper, right, lower)) new_image.paste(cropped_image, (i * 360, 0)) return new_image # Perform Rodin mesh operation def rodin_mesh(prompt, group_uuid, settings, images, name, token): images = [convert_base64_to_binary(img) for img in images] m = MultipartEncoder( fields={ 'prompt': prompt, 'group_uuid': group_uuid, 'settings': json.dumps(settings), # Convert settings dictionary to JSON string **{f'images': (name, image, 'image/jpeg') for i, image in enumerate(images)} } ) headers = { 'Content-Type': m.content_type, 'Authorization': f'Bearer {token}' } response = requests.post(f"{BASE_URL}/task/rodin_mesh", data=m, headers=headers) return response.json() # Convert base64 to binary since the result from `rodin_preprocess_image` is encoded with base64 def convert_base64_to_binary(base64_string): if ',' in base64_string: base64_string = base64_string.split(',')[1] image_data = base64.b64decode(base64_string) image_buffer = io.BytesIO(image_data) return image_buffer def rodin_update(prompt, task_uuid, token, settings): headers = { 'Authorization': f'Bearer {token}' } response = requests.post(f"{BASE_URL}/task/rodin_update", data={"uuid": task_uuid, "prompt": prompt, "settings": settings}, headers=headers) return response.json() def load_image(img_path): image = Image.open(img_path) # 按比例缩小图像到长度为1024 width, height = image.size if width > height: scale = 1024 / width else: scale = 1024 / height new_width = int(width * scale) new_height = int(height * scale) resized_image = image.resize((new_width, new_height)) # 将 PIL.Image 对象转换为字节流 byte_io = BytesIO() resized_image.save(byte_io, format='PNG') image_bytes = byte_io.getvalue() return image_bytes class Generator: def __init__(self, user_id, password) -> None: _, self.token = login(user_id, password) self.task_uuid = None def preprocess(self, prompt, image_path, cache_image_base64, task_uuid=""): if cache_image_base64 and prompt and task_uuid != "": return prompt, cache_image_base64 print("Preprocessing image...") image_file = load_image(image_path) if prompt and task_uuid: preprocess_response = rodin_preprocess_image(generate_prompt=False, image=image_file, name="images.png", token=self.token) else: preprocess_response = rodin_preprocess_image(generate_prompt=True, image=image_file, name="images.png", token=self.token) if 'error' in preprocess_response: print("Error in image preprocessing:", preprocess_response['error']) else: if not (prompt and task_uuid): prompt = preprocess_response.get('prompt', 'Default prompt if none returned') processed_image = "data:image/png;base64," + preprocess_response.get('processed_image', None) return prompt, processed_image def generate_mesh(self, prompt, processed_image, task_uuid=""): print("Generating mesh...") if task_uuid == "": settings = {'view_weights': [1]} # Define weights as per your requirements, for multiple images, use multiple values, e,g [0.5, 0.5] images = [processed_image] # List of images, all the images should be processed first mesh_response = rodin_mesh(prompt=prompt, group_uuid=None, settings=settings, images=images, name="images.jpeg", token=self.token) progress_checker = JobStatusChecker(BASE_URL, mesh_response['job']['subscription_key']) try: progress_checker.start() except Exception as e: print(f"Error in generating mesh: {e}") time.sleep(5) task_uuid = mesh_response['uuid'] # The task_uuid should be same during whole generation process else: new_prompt = prompt settings = { "view_weights": [1], "seed": random.randint(0, 10000), # Customize your seed here "escore": 5.5, # Temprature } update_response = rodin_update(new_prompt, task_uuid, self.token, settings) # Check progress subscription_key = update_response['job']['subscription_key'] checker = JobStatusChecker(BASE_URL, subscription_key) try: checker.start() except Exception as e: print(f"Error in updating mesh: {e}") time.sleep(5) preview_image = next(reversed(rodin_history(task_uuid, self.token).items()))[1]["preview_image"] # print(f"Preview image URL: {rodin_history(task_uuid, self.token)}") response = requests.get(preview_image, stream=True) if response.status_code == 200: # 创建一个PIL Image对象 image = Image.open(response.raw) # 在这里对image对象进行处理,如显示、保存等 else: print(f"Can't get the preview image. Status code:{response.status_code}") raise RuntimeError response.close() return image, task_uuid, crop_image(image, DEFAULT) class JobStatusChecker: def __init__(self, base_url, subscription_key): self.base_url = base_url self.subscription_key = subscription_key self.sio = socketio.Client(logger=True, engineio_logger=True) @self.sio.event def connect(): print("Connected to the server.") @self.sio.event def disconnect(): print("Disconnected from server.") @self.sio.on('message', namespace='*') def message(*args, **kwargs): if len(args) > 2: data = args[2] if data.get('jobStatus') == 'Succeeded': print("Job Succeeded! Please find the SDF image in history") self.sio.disconnect() else: print("Received event with insufficient arguments.") def start(self): self.sio.connect(f"{self.base_url}/scheduler_socket?subscription={self.subscription_key}", namespaces=['/api/scheduler_socket'], transports='websocket') self.sio.wait()