wyysf commited on
Commit
302c8a6
1 Parent(s): 6ce0931

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +333 -338
gradio_app.py CHANGED
@@ -1,339 +1,334 @@
1
- import spaces
2
- import argparse
3
- import os
4
- import json
5
- import torch
6
- import sys
7
- import time
8
- import importlib
9
- import numpy as np
10
- from omegaconf import OmegaConf
11
- from huggingface_hub import hf_hub_download
12
-
13
- from collections import OrderedDict
14
- import trimesh
15
- import gradio as gr
16
- from typing import Any
17
-
18
- proj_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
19
- sys.path.append(os.path.join(proj_dir))
20
-
21
- import tempfile
22
-
23
- from apps.utils import *
24
-
25
- _TITLE = '''CraftsMan: High-fidelity Mesh Generation with 3D Native Generation and Interactive Geometry Refiner'''
26
- _DESCRIPTION = '''
27
- <div>
28
- Select or upload a image, then just click 'Generate'.
29
- <br>
30
- By mimicking the artist/craftsman modeling workflow, we propose CraftsMan (aka 匠心) that uses 3D Latent Set Diffusion Model that directly generate coarse meshes,
31
- then a multi-view normal enhanced image generation model is used to refine the mesh.
32
- We provide the coarse 3D diffusion part here.
33
- <br>
34
- If you found CraftsMan is helpful, please help to ⭐ the <a href='https://github.com/wyysf-98/CraftsMan/' target='_blank'>Github Repo</a>. Thanks!
35
- <a style="display:inline-block; margin-left: .5em" href='https://github.com/wyysf-98/CraftsMan/'><img src='https://img.shields.io/github/stars/wyysf-98/CraftsMan?style=social' /></a>
36
- <br>
37
- *If you have your own multi-view images, you can directly upload it.
38
- </div>
39
- '''
40
- _CITE_ = r"""
41
- ---
42
- 📝 **Citation**
43
- If you find our work useful for your research or applications, please cite using this bibtex:
44
- ```bibtex
45
- @article{craftsman,
46
- author = {Weiyu Li and Jiarui Liu and Rui Chen and Yixun Liang and Xuelin Chen and Ping Tan and Xiaoxiao Long},
47
- title = {CraftsMan: High-fidelity Mesh Generation with 3D Native Generation and Interactive Geometry Refiner},
48
- journal = {arxiv:xxx},
49
- year = {2024},
50
- }
51
- ```
52
- 🤗 **Acknowledgements**
53
- We use <a href='https://github.com/wjakob/instant-meshes' target='_blank'>Instant Meshes</a> to remesh the generated mesh to a lower face count, thanks to the authors for the great work.
54
- 📋 **License**
55
- CraftsMan is under [AGPL-3.0](https://www.gnu.org/licenses/agpl-3.0.en.html), so any downstream solution and products (including cloud services) that include CraftsMan code or a trained model (both pretrained or custom trained) inside it should be open-sourced to comply with the AGPL conditions. If you have any questions about the usage of CraftsMan, please contact us first.
56
- 📧 **Contact**
57
- If you have any questions, feel free to open a discussion or contact us at <b>[email protected]</b>.
58
- """
59
- from apps.third_party.CRM.pipelines import TwoStagePipeline
60
- from apps.third_party.LGM.pipeline_mvdream import MVDreamPipeline
61
-
62
-
63
- model = None
64
- cached_dir = None
65
- stage1_config = OmegaConf.load(f"{parent_dir}/apps/third_party/CRM/configs/nf7_v3_SNR_rd_size_stroke.yaml").config
66
- stage1_sampler_config = stage1_config.sampler
67
- stage1_model_config = stage1_config.models
68
- stage1_model_config.resume = hf_hub_download(repo_id="Zhengyi/CRM", filename="pixel-diffusion.pth", repo_type="model")
69
- stage1_model_config.config = f"{parent_dir}/apps/third_party/CRM/" + stage1_model_config.config
70
- crm_pipeline = None
71
-
72
- sys.path.append(f"apps/third_party/LGM")
73
- imgaedream_pipeline = None
74
-
75
- generator = None
76
-
77
- @spaces.GPU
78
- def gen_mvimg(
79
- mvimg_model, image, seed, guidance_scale, step, text, neg_text, elevation,
80
- ):
81
- if seed == 0:
82
- seed = np.random.randint(1, 65535)
83
-
84
- if mvimg_model == "CRM":
85
- global crm_pipeline
86
- crm_pipeline.set_seed(seed)
87
- mv_imgs = crm_pipeline(
88
- image,
89
- scale=guidance_scale,
90
- step=step
91
- )["stage1_images"]
92
- return mv_imgs[5], mv_imgs[3], mv_imgs[2], mv_imgs[0]
93
-
94
- elif mvimg_model == "ImageDream":
95
- global imagedream_pipeline, generator
96
- image = np.array(image).astype(np.float32) / 255.0
97
- image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4])
98
- mv_imgs = imagedream_pipeline(
99
- text,
100
- image,
101
- negative_prompt=neg_text,
102
- guidance_scale=guidance_scale,
103
- num_inference_steps=step,
104
- elevation=elevation,
105
- generator=generator.manual_seed(seed),
106
- )
107
- return mv_imgs[1], mv_imgs[2], mv_imgs[3], mv_imgs[0]
108
-
109
-
110
- @spaces.GPU
111
- def image2mesh(view_front: np.ndarray,
112
- view_right: np.ndarray,
113
- view_back: np.ndarray,
114
- view_left: np.ndarray,
115
- more: bool = False,
116
- scheluder_name: str ="DDIMScheduler",
117
- guidance_scale: int = 7.5,
118
- seed: int = 4,
119
- octree_depth: int = 7):
120
-
121
- sample_inputs = {
122
- "mvimages": [[
123
- Image.fromarray(view_front),
124
- Image.fromarray(view_right),
125
- Image.fromarray(view_back),
126
- Image.fromarray(view_left)
127
- ]]
128
- }
129
-
130
- global model
131
- latents = model.sample(
132
- sample_inputs,
133
- sample_times=1,
134
- guidance_scale=guidance_scale,
135
- return_intermediates=False,
136
- seed=seed
137
-
138
- )[0]
139
-
140
- # decode the latents to mesh
141
- box_v = 1.1
142
- mesh_outputs, _ = model.shape_model.extract_geometry(
143
- latents,
144
- bounds=[-box_v, -box_v, -box_v, box_v, box_v, box_v],
145
- octree_depth=octree_depth
146
- )
147
- assert len(mesh_outputs) == 1, "Only support single mesh output for gradio demo"
148
- mesh = trimesh.Trimesh(mesh_outputs[0][0], mesh_outputs[0][1])
149
- # filepath = f"{cached_dir}/{time.time()}.obj"
150
- filepath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
151
- mesh.export(filepath, include_normals=True)
152
-
153
- if 'Remesh' in more:
154
- remeshed_filepath = tempfile.NamedTemporaryFile(suffix=f"_remeshed.obj", delete=False).name
155
- print("Remeshing with Instant Meshes...")
156
- # target_face_count = int(len(mesh.faces)/10)
157
- target_face_count = 2000
158
- command = f"{proj_dir}/apps/third_party/InstantMeshes {filepath} -f {target_face_count} -o {remeshed_filepath}"
159
- os.system(command)
160
- filepath = remeshed_filepath
161
- # filepath = filepath.replace('.obj', '_remeshed.obj')
162
-
163
- return filepath
164
-
165
- if __name__=="__main__":
166
- parser = argparse.ArgumentParser()
167
- # parser.add_argument("--model_path", type=str, required=True, help="Path to the object file",)
168
- parser.add_argument("--cached_dir", type=str, default="./gradio_cached_dir")
169
- parser.add_argument("--device", type=int, default=0)
170
- args = parser.parse_args()
171
-
172
- cached_dir = args.cached_dir
173
- os.makedirs(args.cached_dir, exist_ok=True)
174
- device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu")
175
- print(f"using device: {device}")
176
-
177
- # for multi-view images generation
178
- background_choice = OrderedDict({
179
- "Alpha as Mask": "Alpha as Mask",
180
- "Auto Remove Background": "Auto Remove Background",
181
- "Original Image": "Original Image",
182
- })
183
- mvimg_model_config_list = ["CRM", "ImageDream"]
184
- crm_pipeline = TwoStagePipeline(
185
- stage1_model_config,
186
- stage1_sampler_config,
187
- device=device,
188
- dtype=torch.float16
189
- )
190
- imagedream_pipeline = MVDreamPipeline.from_pretrained(
191
- "ashawkey/imagedream-ipmv-diffusers", # remote weights
192
- torch_dtype=torch.float16,
193
- trust_remote_code=True,
194
- )
195
- generator = torch.Generator(device)
196
-
197
-
198
- # for 3D latent set diffusion
199
- ckpt_path = "./ckpts/image-to-shape-diffusion/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6/model.ckpt"
200
- config_path = "./ckpts/image-to-shape-diffusion/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6/config.yaml"
201
- # ckpt_path = hf_hub_download(repo_id="wyysf/CraftsMan", filename="image-to-shape-diffusion/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6/model.ckpt", repo_type="model")
202
- # config_path = hf_hub_download(repo_id="wyysf/CraftsMan", filename="image-to-shape-diffusion/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6/config.yaml", repo_type="model")
203
- scheluder_dict = OrderedDict({
204
- "DDIMScheduler": 'diffusers.schedulers.DDIMScheduler',
205
- # "DPMSolverMultistepScheduler": 'diffusers.schedulers.DPMSolverMultistepScheduler', # not support yet
206
- # "UniPCMultistepScheduler": 'diffusers.schedulers.UniPCMultistepScheduler', # not support yet
207
- })
208
-
209
- # main GUI
210
- custom_theme = gr.themes.Soft(primary_hue="blue").set(
211
- button_secondary_background_fill="*neutral_100",
212
- button_secondary_background_fill_hover="*neutral_200")
213
- custom_css = '''#disp_image {
214
- text-align: center; /* Horizontally center the content */
215
- }'''
216
-
217
- with gr.Blocks(title=_TITLE, theme=custom_theme, css=custom_css) as demo:
218
- with gr.Row():
219
- with gr.Column(scale=1):
220
- gr.Markdown('# ' + _TITLE)
221
- gr.Markdown(_DESCRIPTION)
222
-
223
- with gr.Row():
224
- with gr.Column(scale=2):
225
- with gr.Column():
226
- # input image
227
- with gr.Row():
228
- image_input = gr.Image(
229
- label="Image Input",
230
- image_mode="RGBA",
231
- sources="upload",
232
- type="pil",
233
- )
234
- run_btn = gr.Button('Generate', variant='primary', interactive=True)
235
-
236
- with gr.Row():
237
- gr.Markdown('''Try a different <b>seed and MV Model</b> for better results. Good Luck :)''')
238
- with gr.Row():
239
- seed = gr.Number(0, label='Seed', show_label=True)
240
- mvimg_model = gr.Dropdown(value="CRM", label="MV Image Model", choices=list(mvimg_model_config_list))
241
- more = gr.CheckboxGroup(["Remesh", "Symmetry(TBD)"], label="More", show_label=False)
242
- with gr.Row():
243
- # input prompt
244
- text = gr.Textbox(label="Prompt (Opt.)", info="only works for ImageDream")
245
-
246
- with gr.Accordion('Advanced options', open=False):
247
- # negative prompt
248
- neg_text = gr.Textbox(label="Negative Prompt", value='ugly, blurry, pixelated obscure, unnatural colors, poor lighting, dull, unclear, cropped, lowres, low quality, artifacts, duplicate')
249
- # elevation
250
- elevation = gr.Slider(label="elevation", minimum=-90, maximum=90, step=1, value=0)
251
-
252
- with gr.Row():
253
- gr.Examples(
254
- examples=[os.path.join("./apps/examples", i) for i in os.listdir("./apps/examples")],
255
- inputs=[image_input],
256
- examples_per_page=8
257
- )
258
-
259
- with gr.Column(scale=4):
260
- with gr.Row():
261
- output_model_obj = gr.Model3D(
262
- label="Output Model (OBJ Format)",
263
- camera_position=(90.0, 90.0, 3.5),
264
- interactive=False,
265
- )
266
- with gr.Row():
267
- gr.Markdown('''*please note that the model is fliped due to the gradio viewer, please download the obj file and you will get the correct orientation.''')
268
-
269
- with gr.Row():
270
- view_front = gr.Image(label="Front", interactive=True, show_label=True)
271
- view_right = gr.Image(label="Right", interactive=True, show_label=True)
272
- view_back = gr.Image(label="Back", interactive=True, show_label=True)
273
- view_left = gr.Image(label="Left", interactive=True, show_label=True)
274
-
275
- with gr.Accordion('Advanced options', open=False):
276
- with gr.Row(equal_height=True):
277
- run_mv_btn = gr.Button('Only Generate 2D', interactive=True)
278
- run_3d_btn = gr.Button('Only Generate 3D', interactive=True)
279
-
280
- with gr.Accordion('Advanced options (2D)', open=False):
281
- with gr.Row():
282
- foreground_ratio = gr.Slider(
283
- label="Foreground Ratio",
284
- minimum=0.5,
285
- maximum=1.0,
286
- value=1.0,
287
- step=0.05,
288
- )
289
-
290
- with gr.Row():
291
- background_choice = gr.Dropdown(label="Backgroud Choice", value="Auto Remove Background",choices=list(background_choice.keys()))
292
- rmbg_type = gr.Dropdown(label="Backgroud Remove Type", value="rembg",choices=['sam', "rembg"])
293
- backgroud_color = gr.ColorPicker(label="Background Color", value="#FFFFFF", interactive=True)
294
- # backgroud_color = gr.ColorPicker(label="Background Color", value="#7F7F7F", interactive=True)
295
-
296
- with gr.Row():
297
- mvimg_guidance_scale = gr.Number(value=4.0, minimum=3, maximum=10, label="2D Guidance Scale")
298
- mvimg_steps = gr.Number(value=30, minimum=20, maximum=100, label="2D Sample Steps")
299
-
300
- with gr.Accordion('Advanced options (3D)', open=False):
301
- with gr.Row():
302
- guidance_scale = gr.Number(label="3D Guidance Scale", value=7.5, minimum=3.0, maximum=10.0)
303
- steps = gr.Number(value=50, minimum=20, maximum=100, label="3D Sample Steps")
304
-
305
- with gr.Row():
306
- scheduler = gr.Dropdown(label="scheluder", value="DDIMScheduler",choices=list(scheluder_dict.keys()))
307
- octree_depth = gr.Slider(label="Octree Depth", value=7, minimum=4, maximum=8, step=1)
308
-
309
- gr.Markdown(_CITE_)
310
-
311
- outputs = [output_model_obj]
312
- rmbg = RMBG(device)
313
-
314
- model = load_model(ckpt_path, config_path, device)
315
-
316
- run_btn.click(fn=check_input_image, inputs=[image_input]
317
- ).success(
318
- fn=rmbg.run,
319
- inputs=[rmbg_type, image_input, foreground_ratio, background_choice, backgroud_color],
320
- outputs=[image_input]
321
- ).success(
322
- fn=gen_mvimg,
323
- inputs=[mvimg_model, image_input, seed, mvimg_guidance_scale, mvimg_steps, text, neg_text, elevation],
324
- outputs=[view_front, view_right, view_back, view_left]
325
- ).success(
326
- fn=image2mesh,
327
- inputs=[view_front, view_right, view_back, view_left, more, scheduler, guidance_scale, seed, octree_depth],
328
- outputs=outputs,
329
- api_name="generate_img2obj")
330
- run_mv_btn.click(fn=gen_mvimg,
331
- inputs=[mvimg_model, image_input, seed, mvimg_guidance_scale, mvimg_steps, text, neg_text, elevation],
332
- outputs=[view_front, view_right, view_back, view_left]
333
- )
334
- run_3d_btn.click(fn=image2mesh,
335
- inputs=[view_front, view_right, view_back, view_left, more, scheduler, guidance_scale, seed, octree_depth],
336
- outputs=outputs,
337
- api_name="generate_img2obj")
338
-
339
  demo.queue().launch(share=True, allowed_paths=[args.cached_dir])
 
