skkk commited on
Commit
b51ed7d
1 Parent(s): 9b1e7f2

Revert "test 600m"

Browse files

This reverts commit 893330bad43179b8a93788b03a11e16e5e39f1ed.

Rodin.py CHANGED
@@ -13,27 +13,30 @@ import gradio as gr
13
  from requests_toolbelt.multipart.encoder import MultipartEncoder
14
  from constant import *
15
 
 
 
 
 
16
  def login(email, password):
17
  payload = {'password': password}
18
  if email:
19
  payload['email'] = email
20
-
21
  response = requests.post(f"{BASE_URL}/user/login", json=payload)
22
  try:
23
  response_data = response.json()
24
  except json.JSONDecodeError as e:
25
  log("ERROR", f"Error in login: {response}")
26
  raise e
27
-
28
  if 'error' in response_data and response_data['error']:
29
  raise Exception(response_data['error'])
30
  log("INFO", f"Logged successfully")
31
  user_uuid = response_data['user_uuid']
32
  token = response_data['token']
33
-
34
  return user_uuid, token
35
 
36
-
37
  def rodin_history(task_uuid, token):
38
  headers = {
39
  'Authorization': f'Bearer {token}'
@@ -41,7 +44,6 @@ def rodin_history(task_uuid, token):
41
  response = requests.post(f"{BASE_URL}/task/rodin_history", data={"uuid": task_uuid}, headers=headers)
42
  return response.json()
43
 
44
-
45
  def rodin_preprocess_image(generate_prompt, image, name, token):
46
  m = MultipartEncoder(
47
  fields={
@@ -56,7 +58,6 @@ def rodin_preprocess_image(generate_prompt, image, name, token):
56
  response = requests.post(f"{BASE_URL}/task/rodin_mesh_image_process", data=m, headers=headers)
57
  return response
58
 
59
-
60
  def crop_image(image, type):
61
  if image == None:
62
  raise gr.Error("Please generate the object first")
@@ -77,7 +78,7 @@ def crop_image(image, type):
77
  # Perform Rodin mesh operation
78
  def rodin_mesh(prompt, group_uuid, settings, images, name, token):
79
  images = [convert_base64_to_binary(img) for img in images]
80
-
81
  m = MultipartEncoder(
82
  fields={
83
  'prompt': prompt,
@@ -99,13 +100,12 @@ def rodin_mesh(prompt, group_uuid, settings, images, name, token):
99
  def convert_base64_to_binary(base64_string):
100
  if ',' in base64_string:
101
  base64_string = base64_string.split(',')[1]
102
-
103
  image_data = base64.b64decode(base64_string)
104
  image_buffer = io.BytesIO(image_data)
105
-
106
  return image_buffer
107
 
108
-
109
  def rodin_update(prompt, task_uuid, token, settings):
110
  headers = {
111
  'Authorization': f'Bearer {token}'
@@ -113,7 +113,6 @@ def rodin_update(prompt, task_uuid, token, settings):
113
  response = requests.post(f"{BASE_URL}/task/rodin_update", data={"uuid": task_uuid, "prompt": prompt, "settings": settings}, headers=headers)
114
  return response
115
 
116
-
117
  def load_image(img_path):
118
  try:
119
  image = Image.open(img_path)
@@ -136,11 +135,9 @@ def load_image(img_path):
136
  image_bytes = byte_io.getvalue()
137
  return image_bytes
138
 
139
-
140
  def log(level, info_text):
141
  print(f"[ {level} ] - {time.strftime('%Y%m%d_%H:%M:%S', time.localtime())} - {info_text}")
142
 
143
-
144
  class Generator:
145
  def __init__(self, user_id, password, token) -> None:
146
  # _, self.token = login(user_id, password)
@@ -149,11 +146,11 @@ class Generator:
149
  self.password = password
150
  self.task_uuid = None
151
  self.processed_image = None
152
-
153
- def preprocess(self, prompt, image_path, processed_image, task_uuid=""):
154
- if image_path is None:
155
  raise gr.Error("Please upload an image first")
156
-
157
  if processed_image and prompt and (not task_uuid):
158
  log("INFO", "Using cached image and prompt...")
159
  return prompt, processed_image
@@ -163,10 +160,10 @@ class Generator:
163
  while not success:
164
  if try_times > 3:
165
  raise gr.Error("Failed to preprocess image")
166
- try_times += 1
167
  image_file = load_image(image_path)
168
  log("INFO", "Image loaded, processing...")
169
-
170
  try:
171
  if prompt and task_uuid:
172
  res = rodin_preprocess_image(generate_prompt=False, image=image_file, name=os.path.basename(image_path), token=self.token)
@@ -203,13 +200,13 @@ class Generator:
203
 
204
  log("INFO", "Image preprocessed successfully")
205
  return prompt, processed_image
206
-
207
  def generate_mesh(self, prompt, processed_image, task_uuid=""):
208
  log("INFO", "Generating mesh...")
209
  if task_uuid == "":
210
  settings = {'view_weights': [1]} # Define weights as per your requirements, for multiple images, use multiple values, e,g [0.5, 0.5]
211
  images = [processed_image] # List of images, all the images should be processed first
212
-
213
  res = rodin_mesh(prompt=prompt, group_uuid=None, settings=settings, images=images, name="images.jpeg", token=self.token)
214
  try:
215
  mesh_response = res.json()
@@ -218,14 +215,14 @@ class Generator:
218
  except Exception as e:
219
  log("ERROR", f"Error in generating mesh: {e} and response: {res}")
220
  raise gr.Error("Error in generating mesh, please try again later.")
221
-
222
- task_uuid = mesh_response['uuid'] # The task_uuid should be same during whole generation process
223
  else:
224
  new_prompt = prompt
225
  settings = {
226
  "view_weights": [1],
227
- "seed": random.randint(0, 10000), # Customize your seed here
228
- "escore": 5.5, # Temprature
229
  }
230
  res = rodin_update(new_prompt, task_uuid, self.token, settings)
231
  try:
@@ -243,7 +240,7 @@ class Generator:
243
  except Exception as e:
244
  log("ERROR", f"Error in generating mesh: {history}")
245
  raise gr.Error("Busy connection, please try again later.")
246
-
247
  response = requests.get(preview_image, stream=True)
248
  if response.status_code == 200:
249
  image = Image.open(response.raw)
@@ -260,32 +257,25 @@ class JobStatusChecker:
260
  self.subscription_key = subscription_key
261
  self.sio = socketio.Client(logger=True, engineio_logger=True)
262
 
263
- @self.sio.on('connect', namespace='*')
264
- def connect(*args):
265
- print("[ JobStatusChecker.connect ] Connected to the server.")
266
 
267
- @self.sio.on('disconnect', namespace='*')
268
- def disconnect(*args):
269
- print("[ JobStatusChecker.disconnect ] Disconnected from server.")
270
 
271
  @self.sio.on('message', namespace='*')
272
  def message(*args, **kwargs):
273
- print(f"""[ JobStatusChecker.message ] args = {args}""")
274
- safe_to_disconnect = False
275
  if len(args) > 2:
276
  data = args[2]
277
  if data.get('jobStatus') == 'Succeeded':
278
- safe_to_disconnect = True
279
- if args[1] == "SAFE_TO_DISCONNECT":
280
- safe_to_disconnect = True
281
-
282
- if safe_to_disconnect:
283
- print("[ JobStatusChecker.message ] Job Succeeded! Please find the SDF image in history")
284
- self.sio.disconnect()
285
  else:
286
- print(f"[ JobStatusChecker.message ] Received event with insufficient arguments. {args}")
287
 
288
  def start(self):
289
- self.sio.connect(f"{self.base_url}/scheduler_socket?subscription={self.subscription_key}",
290
  namespaces=['/api/scheduler_socket'], transports='websocket')
291
- self.sio.wait()
 
13
  from requests_toolbelt.multipart.encoder import MultipartEncoder
14
  from constant import *
15
 
16
+ @spaces.GPU
17
+ def foo():
18
+ pass
19
+
20
  def login(email, password):
21
  payload = {'password': password}
22
  if email:
23
  payload['email'] = email
24
+
25
  response = requests.post(f"{BASE_URL}/user/login", json=payload)
26
  try:
27
  response_data = response.json()
28
  except json.JSONDecodeError as e:
29
  log("ERROR", f"Error in login: {response}")
30
  raise e
31
+
32
  if 'error' in response_data and response_data['error']:
33
  raise Exception(response_data['error'])
34
  log("INFO", f"Logged successfully")
35
  user_uuid = response_data['user_uuid']
36
  token = response_data['token']
37
+
38
  return user_uuid, token
39
 
 
40
  def rodin_history(task_uuid, token):
41
  headers = {
42
  'Authorization': f'Bearer {token}'
 
44
  response = requests.post(f"{BASE_URL}/task/rodin_history", data={"uuid": task_uuid}, headers=headers)
45
  return response.json()
46
 
 
47
  def rodin_preprocess_image(generate_prompt, image, name, token):
48
  m = MultipartEncoder(
49
  fields={
 
58
  response = requests.post(f"{BASE_URL}/task/rodin_mesh_image_process", data=m, headers=headers)
59
  return response
60
 
 
61
  def crop_image(image, type):
62
  if image == None:
63
  raise gr.Error("Please generate the object first")
 
78
  # Perform Rodin mesh operation
79
  def rodin_mesh(prompt, group_uuid, settings, images, name, token):
80
  images = [convert_base64_to_binary(img) for img in images]
81
+
82
  m = MultipartEncoder(
83
  fields={
84
  'prompt': prompt,
 
100
  def convert_base64_to_binary(base64_string):
101
  if ',' in base64_string:
102
  base64_string = base64_string.split(',')[1]
103
+
104
  image_data = base64.b64decode(base64_string)
105
  image_buffer = io.BytesIO(image_data)
106
+
107
  return image_buffer
108
 
 
109
  def rodin_update(prompt, task_uuid, token, settings):
110
  headers = {
111
  'Authorization': f'Bearer {token}'
 
113
  response = requests.post(f"{BASE_URL}/task/rodin_update", data={"uuid": task_uuid, "prompt": prompt, "settings": settings}, headers=headers)
114
  return response
115
 
 
116
  def load_image(img_path):
117
  try:
118
  image = Image.open(img_path)
 
135
  image_bytes = byte_io.getvalue()
136
  return image_bytes
137
 
 
138
  def log(level, info_text):
139
  print(f"[ {level} ] - {time.strftime('%Y%m%d_%H:%M:%S', time.localtime())} - {info_text}")
140
 
 
141
  class Generator:
142
  def __init__(self, user_id, password, token) -> None:
143
  # _, self.token = login(user_id, password)
 
146
  self.password = password
147
  self.task_uuid = None
148
  self.processed_image = None
149
+
150
+ def preprocess(self, prompt, image_path, processed_image , task_uuid=""):
151
+ if image_path == None:
152
  raise gr.Error("Please upload an image first")
153
+
154
  if processed_image and prompt and (not task_uuid):
155
  log("INFO", "Using cached image and prompt...")
156
  return prompt, processed_image
 
160
  while not success:
161
  if try_times > 3:
162
  raise gr.Error("Failed to preprocess image")
163
+ try_times += 1
164
  image_file = load_image(image_path)
165
  log("INFO", "Image loaded, processing...")
166
+
167
  try:
168
  if prompt and task_uuid:
169
  res = rodin_preprocess_image(generate_prompt=False, image=image_file, name=os.path.basename(image_path), token=self.token)
 
200
 
201
  log("INFO", "Image preprocessed successfully")
202
  return prompt, processed_image
203
+
204
  def generate_mesh(self, prompt, processed_image, task_uuid=""):
205
  log("INFO", "Generating mesh...")
206
  if task_uuid == "":
207
  settings = {'view_weights': [1]} # Define weights as per your requirements, for multiple images, use multiple values, e,g [0.5, 0.5]
208
  images = [processed_image] # List of images, all the images should be processed first
209
+
210
  res = rodin_mesh(prompt=prompt, group_uuid=None, settings=settings, images=images, name="images.jpeg", token=self.token)
211
  try:
212
  mesh_response = res.json()
 
215
  except Exception as e:
216
  log("ERROR", f"Error in generating mesh: {e} and response: {res}")
217
  raise gr.Error("Error in generating mesh, please try again later.")
218
+
219
+ task_uuid = mesh_response['uuid'] # The task_uuid should be same during whole generation process
220
  else:
221
  new_prompt = prompt
222
  settings = {
223
  "view_weights": [1],
224
+ "seed": random.randint(0, 10000), # Customize your seed here
225
+ "escore": 5.5, # Temprature
226
  }
227
  res = rodin_update(new_prompt, task_uuid, self.token, settings)
228
  try:
 
240
  except Exception as e:
241
  log("ERROR", f"Error in generating mesh: {history}")
242
  raise gr.Error("Busy connection, please try again later.")
243
+
244
  response = requests.get(preview_image, stream=True)
245
  if response.status_code == 200:
246
  image = Image.open(response.raw)
 
257
  self.subscription_key = subscription_key
258
  self.sio = socketio.Client(logger=True, engineio_logger=True)
259
 
260
+ @self.sio.event
261
+ def connect():
262
+ print("Connected to the server.")
263
 
264
+ @self.sio.event
265
+ def disconnect():
266
+ print("Disconnected from server.")
267
 
268
  @self.sio.on('message', namespace='*')
269
  def message(*args, **kwargs):
 
 
270
  if len(args) > 2:
271
  data = args[2]
272
  if data.get('jobStatus') == 'Succeeded':
273
+ print("Job Succeeded! Please find the SDF image in history")
274
+ self.sio.disconnect()
 
 
 
 
 
275
  else:
276
+ print("Received event with insufficient arguments.")
277
 
278
  def start(self):
279
+ self.sio.connect(f"{self.base_url}/scheduler_socket?subscription={self.subscription_key}",
280
  namespaces=['/api/scheduler_socket'], transports='websocket')
281
+ self.sio.wait()
app.py CHANGED
@@ -1,21 +1,8 @@
1
  import os
2
- os.system('pip3 uninstall -y gradio_fake3d')
3
- os.system('pip3 install gradio_fake3d-0.0.3-py3-none-any.whl')
4
- os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
5
- os.environ["TORCH_LINEAR_FLATTEN_3D"] = "1"
6
-
7
- import cv2
8
- import time
9
- import numpy as np
10
- import torch
11
-
12
- from openclay.pipeline_openclay import OpenClayPipeline
13
- from openclay.models import ClayVAE, ClayLDM, ClayConditionNet
14
- from openclay.utils import process_image_square
15
- from transformers import Dinov2Model, BitImageProcessor, CLIPTextModel, CLIPTokenizer
16
 
17
  import gradio as gr
18
- import spaces
19
  import re
20
  from gradio_fake3d import Fake3D
21
  from PIL import Image
@@ -23,57 +10,6 @@ from Rodin import Generator, crop_image, log, convert_base64_to_binary
23
  from constant import *
24
 
25
  generator = Generator(USER, PASSWORD, TOKEN)
26
- os.makedirs(FOLDER_TEMP_MESH, exist_ok=True)
27
-
28
- device = torch.device("cuda")
29
-
30
- tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
31
- text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16).eval().requires_grad_(False).to(device)
32
-
33
- image_processor = BitImageProcessor.from_pretrained("facebook/dinov2-giant", torch_dtype=torch.float16)
34
- image_encoder = Dinov2Model.from_pretrained("facebook/dinov2-giant", torch_dtype=torch.float16).eval().requires_grad_(False).to(device)
35
-
36
- vae = ClayVAE.from_pretrained("DEEMOSTECH/CLAYV1_VAE", token=ACCESS_TOKEN, torch_dtype=torch.float16).eval().requires_grad_(False).to(device)
37
- ldm = ClayLDM.from_pretrained("DEEMOSTECH/CLAYV1_LDM_MEDIUM", token=ACCESS_TOKEN, torch_dtype=torch.float16).eval().requires_grad_(False).to(device)
38
- condition_net_image = ClayConditionNet.from_pretrained("DEEMOSTECH/CLAYV1_LDM_MEDIUM_CONDITION_IMAGE", token=ACCESS_TOKEN, torch_dtype=torch.float16).eval().requires_grad_(False).to(device)
39
- ldm.register_condition_net([condition_net_image])
40
-
41
- pipe = OpenClayPipeline(
42
- vae=vae,
43
- text_encoder=text_encoder,
44
- tokenizer=tokenizer,
45
- ldm=ldm,
46
- )
47
-
48
-
49
- @spaces.GPU
50
- def read_image(image, image_processor, image_encoder, size=224):
51
- render = process_image_square(image)
52
- render = cv2.resize(render, (size, size))
53
- image_pixel_values = image_processor(render, return_tensors="pt", do_rescale=False,
54
- do_resize=False, do_center_crop=False)["pixel_values"][0]
55
- image_embeds_patch = image_encoder(image_pixel_values.half().to(image_encoder.device)[None])["last_hidden_state"]
56
- return render, image_embeds_patch
57
-
58
- @spaces.GPU
59
- def local_inference(block_prompt, image_pil):
60
- image = np.array(image_pil)
61
- _, image_embeds_patch = read_image(image, image_processor, image_encoder)
62
- mesh = pipe(
63
- prompt=block_prompt, negative_prompt='fragmentation.',
64
- res=256,
65
- num_inference_steps=50,
66
- mini_batch=65**3,
67
- seed=42, num=1,
68
- condition_seq=[(image_embeds_patch, [1, 0])],
69
- )
70
-
71
- mesh_path = f"{FOLDER_TEMP_MESH}/{int(time.time())}.glb"
72
- os.makedirs(os.path.dirname(mesh_path), exist_ok=True)
73
- mesh.export(mesh_path)
74
-
75
- return mesh_path
76
-
77
 
78
  change_button_name = """
79
  function updateButton(input) {
@@ -83,14 +19,6 @@ function updateButton(input) {
83
  }
84
  """
85
 
86
- change_button_name_600 = """
87
- function updateButton(input) {
88
- var buttonGenerate = document.getElementById('button_generate_600');
89
- buttonGenerate.innerText = 'Redo';
90
- return '';
91
- }
92
- """
93
-
94
  change_button_name_to_generating = """
95
  function updateButton(input) {
96
  var buttonGenerate = document.getElementById('button_generate');
@@ -99,15 +27,6 @@ function updateButton(input) {
99
  }
100
  """
101
 
102
- change_button_name_to_generating_600 = """
103
- function updateButton(input) {
104
- var buttonGenerate = document.getElementById('button_generate_600');
105
- buttonGenerate.innerText = 'Generating...';
106
- return '';
107
- }
108
- """
109
-
110
-
111
  reset_button_name = """
112
  function updateButton(input) {
113
  var buttonGenerate = document.getElementById('button_generate');
@@ -116,15 +35,6 @@ function updateButton(input) {
116
  }
117
  """
118
 
119
- reset_button_name_600 = """
120
- function updateButton(input) {
121
- var buttonGenerate = document.getElementById('button_generate_600');
122
- buttonGenerate.innerText = 'Generate';
123
- return '';
124
- }
125
- """
126
-
127
-
128
  jump_to_rodin = """
129
  function redirectToGithub(input) {
130
  if (input.includes('OpenClay')) {
@@ -188,23 +98,18 @@ example = [
188
  ["assets/46.png"]
189
  ]
190
 
191
-
192
  def do_nothing(text):
193
  return ""
194
 
195
-
196
  def handle_selection(selection):
197
  return "Rodin Gen-1(0525)"
198
 
199
-
200
  def hint_in_prompt(hint, prompt):
201
  return re.search(fr"{hint[:-1]}", prompt) is not None
202
 
203
-
204
  def prompt_remove_hint(prompt, hint):
205
  return re.sub(fr"\s*{hint[:-1]}[\.,]*", "", prompt)
206
 
207
-
208
  def handle_hint_change(prompt: str, prompt_hint):
209
  prompt = prompt.strip()
210
  if prompt != "" and not prompt.endswith("."):
@@ -218,7 +123,6 @@ def handle_hint_change(prompt: str, prompt_hint):
218
  prompt = prompt.strip()
219
  return prompt
220
 
221
-
222
  def handle_prompt_change(prompt):
223
  hint_list = []
224
  for _, hint in PROMPT_HINT_LIST:
@@ -227,15 +131,6 @@ def handle_prompt_change(prompt):
227
 
228
  return hint_list
229
 
230
- def preprocessing(prompt, image_path, processed_image, task_uuid=""):
231
- prompt, image_base64 = generator.preprocess(prompt, image_path, processed_image, task_uuid)
232
- image_rgb = convert_base64_to_binary(image_base64)
233
- image_rgb = cv2.imdecode(np.frombuffer(image_rgb.getvalue(), np.uint8), -1)[...,[2,1,0,3]]
234
- image_rgb = Image.fromarray(image_rgb, 'RGBA')
235
- # image_rgb = cv2.resize(image_rgb, (256, 256), cv2.INTER_AREA)
236
-
237
- return prompt, image_base64, image_rgb
238
-
239
  def clear_task(task_input=None):
240
  """_summary_
241
  [cache_task_uuid, block_prompt, block_prompt_hint, fake3d, block_3d]
@@ -243,33 +138,26 @@ def clear_task(task_input=None):
243
  log("INFO", "Clearing task...")
244
  return "", "", "", [], "assets/white_image.png"
245
 
246
-
247
  def clear_task_id():
248
  return ""
249
 
250
-
251
  def return_render(image):
252
  image = Image.fromarray(image)
253
  return image, crop_image(image, DEFAULT)
254
 
255
-
256
  def crop_image_default(image):
257
  return crop_image(image, DEFAULT)
258
 
259
-
260
  def crop_image_metal(image):
261
  return crop_image(image, METAL)
262
 
263
-
264
  def crop_image_contrast(image):
265
  return crop_image(image, CONTRAST)
266
 
267
-
268
  def crop_image_normal(image):
269
  return crop_image(image, NORMAL)
270
 
271
-
272
- with gr.Blocks(css=css) as demo:
273
  gr.HTML(html_content)
274
 
275
  cache_task_uuid = gr.Text(value="", visible=False)
@@ -279,86 +167,66 @@ with gr.Blocks(css=css) as demo:
279
 
280
  with gr.Row():
281
  with gr.Column():
 
 
282
  with gr.Group():
283
- with gr.Row():
284
- block_image = gr.Image(
285
- label='Input',
286
- height=max_height,
287
- image_mode="RGBA",
288
- sources="upload",
289
- elem_id="elem_block_image",
290
- elem_classes="elem_imagebox",
291
- type="filepath"
292
- )
293
- block_image_masked = gr.Image(
294
- label='Preprocessed',
295
- height=max_height,
296
- elem_id="elem_block_image_crop",
297
- elem_classes="elem_imagebox",
298
- interactive=False,
299
- )
300
-
301
-
302
- block_prompt = gr.Textbox(
303
- value="",
304
- placeholder="Auto generated description of Image",
305
- lines=1,
306
- show_label=True,
307
- label="Prompt",
308
- )
309
-
310
- block_prompt_hint = gr.CheckboxGroup(value="Labels", choices=PROMPT_HINT_LIST, show_label=False)
311
-
312
- with gr.Column(elem_id="right_col"):
313
- with gr.Group(elem_id="right_col_group"):
314
- with gr.Row(elem_id="right_col_group_row"):
315
- with gr.Group(elem_id="right_col_group_row_gleft"):
316
- block_3d = gr.Model3D(
317
- value='./empty.obj',
318
- height=320,
319
- camera_position=(90 + 30, 90 - 15, 3),
320
- zoom_speed=0.2,
321
- pan_speed=0.3,
322
- label="3D Preview (OpenCLAY(600M))",
323
- elem_id="block_3d"
324
- )
325
-
326
- button_generate_600 = gr.Button(value="Generate", variant="primary", elem_id="button_generate_600")
327
-
328
- with gr.Group(elem_id="right_col_group_row_gright"):
329
- fake3d = Fake3D(interactive=False,
330
- # height=320,
331
- # width=320,
332
- label="3D Preview (Rodin Gen-1(0525))",
333
- elem_id="fake3d"
334
- )
335
-
336
- with gr.Row():
337
- button_generate = gr.Button(value="Generate", variant="primary", elem_id="button_generate")
338
- button_more = gr.Button(value="Download", variant="primary", link=rodin_url)
339
 
 
 
 
 
 
340
 
341
- block_example = gr.Examples(
342
- examples=example,
343
- fn=clear_task,
344
- inputs=[block_image],
345
- outputs=[cache_image_base64, cache_task_uuid, block_prompt, block_prompt_hint],
346
- run_on_click=True,
347
- cache_examples=True,
348
- label="Examples"
349
- )
350
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
 
352
  block_image.upload(
353
- fn=do_nothing,
354
- js=change_button_name_to_generating,
355
  inputs=[cacha_empty],
356
  outputs=[cacha_empty],
357
  queue=False
358
  ).success(
359
- fn=preprocessing,
360
- inputs=[block_prompt, block_image, cache_image_base64, cache_task_uuid],
361
- outputs=[block_prompt, cache_image_base64, block_image_masked],
362
  show_progress="minimal",
363
  queue=True
364
  ).success(
@@ -367,36 +235,36 @@ with gr.Blocks(css=css) as demo:
367
  outputs=[cache_raw_image, cache_task_uuid, fake3d],
368
  queue=True
369
  ).success(
370
- fn=do_nothing,
371
- js=change_button_name,
372
  inputs=[cacha_empty],
373
  outputs=[cacha_empty],
374
  queue=False
375
  )
376
-
377
  block_image.clear(
378
- fn=do_nothing,
379
- js=reset_button_name,
380
  inputs=[cacha_empty],
381
  outputs=[cacha_empty],
382
  queue=False
383
  ).then(
384
- fn=clear_task,
385
- outputs=[cache_image_base64, cache_task_uuid, block_prompt, block_prompt_hint, fake3d],
386
  show_progress="hidden",
387
  queue=False
388
  )
389
-
390
  button_generate.click(
391
- fn=do_nothing,
392
- js=change_button_name_to_generating,
393
  inputs=[cacha_empty],
394
  outputs=[cacha_empty],
395
  queue=False
396
  ).success(
397
- fn=preprocessing,
398
- inputs=[block_prompt, block_image, cache_image_base64, cache_task_uuid],
399
- outputs=[block_prompt, cache_image_base64, block_image_masked],
400
  show_progress="minimal",
401
  queue=True
402
  ).success(
@@ -405,46 +273,26 @@ with gr.Blocks(css=css) as demo:
405
  outputs=[cache_raw_image, cache_task_uuid, fake3d],
406
  queue=True
407
  ).then(
408
- fn=do_nothing,
409
- js=change_button_name,
410
- inputs=[cacha_empty],
411
- outputs=[cacha_empty],
412
- queue=False
413
- )
414
-
415
- button_generate_600.click(
416
- fn=do_nothing,
417
- js=change_button_name_to_generating_600,
418
- inputs=[cacha_empty],
419
- outputs=[cacha_empty],
420
- queue=False
421
- ).success(
422
- fn=preprocessing,
423
- inputs=[block_prompt, block_image, cache_image_base64, cache_task_uuid],
424
- outputs=[block_prompt, cache_image_base64, block_image_masked],
425
- show_progress="minimal",
426
- queue=True
427
- ).success(
428
- fn=local_inference,
429
- inputs=[block_prompt, block_image_masked],
430
- outputs=[block_3d],
431
- queue=True
432
- ).then(
433
- fn=do_nothing,
434
- js=change_button_name_600,
435
  inputs=[cacha_empty],
436
  outputs=[cacha_empty],
437
  queue=False
438
  )
439
-
 
 
 
 
 
440
  button_more.click()
441
-
442
  block_prompt_hint.input(
443
  fn=handle_hint_change, inputs=[block_prompt, block_prompt_hint], outputs=[block_prompt],
444
  show_progress="hidden",
445
  queue=False,
446
  )
447
-
448
  block_prompt.change(
449
  fn=handle_prompt_change,
450
  inputs=[block_prompt],
@@ -452,6 +300,8 @@ with gr.Blocks(css=css) as demo:
452
  trigger_mode="always_last",
453
  show_progress="hidden",
454
  )
 
 
455
 
456
 
457
  if __name__ == "__main__":
 
1
  import os
2
+ os.system('pip uninstall -y gradio_fake3d')
3
+ os.system('pip install gradio_fake3d-0.0.3-py3-none-any.whl')
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  import gradio as gr
 
6
  import re
7
  from gradio_fake3d import Fake3D
8
  from PIL import Image
 
10
  from constant import *
11
 
12
  generator = Generator(USER, PASSWORD, TOKEN)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  change_button_name = """
15
  function updateButton(input) {
 
19
  }
20
  """
21
 
 
 
 
 
 
 
 
 
22
  change_button_name_to_generating = """
23
  function updateButton(input) {
24
  var buttonGenerate = document.getElementById('button_generate');
 
27
  }
28
  """
29
 
 
 
 
 
 
 
 
 
 
30
  reset_button_name = """
31
  function updateButton(input) {
32
  var buttonGenerate = document.getElementById('button_generate');
 
35
  }
36
  """
37
 
 
 
 
 
 
 
 
 
 
38
  jump_to_rodin = """
39
  function redirectToGithub(input) {
40
  if (input.includes('OpenClay')) {
 
98
  ["assets/46.png"]
99
  ]
100
 
 
101
  def do_nothing(text):
102
  return ""
103
 
 
104
  def handle_selection(selection):
105
  return "Rodin Gen-1(0525)"
106
 
 
107
  def hint_in_prompt(hint, prompt):
108
  return re.search(fr"{hint[:-1]}", prompt) is not None
109
 
 
110
  def prompt_remove_hint(prompt, hint):
111
  return re.sub(fr"\s*{hint[:-1]}[\.,]*", "", prompt)
112
 
 
113
  def handle_hint_change(prompt: str, prompt_hint):
114
  prompt = prompt.strip()
115
  if prompt != "" and not prompt.endswith("."):
 
123
  prompt = prompt.strip()
124
  return prompt
125
 
 
126
  def handle_prompt_change(prompt):
127
  hint_list = []
128
  for _, hint in PROMPT_HINT_LIST:
 
131
 
132
  return hint_list
133
 
 
 
 
 
 
 
 
 
 
134
  def clear_task(task_input=None):
135
  """_summary_
136
  [cache_task_uuid, block_prompt, block_prompt_hint, fake3d, block_3d]
 
138
  log("INFO", "Clearing task...")
139
  return "", "", "", [], "assets/white_image.png"
140
 
 
141
  def clear_task_id():
142
  return ""
143
 
 
144
  def return_render(image):
145
  image = Image.fromarray(image)
146
  return image, crop_image(image, DEFAULT)
147
 
 
148
  def crop_image_default(image):
149
  return crop_image(image, DEFAULT)
150
 
 
151
  def crop_image_metal(image):
152
  return crop_image(image, METAL)
153
 
 
154
  def crop_image_contrast(image):
155
  return crop_image(image, CONTRAST)
156
 
 
157
  def crop_image_normal(image):
158
  return crop_image(image, NORMAL)
159
 
160
+ with gr.Blocks() as demo:
 
161
  gr.HTML(html_content)
162
 
163
  cache_task_uuid = gr.Text(value="", visible=False)
 
167
 
168
  with gr.Row():
169
  with gr.Column():
170
+ block_image = gr.Image(height=256, image_mode="RGB", sources="upload", elem_classes="elem_imageupload", type="filepath")
171
+ block_model_card = gr.Dropdown(choices=options, label="Model Card", value="Rodin Gen-1(0525)", interactive=True)
172
  with gr.Group():
173
+ block_prompt = gr.Textbox(
174
+ value="",
175
+ placeholder="Auto generated description of Image",
176
+ lines=1,
177
+ show_label=True,
178
+ label="Prompt",
179
+ )
180
+ block_prompt_hint = gr.CheckboxGroup(value="Labels", choices=PROMPT_HINT_LIST, show_label=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
+ with gr.Column():
183
+ with gr.Group():
184
+ fake3d = Fake3D(interactive=False, label="3D Preview")
185
+ with gr.Row():
186
+ button_generate = gr.Button(value="Generate", variant="primary", elem_id="button_generate")
187
 
188
+ with gr.Column(min_width=200, scale=20):
189
+ with gr.Row():
190
+ block_default = gr.Button("Default", min_width=0)
191
+ block_metal = gr.Button("Metal", min_width=0)
192
+ with gr.Row():
193
+ block_contrast = gr.Button("Contrast", min_width=0)
194
+ block_normal = gr.Button("Normal", min_width=0)
195
+
196
+ button_more = gr.Button(value="Download from Rodin", variant="primary", link=rodin_url)
197
+ gr.Markdown("""
198
+ **TIPS**:
199
+ 1. Upload an image to generate 3D geometry.
200
+ 2. Click Redo to regenerate the model.
201
+ 3. 4 buttons to switch the view.
202
+ 4. Swipe to rotate the model.
203
+ """)
204
+ cache_task_uuid = gr.Text(value="", visible=False)
205
+
206
+
207
+ cache_raw_image = gr.Image(visible=False, type="pil")
208
+ cacha_empty = gr.Text(visible=False)
209
+ cache_image_base64 = gr.Text(visible=False)
210
+ block_example = gr.Examples(
211
+ examples=example,
212
+ fn=clear_task,
213
+ inputs=[block_image],
214
+ outputs=[cache_image_base64, cache_task_uuid, block_prompt, block_prompt_hint, fake3d],
215
+ run_on_click=True,
216
+ cache_examples=True,
217
+ label="Examples"
218
+ )
219
 
220
  block_image.upload(
221
+ fn=do_nothing,
222
+ js=change_button_name_to_generating,
223
  inputs=[cacha_empty],
224
  outputs=[cacha_empty],
225
  queue=False
226
  ).success(
227
+ fn=generator.preprocess,
228
+ inputs=[block_prompt, block_image, cache_image_base64, cache_task_uuid],
229
+ outputs=[block_prompt, cache_image_base64],
230
  show_progress="minimal",
231
  queue=True
232
  ).success(
 
235
  outputs=[cache_raw_image, cache_task_uuid, fake3d],
236
  queue=True
237
  ).success(
238
+ fn=do_nothing,
239
+ js=change_button_name,
240
  inputs=[cacha_empty],
241
  outputs=[cacha_empty],
242
  queue=False
243
  )
244
+
245
  block_image.clear(
246
+ fn=do_nothing,
247
+ js=reset_button_name,
248
  inputs=[cacha_empty],
249
  outputs=[cacha_empty],
250
  queue=False
251
  ).then(
252
+ fn=clear_task,
253
+ outputs=[cache_image_base64, cache_task_uuid, block_prompt, block_prompt_hint, fake3d],
254
  show_progress="hidden",
255
  queue=False
256
  )
257
+
258
  button_generate.click(
259
+ fn=do_nothing,
260
+ js=change_button_name_to_generating,
261
  inputs=[cacha_empty],
262
  outputs=[cacha_empty],
263
  queue=False
264
  ).success(
265
+ fn=generator.preprocess,
266
+ inputs=[block_prompt, block_image, cache_image_base64, cache_task_uuid],
267
+ outputs=[block_prompt, cache_image_base64],
268
  show_progress="minimal",
269
  queue=True
270
  ).success(
 
273
  outputs=[cache_raw_image, cache_task_uuid, fake3d],
274
  queue=True
275
  ).then(
276
+ fn=do_nothing,
277
+ js=change_button_name,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  inputs=[cacha_empty],
279
  outputs=[cacha_empty],
280
  queue=False
281
  )
282
+
283
+ block_default.click(fn=crop_image_default, inputs=[cache_raw_image], outputs=fake3d, show_progress="minimal")
284
+ block_metal.click(fn=crop_image_metal, inputs=[cache_raw_image], outputs=fake3d, show_progress="minimal")
285
+ block_contrast.click(fn=crop_image_contrast, inputs=[cache_raw_image], outputs=fake3d, show_progress="minimal")
286
+ block_normal.click(fn=crop_image_normal, inputs=[cache_raw_image], outputs=fake3d, show_progress="minimal")
287
+
288
  button_more.click()
289
+
290
  block_prompt_hint.input(
291
  fn=handle_hint_change, inputs=[block_prompt, block_prompt_hint], outputs=[block_prompt],
292
  show_progress="hidden",
293
  queue=False,
294
  )
295
+
296
  block_prompt.change(
297
  fn=handle_prompt_change,
298
  inputs=[block_prompt],
 
300
  trigger_mode="always_last",
301
  show_progress="hidden",
302
  )
303
+
304
+ block_model_card.change(fn=handle_selection, inputs=[block_model_card], outputs=[block_model_card], show_progress="hidden", js=jump_to_rodin)
305
 
306
 
307
  if __name__ == "__main__":
constant.py CHANGED
@@ -8,9 +8,6 @@ USER = os.getenv("USER")
8
  PASSWORD = os.getenv("PASSWORD")
9
  TOKEN = os.getenv("TOKEN")
10
 
11
- ACCESS_TOKEN = os.getenv("ACCESS_TOKEN")
12
- FOLDER_TEMP_MESH = './tmp_mesh'
13
-
14
  DEFAULT = [0, 0]
15
  CONTRAST = [360, 0]
16
  METAL = [0, 360]
 
8
  PASSWORD = os.getenv("PASSWORD")
9
  TOKEN = os.getenv("TOKEN")
10
 
 
 
 
11
  DEFAULT = [0, 0]
12
  CONTRAST = [360, 0]
13
  METAL = [0, 360]
openclay/models/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from .condition import ClayConditionNet
2
- from .ldm import ClayLDM
3
- from .vae import ClayVAE
 
 
 
 
openclay/models/condition.py DELETED
@@ -1,102 +0,0 @@
1
- import copy
2
- import numpy as np
3
- import torch
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
- from diffusers import ModelMixin, ConfigMixin
7
-
8
- from diffusers.configuration_utils import register_to_config
9
- from ..modules.embedding import PointEmbed
10
- from ..modules.control_volume import ControlVolume
11
- from ..utils import get_center_position
12
-
13
- class ClayConditionNet(ModelMixin, ConfigMixin):
14
- @register_to_config
15
- def __init__(
16
- self,
17
- kv_dim,
18
- ldm_dim, ldm_heads, ldm_depth,
19
- on_volume=False, volume_input_dim=1, volume_block_dim=256, aggregation_method=None,
20
- on_point=False, point_number=8, point_leading_embed=False, additional_token_length=None,
21
- stage="parallel",
22
- ):
23
- super().__init__()
24
- self.multihead_attn_condition_list = nn.ModuleList([copy.deepcopy(
25
- nn.MultiheadAttention(ldm_dim, ldm_heads, dropout=0, batch_first=True, kdim=kv_dim, vdim=kv_dim)
26
- ) for i in range(ldm_depth)])
27
-
28
- if on_volume:
29
- self.condition_volume_point_embed = PointEmbed(dim=kv_dim)
30
- self.condition_volume_conv = ControlVolume(volume_dim=volume_input_dim,
31
- block_dim=volume_block_dim,
32
- condition_dim=kv_dim,
33
- time_embed_dim=ldm_dim,
34
- downsample_times=1,
35
- aggregation_method=aggregation_method)
36
-
37
- if on_point:
38
- self.condition_point_point_embed = PointEmbed(dim=kv_dim)
39
- if point_number == 0:
40
- self.condition_point_token = None
41
- else:
42
- self.condition_point_token = nn.Parameter(torch.randn(point_number, kv_dim))
43
-
44
- self.on_volume = on_volume
45
- self.on_point = on_point
46
-
47
- assert stage in {"parallel", "postfix"}
48
- self.stage = stage
49
-
50
- def preprocess_condition(self, condition, additional_dict, time_embed) -> torch.Tensor:
51
-
52
- if self.on_volume:
53
- condition_volume = condition
54
- Bor1, volume_dim, X, Y, Z = condition_volume.shape
55
- assert X == Y == Z == 16
56
- condition_volume = self.condition_volume_conv(condition_volume, time_embed) # [Bor1, condition_dim, X, Y, Z]
57
- condition_volume = condition_volume.reshape(-1, condition_volume.shape[1], 8**3).permute(0, 2, 1) # [Bor1, X*Y*Z, condition_dim]
58
-
59
- center_position = get_center_position(8)[None].to(time_embed).reshape(1, 8**3, 3) # [1, X*Y*Z, 3]
60
- condition_volume = condition_volume + self.condition_volume_point_embed(center_position) # [Bor1, X*Y*Z, condition_dim]
61
- condition = condition_volume
62
-
63
- if self.on_point:
64
- point = condition
65
- Bor1, M, _ = point.shape
66
-
67
- condition_point = self.condition_point_point_embed(point)
68
- if self.condition_point_token is not None:
69
- if self.config.point_leading_embed:
70
- condition_point = torch.cat([
71
- condition_point[:, :self.condition_point_token.shape[0]] + self.condition_point_token,
72
- condition_point[:, self.condition_point_token.shape[0]:]
73
- ], dim=1)
74
- else:
75
- condition_point = condition_point + self.condition_point_token * additional_dict.get("condition_point_token_scale", 1)
76
-
77
- condition = condition_point
78
-
79
- return condition
80
-
81
- def process(self, index, x, condition, key_padding_mask=None):
82
- """
83
- x: [B, N, dim]
84
- condition: [B, M, condition_dim]
85
- """
86
-
87
- if self.stage == "parallel":
88
- residual = self.multihead_attn_condition_list[index](x,
89
- condition, condition, need_weights=False,
90
- key_padding_mask=key_padding_mask
91
- )[0]
92
- else:
93
- x = x + self.multihead_attn_condition_list[index](self.norm2_condition_list[index](x),
94
- condition, condition, need_weights=False,
95
- key_padding_mask=key_padding_mask
96
- )[0]
97
-
98
- residual = self.linear2_condition_list[index](F.gelu(self.linear1_condition_list[index]
99
- (self.norm3_condition_list[index](x)))
100
- )
101
-
102
- return residual
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openclay/models/ldm.py DELETED
@@ -1,83 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from diffusers import ModelMixin, ConfigMixin
5
- from diffusers.configuration_utils import register_to_config
6
- from diffusers.models.embeddings import TimestepEmbedding, Timesteps
7
-
8
- from ..modules.transformer import ClayTransformerDecoderLayer
9
-
10
- class ClayLDM(ModelMixin, ConfigMixin):
11
- @register_to_config
12
- def __init__(
13
- self,
14
- depth=24,
15
- dim=512,
16
- latent_dim=64,
17
- heads=8,
18
- ):
19
- super().__init__()
20
-
21
- timestep_input_dim = dim // 2
22
- time_embed_dim = dim
23
- self.time_proj = Timesteps(timestep_input_dim, True, 0)
24
- self.time_embedding = TimestepEmbedding(
25
- timestep_input_dim,
26
- time_embed_dim,
27
- act_fn="silu",
28
- post_act_fn="silu",
29
- )
30
- self.time_embed_dim = time_embed_dim
31
-
32
- self.proj_in = nn.Linear(latent_dim, dim, bias=False)
33
-
34
- self.layers = nn.TransformerDecoder(
35
- ClayTransformerDecoderLayer(dim, heads, dim_feedforward=dim * 4, dropout=0, activation=F.gelu, batch_first=True, norm_first=True, kdim=768, vdim=768, layer_norm_eps=1e-4),
36
- depth
37
- )
38
- for i in range(depth):
39
- self.layers.layers[i].index = i
40
-
41
- self.proj_out = nn.Linear(dim, latent_dim, bias=False)
42
-
43
- def register_condition_net(self, condition_net_list):
44
- self.condition_net_list = nn.ModuleList(condition_net_list)
45
- for layer in self.layers.layers:
46
- layer.register_condition_net(condition_net_list)
47
-
48
- def forward(self, sample, t, condition_text, condition_seq=None):
49
- """
50
- sample: [B, L, C]
51
- t: [N]
52
- condition: [B, 77, 768]
53
- condition_seq: [(condition_1, condition_1_scale), ...]
54
- return: [B, L, C]
55
- """
56
-
57
- B, L, C = sample.shape
58
-
59
- x = self.proj_in(sample)
60
- time_encoding = self.time_proj(t).to(sample)
61
- time_embed = self.time_embedding(time_encoding)
62
-
63
- condition_text = condition_text.expand(B, -1, -1)
64
-
65
- if condition_seq is None:
66
- condition_seq = []
67
- condition_seq = list(condition_seq)
68
- for i in range(len(condition_seq)):
69
- condition, condition_scale = condition_seq[i][:2]
70
- additional_dict = {}
71
- if len(condition_seq[i]) > 2:
72
- additional_dict = condition_seq[i][2]
73
- assert isinstance(additional_dict, dict)
74
- condition = self.condition_net_list[i].preprocess_condition(condition, additional_dict, time_embed)
75
- condition = condition.expand(B, -1, -1)
76
-
77
- condition_seq[i] = (condition, condition_scale, additional_dict)
78
-
79
- x_aug = torch.cat([x, time_embed[:, None]], dim=1)
80
- y = self.layers(x_aug, (condition_text, condition_seq))
81
- eps = self.proj_out(y[:, :-1])
82
-
83
- return eps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openclay/models/vae.py DELETED
@@ -1,124 +0,0 @@
1
- import tqdm
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
-
6
- from diffusers import ModelMixin, ConfigMixin
7
- from diffusers.configuration_utils import register_to_config
8
- # from torch_cluster import fps
9
-
10
- from ..modules.embedding import PointEmbed
11
- from ..modules.attention import CrossAttentionLayer
12
- from ..modules.drop_path import DropPathWrapper
13
- from ..modules.diag_gaussian import DiagonalGaussianDistribution
14
-
15
- class ClayVAE(ModelMixin, ConfigMixin):
16
- @register_to_config
17
- def __init__(
18
- self,
19
- depth=24,
20
- dim=512,
21
- latent_dim=64,
22
- heads=8,
23
- output_dim=1,
24
- ratio=0.25,
25
- ):
26
- super().__init__()
27
-
28
- self.ratio = ratio
29
-
30
- self.point_embed = PointEmbed(dim=dim)
31
-
32
- self.encoder_cross_attention = CrossAttentionLayer(dim, 1, dim_feedforward=dim * 4, dropout=0, activation=F.gelu, batch_first=True, norm_first=True, layer_norm_eps=1e-4)
33
- self.encoder_out = nn.Linear(dim, latent_dim * 2)
34
-
35
- self.decoder_in = nn.Linear(latent_dim, dim)
36
- self.decoder_layers = nn.TransformerEncoder(
37
- DropPathWrapper(
38
- nn.TransformerEncoderLayer(dim,
39
- heads,
40
- dim_feedforward=dim * 4,
41
- dropout=0,
42
- activation=F.gelu,
43
- batch_first=True,
44
- norm_first=True,
45
- layer_norm_eps=1e-4)
46
- ),
47
- depth)
48
-
49
- self.decoder_cross_attention = CrossAttentionLayer(dim, 1, dim_feedforward=dim * 4, dropout=0, activation=F.gelu, batch_first=True, norm_first=True, layer_norm_eps=1e-4)
50
- self.decoder_out = nn.Linear(dim, output_dim)
51
-
52
- def encode(self, pc_data, output_dict=None, attn_mask=None, no_cast=False, **kwargs):
53
-
54
- pc = pc_data[:, :, :3]
55
-
56
- # pc: B x N x 3
57
- B, N, D = pc.shape
58
-
59
- pc_flat = pc.reshape(B * N, D)
60
-
61
- batch = torch.arange(B).to(pc.device)
62
- batch = torch.repeat_interleave(batch, N)
63
-
64
- ratio = self.ratio
65
- idx = fps(pc_flat, batch, ratio=ratio)
66
-
67
- while idx.max() >= pc_flat.shape[0]:
68
- idx = fps(pc_flat, batch, ratio=ratio)
69
-
70
- sampled_pc = pc_flat[idx].reshape(B, -1, 3)
71
-
72
- pc_embeddings = self.point_embed(pc)
73
- sampled_pc_embeddings = self.point_embed(sampled_pc)
74
-
75
- x, attn_output_weights = self.encoder_cross_attention(sampled_pc_embeddings, pc_embeddings, attn_mask, need_weights=output_dict is not None, no_cast=no_cast)
76
- mean, logvar = self.encoder_out(x).chunk(2, dim=-1)
77
-
78
- posterior = DiagonalGaussianDistribution(mean, logvar)
79
- x = posterior.sample()
80
- kl = posterior.kl()
81
-
82
- if output_dict is not None:
83
- output_dict["fps_idx"] = idx
84
- output_dict["mean"] = mean
85
- output_dict["logvar"] = logvar
86
- output_dict["x"] = x
87
- output_dict["attn_output_weights"] = attn_output_weights
88
-
89
- return kl, x
90
-
91
- def decode(self, x, pc, mini_batch=None, no_cast=False, show_progress=True, cpu=True, **kwargs):
92
-
93
- x = self.decode_first(x)
94
-
95
- if mini_batch is None:
96
- y = self.decode_second(x, pc, no_cast)
97
-
98
- else:
99
- ys = []
100
- for mini_batch_start in tqdm.tqdm(range(0, pc.shape[1], mini_batch), "[ ClayVAE.decode ]", disable=not (show_progress and pc.shape[1] > mini_batch)):
101
- mini_pc = pc[:, mini_batch_start:mini_batch_start + mini_batch].to(x.device)
102
- y = self.decode_second(x, mini_pc, no_cast)
103
- if cpu:
104
- y = y.cpu()
105
- ys.append(y)
106
- y = torch.cat(ys, dim=1)
107
-
108
- return y
109
-
110
- def decode_first(self, x):
111
- x = self.decoder_in(x)
112
- x = self.decoder_layers(x)
113
- return x
114
-
115
- def decode_second(self, x, mini_pc, no_cast=False):
116
- pc_embeddings = self.point_embed(mini_pc)
117
- y, _ = self.decoder_cross_attention(pc_embeddings, x, no_cast=no_cast)
118
- y = self.decoder_out(y)
119
- return y
120
-
121
- def forward(self, surface, points, no_cast=False, **kwargs):
122
- kl, x = self.encode(surface, no_cast=no_cast, **kwargs)
123
- x = self.decode(x, points, no_cast=no_cast, **kwargs)[:, :, 0]
124
- return {"logits": x, "kl": kl}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openclay/modules/attention.py DELETED
@@ -1,73 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from torch import Tensor
5
- from typing import Optional, Union, Callable
6
-
7
-
8
- class CrossAttentionLayer(nn.Module):
9
- __constants__ = ["batch_first", "norm_first", "context_norm"]
10
-
11
- def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
12
- activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
13
- layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False, context_norm=True,
14
- device=None, dtype=None) -> None:
15
- factory_kwargs = {"device": device, "dtype": dtype}
16
- super(CrossAttentionLayer, self).__init__()
17
- self.multihead_attn = nn.modules.activation.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs)
18
- # Implementation of Feedforward model
19
- self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
20
- self.dropout = nn.Dropout(dropout)
21
- self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
22
-
23
- self.norm_first = norm_first
24
- self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
25
- self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
26
-
27
- self.dropout2 = nn.Dropout(dropout)
28
- self.dropout3 = nn.Dropout(dropout)
29
-
30
- self.context_norm = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) if context_norm else None
31
- # Legacy string support for activation function.
32
- self.activation = activation
33
-
34
- def forward(self, tgt: Tensor, memory: Tensor, memory_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, need_weights=False, no_cast=False) -> Tensor:
35
-
36
- assert self.norm_first
37
-
38
- if no_cast:
39
- tgt = tgt.float()
40
- memory = memory.float()
41
-
42
- with torch.autocast("cuda", enabled=False):
43
- x = tgt
44
- memory = self.context_norm(memory) if self.context_norm is not None else memory
45
- y, attn_output_weights = self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask, need_weights=need_weights)
46
- x = x + y
47
- x = x + self._ff_block(self.norm3(x))
48
-
49
- return x, attn_output_weights
50
-
51
- x = tgt
52
- memory = self.context_norm(memory) if self.context_norm is not None else memory
53
- y, attn_output_weights = self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask, need_weights=need_weights)
54
- x = x + y
55
- x = x + self._ff_block(self.norm3(x))
56
-
57
- return x, attn_output_weights
58
-
59
- # multihead attention block
60
- def _mha_block(self, x: Tensor, mem: Tensor,
61
- attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], need_weights) -> Tensor:
62
- x, attn_output_weights = self.multihead_attn(x, mem, mem,
63
- attn_mask=attn_mask,
64
- key_padding_mask=key_padding_mask,
65
- need_weights=need_weights)
66
- return self.dropout2(x), attn_output_weights
67
-
68
- # feed forward block
69
- def _ff_block(self, x: Tensor) -> Tensor:
70
- x = self.linear2(self.dropout(self.activation(self.linear1(x))))
71
- return self.dropout3(x)
72
-
73
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openclay/modules/control_volume.py DELETED
@@ -1,52 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- class ControlVolume(nn.Module):
6
- def __init__(self, volume_dim=1, block_dim=384, condition_dim=768, time_embed_dim=2048, downsample_times=1, aggregation_method=None):
7
- super().__init__()
8
-
9
- if aggregation_method is None:
10
- pass
11
- elif aggregation_method == "maxpool":
12
- self.proj_maxpool = nn.Linear(volume_dim, volume_dim)
13
-
14
- self.conv_in = nn.Conv3d(volume_dim, block_dim, kernel_size=3, stride=1, padding=1)
15
- self.proj_t = nn.Linear(time_embed_dim, block_dim)
16
- self.norm = nn.GroupNorm(8, block_dim)
17
-
18
- self.blocks = nn.ModuleList([])
19
- for i in range(downsample_times):
20
- self.blocks.append(nn.Conv3d(block_dim, block_dim, kernel_size=3, padding=1))
21
- self.blocks.append(nn.Conv3d(block_dim, block_dim, kernel_size=3, padding=1, stride=2))
22
-
23
- self.conv_out = nn.Conv3d(block_dim, condition_dim, kernel_size=3, stride=1, padding=1)
24
-
25
- self.volume_dim = volume_dim
26
- self.aggregation_method = aggregation_method
27
-
28
- def forward(self, volume, time_embed):
29
- """
30
- volume: [B, volume_dim, X, Y, Z]
31
- time_embed: [B, block_dim]
32
- return:
33
- [B, condition_dim, X, Y, Z]
34
- """
35
- B, _, X, Y, Z = volume.shape
36
- if self.aggregation_method == "maxpool":
37
- volume = self.proj_maxpool(volume.reshape(B, 4, self.volume_dim, X, Y, Z).permute(0, 1, 3, 4, 5, 2)).permute(0, 5, 2, 3, 4, 1)
38
- volume = F.max_pool1d(volume.reshape(B, self.volume_dim * X * Y * Z, 4), 4).reshape(B, self.volume_dim, X, Y, Z)
39
-
40
- x = F.silu(self.conv_in(volume)) # [B, block_dim, X, Y, Z]
41
-
42
- time_embed = self.proj_t(time_embed) # [B, block_dim]
43
- x = x + time_embed[:, :, None, None, None]
44
-
45
- x = self.norm(x)
46
-
47
- for block in self.blocks:
48
- x = block(x)
49
- x = F.silu(x)
50
-
51
- x = self.conv_out(x) # [B, condition_dim, X, Y, Z]
52
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openclay/modules/diag_gaussian.py DELETED
@@ -1,42 +0,0 @@
1
- import numpy as np
2
- import torch
3
-
4
- class DiagonalGaussianDistribution(object):
5
- def __init__(self, mean, logvar, deterministic=False):
6
- self.mean = mean
7
- self.logvar = logvar
8
- self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
9
- self.deterministic = deterministic
10
- self.std = torch.exp(0.5 * self.logvar)
11
- self.var = torch.exp(self.logvar)
12
- if self.deterministic:
13
- self.var = self.std = torch.zeros_like(self.mean).to(device=self.mean.device)
14
-
15
- def sample(self):
16
- x = self.mean + self.std * torch.randn(self.mean.shape).to(self.mean)
17
- return x
18
-
19
- def kl(self, other=None):
20
- if self.deterministic:
21
- return torch.Tensor([0.])
22
- else:
23
- if other is None:
24
- return 0.5 * torch.mean(torch.pow(self.mean, 2)
25
- + self.var - 1.0 - self.logvar,
26
- dim=[1, 2])
27
- else:
28
- return 0.5 * torch.mean(
29
- torch.pow(self.mean - other.mean, 2) / other.var
30
- + self.var / other.var - 1.0 - self.logvar + other.logvar,
31
- dim=[1, 2, 3])
32
-
33
- def nll(self, sample, dims=[1, 2, 3]):
34
- if self.deterministic:
35
- return torch.Tensor([0.])
36
- logtwopi = np.log(2.0 * np.pi)
37
- return 0.5 * torch.sum(
38
- logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
39
- dim=dims)
40
-
41
- def mode(self):
42
- return self.mean
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openclay/modules/drop_path.py DELETED
@@ -1,34 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from torch import Tensor
4
- from timm.models.layers import DropPath
5
- from typing import Optional
6
-
7
- class Struct:
8
- pass
9
-
10
- class DropPathWrapper(nn.Module):
11
- def __init__(self, layer):
12
- super().__init__()
13
- self.layer = layer
14
- self.drop_path = DropPath(drop_prob=0.1)
15
-
16
- self_attn_dummy = Struct()
17
- self_attn_dummy.batch_first = True
18
- self.self_attn = self_attn_dummy
19
-
20
- def forward(
21
- self,
22
- src: Tensor,
23
- src_mask: Optional[Tensor] = None,
24
- src_key_padding_mask: Optional[Tensor] = None,
25
- is_causal: bool = False
26
- ) -> Tensor:
27
-
28
- x = src
29
- x_p = self.layer(src, src_mask, src_key_padding_mask, is_causal)
30
- p = x_p - x
31
-
32
- y = x + self.drop_path(p)
33
-
34
- return y
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openclay/modules/embedding.py DELETED
@@ -1,36 +0,0 @@
1
- # TODO: add reference to 3dshape2vecset
2
-
3
- import numpy as np
4
- import torch
5
- import torch.nn as nn
6
-
7
- class PointEmbed(nn.Module):
8
- def __init__(self, hidden_dim=48, dim=128):
9
- super().__init__()
10
-
11
- assert hidden_dim % 6 == 0
12
-
13
- self.embedding_dim = hidden_dim
14
- e = torch.pow(2, torch.arange(self.embedding_dim // 6)).float() * np.pi
15
- e = torch.stack([
16
- torch.cat([e, torch.zeros(self.embedding_dim // 6),
17
- torch.zeros(self.embedding_dim // 6)]),
18
- torch.cat([torch.zeros(self.embedding_dim // 6), e,
19
- torch.zeros(self.embedding_dim // 6)]),
20
- torch.cat([torch.zeros(self.embedding_dim // 6),
21
- torch.zeros(self.embedding_dim // 6), e]),
22
- ])
23
- self.register_buffer("basis", e, persistent=False) # 3 x 24
24
-
25
- self.mlp = nn.Linear(self.embedding_dim + 3, dim)
26
-
27
- @staticmethod
28
- def embed(input, basis):
29
- projections = torch.einsum("bnd,de->bne", input, basis)
30
- embeddings = torch.cat([projections.sin(), projections.cos()], dim=2)
31
- return embeddings
32
-
33
- def forward(self, input):
34
- # input: B x N x 3
35
- embed = self.mlp(torch.cat([self.embed(input, self.basis), input], dim=2)) # B x N x C
36
- return embed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openclay/modules/transformer.py DELETED
@@ -1,116 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from torch import Tensor
5
- from torch.nn.modules.transformer import _get_activation_fn
6
- from typing import Optional, Union, Callable, List
7
-
8
- class ClayTransformerDecoderLayer(nn.TransformerDecoderLayer):
9
- __constants__ = ["batch_first", "norm_first"]
10
-
11
- def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
12
- activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
13
- layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
14
- device=None, dtype=None, kdim=None, vdim=None) -> None:
15
- factory_kwargs = {"device": device, "dtype": dtype}
16
- nn.Module.__init__(self)
17
- self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs)
18
- self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, kdim=kdim, vdim=vdim, **factory_kwargs)
19
- # Implementation of Feedforward model
20
- self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
21
- self.dropout = lambda x: x
22
- self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
23
-
24
- self.norm_first = norm_first
25
- self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
26
- self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
27
- self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
28
- self.dropout1 = lambda x: x
29
- self.dropout2 = lambda x: x
30
- self.dropout3 = lambda x: x
31
-
32
- # Legacy string support for activation function.
33
- if isinstance(activation, str):
34
- self.activation = _get_activation_fn(activation)
35
- else:
36
- self.activation = activation
37
-
38
- self.d_model = d_model
39
- self.nhead = nhead
40
- self.dropout_ = dropout
41
- self.batch_first = batch_first
42
- self.kdim = kdim
43
- self.vdim = vdim
44
- self.factory_kwargs = factory_kwargs
45
- self._mha_block_second = None
46
- self.condition_net_list = []
47
-
48
- def register_condition_net(self, condition_net_list):
49
- self.condition_net_list = condition_net_list
50
-
51
- def process_condition_net(self, x_norm, condition_seq, stage):
52
- """
53
- condition_seq: [(condition_1, condition_1_scale), ...]
54
- """
55
-
56
- assert len(condition_seq) == 0 or len(condition_seq) == len(self.condition_net_list), f"len(self.condition_net_list)={len(self.condition_net_list)}, len(condition_seq)={len(condition_seq)}"
57
-
58
- residual_all = 0
59
-
60
- for i in range(len(condition_seq)):
61
- condition_net = self.condition_net_list[i]
62
- condition, condition_scale = condition_seq[i][:2]
63
-
64
- if condition_scale == 0:
65
- continue
66
-
67
- key_padding_mask = None
68
- if len(condition_seq[i]) > 2:
69
- additional_dict = condition_seq[i][2]
70
- assert isinstance(additional_dict, dict)
71
- key_padding_mask = additional_dict.get("key_padding_mask", None)
72
-
73
- if stage == condition_net.stage:
74
- residual = condition_net.process(self.index, x_norm, condition, key_padding_mask=key_padding_mask)
75
- residual_all = residual_all + residual * condition_scale
76
-
77
- return residual_all
78
-
79
- def forward(
80
- self,
81
- tgt: Tensor,
82
- memory_list: List[Tensor],
83
- tgt_mask: Optional[Tensor] = None,
84
- memory_mask: Optional[Tensor] = None,
85
- tgt_key_padding_mask: Optional[Tensor] = None,
86
- memory_key_padding_mask: Optional[Tensor] = None,
87
- tgt_is_causal: bool = False,
88
- memory_is_causal: bool = False,
89
- ) -> Tensor:
90
- """
91
- memory_list = ( condition_text, [(condition_1, condition_1_scale), ...] )
92
- """
93
-
94
- x = tgt
95
- assert self.norm_first
96
-
97
- memory, condition_seq = memory_list
98
-
99
- x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal)
100
- x_norm = self.norm2(x)
101
- x = x + self._mha_block(x_norm, memory, memory_mask, memory_key_padding_mask, memory_is_causal) + self.process_condition_net(x_norm, condition_seq, stage="parallel")
102
- x = x + self._ff_block(self.norm3(x))
103
-
104
- x = x + self.process_condition_net(x_norm, condition_seq, stage="postfix")
105
- return x
106
-
107
- def _mha_block(self, x: Tensor, mem: Tensor,
108
- attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False, multihead_attn=None) -> Tensor:
109
- if multihead_attn is None:
110
- multihead_attn = self.multihead_attn
111
- x = multihead_attn(x, mem, mem,
112
- attn_mask=attn_mask,
113
- key_padding_mask=key_padding_mask,
114
- is_causal=is_causal,
115
- need_weights=False)[0]
116
- return self.dropout2(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openclay/pipeline_openclay.py DELETED
@@ -1,195 +0,0 @@
1
- import numpy as np
2
- import torch
3
- import torch.nn.functional as F
4
- import torch.utils.checkpoint
5
- import torch.utils.data
6
- import tqdm
7
- import mcubes
8
- import inspect
9
- import trimesh
10
- import gc
11
- from diffusers import UniPCMultistepScheduler
12
-
13
- from .utils import get_grid_tensor
14
- from .models import ClayVAE, ClayLDM, ClayConditionNet
15
-
16
- from transformers import Dinov2Model, BitImageProcessor, CLIPTextModel, CLIPTokenizer
17
- from diffusers import DiffusionPipeline
18
-
19
- class OpenClayPipeline(DiffusionPipeline):
20
- def __init__(self,
21
- vae: ClayVAE ,
22
- text_encoder: CLIPTextModel,
23
- tokenizer: CLIPTokenizer,
24
- ldm: ClayLDM,
25
- scheduler=None,
26
- ):
27
- super().__init__()
28
-
29
-
30
- if scheduler is None:
31
- scheduler = self.get_unipc_scheduler()
32
-
33
- self.register_modules(
34
- vae=vae,
35
- text_encoder=text_encoder,
36
- tokenizer=tokenizer,
37
- ldm=ldm,
38
- scheduler=scheduler,
39
- )
40
-
41
- def get_unipc_scheduler(self):
42
- scheduler = UniPCMultistepScheduler(
43
- num_train_timesteps=1000,
44
- beta_schedule="squaredcos_cap_v2",
45
- prediction_type="v_prediction",
46
- timestep_spacing="linspace",
47
- rescale_betas_zero_snr=True
48
- )
49
- return scheduler
50
-
51
- def get_timesteps(self, num_inference_steps, strength):
52
- # get the original timestep using init_timestep
53
- init_timestep = min(int(round(num_inference_steps * strength)), num_inference_steps)
54
- t_start = max(num_inference_steps - init_timestep, 0)
55
- timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:]
56
-
57
- return timesteps, num_inference_steps - t_start
58
-
59
-
60
- def vae_decode_latent(self, latent, res=128, mini_batch=129 * 129 * 129, show_progress=True):
61
- assert torch.isfinite(latent).all()
62
-
63
- gap = 2 / res
64
- grid = get_grid_tensor(res).to(latent)
65
- logits = self.vae.decode(latent, grid, mini_batch=mini_batch, show_progress=show_progress)[:, :, 0].cpu()
66
-
67
- logits = logits.view(res + 1, res + 1, res + 1)
68
-
69
- if isinstance(logits, torch.Tensor):
70
- logits = logits.cpu().numpy().astype(np.float32)
71
- assert isinstance(logits, np.ndarray)
72
- verts, faces = mcubes.marching_cubes(logits, 0)
73
- verts *= gap
74
- verts -= 1
75
-
76
- m = trimesh.Trimesh(verts, faces)
77
- return m
78
-
79
- def __call__(self, prompt="", negative_prompt="", sample=None, strength=None,
80
- res=128, rescale_phi=0.7, cfg=7.5, resacle_cfg=True,
81
- num_latents=1024, num_inference_steps=100, timesteps=None,
82
- mini_batch=129**3, seed=42, num=1,
83
- condition_seq=None, show_progress=True,
84
- ):
85
-
86
- """
87
- condition_seq: [(condition_1, condition_1_scale), ...]
88
- sample_lock_index: [m]
89
- """
90
-
91
- device = self.text_encoder.device
92
- generator = torch.Generator(device).manual_seed(seed)
93
-
94
- prompt: list = [prompt] if isinstance(prompt, str) else prompt
95
- negative_prompt: list = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
96
-
97
- assert len(prompt) == len(negative_prompt) == num or len(prompt) == 1 or len(negative_prompt) == 1
98
- if len(prompt) == 1:
99
- prompt = prompt * num
100
- if len(negative_prompt) == 1:
101
- negative_prompt = negative_prompt * num
102
-
103
- token = self.tokenizer(prompt + negative_prompt, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt").input_ids
104
- encoder_hidden_states = self.text_encoder(token.to(device))[0].to(self.dtype)
105
- encoder_hidden_states = encoder_hidden_states.reshape(2, num, encoder_hidden_states.shape[1], encoder_hidden_states.shape[2]).permute(1, 0, 2, 3).reshape(num * 2, encoder_hidden_states.shape[1], encoder_hidden_states.shape[2])
106
- assert encoder_hidden_states.shape[0] == num * 2
107
-
108
- # set step values
109
- extra_set_kwargs = {}
110
- if "timesteps" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()):
111
- extra_set_kwargs["timesteps"] = timesteps
112
- self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
113
- timesteps = self.scheduler.timesteps
114
-
115
- noise = torch.randn(num, num_latents, 64, dtype=self.dtype, device=device, generator=generator)
116
- if sample is not None:
117
- assert torch.isfinite(sample).all()
118
- timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength)
119
- t_start = timesteps[0]
120
- sample_noisy = self.scheduler.add_noise(sample, noise, torch.tensor(t_start).to(sample.device))
121
- sample = sample_noisy
122
- else:
123
- sample = noise
124
-
125
- condition_dict_cond = {
126
- "condition_seq": [
127
- (
128
- condition_scale_dict[0], condition_scale_dict[1][0],
129
- ((condition_scale_dict[2][0] if not isinstance(condition_scale_dict[2], dict) else condition_scale_dict[2]) if len(condition_scale_dict) > 2 else {})
130
- )
131
- for condition_scale_dict in condition_seq
132
- ],
133
- }
134
- condition_dict_uncond = {
135
- "condition_seq": [
136
- (
137
- condition_scale_dict[0], condition_scale_dict[1][1],
138
- ((condition_scale_dict[2][1] if not isinstance(condition_scale_dict[2], dict) else condition_scale_dict[2]) if len(condition_scale_dict) > 2 else {})
139
- )
140
- for condition_scale_dict in condition_seq
141
- ],
142
- }
143
-
144
- extra_step_kwargs = {}
145
- if "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()):
146
- extra_step_kwargs["generator"] = generator
147
-
148
- for t in tqdm.tqdm(timesteps, "[ ClayLDMPipeline.__call__ ]", disable=not show_progress):
149
- # 1. predict noise model_output
150
- if isinstance(t, torch.Tensor):
151
- t = t.item()
152
- t_tensor = torch.tensor([t], dtype=torch.long, device=device)
153
-
154
- model_output_cond = self.ldm(
155
- sample,
156
- t_tensor.expand(num),
157
- encoder_hidden_states[0::2],
158
- **condition_dict_cond,
159
- )
160
- model_output_uncond = self.ldm(
161
- sample,
162
- t_tensor.expand(num),
163
- encoder_hidden_states[1::2],
164
- **condition_dict_uncond,
165
- )
166
-
167
- model_output_cfg = (model_output_cond - model_output_uncond) * cfg + model_output_uncond
168
- if resacle_cfg:
169
- model_output_rescaled = model_output_cfg / model_output_cfg.std(dim=(1, 2), keepdim=True) * model_output_cond.std(dim=(1, 2), keepdim=True)
170
- model_output = rescale_phi * model_output_rescaled + (1 - rescale_phi) * model_output_cfg
171
- else:
172
- model_output = model_output_cfg
173
-
174
- # 2. compute previous image: x_t -> x_t-1
175
- sample = self.scheduler.step(
176
- model_output[:, None, :, :].permute(0, 3, 1, 2),
177
- t,
178
- sample[:, None, :, :].permute(0, 3, 1, 2),
179
- **extra_step_kwargs
180
- ).prev_sample.permute(0, 2, 3, 1)[:, 0, :, :]
181
-
182
- assert torch.isfinite(sample).all(), sample
183
-
184
- gc.collect()
185
- torch.cuda.empty_cache()
186
-
187
- mesh_list = []
188
- for i in tqdm.tqdm(range(sample.shape[0])):
189
- mesh = self.vae_decode_latent(sample[i:i + 1], res=res, mini_batch=mini_batch)
190
- mesh.vertices[:, 0] += i % 4 * 2
191
- mesh.vertices[:, 2] += i // 4 * 4
192
- mesh_list.append(mesh)
193
-
194
- mesh_combined = trimesh.util.concatenate(mesh_list)
195
- return mesh_combined
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openclay/utils.py DELETED
@@ -1,80 +0,0 @@
1
- import numpy as np
2
- import torch
3
- import cv2
4
-
5
- def get_grid_tensor(res=128):
6
-
7
- gap = 2. / res
8
- x = torch.linspace(-1, 1, res + 1)
9
- y = torch.linspace(-1, 1, res + 1)
10
- z = torch.linspace(-1, 1, res + 1)
11
- grid = torch.stack(torch.meshgrid(x, y, z)).view(3, -1).T[None]
12
- return grid
13
-
14
-
15
- def pad_to_square(image, pad_color=None):
16
- H, W, C = image.shape
17
- max_side = max(H, W)
18
- padded_image = np.ones((max_side, max_side, C))
19
- if pad_color is None:
20
- pad_color = image[0, 0]
21
- padded_image[:] = pad_color
22
- vertical_offset = (max_side - H) // 2
23
- horizontal_offset = (max_side - W) // 2
24
- padded_image[vertical_offset:vertical_offset + H, horizontal_offset:horizontal_offset + W, :] = image
25
- return padded_image
26
-
27
-
28
- def read_image_square(path):
29
- image = cv2.imread(path, -1)
30
- return process_image_square(image)
31
-
32
- def process_image_square(image):
33
- image = image.astype(np.float32) / 255
34
-
35
- if image.shape[2] == 4: # background
36
- fg = image[:, :, 3] > 0.5
37
- fg_coord = np.stack(np.where(fg))
38
- rc_min = fg_coord.min(axis=1)
39
- rc_max = fg_coord.max(axis=1)
40
- rc_range = rc_max - rc_min
41
- rc_min -= (rc_range * 0.1).astype(int)
42
- rc_max += (rc_range * 0.1).astype(int)
43
- rc_min = rc_min.clip(0, None)
44
- rc_max = rc_max.clip(0, None)
45
- image = image[rc_min[0]:rc_max[0], rc_min[1]:rc_max[1]]
46
- image = image[:, :, :3] * image[:, :, 3:] + 1 * (1 - image[:, :, 3:])
47
- render = image[:, :, ::-1]
48
- render = pad_to_square(render)
49
-
50
- return render
51
-
52
- def get_center_position(num_voxel):
53
- """
54
- num_voxel: int
55
- return:
56
- [X, Y, Z, 3]
57
- """
58
- center_position = (torch.stack(torch.meshgrid([torch.arange(num_voxel, dtype=torch.float32)] * 3, indexing="ij"), dim=3) + 0.5) / num_voxel * 2 - 1
59
- return center_position
60
-
61
-
62
- def geometry_get_voxel(geometry):
63
- import pysdf
64
-
65
- res_large = 128
66
- voxel_large_or = np.zeros((res_large, res_large, res_large), dtype=bool)
67
- center_position_large = get_center_position(res_large).numpy()
68
-
69
- voxel_large = pysdf.SDF(geometry.vertices, geometry.faces).contains(center_position_large.reshape(-1, 3)).reshape(res_large, res_large, res_large)
70
-
71
- voxel_large_or |= voxel_large
72
-
73
- res = 16
74
- voxel_or = np.zeros((res, res, res), dtype=bool)
75
-
76
- loc = np.mgrid[:res, :res, :res].transpose(1, 2, 3, 0).reshape(-1, 3)
77
- for l in loc:
78
- voxel_or[l[0], l[1], l[2]] = voxel_large_or[l[0] * 8:l[0] * 8 + 8, l[1] * 8:l[1] * 8 + 8, l[2] * 8:l[2] * 8 + 8].sum() > 0
79
-
80
- return voxel_or
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -3,16 +3,4 @@ requests
3
  pillow
4
  gradio==4.31.2
5
  requests-toolbelt
6
- websocket-client
7
-
8
- numpy==1.24.1
9
-
10
- PyMCubes
11
- pysdf==0.1.9
12
- opencv-python==4.9.0.80
13
- tqdm==4.66.1
14
- trimesh==4.0.5
15
- timm==0.9.12
16
- diffusers==0.29.0
17
- transformers
18
- accelerate
 
3
  pillow
4
  gradio==4.31.2
5
  requests-toolbelt
6
+ websocket-client