Spaces:
Running
on
Zero
Running
on
Zero
Revert "test 600m"
Browse filesThis reverts commit 893330bad43179b8a93788b03a11e16e5e39f1ed.
- Rodin.py +34 -44
- app.py +79 -229
- constant.py +0 -3
- openclay/models/__init__.py +0 -3
- openclay/models/condition.py +0 -102
- openclay/models/ldm.py +0 -83
- openclay/models/vae.py +0 -124
- openclay/modules/attention.py +0 -73
- openclay/modules/control_volume.py +0 -52
- openclay/modules/diag_gaussian.py +0 -42
- openclay/modules/drop_path.py +0 -34
- openclay/modules/embedding.py +0 -36
- openclay/modules/transformer.py +0 -116
- openclay/pipeline_openclay.py +0 -195
- openclay/utils.py +0 -80
- requirements.txt +1 -13
Rodin.py
CHANGED
@@ -13,27 +13,30 @@ import gradio as gr
|
|
13 |
from requests_toolbelt.multipart.encoder import MultipartEncoder
|
14 |
from constant import *
|
15 |
|
|
|
|
|
|
|
|
|
16 |
def login(email, password):
|
17 |
payload = {'password': password}
|
18 |
if email:
|
19 |
payload['email'] = email
|
20 |
-
|
21 |
response = requests.post(f"{BASE_URL}/user/login", json=payload)
|
22 |
try:
|
23 |
response_data = response.json()
|
24 |
except json.JSONDecodeError as e:
|
25 |
log("ERROR", f"Error in login: {response}")
|
26 |
raise e
|
27 |
-
|
28 |
if 'error' in response_data and response_data['error']:
|
29 |
raise Exception(response_data['error'])
|
30 |
log("INFO", f"Logged successfully")
|
31 |
user_uuid = response_data['user_uuid']
|
32 |
token = response_data['token']
|
33 |
-
|
34 |
return user_uuid, token
|
35 |
|
36 |
-
|
37 |
def rodin_history(task_uuid, token):
|
38 |
headers = {
|
39 |
'Authorization': f'Bearer {token}'
|
@@ -41,7 +44,6 @@ def rodin_history(task_uuid, token):
|
|
41 |
response = requests.post(f"{BASE_URL}/task/rodin_history", data={"uuid": task_uuid}, headers=headers)
|
42 |
return response.json()
|
43 |
|
44 |
-
|
45 |
def rodin_preprocess_image(generate_prompt, image, name, token):
|
46 |
m = MultipartEncoder(
|
47 |
fields={
|
@@ -56,7 +58,6 @@ def rodin_preprocess_image(generate_prompt, image, name, token):
|
|
56 |
response = requests.post(f"{BASE_URL}/task/rodin_mesh_image_process", data=m, headers=headers)
|
57 |
return response
|
58 |
|
59 |
-
|
60 |
def crop_image(image, type):
|
61 |
if image == None:
|
62 |
raise gr.Error("Please generate the object first")
|
@@ -77,7 +78,7 @@ def crop_image(image, type):
|
|
77 |
# Perform Rodin mesh operation
|
78 |
def rodin_mesh(prompt, group_uuid, settings, images, name, token):
|
79 |
images = [convert_base64_to_binary(img) for img in images]
|
80 |
-
|
81 |
m = MultipartEncoder(
|
82 |
fields={
|
83 |
'prompt': prompt,
|
@@ -99,13 +100,12 @@ def rodin_mesh(prompt, group_uuid, settings, images, name, token):
|
|
99 |
def convert_base64_to_binary(base64_string):
|
100 |
if ',' in base64_string:
|
101 |
base64_string = base64_string.split(',')[1]
|
102 |
-
|
103 |
image_data = base64.b64decode(base64_string)
|
104 |
image_buffer = io.BytesIO(image_data)
|
105 |
-
|
106 |
return image_buffer
|
107 |
|
108 |
-
|
109 |
def rodin_update(prompt, task_uuid, token, settings):
|
110 |
headers = {
|
111 |
'Authorization': f'Bearer {token}'
|
@@ -113,7 +113,6 @@ def rodin_update(prompt, task_uuid, token, settings):
|
|
113 |
response = requests.post(f"{BASE_URL}/task/rodin_update", data={"uuid": task_uuid, "prompt": prompt, "settings": settings}, headers=headers)
|
114 |
return response
|
115 |
|
116 |
-
|
117 |
def load_image(img_path):
|
118 |
try:
|
119 |
image = Image.open(img_path)
|
@@ -136,11 +135,9 @@ def load_image(img_path):
|
|
136 |
image_bytes = byte_io.getvalue()
|
137 |
return image_bytes
|
138 |
|
139 |
-
|
140 |
def log(level, info_text):
|
141 |
print(f"[ {level} ] - {time.strftime('%Y%m%d_%H:%M:%S', time.localtime())} - {info_text}")
|
142 |
|
143 |
-
|
144 |
class Generator:
|
145 |
def __init__(self, user_id, password, token) -> None:
|
146 |
# _, self.token = login(user_id, password)
|
@@ -149,11 +146,11 @@ class Generator:
|
|
149 |
self.password = password
|
150 |
self.task_uuid = None
|
151 |
self.processed_image = None
|
152 |
-
|
153 |
-
def preprocess(self, prompt, image_path, processed_image, task_uuid=""):
|
154 |
-
if image_path
|
155 |
raise gr.Error("Please upload an image first")
|
156 |
-
|
157 |
if processed_image and prompt and (not task_uuid):
|
158 |
log("INFO", "Using cached image and prompt...")
|
159 |
return prompt, processed_image
|
@@ -163,10 +160,10 @@ class Generator:
|
|
163 |
while not success:
|
164 |
if try_times > 3:
|
165 |
raise gr.Error("Failed to preprocess image")
|
166 |
-
try_times += 1
|
167 |
image_file = load_image(image_path)
|
168 |
log("INFO", "Image loaded, processing...")
|
169 |
-
|
170 |
try:
|
171 |
if prompt and task_uuid:
|
172 |
res = rodin_preprocess_image(generate_prompt=False, image=image_file, name=os.path.basename(image_path), token=self.token)
|
@@ -203,13 +200,13 @@ class Generator:
|
|
203 |
|
204 |
log("INFO", "Image preprocessed successfully")
|
205 |
return prompt, processed_image
|
206 |
-
|
207 |
def generate_mesh(self, prompt, processed_image, task_uuid=""):
|
208 |
log("INFO", "Generating mesh...")
|
209 |
if task_uuid == "":
|
210 |
settings = {'view_weights': [1]} # Define weights as per your requirements, for multiple images, use multiple values, e,g [0.5, 0.5]
|
211 |
images = [processed_image] # List of images, all the images should be processed first
|
212 |
-
|
213 |
res = rodin_mesh(prompt=prompt, group_uuid=None, settings=settings, images=images, name="images.jpeg", token=self.token)
|
214 |
try:
|
215 |
mesh_response = res.json()
|
@@ -218,14 +215,14 @@ class Generator:
|
|
218 |
except Exception as e:
|
219 |
log("ERROR", f"Error in generating mesh: {e} and response: {res}")
|
220 |
raise gr.Error("Error in generating mesh, please try again later.")
|
221 |
-
|
222 |
-
task_uuid = mesh_response['uuid']
|
223 |
else:
|
224 |
new_prompt = prompt
|
225 |
settings = {
|
226 |
"view_weights": [1],
|
227 |
-
"seed": random.randint(0, 10000),
|
228 |
-
"escore": 5.5,
|
229 |
}
|
230 |
res = rodin_update(new_prompt, task_uuid, self.token, settings)
|
231 |
try:
|
@@ -243,7 +240,7 @@ class Generator:
|
|
243 |
except Exception as e:
|
244 |
log("ERROR", f"Error in generating mesh: {history}")
|
245 |
raise gr.Error("Busy connection, please try again later.")
|
246 |
-
|
247 |
response = requests.get(preview_image, stream=True)
|
248 |
if response.status_code == 200:
|
249 |
image = Image.open(response.raw)
|
@@ -260,32 +257,25 @@ class JobStatusChecker:
|
|
260 |
self.subscription_key = subscription_key
|
261 |
self.sio = socketio.Client(logger=True, engineio_logger=True)
|
262 |
|
263 |
-
@self.sio.
|
264 |
-
def connect(
|
265 |
-
print("
|
266 |
|
267 |
-
@self.sio.
|
268 |
-
def disconnect(
|
269 |
-
print("
|
270 |
|
271 |
@self.sio.on('message', namespace='*')
|
272 |
def message(*args, **kwargs):
|
273 |
-
print(f"""[ JobStatusChecker.message ] args = {args}""")
|
274 |
-
safe_to_disconnect = False
|
275 |
if len(args) > 2:
|
276 |
data = args[2]
|
277 |
if data.get('jobStatus') == 'Succeeded':
|
278 |
-
|
279 |
-
|
280 |
-
safe_to_disconnect = True
|
281 |
-
|
282 |
-
if safe_to_disconnect:
|
283 |
-
print("[ JobStatusChecker.message ] Job Succeeded! Please find the SDF image in history")
|
284 |
-
self.sio.disconnect()
|
285 |
else:
|
286 |
-
print(
|
287 |
|
288 |
def start(self):
|
289 |
-
self.sio.connect(f"{self.base_url}/scheduler_socket?subscription={self.subscription_key}",
|
290 |
namespaces=['/api/scheduler_socket'], transports='websocket')
|
291 |
-
self.sio.wait()
|
|
|
13 |
from requests_toolbelt.multipart.encoder import MultipartEncoder
|
14 |
from constant import *
|
15 |
|
16 |
+
@spaces.GPU
|
17 |
+
def foo():
|
18 |
+
pass
|
19 |
+
|
20 |
def login(email, password):
|
21 |
payload = {'password': password}
|
22 |
if email:
|
23 |
payload['email'] = email
|
24 |
+
|
25 |
response = requests.post(f"{BASE_URL}/user/login", json=payload)
|
26 |
try:
|
27 |
response_data = response.json()
|
28 |
except json.JSONDecodeError as e:
|
29 |
log("ERROR", f"Error in login: {response}")
|
30 |
raise e
|
31 |
+
|
32 |
if 'error' in response_data and response_data['error']:
|
33 |
raise Exception(response_data['error'])
|
34 |
log("INFO", f"Logged successfully")
|
35 |
user_uuid = response_data['user_uuid']
|
36 |
token = response_data['token']
|
37 |
+
|
38 |
return user_uuid, token
|
39 |
|
|
|
40 |
def rodin_history(task_uuid, token):
|
41 |
headers = {
|
42 |
'Authorization': f'Bearer {token}'
|
|
|
44 |
response = requests.post(f"{BASE_URL}/task/rodin_history", data={"uuid": task_uuid}, headers=headers)
|
45 |
return response.json()
|
46 |
|
|
|
47 |
def rodin_preprocess_image(generate_prompt, image, name, token):
|
48 |
m = MultipartEncoder(
|
49 |
fields={
|
|
|
58 |
response = requests.post(f"{BASE_URL}/task/rodin_mesh_image_process", data=m, headers=headers)
|
59 |
return response
|
60 |
|
|
|
61 |
def crop_image(image, type):
|
62 |
if image == None:
|
63 |
raise gr.Error("Please generate the object first")
|
|
|
78 |
# Perform Rodin mesh operation
|
79 |
def rodin_mesh(prompt, group_uuid, settings, images, name, token):
|
80 |
images = [convert_base64_to_binary(img) for img in images]
|
81 |
+
|
82 |
m = MultipartEncoder(
|
83 |
fields={
|
84 |
'prompt': prompt,
|
|
|
100 |
def convert_base64_to_binary(base64_string):
|
101 |
if ',' in base64_string:
|
102 |
base64_string = base64_string.split(',')[1]
|
103 |
+
|
104 |
image_data = base64.b64decode(base64_string)
|
105 |
image_buffer = io.BytesIO(image_data)
|
106 |
+
|
107 |
return image_buffer
|
108 |
|
|
|
109 |
def rodin_update(prompt, task_uuid, token, settings):
|
110 |
headers = {
|
111 |
'Authorization': f'Bearer {token}'
|
|
|
113 |
response = requests.post(f"{BASE_URL}/task/rodin_update", data={"uuid": task_uuid, "prompt": prompt, "settings": settings}, headers=headers)
|
114 |
return response
|
115 |
|
|
|
116 |
def load_image(img_path):
|
117 |
try:
|
118 |
image = Image.open(img_path)
|
|
|
135 |
image_bytes = byte_io.getvalue()
|
136 |
return image_bytes
|
137 |
|
|
|
138 |
def log(level, info_text):
|
139 |
print(f"[ {level} ] - {time.strftime('%Y%m%d_%H:%M:%S', time.localtime())} - {info_text}")
|
140 |
|
|
|
141 |
class Generator:
|
142 |
def __init__(self, user_id, password, token) -> None:
|
143 |
# _, self.token = login(user_id, password)
|
|
|
146 |
self.password = password
|
147 |
self.task_uuid = None
|
148 |
self.processed_image = None
|
149 |
+
|
150 |
+
def preprocess(self, prompt, image_path, processed_image , task_uuid=""):
|
151 |
+
if image_path == None:
|
152 |
raise gr.Error("Please upload an image first")
|
153 |
+
|
154 |
if processed_image and prompt and (not task_uuid):
|
155 |
log("INFO", "Using cached image and prompt...")
|
156 |
return prompt, processed_image
|
|
|
160 |
while not success:
|
161 |
if try_times > 3:
|
162 |
raise gr.Error("Failed to preprocess image")
|
163 |
+
try_times += 1
|
164 |
image_file = load_image(image_path)
|
165 |
log("INFO", "Image loaded, processing...")
|
166 |
+
|
167 |
try:
|
168 |
if prompt and task_uuid:
|
169 |
res = rodin_preprocess_image(generate_prompt=False, image=image_file, name=os.path.basename(image_path), token=self.token)
|
|
|
200 |
|
201 |
log("INFO", "Image preprocessed successfully")
|
202 |
return prompt, processed_image
|
203 |
+
|
204 |
def generate_mesh(self, prompt, processed_image, task_uuid=""):
|
205 |
log("INFO", "Generating mesh...")
|
206 |
if task_uuid == "":
|
207 |
settings = {'view_weights': [1]} # Define weights as per your requirements, for multiple images, use multiple values, e,g [0.5, 0.5]
|
208 |
images = [processed_image] # List of images, all the images should be processed first
|
209 |
+
|
210 |
res = rodin_mesh(prompt=prompt, group_uuid=None, settings=settings, images=images, name="images.jpeg", token=self.token)
|
211 |
try:
|
212 |
mesh_response = res.json()
|
|
|
215 |
except Exception as e:
|
216 |
log("ERROR", f"Error in generating mesh: {e} and response: {res}")
|
217 |
raise gr.Error("Error in generating mesh, please try again later.")
|
218 |
+
|
219 |
+
task_uuid = mesh_response['uuid'] # The task_uuid should be same during whole generation process
|
220 |
else:
|
221 |
new_prompt = prompt
|
222 |
settings = {
|
223 |
"view_weights": [1],
|
224 |
+
"seed": random.randint(0, 10000), # Customize your seed here
|
225 |
+
"escore": 5.5, # Temprature
|
226 |
}
|
227 |
res = rodin_update(new_prompt, task_uuid, self.token, settings)
|
228 |
try:
|
|
|
240 |
except Exception as e:
|
241 |
log("ERROR", f"Error in generating mesh: {history}")
|
242 |
raise gr.Error("Busy connection, please try again later.")
|
243 |
+
|
244 |
response = requests.get(preview_image, stream=True)
|
245 |
if response.status_code == 200:
|
246 |
image = Image.open(response.raw)
|
|
|
257 |
self.subscription_key = subscription_key
|
258 |
self.sio = socketio.Client(logger=True, engineio_logger=True)
|
259 |
|
260 |
+
@self.sio.event
|
261 |
+
def connect():
|
262 |
+
print("Connected to the server.")
|
263 |
|
264 |
+
@self.sio.event
|
265 |
+
def disconnect():
|
266 |
+
print("Disconnected from server.")
|
267 |
|
268 |
@self.sio.on('message', namespace='*')
|
269 |
def message(*args, **kwargs):
|
|
|
|
|
270 |
if len(args) > 2:
|
271 |
data = args[2]
|
272 |
if data.get('jobStatus') == 'Succeeded':
|
273 |
+
print("Job Succeeded! Please find the SDF image in history")
|
274 |
+
self.sio.disconnect()
|
|
|
|
|
|
|
|
|
|
|
275 |
else:
|
276 |
+
print("Received event with insufficient arguments.")
|
277 |
|
278 |
def start(self):
|
279 |
+
self.sio.connect(f"{self.base_url}/scheduler_socket?subscription={self.subscription_key}",
|
280 |
namespaces=['/api/scheduler_socket'], transports='websocket')
|
281 |
+
self.sio.wait()
|
app.py
CHANGED
@@ -1,21 +1,8 @@
|
|
1 |
import os
|
2 |
-
os.system('
|
3 |
-
os.system('
|
4 |
-
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
5 |
-
os.environ["TORCH_LINEAR_FLATTEN_3D"] = "1"
|
6 |
-
|
7 |
-
import cv2
|
8 |
-
import time
|
9 |
-
import numpy as np
|
10 |
-
import torch
|
11 |
-
|
12 |
-
from openclay.pipeline_openclay import OpenClayPipeline
|
13 |
-
from openclay.models import ClayVAE, ClayLDM, ClayConditionNet
|
14 |
-
from openclay.utils import process_image_square
|
15 |
-
from transformers import Dinov2Model, BitImageProcessor, CLIPTextModel, CLIPTokenizer
|
16 |
|
17 |
import gradio as gr
|
18 |
-
import spaces
|
19 |
import re
|
20 |
from gradio_fake3d import Fake3D
|
21 |
from PIL import Image
|
@@ -23,57 +10,6 @@ from Rodin import Generator, crop_image, log, convert_base64_to_binary
|
|
23 |
from constant import *
|
24 |
|
25 |
generator = Generator(USER, PASSWORD, TOKEN)
|
26 |
-
os.makedirs(FOLDER_TEMP_MESH, exist_ok=True)
|
27 |
-
|
28 |
-
device = torch.device("cuda")
|
29 |
-
|
30 |
-
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
31 |
-
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16).eval().requires_grad_(False).to(device)
|
32 |
-
|
33 |
-
image_processor = BitImageProcessor.from_pretrained("facebook/dinov2-giant", torch_dtype=torch.float16)
|
34 |
-
image_encoder = Dinov2Model.from_pretrained("facebook/dinov2-giant", torch_dtype=torch.float16).eval().requires_grad_(False).to(device)
|
35 |
-
|
36 |
-
vae = ClayVAE.from_pretrained("DEEMOSTECH/CLAYV1_VAE", token=ACCESS_TOKEN, torch_dtype=torch.float16).eval().requires_grad_(False).to(device)
|
37 |
-
ldm = ClayLDM.from_pretrained("DEEMOSTECH/CLAYV1_LDM_MEDIUM", token=ACCESS_TOKEN, torch_dtype=torch.float16).eval().requires_grad_(False).to(device)
|
38 |
-
condition_net_image = ClayConditionNet.from_pretrained("DEEMOSTECH/CLAYV1_LDM_MEDIUM_CONDITION_IMAGE", token=ACCESS_TOKEN, torch_dtype=torch.float16).eval().requires_grad_(False).to(device)
|
39 |
-
ldm.register_condition_net([condition_net_image])
|
40 |
-
|
41 |
-
pipe = OpenClayPipeline(
|
42 |
-
vae=vae,
|
43 |
-
text_encoder=text_encoder,
|
44 |
-
tokenizer=tokenizer,
|
45 |
-
ldm=ldm,
|
46 |
-
)
|
47 |
-
|
48 |
-
|
49 |
-
@spaces.GPU
|
50 |
-
def read_image(image, image_processor, image_encoder, size=224):
|
51 |
-
render = process_image_square(image)
|
52 |
-
render = cv2.resize(render, (size, size))
|
53 |
-
image_pixel_values = image_processor(render, return_tensors="pt", do_rescale=False,
|
54 |
-
do_resize=False, do_center_crop=False)["pixel_values"][0]
|
55 |
-
image_embeds_patch = image_encoder(image_pixel_values.half().to(image_encoder.device)[None])["last_hidden_state"]
|
56 |
-
return render, image_embeds_patch
|
57 |
-
|
58 |
-
@spaces.GPU
|
59 |
-
def local_inference(block_prompt, image_pil):
|
60 |
-
image = np.array(image_pil)
|
61 |
-
_, image_embeds_patch = read_image(image, image_processor, image_encoder)
|
62 |
-
mesh = pipe(
|
63 |
-
prompt=block_prompt, negative_prompt='fragmentation.',
|
64 |
-
res=256,
|
65 |
-
num_inference_steps=50,
|
66 |
-
mini_batch=65**3,
|
67 |
-
seed=42, num=1,
|
68 |
-
condition_seq=[(image_embeds_patch, [1, 0])],
|
69 |
-
)
|
70 |
-
|
71 |
-
mesh_path = f"{FOLDER_TEMP_MESH}/{int(time.time())}.glb"
|
72 |
-
os.makedirs(os.path.dirname(mesh_path), exist_ok=True)
|
73 |
-
mesh.export(mesh_path)
|
74 |
-
|
75 |
-
return mesh_path
|
76 |
-
|
77 |
|
78 |
change_button_name = """
|
79 |
function updateButton(input) {
|
@@ -83,14 +19,6 @@ function updateButton(input) {
|
|
83 |
}
|
84 |
"""
|
85 |
|
86 |
-
change_button_name_600 = """
|
87 |
-
function updateButton(input) {
|
88 |
-
var buttonGenerate = document.getElementById('button_generate_600');
|
89 |
-
buttonGenerate.innerText = 'Redo';
|
90 |
-
return '';
|
91 |
-
}
|
92 |
-
"""
|
93 |
-
|
94 |
change_button_name_to_generating = """
|
95 |
function updateButton(input) {
|
96 |
var buttonGenerate = document.getElementById('button_generate');
|
@@ -99,15 +27,6 @@ function updateButton(input) {
|
|
99 |
}
|
100 |
"""
|
101 |
|
102 |
-
change_button_name_to_generating_600 = """
|
103 |
-
function updateButton(input) {
|
104 |
-
var buttonGenerate = document.getElementById('button_generate_600');
|
105 |
-
buttonGenerate.innerText = 'Generating...';
|
106 |
-
return '';
|
107 |
-
}
|
108 |
-
"""
|
109 |
-
|
110 |
-
|
111 |
reset_button_name = """
|
112 |
function updateButton(input) {
|
113 |
var buttonGenerate = document.getElementById('button_generate');
|
@@ -116,15 +35,6 @@ function updateButton(input) {
|
|
116 |
}
|
117 |
"""
|
118 |
|
119 |
-
reset_button_name_600 = """
|
120 |
-
function updateButton(input) {
|
121 |
-
var buttonGenerate = document.getElementById('button_generate_600');
|
122 |
-
buttonGenerate.innerText = 'Generate';
|
123 |
-
return '';
|
124 |
-
}
|
125 |
-
"""
|
126 |
-
|
127 |
-
|
128 |
jump_to_rodin = """
|
129 |
function redirectToGithub(input) {
|
130 |
if (input.includes('OpenClay')) {
|
@@ -188,23 +98,18 @@ example = [
|
|
188 |
["assets/46.png"]
|
189 |
]
|
190 |
|
191 |
-
|
192 |
def do_nothing(text):
|
193 |
return ""
|
194 |
|
195 |
-
|
196 |
def handle_selection(selection):
|
197 |
return "Rodin Gen-1(0525)"
|
198 |
|
199 |
-
|
200 |
def hint_in_prompt(hint, prompt):
|
201 |
return re.search(fr"{hint[:-1]}", prompt) is not None
|
202 |
|
203 |
-
|
204 |
def prompt_remove_hint(prompt, hint):
|
205 |
return re.sub(fr"\s*{hint[:-1]}[\.,]*", "", prompt)
|
206 |
|
207 |
-
|
208 |
def handle_hint_change(prompt: str, prompt_hint):
|
209 |
prompt = prompt.strip()
|
210 |
if prompt != "" and not prompt.endswith("."):
|
@@ -218,7 +123,6 @@ def handle_hint_change(prompt: str, prompt_hint):
|
|
218 |
prompt = prompt.strip()
|
219 |
return prompt
|
220 |
|
221 |
-
|
222 |
def handle_prompt_change(prompt):
|
223 |
hint_list = []
|
224 |
for _, hint in PROMPT_HINT_LIST:
|
@@ -227,15 +131,6 @@ def handle_prompt_change(prompt):
|
|
227 |
|
228 |
return hint_list
|
229 |
|
230 |
-
def preprocessing(prompt, image_path, processed_image, task_uuid=""):
|
231 |
-
prompt, image_base64 = generator.preprocess(prompt, image_path, processed_image, task_uuid)
|
232 |
-
image_rgb = convert_base64_to_binary(image_base64)
|
233 |
-
image_rgb = cv2.imdecode(np.frombuffer(image_rgb.getvalue(), np.uint8), -1)[...,[2,1,0,3]]
|
234 |
-
image_rgb = Image.fromarray(image_rgb, 'RGBA')
|
235 |
-
# image_rgb = cv2.resize(image_rgb, (256, 256), cv2.INTER_AREA)
|
236 |
-
|
237 |
-
return prompt, image_base64, image_rgb
|
238 |
-
|
239 |
def clear_task(task_input=None):
|
240 |
"""_summary_
|
241 |
[cache_task_uuid, block_prompt, block_prompt_hint, fake3d, block_3d]
|
@@ -243,33 +138,26 @@ def clear_task(task_input=None):
|
|
243 |
log("INFO", "Clearing task...")
|
244 |
return "", "", "", [], "assets/white_image.png"
|
245 |
|
246 |
-
|
247 |
def clear_task_id():
|
248 |
return ""
|
249 |
|
250 |
-
|
251 |
def return_render(image):
|
252 |
image = Image.fromarray(image)
|
253 |
return image, crop_image(image, DEFAULT)
|
254 |
|
255 |
-
|
256 |
def crop_image_default(image):
|
257 |
return crop_image(image, DEFAULT)
|
258 |
|
259 |
-
|
260 |
def crop_image_metal(image):
|
261 |
return crop_image(image, METAL)
|
262 |
|
263 |
-
|
264 |
def crop_image_contrast(image):
|
265 |
return crop_image(image, CONTRAST)
|
266 |
|
267 |
-
|
268 |
def crop_image_normal(image):
|
269 |
return crop_image(image, NORMAL)
|
270 |
|
271 |
-
|
272 |
-
with gr.Blocks(css=css) as demo:
|
273 |
gr.HTML(html_content)
|
274 |
|
275 |
cache_task_uuid = gr.Text(value="", visible=False)
|
@@ -279,86 +167,66 @@ with gr.Blocks(css=css) as demo:
|
|
279 |
|
280 |
with gr.Row():
|
281 |
with gr.Column():
|
|
|
|
|
282 |
with gr.Group():
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
type="filepath"
|
292 |
-
)
|
293 |
-
block_image_masked = gr.Image(
|
294 |
-
label='Preprocessed',
|
295 |
-
height=max_height,
|
296 |
-
elem_id="elem_block_image_crop",
|
297 |
-
elem_classes="elem_imagebox",
|
298 |
-
interactive=False,
|
299 |
-
)
|
300 |
-
|
301 |
-
|
302 |
-
block_prompt = gr.Textbox(
|
303 |
-
value="",
|
304 |
-
placeholder="Auto generated description of Image",
|
305 |
-
lines=1,
|
306 |
-
show_label=True,
|
307 |
-
label="Prompt",
|
308 |
-
)
|
309 |
-
|
310 |
-
block_prompt_hint = gr.CheckboxGroup(value="Labels", choices=PROMPT_HINT_LIST, show_label=False)
|
311 |
-
|
312 |
-
with gr.Column(elem_id="right_col"):
|
313 |
-
with gr.Group(elem_id="right_col_group"):
|
314 |
-
with gr.Row(elem_id="right_col_group_row"):
|
315 |
-
with gr.Group(elem_id="right_col_group_row_gleft"):
|
316 |
-
block_3d = gr.Model3D(
|
317 |
-
value='./empty.obj',
|
318 |
-
height=320,
|
319 |
-
camera_position=(90 + 30, 90 - 15, 3),
|
320 |
-
zoom_speed=0.2,
|
321 |
-
pan_speed=0.3,
|
322 |
-
label="3D Preview (OpenCLAY(600M))",
|
323 |
-
elem_id="block_3d"
|
324 |
-
)
|
325 |
-
|
326 |
-
button_generate_600 = gr.Button(value="Generate", variant="primary", elem_id="button_generate_600")
|
327 |
-
|
328 |
-
with gr.Group(elem_id="right_col_group_row_gright"):
|
329 |
-
fake3d = Fake3D(interactive=False,
|
330 |
-
# height=320,
|
331 |
-
# width=320,
|
332 |
-
label="3D Preview (Rodin Gen-1(0525))",
|
333 |
-
elem_id="fake3d"
|
334 |
-
)
|
335 |
-
|
336 |
-
with gr.Row():
|
337 |
-
button_generate = gr.Button(value="Generate", variant="primary", elem_id="button_generate")
|
338 |
-
button_more = gr.Button(value="Download", variant="primary", link=rodin_url)
|
339 |
|
|
|
|
|
|
|
|
|
|
|
340 |
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
)
|
350 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
351 |
|
352 |
block_image.upload(
|
353 |
-
fn=do_nothing,
|
354 |
-
js=change_button_name_to_generating,
|
355 |
inputs=[cacha_empty],
|
356 |
outputs=[cacha_empty],
|
357 |
queue=False
|
358 |
).success(
|
359 |
-
fn=
|
360 |
-
inputs=[block_prompt, block_image, cache_image_base64, cache_task_uuid],
|
361 |
-
outputs=[block_prompt, cache_image_base64
|
362 |
show_progress="minimal",
|
363 |
queue=True
|
364 |
).success(
|
@@ -367,36 +235,36 @@ with gr.Blocks(css=css) as demo:
|
|
367 |
outputs=[cache_raw_image, cache_task_uuid, fake3d],
|
368 |
queue=True
|
369 |
).success(
|
370 |
-
fn=do_nothing,
|
371 |
-
js=change_button_name,
|
372 |
inputs=[cacha_empty],
|
373 |
outputs=[cacha_empty],
|
374 |
queue=False
|
375 |
)
|
376 |
-
|
377 |
block_image.clear(
|
378 |
-
fn=do_nothing,
|
379 |
-
js=reset_button_name,
|
380 |
inputs=[cacha_empty],
|
381 |
outputs=[cacha_empty],
|
382 |
queue=False
|
383 |
).then(
|
384 |
-
fn=clear_task,
|
385 |
-
outputs=[cache_image_base64, cache_task_uuid, block_prompt, block_prompt_hint, fake3d],
|
386 |
show_progress="hidden",
|
387 |
queue=False
|
388 |
)
|
389 |
-
|
390 |
button_generate.click(
|
391 |
-
fn=do_nothing,
|
392 |
-
js=change_button_name_to_generating,
|
393 |
inputs=[cacha_empty],
|
394 |
outputs=[cacha_empty],
|
395 |
queue=False
|
396 |
).success(
|
397 |
-
fn=
|
398 |
-
inputs=[block_prompt, block_image, cache_image_base64, cache_task_uuid],
|
399 |
-
outputs=[block_prompt, cache_image_base64
|
400 |
show_progress="minimal",
|
401 |
queue=True
|
402 |
).success(
|
@@ -405,46 +273,26 @@ with gr.Blocks(css=css) as demo:
|
|
405 |
outputs=[cache_raw_image, cache_task_uuid, fake3d],
|
406 |
queue=True
|
407 |
).then(
|
408 |
-
fn=do_nothing,
|
409 |
-
js=change_button_name,
|
410 |
-
inputs=[cacha_empty],
|
411 |
-
outputs=[cacha_empty],
|
412 |
-
queue=False
|
413 |
-
)
|
414 |
-
|
415 |
-
button_generate_600.click(
|
416 |
-
fn=do_nothing,
|
417 |
-
js=change_button_name_to_generating_600,
|
418 |
-
inputs=[cacha_empty],
|
419 |
-
outputs=[cacha_empty],
|
420 |
-
queue=False
|
421 |
-
).success(
|
422 |
-
fn=preprocessing,
|
423 |
-
inputs=[block_prompt, block_image, cache_image_base64, cache_task_uuid],
|
424 |
-
outputs=[block_prompt, cache_image_base64, block_image_masked],
|
425 |
-
show_progress="minimal",
|
426 |
-
queue=True
|
427 |
-
).success(
|
428 |
-
fn=local_inference,
|
429 |
-
inputs=[block_prompt, block_image_masked],
|
430 |
-
outputs=[block_3d],
|
431 |
-
queue=True
|
432 |
-
).then(
|
433 |
-
fn=do_nothing,
|
434 |
-
js=change_button_name_600,
|
435 |
inputs=[cacha_empty],
|
436 |
outputs=[cacha_empty],
|
437 |
queue=False
|
438 |
)
|
439 |
-
|
|
|
|
|
|
|
|
|
|
|
440 |
button_more.click()
|
441 |
-
|
442 |
block_prompt_hint.input(
|
443 |
fn=handle_hint_change, inputs=[block_prompt, block_prompt_hint], outputs=[block_prompt],
|
444 |
show_progress="hidden",
|
445 |
queue=False,
|
446 |
)
|
447 |
-
|
448 |
block_prompt.change(
|
449 |
fn=handle_prompt_change,
|
450 |
inputs=[block_prompt],
|
@@ -452,6 +300,8 @@ with gr.Blocks(css=css) as demo:
|
|
452 |
trigger_mode="always_last",
|
453 |
show_progress="hidden",
|
454 |
)
|
|
|
|
|
455 |
|
456 |
|
457 |
if __name__ == "__main__":
|
|
|
1 |
import os
|
2 |
+
os.system('pip uninstall -y gradio_fake3d')
|
3 |
+
os.system('pip install gradio_fake3d-0.0.3-py3-none-any.whl')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
import gradio as gr
|
|
|
6 |
import re
|
7 |
from gradio_fake3d import Fake3D
|
8 |
from PIL import Image
|
|
|
10 |
from constant import *
|
11 |
|
12 |
generator = Generator(USER, PASSWORD, TOKEN)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
change_button_name = """
|
15 |
function updateButton(input) {
|
|
|
19 |
}
|
20 |
"""
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
change_button_name_to_generating = """
|
23 |
function updateButton(input) {
|
24 |
var buttonGenerate = document.getElementById('button_generate');
|
|
|
27 |
}
|
28 |
"""
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
reset_button_name = """
|
31 |
function updateButton(input) {
|
32 |
var buttonGenerate = document.getElementById('button_generate');
|
|
|
35 |
}
|
36 |
"""
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
jump_to_rodin = """
|
39 |
function redirectToGithub(input) {
|
40 |
if (input.includes('OpenClay')) {
|
|
|
98 |
["assets/46.png"]
|
99 |
]
|
100 |
|
|
|
101 |
def do_nothing(text):
|
102 |
return ""
|
103 |
|
|
|
104 |
def handle_selection(selection):
|
105 |
return "Rodin Gen-1(0525)"
|
106 |
|
|
|
107 |
def hint_in_prompt(hint, prompt):
|
108 |
return re.search(fr"{hint[:-1]}", prompt) is not None
|
109 |
|
|
|
110 |
def prompt_remove_hint(prompt, hint):
|
111 |
return re.sub(fr"\s*{hint[:-1]}[\.,]*", "", prompt)
|
112 |
|
|
|
113 |
def handle_hint_change(prompt: str, prompt_hint):
|
114 |
prompt = prompt.strip()
|
115 |
if prompt != "" and not prompt.endswith("."):
|
|
|
123 |
prompt = prompt.strip()
|
124 |
return prompt
|
125 |
|
|
|
126 |
def handle_prompt_change(prompt):
|
127 |
hint_list = []
|
128 |
for _, hint in PROMPT_HINT_LIST:
|
|
|
131 |
|
132 |
return hint_list
|
133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
def clear_task(task_input=None):
|
135 |
"""_summary_
|
136 |
[cache_task_uuid, block_prompt, block_prompt_hint, fake3d, block_3d]
|
|
|
138 |
log("INFO", "Clearing task...")
|
139 |
return "", "", "", [], "assets/white_image.png"
|
140 |
|
|
|
141 |
def clear_task_id():
|
142 |
return ""
|
143 |
|
|
|
144 |
def return_render(image):
|
145 |
image = Image.fromarray(image)
|
146 |
return image, crop_image(image, DEFAULT)
|
147 |
|
|
|
148 |
def crop_image_default(image):
|
149 |
return crop_image(image, DEFAULT)
|
150 |
|
|
|
151 |
def crop_image_metal(image):
|
152 |
return crop_image(image, METAL)
|
153 |
|
|
|
154 |
def crop_image_contrast(image):
|
155 |
return crop_image(image, CONTRAST)
|
156 |
|
|
|
157 |
def crop_image_normal(image):
|
158 |
return crop_image(image, NORMAL)
|
159 |
|
160 |
+
with gr.Blocks() as demo:
|
|
|
161 |
gr.HTML(html_content)
|
162 |
|
163 |
cache_task_uuid = gr.Text(value="", visible=False)
|
|
|
167 |
|
168 |
with gr.Row():
|
169 |
with gr.Column():
|
170 |
+
block_image = gr.Image(height=256, image_mode="RGB", sources="upload", elem_classes="elem_imageupload", type="filepath")
|
171 |
+
block_model_card = gr.Dropdown(choices=options, label="Model Card", value="Rodin Gen-1(0525)", interactive=True)
|
172 |
with gr.Group():
|
173 |
+
block_prompt = gr.Textbox(
|
174 |
+
value="",
|
175 |
+
placeholder="Auto generated description of Image",
|
176 |
+
lines=1,
|
177 |
+
show_label=True,
|
178 |
+
label="Prompt",
|
179 |
+
)
|
180 |
+
block_prompt_hint = gr.CheckboxGroup(value="Labels", choices=PROMPT_HINT_LIST, show_label=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
|
182 |
+
with gr.Column():
|
183 |
+
with gr.Group():
|
184 |
+
fake3d = Fake3D(interactive=False, label="3D Preview")
|
185 |
+
with gr.Row():
|
186 |
+
button_generate = gr.Button(value="Generate", variant="primary", elem_id="button_generate")
|
187 |
|
188 |
+
with gr.Column(min_width=200, scale=20):
|
189 |
+
with gr.Row():
|
190 |
+
block_default = gr.Button("Default", min_width=0)
|
191 |
+
block_metal = gr.Button("Metal", min_width=0)
|
192 |
+
with gr.Row():
|
193 |
+
block_contrast = gr.Button("Contrast", min_width=0)
|
194 |
+
block_normal = gr.Button("Normal", min_width=0)
|
195 |
+
|
196 |
+
button_more = gr.Button(value="Download from Rodin", variant="primary", link=rodin_url)
|
197 |
+
gr.Markdown("""
|
198 |
+
**TIPS**:
|
199 |
+
1. Upload an image to generate 3D geometry.
|
200 |
+
2. Click Redo to regenerate the model.
|
201 |
+
3. 4 buttons to switch the view.
|
202 |
+
4. Swipe to rotate the model.
|
203 |
+
""")
|
204 |
+
cache_task_uuid = gr.Text(value="", visible=False)
|
205 |
+
|
206 |
+
|
207 |
+
cache_raw_image = gr.Image(visible=False, type="pil")
|
208 |
+
cacha_empty = gr.Text(visible=False)
|
209 |
+
cache_image_base64 = gr.Text(visible=False)
|
210 |
+
block_example = gr.Examples(
|
211 |
+
examples=example,
|
212 |
+
fn=clear_task,
|
213 |
+
inputs=[block_image],
|
214 |
+
outputs=[cache_image_base64, cache_task_uuid, block_prompt, block_prompt_hint, fake3d],
|
215 |
+
run_on_click=True,
|
216 |
+
cache_examples=True,
|
217 |
+
label="Examples"
|
218 |
+
)
|
219 |
|
220 |
block_image.upload(
|
221 |
+
fn=do_nothing,
|
222 |
+
js=change_button_name_to_generating,
|
223 |
inputs=[cacha_empty],
|
224 |
outputs=[cacha_empty],
|
225 |
queue=False
|
226 |
).success(
|
227 |
+
fn=generator.preprocess,
|
228 |
+
inputs=[block_prompt, block_image, cache_image_base64, cache_task_uuid],
|
229 |
+
outputs=[block_prompt, cache_image_base64],
|
230 |
show_progress="minimal",
|
231 |
queue=True
|
232 |
).success(
|
|
|
235 |
outputs=[cache_raw_image, cache_task_uuid, fake3d],
|
236 |
queue=True
|
237 |
).success(
|
238 |
+
fn=do_nothing,
|
239 |
+
js=change_button_name,
|
240 |
inputs=[cacha_empty],
|
241 |
outputs=[cacha_empty],
|
242 |
queue=False
|
243 |
)
|
244 |
+
|
245 |
block_image.clear(
|
246 |
+
fn=do_nothing,
|
247 |
+
js=reset_button_name,
|
248 |
inputs=[cacha_empty],
|
249 |
outputs=[cacha_empty],
|
250 |
queue=False
|
251 |
).then(
|
252 |
+
fn=clear_task,
|
253 |
+
outputs=[cache_image_base64, cache_task_uuid, block_prompt, block_prompt_hint, fake3d],
|
254 |
show_progress="hidden",
|
255 |
queue=False
|
256 |
)
|
257 |
+
|
258 |
button_generate.click(
|
259 |
+
fn=do_nothing,
|
260 |
+
js=change_button_name_to_generating,
|
261 |
inputs=[cacha_empty],
|
262 |
outputs=[cacha_empty],
|
263 |
queue=False
|
264 |
).success(
|
265 |
+
fn=generator.preprocess,
|
266 |
+
inputs=[block_prompt, block_image, cache_image_base64, cache_task_uuid],
|
267 |
+
outputs=[block_prompt, cache_image_base64],
|
268 |
show_progress="minimal",
|
269 |
queue=True
|
270 |
).success(
|
|
|
273 |
outputs=[cache_raw_image, cache_task_uuid, fake3d],
|
274 |
queue=True
|
275 |
).then(
|
276 |
+
fn=do_nothing,
|
277 |
+
js=change_button_name,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
278 |
inputs=[cacha_empty],
|
279 |
outputs=[cacha_empty],
|
280 |
queue=False
|
281 |
)
|
282 |
+
|
283 |
+
block_default.click(fn=crop_image_default, inputs=[cache_raw_image], outputs=fake3d, show_progress="minimal")
|
284 |
+
block_metal.click(fn=crop_image_metal, inputs=[cache_raw_image], outputs=fake3d, show_progress="minimal")
|
285 |
+
block_contrast.click(fn=crop_image_contrast, inputs=[cache_raw_image], outputs=fake3d, show_progress="minimal")
|
286 |
+
block_normal.click(fn=crop_image_normal, inputs=[cache_raw_image], outputs=fake3d, show_progress="minimal")
|
287 |
+
|
288 |
button_more.click()
|
289 |
+
|
290 |
block_prompt_hint.input(
|
291 |
fn=handle_hint_change, inputs=[block_prompt, block_prompt_hint], outputs=[block_prompt],
|
292 |
show_progress="hidden",
|
293 |
queue=False,
|
294 |
)
|
295 |
+
|
296 |
block_prompt.change(
|
297 |
fn=handle_prompt_change,
|
298 |
inputs=[block_prompt],
|
|
|
300 |
trigger_mode="always_last",
|
301 |
show_progress="hidden",
|
302 |
)
|
303 |
+
|
304 |
+
block_model_card.change(fn=handle_selection, inputs=[block_model_card], outputs=[block_model_card], show_progress="hidden", js=jump_to_rodin)
|
305 |
|
306 |
|
307 |
if __name__ == "__main__":
|
constant.py
CHANGED
@@ -8,9 +8,6 @@ USER = os.getenv("USER")
|
|
8 |
PASSWORD = os.getenv("PASSWORD")
|
9 |
TOKEN = os.getenv("TOKEN")
|
10 |
|
11 |
-
ACCESS_TOKEN = os.getenv("ACCESS_TOKEN")
|
12 |
-
FOLDER_TEMP_MESH = './tmp_mesh'
|
13 |
-
|
14 |
DEFAULT = [0, 0]
|
15 |
CONTRAST = [360, 0]
|
16 |
METAL = [0, 360]
|
|
|
8 |
PASSWORD = os.getenv("PASSWORD")
|
9 |
TOKEN = os.getenv("TOKEN")
|
10 |
|
|
|
|
|
|
|
11 |
DEFAULT = [0, 0]
|
12 |
CONTRAST = [360, 0]
|
13 |
METAL = [0, 360]
|
openclay/models/__init__.py
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
from .condition import ClayConditionNet
|
2 |
-
from .ldm import ClayLDM
|
3 |
-
from .vae import ClayVAE
|
|
|
|
|
|
|
|
openclay/models/condition.py
DELETED
@@ -1,102 +0,0 @@
|
|
1 |
-
import copy
|
2 |
-
import numpy as np
|
3 |
-
import torch
|
4 |
-
import torch.nn as nn
|
5 |
-
import torch.nn.functional as F
|
6 |
-
from diffusers import ModelMixin, ConfigMixin
|
7 |
-
|
8 |
-
from diffusers.configuration_utils import register_to_config
|
9 |
-
from ..modules.embedding import PointEmbed
|
10 |
-
from ..modules.control_volume import ControlVolume
|
11 |
-
from ..utils import get_center_position
|
12 |
-
|
13 |
-
class ClayConditionNet(ModelMixin, ConfigMixin):
|
14 |
-
@register_to_config
|
15 |
-
def __init__(
|
16 |
-
self,
|
17 |
-
kv_dim,
|
18 |
-
ldm_dim, ldm_heads, ldm_depth,
|
19 |
-
on_volume=False, volume_input_dim=1, volume_block_dim=256, aggregation_method=None,
|
20 |
-
on_point=False, point_number=8, point_leading_embed=False, additional_token_length=None,
|
21 |
-
stage="parallel",
|
22 |
-
):
|
23 |
-
super().__init__()
|
24 |
-
self.multihead_attn_condition_list = nn.ModuleList([copy.deepcopy(
|
25 |
-
nn.MultiheadAttention(ldm_dim, ldm_heads, dropout=0, batch_first=True, kdim=kv_dim, vdim=kv_dim)
|
26 |
-
) for i in range(ldm_depth)])
|
27 |
-
|
28 |
-
if on_volume:
|
29 |
-
self.condition_volume_point_embed = PointEmbed(dim=kv_dim)
|
30 |
-
self.condition_volume_conv = ControlVolume(volume_dim=volume_input_dim,
|
31 |
-
block_dim=volume_block_dim,
|
32 |
-
condition_dim=kv_dim,
|
33 |
-
time_embed_dim=ldm_dim,
|
34 |
-
downsample_times=1,
|
35 |
-
aggregation_method=aggregation_method)
|
36 |
-
|
37 |
-
if on_point:
|
38 |
-
self.condition_point_point_embed = PointEmbed(dim=kv_dim)
|
39 |
-
if point_number == 0:
|
40 |
-
self.condition_point_token = None
|
41 |
-
else:
|
42 |
-
self.condition_point_token = nn.Parameter(torch.randn(point_number, kv_dim))
|
43 |
-
|
44 |
-
self.on_volume = on_volume
|
45 |
-
self.on_point = on_point
|
46 |
-
|
47 |
-
assert stage in {"parallel", "postfix"}
|
48 |
-
self.stage = stage
|
49 |
-
|
50 |
-
def preprocess_condition(self, condition, additional_dict, time_embed) -> torch.Tensor:
|
51 |
-
|
52 |
-
if self.on_volume:
|
53 |
-
condition_volume = condition
|
54 |
-
Bor1, volume_dim, X, Y, Z = condition_volume.shape
|
55 |
-
assert X == Y == Z == 16
|
56 |
-
condition_volume = self.condition_volume_conv(condition_volume, time_embed) # [Bor1, condition_dim, X, Y, Z]
|
57 |
-
condition_volume = condition_volume.reshape(-1, condition_volume.shape[1], 8**3).permute(0, 2, 1) # [Bor1, X*Y*Z, condition_dim]
|
58 |
-
|
59 |
-
center_position = get_center_position(8)[None].to(time_embed).reshape(1, 8**3, 3) # [1, X*Y*Z, 3]
|
60 |
-
condition_volume = condition_volume + self.condition_volume_point_embed(center_position) # [Bor1, X*Y*Z, condition_dim]
|
61 |
-
condition = condition_volume
|
62 |
-
|
63 |
-
if self.on_point:
|
64 |
-
point = condition
|
65 |
-
Bor1, M, _ = point.shape
|
66 |
-
|
67 |
-
condition_point = self.condition_point_point_embed(point)
|
68 |
-
if self.condition_point_token is not None:
|
69 |
-
if self.config.point_leading_embed:
|
70 |
-
condition_point = torch.cat([
|
71 |
-
condition_point[:, :self.condition_point_token.shape[0]] + self.condition_point_token,
|
72 |
-
condition_point[:, self.condition_point_token.shape[0]:]
|
73 |
-
], dim=1)
|
74 |
-
else:
|
75 |
-
condition_point = condition_point + self.condition_point_token * additional_dict.get("condition_point_token_scale", 1)
|
76 |
-
|
77 |
-
condition = condition_point
|
78 |
-
|
79 |
-
return condition
|
80 |
-
|
81 |
-
def process(self, index, x, condition, key_padding_mask=None):
|
82 |
-
"""
|
83 |
-
x: [B, N, dim]
|
84 |
-
condition: [B, M, condition_dim]
|
85 |
-
"""
|
86 |
-
|
87 |
-
if self.stage == "parallel":
|
88 |
-
residual = self.multihead_attn_condition_list[index](x,
|
89 |
-
condition, condition, need_weights=False,
|
90 |
-
key_padding_mask=key_padding_mask
|
91 |
-
)[0]
|
92 |
-
else:
|
93 |
-
x = x + self.multihead_attn_condition_list[index](self.norm2_condition_list[index](x),
|
94 |
-
condition, condition, need_weights=False,
|
95 |
-
key_padding_mask=key_padding_mask
|
96 |
-
)[0]
|
97 |
-
|
98 |
-
residual = self.linear2_condition_list[index](F.gelu(self.linear1_condition_list[index]
|
99 |
-
(self.norm3_condition_list[index](x)))
|
100 |
-
)
|
101 |
-
|
102 |
-
return residual
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
openclay/models/ldm.py
DELETED
@@ -1,83 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
import torch.nn.functional as F
|
4 |
-
from diffusers import ModelMixin, ConfigMixin
|
5 |
-
from diffusers.configuration_utils import register_to_config
|
6 |
-
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
7 |
-
|
8 |
-
from ..modules.transformer import ClayTransformerDecoderLayer
|
9 |
-
|
10 |
-
class ClayLDM(ModelMixin, ConfigMixin):
|
11 |
-
@register_to_config
|
12 |
-
def __init__(
|
13 |
-
self,
|
14 |
-
depth=24,
|
15 |
-
dim=512,
|
16 |
-
latent_dim=64,
|
17 |
-
heads=8,
|
18 |
-
):
|
19 |
-
super().__init__()
|
20 |
-
|
21 |
-
timestep_input_dim = dim // 2
|
22 |
-
time_embed_dim = dim
|
23 |
-
self.time_proj = Timesteps(timestep_input_dim, True, 0)
|
24 |
-
self.time_embedding = TimestepEmbedding(
|
25 |
-
timestep_input_dim,
|
26 |
-
time_embed_dim,
|
27 |
-
act_fn="silu",
|
28 |
-
post_act_fn="silu",
|
29 |
-
)
|
30 |
-
self.time_embed_dim = time_embed_dim
|
31 |
-
|
32 |
-
self.proj_in = nn.Linear(latent_dim, dim, bias=False)
|
33 |
-
|
34 |
-
self.layers = nn.TransformerDecoder(
|
35 |
-
ClayTransformerDecoderLayer(dim, heads, dim_feedforward=dim * 4, dropout=0, activation=F.gelu, batch_first=True, norm_first=True, kdim=768, vdim=768, layer_norm_eps=1e-4),
|
36 |
-
depth
|
37 |
-
)
|
38 |
-
for i in range(depth):
|
39 |
-
self.layers.layers[i].index = i
|
40 |
-
|
41 |
-
self.proj_out = nn.Linear(dim, latent_dim, bias=False)
|
42 |
-
|
43 |
-
def register_condition_net(self, condition_net_list):
|
44 |
-
self.condition_net_list = nn.ModuleList(condition_net_list)
|
45 |
-
for layer in self.layers.layers:
|
46 |
-
layer.register_condition_net(condition_net_list)
|
47 |
-
|
48 |
-
def forward(self, sample, t, condition_text, condition_seq=None):
|
49 |
-
"""
|
50 |
-
sample: [B, L, C]
|
51 |
-
t: [N]
|
52 |
-
condition: [B, 77, 768]
|
53 |
-
condition_seq: [(condition_1, condition_1_scale), ...]
|
54 |
-
return: [B, L, C]
|
55 |
-
"""
|
56 |
-
|
57 |
-
B, L, C = sample.shape
|
58 |
-
|
59 |
-
x = self.proj_in(sample)
|
60 |
-
time_encoding = self.time_proj(t).to(sample)
|
61 |
-
time_embed = self.time_embedding(time_encoding)
|
62 |
-
|
63 |
-
condition_text = condition_text.expand(B, -1, -1)
|
64 |
-
|
65 |
-
if condition_seq is None:
|
66 |
-
condition_seq = []
|
67 |
-
condition_seq = list(condition_seq)
|
68 |
-
for i in range(len(condition_seq)):
|
69 |
-
condition, condition_scale = condition_seq[i][:2]
|
70 |
-
additional_dict = {}
|
71 |
-
if len(condition_seq[i]) > 2:
|
72 |
-
additional_dict = condition_seq[i][2]
|
73 |
-
assert isinstance(additional_dict, dict)
|
74 |
-
condition = self.condition_net_list[i].preprocess_condition(condition, additional_dict, time_embed)
|
75 |
-
condition = condition.expand(B, -1, -1)
|
76 |
-
|
77 |
-
condition_seq[i] = (condition, condition_scale, additional_dict)
|
78 |
-
|
79 |
-
x_aug = torch.cat([x, time_embed[:, None]], dim=1)
|
80 |
-
y = self.layers(x_aug, (condition_text, condition_seq))
|
81 |
-
eps = self.proj_out(y[:, :-1])
|
82 |
-
|
83 |
-
return eps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
openclay/models/vae.py
DELETED
@@ -1,124 +0,0 @@
|
|
1 |
-
import tqdm
|
2 |
-
import torch
|
3 |
-
import torch.nn as nn
|
4 |
-
import torch.nn.functional as F
|
5 |
-
|
6 |
-
from diffusers import ModelMixin, ConfigMixin
|
7 |
-
from diffusers.configuration_utils import register_to_config
|
8 |
-
# from torch_cluster import fps
|
9 |
-
|
10 |
-
from ..modules.embedding import PointEmbed
|
11 |
-
from ..modules.attention import CrossAttentionLayer
|
12 |
-
from ..modules.drop_path import DropPathWrapper
|
13 |
-
from ..modules.diag_gaussian import DiagonalGaussianDistribution
|
14 |
-
|
15 |
-
class ClayVAE(ModelMixin, ConfigMixin):
|
16 |
-
@register_to_config
|
17 |
-
def __init__(
|
18 |
-
self,
|
19 |
-
depth=24,
|
20 |
-
dim=512,
|
21 |
-
latent_dim=64,
|
22 |
-
heads=8,
|
23 |
-
output_dim=1,
|
24 |
-
ratio=0.25,
|
25 |
-
):
|
26 |
-
super().__init__()
|
27 |
-
|
28 |
-
self.ratio = ratio
|
29 |
-
|
30 |
-
self.point_embed = PointEmbed(dim=dim)
|
31 |
-
|
32 |
-
self.encoder_cross_attention = CrossAttentionLayer(dim, 1, dim_feedforward=dim * 4, dropout=0, activation=F.gelu, batch_first=True, norm_first=True, layer_norm_eps=1e-4)
|
33 |
-
self.encoder_out = nn.Linear(dim, latent_dim * 2)
|
34 |
-
|
35 |
-
self.decoder_in = nn.Linear(latent_dim, dim)
|
36 |
-
self.decoder_layers = nn.TransformerEncoder(
|
37 |
-
DropPathWrapper(
|
38 |
-
nn.TransformerEncoderLayer(dim,
|
39 |
-
heads,
|
40 |
-
dim_feedforward=dim * 4,
|
41 |
-
dropout=0,
|
42 |
-
activation=F.gelu,
|
43 |
-
batch_first=True,
|
44 |
-
norm_first=True,
|
45 |
-
layer_norm_eps=1e-4)
|
46 |
-
),
|
47 |
-
depth)
|
48 |
-
|
49 |
-
self.decoder_cross_attention = CrossAttentionLayer(dim, 1, dim_feedforward=dim * 4, dropout=0, activation=F.gelu, batch_first=True, norm_first=True, layer_norm_eps=1e-4)
|
50 |
-
self.decoder_out = nn.Linear(dim, output_dim)
|
51 |
-
|
52 |
-
def encode(self, pc_data, output_dict=None, attn_mask=None, no_cast=False, **kwargs):
|
53 |
-
|
54 |
-
pc = pc_data[:, :, :3]
|
55 |
-
|
56 |
-
# pc: B x N x 3
|
57 |
-
B, N, D = pc.shape
|
58 |
-
|
59 |
-
pc_flat = pc.reshape(B * N, D)
|
60 |
-
|
61 |
-
batch = torch.arange(B).to(pc.device)
|
62 |
-
batch = torch.repeat_interleave(batch, N)
|
63 |
-
|
64 |
-
ratio = self.ratio
|
65 |
-
idx = fps(pc_flat, batch, ratio=ratio)
|
66 |
-
|
67 |
-
while idx.max() >= pc_flat.shape[0]:
|
68 |
-
idx = fps(pc_flat, batch, ratio=ratio)
|
69 |
-
|
70 |
-
sampled_pc = pc_flat[idx].reshape(B, -1, 3)
|
71 |
-
|
72 |
-
pc_embeddings = self.point_embed(pc)
|
73 |
-
sampled_pc_embeddings = self.point_embed(sampled_pc)
|
74 |
-
|
75 |
-
x, attn_output_weights = self.encoder_cross_attention(sampled_pc_embeddings, pc_embeddings, attn_mask, need_weights=output_dict is not None, no_cast=no_cast)
|
76 |
-
mean, logvar = self.encoder_out(x).chunk(2, dim=-1)
|
77 |
-
|
78 |
-
posterior = DiagonalGaussianDistribution(mean, logvar)
|
79 |
-
x = posterior.sample()
|
80 |
-
kl = posterior.kl()
|
81 |
-
|
82 |
-
if output_dict is not None:
|
83 |
-
output_dict["fps_idx"] = idx
|
84 |
-
output_dict["mean"] = mean
|
85 |
-
output_dict["logvar"] = logvar
|
86 |
-
output_dict["x"] = x
|
87 |
-
output_dict["attn_output_weights"] = attn_output_weights
|
88 |
-
|
89 |
-
return kl, x
|
90 |
-
|
91 |
-
def decode(self, x, pc, mini_batch=None, no_cast=False, show_progress=True, cpu=True, **kwargs):
|
92 |
-
|
93 |
-
x = self.decode_first(x)
|
94 |
-
|
95 |
-
if mini_batch is None:
|
96 |
-
y = self.decode_second(x, pc, no_cast)
|
97 |
-
|
98 |
-
else:
|
99 |
-
ys = []
|
100 |
-
for mini_batch_start in tqdm.tqdm(range(0, pc.shape[1], mini_batch), "[ ClayVAE.decode ]", disable=not (show_progress and pc.shape[1] > mini_batch)):
|
101 |
-
mini_pc = pc[:, mini_batch_start:mini_batch_start + mini_batch].to(x.device)
|
102 |
-
y = self.decode_second(x, mini_pc, no_cast)
|
103 |
-
if cpu:
|
104 |
-
y = y.cpu()
|
105 |
-
ys.append(y)
|
106 |
-
y = torch.cat(ys, dim=1)
|
107 |
-
|
108 |
-
return y
|
109 |
-
|
110 |
-
def decode_first(self, x):
|
111 |
-
x = self.decoder_in(x)
|
112 |
-
x = self.decoder_layers(x)
|
113 |
-
return x
|
114 |
-
|
115 |
-
def decode_second(self, x, mini_pc, no_cast=False):
|
116 |
-
pc_embeddings = self.point_embed(mini_pc)
|
117 |
-
y, _ = self.decoder_cross_attention(pc_embeddings, x, no_cast=no_cast)
|
118 |
-
y = self.decoder_out(y)
|
119 |
-
return y
|
120 |
-
|
121 |
-
def forward(self, surface, points, no_cast=False, **kwargs):
|
122 |
-
kl, x = self.encode(surface, no_cast=no_cast, **kwargs)
|
123 |
-
x = self.decode(x, points, no_cast=no_cast, **kwargs)[:, :, 0]
|
124 |
-
return {"logits": x, "kl": kl}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
openclay/modules/attention.py
DELETED
@@ -1,73 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
import torch.nn.functional as F
|
4 |
-
from torch import Tensor
|
5 |
-
from typing import Optional, Union, Callable
|
6 |
-
|
7 |
-
|
8 |
-
class CrossAttentionLayer(nn.Module):
|
9 |
-
__constants__ = ["batch_first", "norm_first", "context_norm"]
|
10 |
-
|
11 |
-
def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
|
12 |
-
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
13 |
-
layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False, context_norm=True,
|
14 |
-
device=None, dtype=None) -> None:
|
15 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
16 |
-
super(CrossAttentionLayer, self).__init__()
|
17 |
-
self.multihead_attn = nn.modules.activation.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs)
|
18 |
-
# Implementation of Feedforward model
|
19 |
-
self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
|
20 |
-
self.dropout = nn.Dropout(dropout)
|
21 |
-
self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
|
22 |
-
|
23 |
-
self.norm_first = norm_first
|
24 |
-
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
25 |
-
self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
26 |
-
|
27 |
-
self.dropout2 = nn.Dropout(dropout)
|
28 |
-
self.dropout3 = nn.Dropout(dropout)
|
29 |
-
|
30 |
-
self.context_norm = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) if context_norm else None
|
31 |
-
# Legacy string support for activation function.
|
32 |
-
self.activation = activation
|
33 |
-
|
34 |
-
def forward(self, tgt: Tensor, memory: Tensor, memory_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, need_weights=False, no_cast=False) -> Tensor:
|
35 |
-
|
36 |
-
assert self.norm_first
|
37 |
-
|
38 |
-
if no_cast:
|
39 |
-
tgt = tgt.float()
|
40 |
-
memory = memory.float()
|
41 |
-
|
42 |
-
with torch.autocast("cuda", enabled=False):
|
43 |
-
x = tgt
|
44 |
-
memory = self.context_norm(memory) if self.context_norm is not None else memory
|
45 |
-
y, attn_output_weights = self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask, need_weights=need_weights)
|
46 |
-
x = x + y
|
47 |
-
x = x + self._ff_block(self.norm3(x))
|
48 |
-
|
49 |
-
return x, attn_output_weights
|
50 |
-
|
51 |
-
x = tgt
|
52 |
-
memory = self.context_norm(memory) if self.context_norm is not None else memory
|
53 |
-
y, attn_output_weights = self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask, need_weights=need_weights)
|
54 |
-
x = x + y
|
55 |
-
x = x + self._ff_block(self.norm3(x))
|
56 |
-
|
57 |
-
return x, attn_output_weights
|
58 |
-
|
59 |
-
# multihead attention block
|
60 |
-
def _mha_block(self, x: Tensor, mem: Tensor,
|
61 |
-
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], need_weights) -> Tensor:
|
62 |
-
x, attn_output_weights = self.multihead_attn(x, mem, mem,
|
63 |
-
attn_mask=attn_mask,
|
64 |
-
key_padding_mask=key_padding_mask,
|
65 |
-
need_weights=need_weights)
|
66 |
-
return self.dropout2(x), attn_output_weights
|
67 |
-
|
68 |
-
# feed forward block
|
69 |
-
def _ff_block(self, x: Tensor) -> Tensor:
|
70 |
-
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
71 |
-
return self.dropout3(x)
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
openclay/modules/control_volume.py
DELETED
@@ -1,52 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
import torch.nn.functional as F
|
4 |
-
|
5 |
-
class ControlVolume(nn.Module):
|
6 |
-
def __init__(self, volume_dim=1, block_dim=384, condition_dim=768, time_embed_dim=2048, downsample_times=1, aggregation_method=None):
|
7 |
-
super().__init__()
|
8 |
-
|
9 |
-
if aggregation_method is None:
|
10 |
-
pass
|
11 |
-
elif aggregation_method == "maxpool":
|
12 |
-
self.proj_maxpool = nn.Linear(volume_dim, volume_dim)
|
13 |
-
|
14 |
-
self.conv_in = nn.Conv3d(volume_dim, block_dim, kernel_size=3, stride=1, padding=1)
|
15 |
-
self.proj_t = nn.Linear(time_embed_dim, block_dim)
|
16 |
-
self.norm = nn.GroupNorm(8, block_dim)
|
17 |
-
|
18 |
-
self.blocks = nn.ModuleList([])
|
19 |
-
for i in range(downsample_times):
|
20 |
-
self.blocks.append(nn.Conv3d(block_dim, block_dim, kernel_size=3, padding=1))
|
21 |
-
self.blocks.append(nn.Conv3d(block_dim, block_dim, kernel_size=3, padding=1, stride=2))
|
22 |
-
|
23 |
-
self.conv_out = nn.Conv3d(block_dim, condition_dim, kernel_size=3, stride=1, padding=1)
|
24 |
-
|
25 |
-
self.volume_dim = volume_dim
|
26 |
-
self.aggregation_method = aggregation_method
|
27 |
-
|
28 |
-
def forward(self, volume, time_embed):
|
29 |
-
"""
|
30 |
-
volume: [B, volume_dim, X, Y, Z]
|
31 |
-
time_embed: [B, block_dim]
|
32 |
-
return:
|
33 |
-
[B, condition_dim, X, Y, Z]
|
34 |
-
"""
|
35 |
-
B, _, X, Y, Z = volume.shape
|
36 |
-
if self.aggregation_method == "maxpool":
|
37 |
-
volume = self.proj_maxpool(volume.reshape(B, 4, self.volume_dim, X, Y, Z).permute(0, 1, 3, 4, 5, 2)).permute(0, 5, 2, 3, 4, 1)
|
38 |
-
volume = F.max_pool1d(volume.reshape(B, self.volume_dim * X * Y * Z, 4), 4).reshape(B, self.volume_dim, X, Y, Z)
|
39 |
-
|
40 |
-
x = F.silu(self.conv_in(volume)) # [B, block_dim, X, Y, Z]
|
41 |
-
|
42 |
-
time_embed = self.proj_t(time_embed) # [B, block_dim]
|
43 |
-
x = x + time_embed[:, :, None, None, None]
|
44 |
-
|
45 |
-
x = self.norm(x)
|
46 |
-
|
47 |
-
for block in self.blocks:
|
48 |
-
x = block(x)
|
49 |
-
x = F.silu(x)
|
50 |
-
|
51 |
-
x = self.conv_out(x) # [B, condition_dim, X, Y, Z]
|
52 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
openclay/modules/diag_gaussian.py
DELETED
@@ -1,42 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import torch
|
3 |
-
|
4 |
-
class DiagonalGaussianDistribution(object):
|
5 |
-
def __init__(self, mean, logvar, deterministic=False):
|
6 |
-
self.mean = mean
|
7 |
-
self.logvar = logvar
|
8 |
-
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
9 |
-
self.deterministic = deterministic
|
10 |
-
self.std = torch.exp(0.5 * self.logvar)
|
11 |
-
self.var = torch.exp(self.logvar)
|
12 |
-
if self.deterministic:
|
13 |
-
self.var = self.std = torch.zeros_like(self.mean).to(device=self.mean.device)
|
14 |
-
|
15 |
-
def sample(self):
|
16 |
-
x = self.mean + self.std * torch.randn(self.mean.shape).to(self.mean)
|
17 |
-
return x
|
18 |
-
|
19 |
-
def kl(self, other=None):
|
20 |
-
if self.deterministic:
|
21 |
-
return torch.Tensor([0.])
|
22 |
-
else:
|
23 |
-
if other is None:
|
24 |
-
return 0.5 * torch.mean(torch.pow(self.mean, 2)
|
25 |
-
+ self.var - 1.0 - self.logvar,
|
26 |
-
dim=[1, 2])
|
27 |
-
else:
|
28 |
-
return 0.5 * torch.mean(
|
29 |
-
torch.pow(self.mean - other.mean, 2) / other.var
|
30 |
-
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
31 |
-
dim=[1, 2, 3])
|
32 |
-
|
33 |
-
def nll(self, sample, dims=[1, 2, 3]):
|
34 |
-
if self.deterministic:
|
35 |
-
return torch.Tensor([0.])
|
36 |
-
logtwopi = np.log(2.0 * np.pi)
|
37 |
-
return 0.5 * torch.sum(
|
38 |
-
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
39 |
-
dim=dims)
|
40 |
-
|
41 |
-
def mode(self):
|
42 |
-
return self.mean
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
openclay/modules/drop_path.py
DELETED
@@ -1,34 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
from torch import Tensor
|
4 |
-
from timm.models.layers import DropPath
|
5 |
-
from typing import Optional
|
6 |
-
|
7 |
-
class Struct:
|
8 |
-
pass
|
9 |
-
|
10 |
-
class DropPathWrapper(nn.Module):
|
11 |
-
def __init__(self, layer):
|
12 |
-
super().__init__()
|
13 |
-
self.layer = layer
|
14 |
-
self.drop_path = DropPath(drop_prob=0.1)
|
15 |
-
|
16 |
-
self_attn_dummy = Struct()
|
17 |
-
self_attn_dummy.batch_first = True
|
18 |
-
self.self_attn = self_attn_dummy
|
19 |
-
|
20 |
-
def forward(
|
21 |
-
self,
|
22 |
-
src: Tensor,
|
23 |
-
src_mask: Optional[Tensor] = None,
|
24 |
-
src_key_padding_mask: Optional[Tensor] = None,
|
25 |
-
is_causal: bool = False
|
26 |
-
) -> Tensor:
|
27 |
-
|
28 |
-
x = src
|
29 |
-
x_p = self.layer(src, src_mask, src_key_padding_mask, is_causal)
|
30 |
-
p = x_p - x
|
31 |
-
|
32 |
-
y = x + self.drop_path(p)
|
33 |
-
|
34 |
-
return y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
openclay/modules/embedding.py
DELETED
@@ -1,36 +0,0 @@
|
|
1 |
-
# TODO: add reference to 3dshape2vecset
|
2 |
-
|
3 |
-
import numpy as np
|
4 |
-
import torch
|
5 |
-
import torch.nn as nn
|
6 |
-
|
7 |
-
class PointEmbed(nn.Module):
|
8 |
-
def __init__(self, hidden_dim=48, dim=128):
|
9 |
-
super().__init__()
|
10 |
-
|
11 |
-
assert hidden_dim % 6 == 0
|
12 |
-
|
13 |
-
self.embedding_dim = hidden_dim
|
14 |
-
e = torch.pow(2, torch.arange(self.embedding_dim // 6)).float() * np.pi
|
15 |
-
e = torch.stack([
|
16 |
-
torch.cat([e, torch.zeros(self.embedding_dim // 6),
|
17 |
-
torch.zeros(self.embedding_dim // 6)]),
|
18 |
-
torch.cat([torch.zeros(self.embedding_dim // 6), e,
|
19 |
-
torch.zeros(self.embedding_dim // 6)]),
|
20 |
-
torch.cat([torch.zeros(self.embedding_dim // 6),
|
21 |
-
torch.zeros(self.embedding_dim // 6), e]),
|
22 |
-
])
|
23 |
-
self.register_buffer("basis", e, persistent=False) # 3 x 24
|
24 |
-
|
25 |
-
self.mlp = nn.Linear(self.embedding_dim + 3, dim)
|
26 |
-
|
27 |
-
@staticmethod
|
28 |
-
def embed(input, basis):
|
29 |
-
projections = torch.einsum("bnd,de->bne", input, basis)
|
30 |
-
embeddings = torch.cat([projections.sin(), projections.cos()], dim=2)
|
31 |
-
return embeddings
|
32 |
-
|
33 |
-
def forward(self, input):
|
34 |
-
# input: B x N x 3
|
35 |
-
embed = self.mlp(torch.cat([self.embed(input, self.basis), input], dim=2)) # B x N x C
|
36 |
-
return embed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
openclay/modules/transformer.py
DELETED
@@ -1,116 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
import torch.nn.functional as F
|
4 |
-
from torch import Tensor
|
5 |
-
from torch.nn.modules.transformer import _get_activation_fn
|
6 |
-
from typing import Optional, Union, Callable, List
|
7 |
-
|
8 |
-
class ClayTransformerDecoderLayer(nn.TransformerDecoderLayer):
|
9 |
-
__constants__ = ["batch_first", "norm_first"]
|
10 |
-
|
11 |
-
def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
|
12 |
-
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
13 |
-
layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
|
14 |
-
device=None, dtype=None, kdim=None, vdim=None) -> None:
|
15 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
16 |
-
nn.Module.__init__(self)
|
17 |
-
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs)
|
18 |
-
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, kdim=kdim, vdim=vdim, **factory_kwargs)
|
19 |
-
# Implementation of Feedforward model
|
20 |
-
self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
|
21 |
-
self.dropout = lambda x: x
|
22 |
-
self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
|
23 |
-
|
24 |
-
self.norm_first = norm_first
|
25 |
-
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
26 |
-
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
27 |
-
self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
28 |
-
self.dropout1 = lambda x: x
|
29 |
-
self.dropout2 = lambda x: x
|
30 |
-
self.dropout3 = lambda x: x
|
31 |
-
|
32 |
-
# Legacy string support for activation function.
|
33 |
-
if isinstance(activation, str):
|
34 |
-
self.activation = _get_activation_fn(activation)
|
35 |
-
else:
|
36 |
-
self.activation = activation
|
37 |
-
|
38 |
-
self.d_model = d_model
|
39 |
-
self.nhead = nhead
|
40 |
-
self.dropout_ = dropout
|
41 |
-
self.batch_first = batch_first
|
42 |
-
self.kdim = kdim
|
43 |
-
self.vdim = vdim
|
44 |
-
self.factory_kwargs = factory_kwargs
|
45 |
-
self._mha_block_second = None
|
46 |
-
self.condition_net_list = []
|
47 |
-
|
48 |
-
def register_condition_net(self, condition_net_list):
|
49 |
-
self.condition_net_list = condition_net_list
|
50 |
-
|
51 |
-
def process_condition_net(self, x_norm, condition_seq, stage):
|
52 |
-
"""
|
53 |
-
condition_seq: [(condition_1, condition_1_scale), ...]
|
54 |
-
"""
|
55 |
-
|
56 |
-
assert len(condition_seq) == 0 or len(condition_seq) == len(self.condition_net_list), f"len(self.condition_net_list)={len(self.condition_net_list)}, len(condition_seq)={len(condition_seq)}"
|
57 |
-
|
58 |
-
residual_all = 0
|
59 |
-
|
60 |
-
for i in range(len(condition_seq)):
|
61 |
-
condition_net = self.condition_net_list[i]
|
62 |
-
condition, condition_scale = condition_seq[i][:2]
|
63 |
-
|
64 |
-
if condition_scale == 0:
|
65 |
-
continue
|
66 |
-
|
67 |
-
key_padding_mask = None
|
68 |
-
if len(condition_seq[i]) > 2:
|
69 |
-
additional_dict = condition_seq[i][2]
|
70 |
-
assert isinstance(additional_dict, dict)
|
71 |
-
key_padding_mask = additional_dict.get("key_padding_mask", None)
|
72 |
-
|
73 |
-
if stage == condition_net.stage:
|
74 |
-
residual = condition_net.process(self.index, x_norm, condition, key_padding_mask=key_padding_mask)
|
75 |
-
residual_all = residual_all + residual * condition_scale
|
76 |
-
|
77 |
-
return residual_all
|
78 |
-
|
79 |
-
def forward(
|
80 |
-
self,
|
81 |
-
tgt: Tensor,
|
82 |
-
memory_list: List[Tensor],
|
83 |
-
tgt_mask: Optional[Tensor] = None,
|
84 |
-
memory_mask: Optional[Tensor] = None,
|
85 |
-
tgt_key_padding_mask: Optional[Tensor] = None,
|
86 |
-
memory_key_padding_mask: Optional[Tensor] = None,
|
87 |
-
tgt_is_causal: bool = False,
|
88 |
-
memory_is_causal: bool = False,
|
89 |
-
) -> Tensor:
|
90 |
-
"""
|
91 |
-
memory_list = ( condition_text, [(condition_1, condition_1_scale), ...] )
|
92 |
-
"""
|
93 |
-
|
94 |
-
x = tgt
|
95 |
-
assert self.norm_first
|
96 |
-
|
97 |
-
memory, condition_seq = memory_list
|
98 |
-
|
99 |
-
x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal)
|
100 |
-
x_norm = self.norm2(x)
|
101 |
-
x = x + self._mha_block(x_norm, memory, memory_mask, memory_key_padding_mask, memory_is_causal) + self.process_condition_net(x_norm, condition_seq, stage="parallel")
|
102 |
-
x = x + self._ff_block(self.norm3(x))
|
103 |
-
|
104 |
-
x = x + self.process_condition_net(x_norm, condition_seq, stage="postfix")
|
105 |
-
return x
|
106 |
-
|
107 |
-
def _mha_block(self, x: Tensor, mem: Tensor,
|
108 |
-
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False, multihead_attn=None) -> Tensor:
|
109 |
-
if multihead_attn is None:
|
110 |
-
multihead_attn = self.multihead_attn
|
111 |
-
x = multihead_attn(x, mem, mem,
|
112 |
-
attn_mask=attn_mask,
|
113 |
-
key_padding_mask=key_padding_mask,
|
114 |
-
is_causal=is_causal,
|
115 |
-
need_weights=False)[0]
|
116 |
-
return self.dropout2(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
openclay/pipeline_openclay.py
DELETED
@@ -1,195 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import torch
|
3 |
-
import torch.nn.functional as F
|
4 |
-
import torch.utils.checkpoint
|
5 |
-
import torch.utils.data
|
6 |
-
import tqdm
|
7 |
-
import mcubes
|
8 |
-
import inspect
|
9 |
-
import trimesh
|
10 |
-
import gc
|
11 |
-
from diffusers import UniPCMultistepScheduler
|
12 |
-
|
13 |
-
from .utils import get_grid_tensor
|
14 |
-
from .models import ClayVAE, ClayLDM, ClayConditionNet
|
15 |
-
|
16 |
-
from transformers import Dinov2Model, BitImageProcessor, CLIPTextModel, CLIPTokenizer
|
17 |
-
from diffusers import DiffusionPipeline
|
18 |
-
|
19 |
-
class OpenClayPipeline(DiffusionPipeline):
|
20 |
-
def __init__(self,
|
21 |
-
vae: ClayVAE ,
|
22 |
-
text_encoder: CLIPTextModel,
|
23 |
-
tokenizer: CLIPTokenizer,
|
24 |
-
ldm: ClayLDM,
|
25 |
-
scheduler=None,
|
26 |
-
):
|
27 |
-
super().__init__()
|
28 |
-
|
29 |
-
|
30 |
-
if scheduler is None:
|
31 |
-
scheduler = self.get_unipc_scheduler()
|
32 |
-
|
33 |
-
self.register_modules(
|
34 |
-
vae=vae,
|
35 |
-
text_encoder=text_encoder,
|
36 |
-
tokenizer=tokenizer,
|
37 |
-
ldm=ldm,
|
38 |
-
scheduler=scheduler,
|
39 |
-
)
|
40 |
-
|
41 |
-
def get_unipc_scheduler(self):
|
42 |
-
scheduler = UniPCMultistepScheduler(
|
43 |
-
num_train_timesteps=1000,
|
44 |
-
beta_schedule="squaredcos_cap_v2",
|
45 |
-
prediction_type="v_prediction",
|
46 |
-
timestep_spacing="linspace",
|
47 |
-
rescale_betas_zero_snr=True
|
48 |
-
)
|
49 |
-
return scheduler
|
50 |
-
|
51 |
-
def get_timesteps(self, num_inference_steps, strength):
|
52 |
-
# get the original timestep using init_timestep
|
53 |
-
init_timestep = min(int(round(num_inference_steps * strength)), num_inference_steps)
|
54 |
-
t_start = max(num_inference_steps - init_timestep, 0)
|
55 |
-
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:]
|
56 |
-
|
57 |
-
return timesteps, num_inference_steps - t_start
|
58 |
-
|
59 |
-
|
60 |
-
def vae_decode_latent(self, latent, res=128, mini_batch=129 * 129 * 129, show_progress=True):
|
61 |
-
assert torch.isfinite(latent).all()
|
62 |
-
|
63 |
-
gap = 2 / res
|
64 |
-
grid = get_grid_tensor(res).to(latent)
|
65 |
-
logits = self.vae.decode(latent, grid, mini_batch=mini_batch, show_progress=show_progress)[:, :, 0].cpu()
|
66 |
-
|
67 |
-
logits = logits.view(res + 1, res + 1, res + 1)
|
68 |
-
|
69 |
-
if isinstance(logits, torch.Tensor):
|
70 |
-
logits = logits.cpu().numpy().astype(np.float32)
|
71 |
-
assert isinstance(logits, np.ndarray)
|
72 |
-
verts, faces = mcubes.marching_cubes(logits, 0)
|
73 |
-
verts *= gap
|
74 |
-
verts -= 1
|
75 |
-
|
76 |
-
m = trimesh.Trimesh(verts, faces)
|
77 |
-
return m
|
78 |
-
|
79 |
-
def __call__(self, prompt="", negative_prompt="", sample=None, strength=None,
|
80 |
-
res=128, rescale_phi=0.7, cfg=7.5, resacle_cfg=True,
|
81 |
-
num_latents=1024, num_inference_steps=100, timesteps=None,
|
82 |
-
mini_batch=129**3, seed=42, num=1,
|
83 |
-
condition_seq=None, show_progress=True,
|
84 |
-
):
|
85 |
-
|
86 |
-
"""
|
87 |
-
condition_seq: [(condition_1, condition_1_scale), ...]
|
88 |
-
sample_lock_index: [m]
|
89 |
-
"""
|
90 |
-
|
91 |
-
device = self.text_encoder.device
|
92 |
-
generator = torch.Generator(device).manual_seed(seed)
|
93 |
-
|
94 |
-
prompt: list = [prompt] if isinstance(prompt, str) else prompt
|
95 |
-
negative_prompt: list = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
96 |
-
|
97 |
-
assert len(prompt) == len(negative_prompt) == num or len(prompt) == 1 or len(negative_prompt) == 1
|
98 |
-
if len(prompt) == 1:
|
99 |
-
prompt = prompt * num
|
100 |
-
if len(negative_prompt) == 1:
|
101 |
-
negative_prompt = negative_prompt * num
|
102 |
-
|
103 |
-
token = self.tokenizer(prompt + negative_prompt, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt").input_ids
|
104 |
-
encoder_hidden_states = self.text_encoder(token.to(device))[0].to(self.dtype)
|
105 |
-
encoder_hidden_states = encoder_hidden_states.reshape(2, num, encoder_hidden_states.shape[1], encoder_hidden_states.shape[2]).permute(1, 0, 2, 3).reshape(num * 2, encoder_hidden_states.shape[1], encoder_hidden_states.shape[2])
|
106 |
-
assert encoder_hidden_states.shape[0] == num * 2
|
107 |
-
|
108 |
-
# set step values
|
109 |
-
extra_set_kwargs = {}
|
110 |
-
if "timesteps" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()):
|
111 |
-
extra_set_kwargs["timesteps"] = timesteps
|
112 |
-
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
|
113 |
-
timesteps = self.scheduler.timesteps
|
114 |
-
|
115 |
-
noise = torch.randn(num, num_latents, 64, dtype=self.dtype, device=device, generator=generator)
|
116 |
-
if sample is not None:
|
117 |
-
assert torch.isfinite(sample).all()
|
118 |
-
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength)
|
119 |
-
t_start = timesteps[0]
|
120 |
-
sample_noisy = self.scheduler.add_noise(sample, noise, torch.tensor(t_start).to(sample.device))
|
121 |
-
sample = sample_noisy
|
122 |
-
else:
|
123 |
-
sample = noise
|
124 |
-
|
125 |
-
condition_dict_cond = {
|
126 |
-
"condition_seq": [
|
127 |
-
(
|
128 |
-
condition_scale_dict[0], condition_scale_dict[1][0],
|
129 |
-
((condition_scale_dict[2][0] if not isinstance(condition_scale_dict[2], dict) else condition_scale_dict[2]) if len(condition_scale_dict) > 2 else {})
|
130 |
-
)
|
131 |
-
for condition_scale_dict in condition_seq
|
132 |
-
],
|
133 |
-
}
|
134 |
-
condition_dict_uncond = {
|
135 |
-
"condition_seq": [
|
136 |
-
(
|
137 |
-
condition_scale_dict[0], condition_scale_dict[1][1],
|
138 |
-
((condition_scale_dict[2][1] if not isinstance(condition_scale_dict[2], dict) else condition_scale_dict[2]) if len(condition_scale_dict) > 2 else {})
|
139 |
-
)
|
140 |
-
for condition_scale_dict in condition_seq
|
141 |
-
],
|
142 |
-
}
|
143 |
-
|
144 |
-
extra_step_kwargs = {}
|
145 |
-
if "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()):
|
146 |
-
extra_step_kwargs["generator"] = generator
|
147 |
-
|
148 |
-
for t in tqdm.tqdm(timesteps, "[ ClayLDMPipeline.__call__ ]", disable=not show_progress):
|
149 |
-
# 1. predict noise model_output
|
150 |
-
if isinstance(t, torch.Tensor):
|
151 |
-
t = t.item()
|
152 |
-
t_tensor = torch.tensor([t], dtype=torch.long, device=device)
|
153 |
-
|
154 |
-
model_output_cond = self.ldm(
|
155 |
-
sample,
|
156 |
-
t_tensor.expand(num),
|
157 |
-
encoder_hidden_states[0::2],
|
158 |
-
**condition_dict_cond,
|
159 |
-
)
|
160 |
-
model_output_uncond = self.ldm(
|
161 |
-
sample,
|
162 |
-
t_tensor.expand(num),
|
163 |
-
encoder_hidden_states[1::2],
|
164 |
-
**condition_dict_uncond,
|
165 |
-
)
|
166 |
-
|
167 |
-
model_output_cfg = (model_output_cond - model_output_uncond) * cfg + model_output_uncond
|
168 |
-
if resacle_cfg:
|
169 |
-
model_output_rescaled = model_output_cfg / model_output_cfg.std(dim=(1, 2), keepdim=True) * model_output_cond.std(dim=(1, 2), keepdim=True)
|
170 |
-
model_output = rescale_phi * model_output_rescaled + (1 - rescale_phi) * model_output_cfg
|
171 |
-
else:
|
172 |
-
model_output = model_output_cfg
|
173 |
-
|
174 |
-
# 2. compute previous image: x_t -> x_t-1
|
175 |
-
sample = self.scheduler.step(
|
176 |
-
model_output[:, None, :, :].permute(0, 3, 1, 2),
|
177 |
-
t,
|
178 |
-
sample[:, None, :, :].permute(0, 3, 1, 2),
|
179 |
-
**extra_step_kwargs
|
180 |
-
).prev_sample.permute(0, 2, 3, 1)[:, 0, :, :]
|
181 |
-
|
182 |
-
assert torch.isfinite(sample).all(), sample
|
183 |
-
|
184 |
-
gc.collect()
|
185 |
-
torch.cuda.empty_cache()
|
186 |
-
|
187 |
-
mesh_list = []
|
188 |
-
for i in tqdm.tqdm(range(sample.shape[0])):
|
189 |
-
mesh = self.vae_decode_latent(sample[i:i + 1], res=res, mini_batch=mini_batch)
|
190 |
-
mesh.vertices[:, 0] += i % 4 * 2
|
191 |
-
mesh.vertices[:, 2] += i // 4 * 4
|
192 |
-
mesh_list.append(mesh)
|
193 |
-
|
194 |
-
mesh_combined = trimesh.util.concatenate(mesh_list)
|
195 |
-
return mesh_combined
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
openclay/utils.py
DELETED
@@ -1,80 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import torch
|
3 |
-
import cv2
|
4 |
-
|
5 |
-
def get_grid_tensor(res=128):
|
6 |
-
|
7 |
-
gap = 2. / res
|
8 |
-
x = torch.linspace(-1, 1, res + 1)
|
9 |
-
y = torch.linspace(-1, 1, res + 1)
|
10 |
-
z = torch.linspace(-1, 1, res + 1)
|
11 |
-
grid = torch.stack(torch.meshgrid(x, y, z)).view(3, -1).T[None]
|
12 |
-
return grid
|
13 |
-
|
14 |
-
|
15 |
-
def pad_to_square(image, pad_color=None):
|
16 |
-
H, W, C = image.shape
|
17 |
-
max_side = max(H, W)
|
18 |
-
padded_image = np.ones((max_side, max_side, C))
|
19 |
-
if pad_color is None:
|
20 |
-
pad_color = image[0, 0]
|
21 |
-
padded_image[:] = pad_color
|
22 |
-
vertical_offset = (max_side - H) // 2
|
23 |
-
horizontal_offset = (max_side - W) // 2
|
24 |
-
padded_image[vertical_offset:vertical_offset + H, horizontal_offset:horizontal_offset + W, :] = image
|
25 |
-
return padded_image
|
26 |
-
|
27 |
-
|
28 |
-
def read_image_square(path):
|
29 |
-
image = cv2.imread(path, -1)
|
30 |
-
return process_image_square(image)
|
31 |
-
|
32 |
-
def process_image_square(image):
|
33 |
-
image = image.astype(np.float32) / 255
|
34 |
-
|
35 |
-
if image.shape[2] == 4: # background
|
36 |
-
fg = image[:, :, 3] > 0.5
|
37 |
-
fg_coord = np.stack(np.where(fg))
|
38 |
-
rc_min = fg_coord.min(axis=1)
|
39 |
-
rc_max = fg_coord.max(axis=1)
|
40 |
-
rc_range = rc_max - rc_min
|
41 |
-
rc_min -= (rc_range * 0.1).astype(int)
|
42 |
-
rc_max += (rc_range * 0.1).astype(int)
|
43 |
-
rc_min = rc_min.clip(0, None)
|
44 |
-
rc_max = rc_max.clip(0, None)
|
45 |
-
image = image[rc_min[0]:rc_max[0], rc_min[1]:rc_max[1]]
|
46 |
-
image = image[:, :, :3] * image[:, :, 3:] + 1 * (1 - image[:, :, 3:])
|
47 |
-
render = image[:, :, ::-1]
|
48 |
-
render = pad_to_square(render)
|
49 |
-
|
50 |
-
return render
|
51 |
-
|
52 |
-
def get_center_position(num_voxel):
|
53 |
-
"""
|
54 |
-
num_voxel: int
|
55 |
-
return:
|
56 |
-
[X, Y, Z, 3]
|
57 |
-
"""
|
58 |
-
center_position = (torch.stack(torch.meshgrid([torch.arange(num_voxel, dtype=torch.float32)] * 3, indexing="ij"), dim=3) + 0.5) / num_voxel * 2 - 1
|
59 |
-
return center_position
|
60 |
-
|
61 |
-
|
62 |
-
def geometry_get_voxel(geometry):
|
63 |
-
import pysdf
|
64 |
-
|
65 |
-
res_large = 128
|
66 |
-
voxel_large_or = np.zeros((res_large, res_large, res_large), dtype=bool)
|
67 |
-
center_position_large = get_center_position(res_large).numpy()
|
68 |
-
|
69 |
-
voxel_large = pysdf.SDF(geometry.vertices, geometry.faces).contains(center_position_large.reshape(-1, 3)).reshape(res_large, res_large, res_large)
|
70 |
-
|
71 |
-
voxel_large_or |= voxel_large
|
72 |
-
|
73 |
-
res = 16
|
74 |
-
voxel_or = np.zeros((res, res, res), dtype=bool)
|
75 |
-
|
76 |
-
loc = np.mgrid[:res, :res, :res].transpose(1, 2, 3, 0).reshape(-1, 3)
|
77 |
-
for l in loc:
|
78 |
-
voxel_or[l[0], l[1], l[2]] = voxel_large_or[l[0] * 8:l[0] * 8 + 8, l[1] * 8:l[1] * 8 + 8, l[2] * 8:l[2] * 8 + 8].sum() > 0
|
79 |
-
|
80 |
-
return voxel_or
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -3,16 +3,4 @@ requests
|
|
3 |
pillow
|
4 |
gradio==4.31.2
|
5 |
requests-toolbelt
|
6 |
-
websocket-client
|
7 |
-
|
8 |
-
numpy==1.24.1
|
9 |
-
|
10 |
-
PyMCubes
|
11 |
-
pysdf==0.1.9
|
12 |
-
opencv-python==4.9.0.80
|
13 |
-
tqdm==4.66.1
|
14 |
-
trimesh==4.0.5
|
15 |
-
timm==0.9.12
|
16 |
-
diffusers==0.29.0
|
17 |
-
transformers
|
18 |
-
accelerate
|
|
|
3 |
pillow
|
4 |
gradio==4.31.2
|
5 |
requests-toolbelt
|
6 |
+
websocket-client
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|