1
+ import spaces
2
+ import argparse
3
+ import os
4
+ import json
5
+ import torch
6
+ import sys
7
+ import time
8
+ import importlib
9
+ import numpy as np
10
+ from omegaconf import OmegaConf
11
+ from huggingface_hub import hf_hub_download
12
+
13
+ from collections import OrderedDict
14
+ import trimesh
15
+ import gradio as gr
16
+ from typing import Any
17
+
18
+ proj_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
19
+ sys.path.append(os.path.join(proj_dir))
20
+
21
+ import tempfile
22
+
23
+ from apps.utils import *
24
+
25
+ _TITLE = '''CraftsMan: High-fidelity Mesh Generation with 3D Native Generation and Interactive Geometry Refiner'''
26
+ _DESCRIPTION = '''
27
+ <div>
28
+ Select or upload a image, then just click 'Generate'.
29
+ <br>
30
+ By mimicking the artist/craftsman modeling workflow, we propose CraftsMan (aka 匠心) that uses 3D Latent Set Diffusion Model that directly generate coarse meshes,
31
+ then a multi-view normal enhanced image generation model is used to refine the mesh.
32
+ We provide the coarse 3D diffusion part here.
33
+ <br>
34
+ If you found CraftsMan is helpful, please help to ⭐ the <a href='https://github.com/wyysf-98/CraftsMan/' target='_blank'>Github Repo</a>. Thanks!
35
+ <a style="display:inline-block; margin-left: .5em" href='https://github.com/wyysf-98/CraftsMan/'><img src='https://img.shields.io/github/stars/wyysf-98/CraftsMan?style=social' /></a>
36
+ <br>
37
+ *If you have your own multi-view images, you can directly upload it.
38
+ </div>
39
+ '''
40
+ _CITE_ = r"""
41
+ ---
42
+ 📝 **Citation**
43
+ If you find our work useful for your research or applications, please cite using this bibtex:
44
+ ```bibtex
45
+ @article{craftsman,
46
+ author = {Weiyu Li and Jiarui Liu and Rui Chen and Yixun Liang and Xuelin Chen and Ping Tan and Xiaoxiao Long},
47
+ title = {CraftsMan: High-fidelity Mesh Generation with 3D Native Generation and Interactive Geometry Refiner},
48
+ journal = {arxiv:xxx},
49
+ year = {2024},
50
+ }
51
+ ```
52
+ 🤗 **Acknowledgements**
53
+ We use <a href='https://github.com/wjakob/instant-meshes' target='_blank'>Instant Meshes</a> to remesh the generated mesh to a lower face count, thanks to the authors for the great work.
54
+ 📋 **License**
55
+ CraftsMan is under [AGPL-3.0](https://www.gnu.org/licenses/agpl-3.0.en.html), so any downstream solution and products (including cloud services) that include CraftsMan code or a trained model (both pretrained or custom trained) inside it should be open-sourced to comply with the AGPL conditions. If you have any questions about the usage of CraftsMan, please contact us first.
56
+ 📧 **Contact**
57
+ If you have any questions, feel free to open a discussion or contact us at <b>[email protected]</b>.
58
+ """
59
+ from apps.third_party.CRM.pipelines import TwoStagePipeline
60
+ from apps.third_party.LGM.pipeline_mvdream import MVDreamPipeline
61
+
62
+
63
+ model = None
64
+ cached_dir = None
65
+ stage1_config = OmegaConf.load(f"{parent_dir}/apps/third_party/CRM/configs/nf7_v3_SNR_rd_size_stroke.yaml").config
66
+ stage1_sampler_config = stage1_config.sampler
67
+ stage1_model_config = stage1_config.models
68
+ stage1_model_config.resume = hf_hub_download(repo_id="Zhengyi/CRM", filename="pixel-diffusion.pth", repo_type="model")
69
+ stage1_model_config.config = f"{parent_dir}/apps/third_party/CRM/" + stage1_model_config.config
70
+ crm_pipeline = None
71
+
72
+ sys.path.append(f"apps/third_party/LGM")
73
+ imgaedream_pipeline = None
74
+
75
+ @spaces.GPU
76
+ def gen_mvimg(
77
+ mvimg_model, image, seed, guidance_scale, step, text, neg_text, elevation,
78
+ ):
79
+ if seed == 0:
80
+ seed = np.random.randint(1, 65535)
81
+
82
+ if mvimg_model == "CRM":
83
+ global crm_pipeline
84
+ crm_pipeline.set_seed(seed)
85
+ mv_imgs = crm_pipeline(
86
+ image,
87
+ scale=guidance_scale,
88
+ step=step
89
+ )["stage1_images"]
90
+ return mv_imgs[5], mv_imgs[3], mv_imgs[2], mv_imgs[0]
91
+
92
+ elif mvimg_model == "ImageDream":
93
+ global imagedream_pipeline, generator
94
+ image = np.array(image).astype(np.float32) / 255.0
95
+ image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4])
96
+ mv_imgs = imagedream_pipeline(
97
+ text,
98
+ image,
99
+ negative_prompt=neg_text,
100
+ guidance_scale=guidance_scale,
101
+ num_inference_steps=step,
102
+ elevation=elevation,
103
+ )
104
+ return mv_imgs[1], mv_imgs[2], mv_imgs[3], mv_imgs[0]
105
+
106
+
107
+ @spaces.GPU
108
+ def image2mesh(view_front: np.ndarray,
109
+ view_right: np.ndarray,
110
+ view_back: np.ndarray,
111
+ view_left: np.ndarray,
112
+ more: bool = False,
113
+ scheluder_name: str ="DDIMScheduler",
114
+ guidance_scale: int = 7.5,
115
+ seed: int = 4,
116
+ octree_depth: int = 7):
117
+
118
+ sample_inputs = {
119
+ "mvimages": [[
120
+ Image.fromarray(view_front),
121
+ Image.fromarray(view_right),
122
+ Image.fromarray(view_back),
123
+ Image.fromarray(view_left)
124
+ ]]
125
+ }
126
+
127
+ global model
128
+ latents = model.sample(
129
+ sample_inputs,
130
+ sample_times=1,
131
+ guidance_scale=guidance_scale,
132
+ return_intermediates=False,
133
+ seed=seed
134
+
135
+ )[0]
136
+
137
+ # decode the latents to mesh
138
+ box_v = 1.1
139
+ mesh_outputs, _ = model.shape_model.extract_geometry(
140
+ latents,
141
+ bounds=[-box_v, -box_v, -box_v, box_v, box_v, box_v],
142
+ octree_depth=octree_depth
143
+ )
144
+ assert len(mesh_outputs) == 1, "Only support single mesh output for gradio demo"
145
+ mesh = trimesh.Trimesh(mesh_outputs[0][0], mesh_outputs[0][1])
146
+ # filepath = f"{cached_dir}/{time.time()}.obj"
147
+ filepath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
148
+ mesh.export(filepath, include_normals=True)
149
+
150
+ if 'Remesh' in more:
151
+ remeshed_filepath = tempfile.NamedTemporaryFile(suffix=f"_remeshed.obj", delete=False).name
152
+ print("Remeshing with Instant Meshes...")
153
+ # target_face_count = int(len(mesh.faces)/10)
154
+ target_face_count = 2000
155
+ command = f"{proj_dir}/apps/third_party/InstantMeshes {filepath} -f {target_face_count} -o {remeshed_filepath}"
156
+ os.system(command)
157
+ filepath = remeshed_filepath
158
+ # filepath = filepath.replace('.obj', '_remeshed.obj')
159
+
160
+ return filepath
161
+
162
+ if __name__=="__main__":
163
+ parser = argparse.ArgumentParser()
164
+ # parser.add_argument("--model_path", type=str, required=True, help="Path to the object file",)
165
+ parser.add_argument("--cached_dir", type=str, default="./gradio_cached_dir")
166
+ parser.add_argument("--device", type=int, default=0)
167
+ args = parser.parse_args()
168
+
169
+ cached_dir = args.cached_dir
170
+ os.makedirs(args.cached_dir, exist_ok=True)
171
+ device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu")
172
+ print(f"using device: {device}")
173
+
174
+ # for multi-view images generation
175
+ background_choice = OrderedDict({
176
+ "Alpha as Mask": "Alpha as Mask",
177
+ "Auto Remove Background": "Auto Remove Background",
178
+ "Original Image": "Original Image",
179
+ })
180
+ mvimg_model_config_list = ["CRM", "ImageDream"]
181
+ crm_pipeline = TwoStagePipeline(
182
+ stage1_model_config,
183
+ stage1_sampler_config,
184
+ device=device,
185
+ dtype=torch.float16
186
+ )
187
+ imagedream_pipeline = MVDreamPipeline.from_pretrained(
188
+ "ashawkey/imagedream-ipmv-diffusers", # remote weights
189
+ torch_dtype=torch.float16,
190
+ trust_remote_code=True,
191
+ )
192
+
193
+ # for 3D latent set diffusion
194
+ ckpt_path = "./ckpts/image-to-shape-diffusion/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6/model.ckpt"
195
+ config_path = "./ckpts/image-to-shape-diffusion/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6/config.yaml"
196
+ # ckpt_path = hf_hub_download(repo_id="wyysf/CraftsMan", filename="image-to-shape-diffusion/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6/model.ckpt", repo_type="model")
197
+ # config_path = hf_hub_download(repo_id="wyysf/CraftsMan", filename="image-to-shape-diffusion/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6/config.yaml", repo_type="model")
198
+ scheluder_dict = OrderedDict({
199
+ "DDIMScheduler": 'diffusers.schedulers.DDIMScheduler',
200
+ # "DPMSolverMultistepScheduler": 'diffusers.schedulers.DPMSolverMultistepScheduler', # not support yet
201
+ # "UniPCMultistepScheduler": 'diffusers.schedulers.UniPCMultistepScheduler', # not support yet
202
+ })
203
+
204
+ # main GUI
205
+ custom_theme = gr.themes.Soft(primary_hue="blue").set(
206
+ button_secondary_background_fill="*neutral_100",
207
+ button_secondary_background_fill_hover="*neutral_200")
208
+ custom_css = '''#disp_image {
209
+ text-align: center; /* Horizontally center the content */
210
+ }'''
211
+
212
+ with gr.Blocks(title=_TITLE, theme=custom_theme, css=custom_css) as demo:
213
+ with gr.Row():
214
+ with gr.Column(scale=1):
215
+ gr.Markdown('# ' + _TITLE)
216
+ gr.Markdown(_DESCRIPTION)
217
+
218
+ with gr.Row():
219
+ with gr.Column(scale=2):
220
+ with gr.Column():
221
+ # input image
222
+ with gr.Row():
223
+ image_input = gr.Image(
224
+ label="Image Input",
225
+ image_mode="RGBA",
226
+ sources="upload",
227
+ type="pil",
228
+ )
229
+ run_btn = gr.Button('Generate', variant='primary', interactive=True)
230
+
231
+ with gr.Row():
232
+ gr.Markdown('''Try a different <b>seed and MV Model</b> for better results. Good Luck :)''')
233
+ with gr.Row():
234
+ seed = gr.Number(0, label='Seed', show_label=True)
235
+ mvimg_model = gr.Dropdown(value="CRM", label="MV Image Model", choices=list(mvimg_model_config_list))
236
+ more = gr.CheckboxGroup(["Remesh", "Symmetry(TBD)"], label="More", show_label=False)
237
+ with gr.Row():
238
+ # input prompt
239
+ text = gr.Textbox(label="Prompt (Opt.)", info="only works for ImageDream")
240
+
241
+ with gr.Accordion('Advanced options', open=False):
242
+ # negative prompt
243
+ neg_text = gr.Textbox(label="Negative Prompt", value='ugly, blurry, pixelated obscure, unnatural colors, poor lighting, dull, unclear, cropped, lowres, low quality, artifacts, duplicate')
244
+ # elevation
245
+ elevation = gr.Slider(label="elevation", minimum=-90, maximum=90, step=1, value=0)
246
+
247
+ with gr.Row():
248
+ gr.Examples(
249
+ examples=[os.path.join("./apps/examples", i) for i in os.listdir("./apps/examples")],
250
+ inputs=[image_input],
251
+ examples_per_page=8
252
+ )
253
+
254
+ with gr.Column(scale=4):
255
+ with gr.Row():
256
+ output_model_obj = gr.Model3D(
257
+ label="Output Model (OBJ Format)",
258
+ camera_position=(90.0, 90.0, 3.5),
259
+ interactive=False,
260
+ )
261
+ with gr.Row():
262
+ gr.Markdown('''*please note that the model is fliped due to the gradio viewer, please download the obj file and you will get the correct orientation.''')
263
+
264
+ with gr.Row():
265
+ view_front = gr.Image(label="Front", interactive=True, show_label=True)
266
+ view_right = gr.Image(label="Right", interactive=True, show_label=True)
267
+ view_back = gr.Image(label="Back", interactive=True, show_label=True)
268
+ view_left = gr.Image(label="Left", interactive=True, show_label=True)
269
+
270
+ with gr.Accordion('Advanced options', open=False):
271
+ with gr.Row(equal_height=True):
272
+ run_mv_btn = gr.Button('Only Generate 2D', interactive=True)
273
+ run_3d_btn = gr.Button('Only Generate 3D', interactive=True)
274
+
275
+ with gr.Accordion('Advanced options (2D)', open=False):
276
+ with gr.Row():
277
+ foreground_ratio = gr.Slider(
278
+ label="Foreground Ratio",
279
+ minimum=0.5,
280
+ maximum=1.0,
281
+ value=1.0,
282
+ step=0.05,
283
+ )
284
+
285
+ with gr.Row():
286
+ background_choice = gr.Dropdown(label="Backgroud Choice", value="Auto Remove Background",choices=list(background_choice.keys()))
287
+ rmbg_type = gr.Dropdown(label="Backgroud Remove Type", value="rembg",choices=['sam', "rembg"])
288
+ backgroud_color = gr.ColorPicker(label="Background Color", value="#FFFFFF", interactive=True)
289
+ # backgroud_color = gr.ColorPicker(label="Background Color", value="#7F7F7F", interactive=True)
290
+
291
+ with gr.Row():
292
+ mvimg_guidance_scale = gr.Number(value=4.0, minimum=3, maximum=10, label="2D Guidance Scale")
293
+ mvimg_steps = gr.Number(value=30, minimum=20, maximum=100, label="2D Sample Steps")
294
+
295
+ with gr.Accordion('Advanced options (3D)', open=False):
296
+ with gr.Row():
297
+ guidance_scale = gr.Number(label="3D Guidance Scale", value=7.5, minimum=3.0, maximum=10.0)
298
+ steps = gr.Number(value=50, minimum=20, maximum=100, label="3D Sample Steps")
299
+
300
+ with gr.Row():
301
+ scheduler = gr.Dropdown(label="scheluder", value="DDIMScheduler",choices=list(scheluder_dict.keys()))
302
+ octree_depth = gr.Slider(label="Octree Depth", value=7, minimum=4, maximum=8, step=1)
303
+
304
+ gr.Markdown(_CITE_)
305
+
306
+ outputs = [output_model_obj]
307
+ rmbg = RMBG(device)
308
+
309
+ model = load_model(ckpt_path, config_path, device)
310
+
311
+ run_btn.click(fn=check_input_image, inputs=[image_input]
312
+ ).success(
313
+ fn=rmbg.run,
314
+ inputs=[rmbg_type, image_input, foreground_ratio, background_choice, backgroud_color],
315
+ outputs=[image_input]
316
+ ).success(
317
+ fn=gen_mvimg,
318
+ inputs=[mvimg_model, image_input, seed, mvimg_guidance_scale, mvimg_steps, text, neg_text, elevation],
319
+ outputs=[view_front, view_right, view_back, view_left]
320
+ ).success(
321
+ fn=image2mesh,
322
+ inputs=[view_front, view_right, view_back, view_left, more, scheduler, guidance_scale, seed, octree_depth],
323
+ outputs=outputs,
324
+ api_name="generate_img2obj")
325
+ run_mv_btn.click(fn=gen_mvimg,
326
+ inputs=[mvimg_model, image_input, seed, mvimg_guidance_scale, mvimg_steps, text, neg_text, elevation],
327
+ outputs=[view_front, view_right, view_back, view_left]
328
+ )
329
+ run_3d_btn.click(fn=image2mesh,
330
+ inputs=[view_front, view_right, view_back, view_left, more, scheduler, guidance_scale, seed, octree_depth],
331
+ outputs=outputs,
332
+ api_name="generate_img2obj")
333
+
 
 
 
 
 
334
  demo.queue().launch(share=True, allowed_paths=[args.cached_dir])