gokaygokay commited on
Commit
3d535fa
1 Parent(s): 52a5733

Upload 43 files

Browse files
Files changed (43) hide show
  1. .gitattributes +39 -35
  2. README.md +12 -12
  3. app.py +131 -154
  4. demo_files/comp.gif +3 -0
  5. demo_files/examples/animal_character.png +3 -0
  6. demo_files/examples/animal_character_2.png +3 -0
  7. demo_files/examples/axe.png +0 -0
  8. demo_files/examples/chair1.png +0 -0
  9. demo_files/examples/character1.png +0 -0
  10. demo_files/examples/otter_samurai.png +0 -0
  11. demo_files/examples/raccoon_wizard.png +0 -0
  12. demo_files/examples/stylized-rocks.png +0 -0
  13. demo_files/examples/tree.png +0 -0
  14. demo_files/hdri/abandoned_tiled_room_1k.hdr +0 -0
  15. demo_files/hdri/metro_noord_1k.hdr +0 -0
  16. demo_files/hdri/neon_photostudio_1k.hdr +0 -0
  17. demo_files/hdri/peppermint_powerplant_1k.hdr +0 -0
  18. demo_files/hdri/rainforest_trail_1k.hdr +0 -0
  19. demo_files/hdri/studio_small_08_1k.hdr +0 -0
  20. demo_files/hdri/urban_alley_01_1k.hdr +0 -0
  21. demo_files/scatterplot.jpg +0 -0
  22. demo_files/teaser.gif +3 -0
  23. flux_lora.py +109 -0
  24. load/tets/160_tets.npz +3 -0
  25. requirements.txt +19 -11
  26. sf3d/box_uv_unwrap.py +610 -0
  27. sf3d/models/camera.py +32 -0
  28. sf3d/models/global_estimator/multi_head_estimator.py +118 -0
  29. sf3d/models/image_estimator/clip_based_estimator.py +168 -0
  30. sf3d/models/isosurface.py +229 -0
  31. sf3d/models/mesh.py +172 -0
  32. sf3d/models/network.py +195 -0
  33. sf3d/models/tokenizers/dinov2.py +1196 -0
  34. sf3d/models/tokenizers/image.py +99 -0
  35. sf3d/models/tokenizers/triplane.py +49 -0
  36. sf3d/models/transformers/attention.py +31 -0
  37. sf3d/models/transformers/backbone.py +515 -0
  38. sf3d/models/utils.py +292 -0
  39. sf3d/system.py +482 -0
  40. sf3d/texture_baker.py +87 -0
  41. sf3d/texture_baker.slang +93 -0
  42. sf3d/utils.py +91 -0
  43. stable_fast.py +355 -0
.gitattributes CHANGED
@@ -1,35 +1,39 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ demo_files/comp.gif filter=lfs diff=lfs merge=lfs -text
37
+ demo_files/examples/animal_character_2.png filter=lfs diff=lfs merge=lfs -text
38
+ demo_files/examples/animal_character.png filter=lfs diff=lfs merge=lfs -text
39
+ demo_files/teaser.gif filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,13 +1,13 @@
1
- ---
2
- title: FLUX.1-dev + Captioner
3
- emoji: 🐨
4
- colorFrom: blue
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 4.37.2
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
-
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: FLUX.1-dev + Captioner
3
+ emoji: 🐨
4
+ colorFrom: blue
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 4.37.2
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ ---
12
+
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,154 +1,131 @@
1
- import spaces
2
- import gradio as gr
3
- import torch
4
- from PIL import Image
5
- from transformers import AutoProcessor, AutoModelForCausalLM, pipeline
6
- from diffusers import DiffusionPipeline
7
- import random
8
- import numpy as np
9
- import os
10
- import subprocess
11
- subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
12
-
13
- # Initialize models
14
- device = "cuda" if torch.cuda.is_available() else "cpu"
15
- dtype = torch.bfloat16
16
-
17
- huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
18
-
19
- # FLUX.1-dev model
20
- pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype, token = huggingface_token).to(device)
21
-
22
- # Initialize Florence model
23
- florence_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to(device).eval()
24
- florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True)
25
-
26
- # Prompt Enhancer
27
- enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=device)
28
-
29
- MAX_SEED = np.iinfo(np.int32).max
30
- MAX_IMAGE_SIZE = 2048
31
-
32
- # Florence caption function
33
- @spaces.GPU
34
- def florence_caption(image):
35
- # Convert image to PIL if it's not already
36
- if not isinstance(image, Image.Image):
37
- image = Image.fromarray(image)
38
-
39
- inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
40
- generated_ids = florence_model.generate(
41
- input_ids=inputs["input_ids"],
42
- pixel_values=inputs["pixel_values"],
43
- max_new_tokens=1024,
44
- early_stopping=False,
45
- do_sample=False,
46
- num_beams=3,
47
- )
48
- generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
49
- parsed_answer = florence_processor.post_process_generation(
50
- generated_text,
51
- task="<MORE_DETAILED_CAPTION>",
52
- image_size=(image.width, image.height)
53
- )
54
- return parsed_answer["<MORE_DETAILED_CAPTION>"]
55
-
56
- # Prompt Enhancer function
57
- def enhance_prompt(input_prompt):
58
- result = enhancer_long("Enhance the description: " + input_prompt)
59
- enhanced_text = result[0]['summary_text']
60
- return enhanced_text
61
-
62
- @spaces.GPU(duration=190)
63
- def process_workflow(image, text_prompt, use_enhancer, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress(track_tqdm=True)):
64
- if image is not None:
65
- # Convert image to PIL if it's not already
66
- if not isinstance(image, Image.Image):
67
- image = Image.fromarray(image)
68
-
69
- prompt = florence_caption(image)
70
- print(prompt)
71
- else:
72
- prompt = text_prompt
73
-
74
- if use_enhancer:
75
- prompt = enhance_prompt(prompt)
76
-
77
- if randomize_seed:
78
- seed = random.randint(0, MAX_SEED)
79
-
80
- generator = torch.Generator(device=device).manual_seed(seed)
81
-
82
- image = pipe(
83
- prompt=prompt,
84
- generator=generator,
85
- num_inference_steps=num_inference_steps,
86
- width=width,
87
- height=height,
88
- guidance_scale=guidance_scale
89
- ).images[0]
90
-
91
- return image, prompt, seed
92
-
93
- custom_css = """
94
- .input-group, .output-group {
95
- border: 1px solid #e0e0e0;
96
- border-radius: 10px;
97
- padding: 20px;
98
- margin-bottom: 20px;
99
- background-color: #f9f9f9;
100
- }
101
- .submit-btn {
102
- background-color: #2980b9 !important;
103
- color: white !important;
104
- }
105
- .submit-btn:hover {
106
- background-color: #3498db !important;
107
- }
108
- """
109
-
110
- title = """<h1 align="center">FLUX.1-dev with Florence-2 Captioner and Prompt Enhancer</h1>
111
- <p><center>
112
- <a href="https://huggingface.co/black-forest-labs/FLUX.1-dev" target="_blank">[FLUX.1-dev Model]</a>
113
- <a href="https://huggingface.co/microsoft/Florence-2-base" target="_blank">[Florence-2 Model]</a>
114
- <a href="https://huggingface.co/gokaygokay/Lamini-Prompt-Enchance-Long" target="_blank">[Prompt Enhancer Long]</a>
115
- <p align="center">Create long prompts from images or enhance your short prompts with prompt enhancer</p>
116
- </center></p>
117
- """
118
-
119
- with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray")) as demo:
120
- gr.HTML(title)
121
-
122
- with gr.Row():
123
- with gr.Column(scale=1):
124
- with gr.Group(elem_classes="input-group"):
125
- input_image = gr.Image(label="Input Image (Florence-2 Captioner)")
126
-
127
- with gr.Accordion("Advanced Settings", open=False):
128
- text_prompt = gr.Textbox(label="Text Prompt (optional, used if no image is uploaded)")
129
- use_enhancer = gr.Checkbox(label="Use Prompt Enhancer", value=False)
130
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
131
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
132
- width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
133
- height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
134
- guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=15, step=0.1, value=3.5)
135
- num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=28)
136
-
137
- generate_btn = gr.Button("Generate Image", elem_classes="submit-btn")
138
-
139
- with gr.Column(scale=1):
140
- with gr.Group(elem_classes="output-group"):
141
- output_image = gr.Image(label="Result", elem_id="gallery", show_label=False)
142
- final_prompt = gr.Textbox(label="Final Prompt Used")
143
- used_seed = gr.Number(label="Seed Used")
144
-
145
- generate_btn.click(
146
- fn=process_workflow,
147
- inputs=[
148
- input_image, text_prompt, use_enhancer, seed, randomize_seed,
149
- width, height, guidance_scale, num_inference_steps
150
- ],
151
- outputs=[output_image, final_prompt, used_seed]
152
- )
153
-
154
- demo.launch(debug=True)
 
1
+ import os
2
+ import tempfile
3
+ import time
4
+ import gradio as gr
5
+ import torch
6
+ from PIL import Image
7
+ from diffusers import FluxPipeline
8
+ from huggingface_hub import hf_hub_download
9
+ from sf3d.system import SF3D
10
+ import sf3d.utils as sf3d_utils
11
+ from gradio_litmodel3d import LitModel3D
12
+
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ dtype = torch.bfloat16
15
+
16
+ torch.backends.cuda.matmul.allow_tf32 = True
17
+ huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
18
+ # Set up environment and cache
19
+ cache_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models")
20
+ os.environ["TRANSFORMERS_CACHE"] = cache_path
21
+ os.environ["HF_HUB_CACHE"] = cache_path
22
+ os.environ["HF_HOME"] = cache_path
23
+
24
+ if not os.path.exists(cache_path):
25
+ os.makedirs(cache_path, exist_ok=True)
26
+
27
+ # Initialize Flux pipeline
28
+ pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, token=huggingface_token)
29
+ pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"))
30
+ pipe.fuse_lora(lora_scale=0.125)
31
+ pipe.to(device="cuda", dtype=torch.bfloat16)
32
+
33
+ # Initialize SF3D model
34
+ sf3d_model = SF3D.from_pretrained(
35
+ "stabilityai/stable-fast-3d",
36
+ config_name="config.yaml",
37
+ weight_name="model.safetensors",
38
+ token=huggingface_token
39
+
40
+ )
41
+ sf3d_model.eval().cuda()
42
+
43
+ # Constants for SF3D
44
+ COND_WIDTH, COND_HEIGHT = 512, 512
45
+ COND_DISTANCE, COND_FOVY_DEG = 1.6, 40
46
+ BACKGROUND_COLOR = [0.5, 0.5, 0.5]
47
+
48
+ c2w_cond = sf3d_utils.default_cond_c2w(COND_DISTANCE)
49
+ intrinsic, intrinsic_normed_cond = sf3d_utils.create_intrinsic_from_fov_deg(
50
+ COND_FOVY_DEG, COND_HEIGHT, COND_WIDTH
51
+ )
52
+
53
+ def generate_image(prompt, height, width, steps, scales, seed):
54
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
55
+ return pipe(
56
+ prompt=[prompt],
57
+ generator=torch.Generator().manual_seed(int(seed)),
58
+ num_inference_steps=int(steps),
59
+ guidance_scale=float(scales),
60
+ height=int(height),
61
+ width=int(width),
62
+ max_sequence_length=256
63
+ ).images[0]
64
+
65
+ def create_batch(input_image: Image.Image) -> dict:
66
+ img_cond = torch.from_numpy(
67
+ np.asarray(input_image.resize((COND_WIDTH, COND_HEIGHT))).astype(np.float32) / 255.0
68
+ ).float().clip(0, 1)
69
+ mask_cond = img_cond[:, :, -1:]
70
+ rgb_cond = torch.lerp(
71
+ torch.tensor(BACKGROUND_COLOR)[None, None, :], img_cond[:, :, :3], mask_cond
72
+ )
73
+
74
+ batch_elem = {
75
+ "rgb_cond": rgb_cond,
76
+ "mask_cond": mask_cond,
77
+ "c2w_cond": c2w_cond.unsqueeze(0),
78
+ "intrinsic_cond": intrinsic.unsqueeze(0),
79
+ "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
80
+ }
81
+ return {k: v.unsqueeze(0) for k, v in batch_elem.items()}
82
+
83
+ def generate_3d_model(input_image):
84
+ with torch.no_grad():
85
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
86
+ model_batch = create_batch(input_image)
87
+ model_batch = {k: v.cuda() for k, v in model_batch.items()}
88
+ trimesh_mesh, _ = sf3d_model.generate_mesh(model_batch, 1024)
89
+ trimesh_mesh = trimesh_mesh[0]
90
+
91
+ tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".glb")
92
+ trimesh_mesh.export(tmp_file.name, file_type="glb", include_normals=True)
93
+ return tmp_file.name
94
+
95
+ def process_and_generate(prompt, height, width, steps, scales, seed):
96
+ # Generate image from prompt
97
+ generated_image = generate_image(prompt, height, width, steps, scales, seed)
98
+
99
+ # Generate 3D model from the image
100
+ glb_file = generate_3d_model(generated_image)
101
+
102
+ return generated_image, glb_file
103
+
104
+ # Gradio interface
105
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
106
+ gr.Markdown("# Text-to-3D Model Generator")
107
+
108
+ with gr.Row():
109
+ with gr.Column(scale=3):
110
+ prompt = gr.Textbox(label="Your Image Description", lines=3)
111
+ with gr.Accordion("Advanced Settings", open=False):
112
+ height = gr.Slider(label="Height", minimum=256, maximum=1152, step=64, value=1024)
113
+ width = gr.Slider(label="Width", minimum=256, maximum=1152, step=64, value=1024)
114
+ steps = gr.Slider(label="Inference Steps", minimum=6, maximum=25, step=1, value=8)
115
+ scales = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=5.0, step=0.1, value=3.5)
116
+ seed = gr.Number(label="Seed", value=3413, precision=0)
117
+
118
+ generate_btn = gr.Button("Generate 3D Model", variant="primary")
119
+
120
+ with gr.Column(scale=4):
121
+ output_image = gr.Image(label="Generated Image")
122
+ output_3d = LitModel3D(label="3D Model", clear_color=[0.0, 0.0, 0.0, 0.0])
123
+
124
+ generate_btn.click(
125
+ process_and_generate,
126
+ inputs=[prompt, height, width, steps, scales, seed],
127
+ outputs=[output_image, output_3d]
128
+ )
129
+
130
+ if __name__ == "__main__":
131
+ demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo_files/comp.gif ADDED

Git LFS Details

  • SHA256: 1d5e060d90f29889c55c1c5681dbeb4b4c2408709d18f7451bb0a6f02c6e9bc5
  • Pointer size: 132 Bytes
  • Size of remote file: 1.93 MB
demo_files/examples/animal_character.png ADDED

Git LFS Details

  • SHA256: 5949f60c651e71a41b7291197f91bb8be2c8861472765fc884e604e18b7806a0
  • Pointer size: 132 Bytes
  • Size of remote file: 1.39 MB
demo_files/examples/animal_character_2.png ADDED

Git LFS Details

  • SHA256: ffc3f10c629afd64798d38dad2cc419eb343c7106149426f78634a91367bf031
  • Pointer size: 132 Bytes
  • Size of remote file: 1.6 MB
demo_files/examples/axe.png ADDED
demo_files/examples/chair1.png ADDED
demo_files/examples/character1.png ADDED
demo_files/examples/otter_samurai.png ADDED
demo_files/examples/raccoon_wizard.png ADDED
demo_files/examples/stylized-rocks.png ADDED
demo_files/examples/tree.png ADDED
demo_files/hdri/abandoned_tiled_room_1k.hdr ADDED
Binary file (478 kB). View file
 
demo_files/hdri/metro_noord_1k.hdr ADDED
Binary file (467 kB). View file
 
demo_files/hdri/neon_photostudio_1k.hdr ADDED
Binary file (438 kB). View file
 
demo_files/hdri/peppermint_powerplant_1k.hdr ADDED
Binary file (473 kB). View file
 
demo_files/hdri/rainforest_trail_1k.hdr ADDED
Binary file (512 kB). View file
 
demo_files/hdri/studio_small_08_1k.hdr ADDED
Binary file (412 kB). View file
 
demo_files/hdri/urban_alley_01_1k.hdr ADDED
Binary file (458 kB). View file
 
demo_files/scatterplot.jpg ADDED
demo_files/teaser.gif ADDED

Git LFS Details

  • SHA256: 1d5dcb4fbe710e94c0fa70cc2c783d66e327222cb5e74839cfd003e619bc2e1d
  • Pointer size: 132 Bytes
  • Size of remote file: 2.81 MB
flux_lora.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import argparse
3
+ import os
4
+ import time
5
+ from os import path
6
+ from safetensors.torch import load_file
7
+ from huggingface_hub import hf_hub_download
8
+
9
+ cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
10
+ os.environ["TRANSFORMERS_CACHE"] = cache_path
11
+ os.environ["HF_HUB_CACHE"] = cache_path
12
+ os.environ["HF_HOME"] = cache_path
13
+
14
+ import gradio as gr
15
+ import torch
16
+ from diffusers import FluxPipeline
17
+
18
+ torch.backends.cuda.matmul.allow_tf32 = True
19
+
20
+ class timer:
21
+ def __init__(self, method_name="timed process"):
22
+ self.method = method_name
23
+ def __enter__(self):
24
+ self.start = time.time()
25
+ print(f"{self.method} starts")
26
+ def __exit__(self, exc_type, exc_val, exc_tb):
27
+ end = time.time()
28
+ print(f"{self.method} took {str(round(end - self.start, 2))}s")
29
+
30
+ if not path.exists(cache_path):
31
+ os.makedirs(cache_path, exist_ok=True)
32
+
33
+ pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
34
+ pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"))
35
+ pipe.fuse_lora(lora_scale=0.125)
36
+ pipe.to(device="cuda", dtype=torch.bfloat16)
37
+
38
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
39
+ gr.Markdown(
40
+ """
41
+ <div style="text-align: center; max-width: 650px; margin: 0 auto;">
42
+ <h1 style="font-size: 2.5rem; font-weight: 700; margin-bottom: 1rem; display: contents;">Hyper-FLUX-8steps-LoRA</h1>
43
+ <p style="font-size: 1rem; margin-bottom: 1.5rem;">AutoML team from ByteDance</p>
44
+ </div>
45
+ """
46
+ )
47
+
48
+ with gr.Row():
49
+ with gr.Column(scale=3):
50
+ with gr.Group():
51
+ prompt = gr.Textbox(
52
+ label="Your Image Description",
53
+ placeholder="E.g., A serene landscape with mountains and a lake at sunset",
54
+ lines=3
55
+ )
56
+
57
+ with gr.Accordion("Advanced Settings", open=False):
58
+ with gr.Group():
59
+ with gr.Row():
60
+ height = gr.Slider(label="Height", minimum=256, maximum=1152, step=64, value=1024)
61
+ width = gr.Slider(label="Width", minimum=256, maximum=1152, step=64, value=1024)
62
+
63
+ with gr.Row():
64
+ steps = gr.Slider(label="Inference Steps", minimum=6, maximum=25, step=1, value=8)
65
+ scales = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=5.0, step=0.1, value=3.5)
66
+
67
+ seed = gr.Number(label="Seed (for reproducibility)", value=3413, precision=0)
68
+
69
+ generate_btn = gr.Button("Generate Image", variant="primary", scale=1)
70
+
71
+ with gr.Column(scale=4):
72
+ output = gr.Image(label="Your Generated Image")
73
+
74
+ gr.Markdown(
75
+ """
76
+ <div style="max-width: 650px; margin: 2rem auto; padding: 1rem; border-radius: 10px; background-color: #f0f0f0;">
77
+ <h2 style="font-size: 1.5rem; margin-bottom: 1rem;">How to Use</h2>
78
+ <ol style="padding-left: 1.5rem;">
79
+ <li>Enter a detailed description of the image you want to create.</li>
80
+ <li>Adjust advanced settings if desired (tap to expand).</li>
81
+ <li>Tap "Generate Image" and wait for your creation!</li>
82
+ </ol>
83
+ <p style="margin-top: 1rem; font-style: italic;">Tip: Be specific in your description for best results!</p>
84
+ </div>
85
+ """
86
+ )
87
+
88
+ @spaces.GPU
89
+ def process_image(height, width, steps, scales, prompt, seed):
90
+ global pipe
91
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
92
+ return pipe(
93
+ prompt=[prompt],
94
+ generator=torch.Generator().manual_seed(int(seed)),
95
+ num_inference_steps=int(steps),
96
+ guidance_scale=float(scales),
97
+ height=int(height),
98
+ width=int(width),
99
+ max_sequence_length=256
100
+ ).images[0]
101
+
102
+ generate_btn.click(
103
+ process_image,
104
+ inputs=[height, width, steps, scales, prompt, seed],
105
+ outputs=output
106
+ )
107
+
108
+ if __name__ == "__main__":
109
+ demo.launch()
load/tets/160_tets.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f4be37efc604d28d55a1a78c2aabefeeab7e63149f541aa45f9dd858ee35bb9
3
+ size 15408790
requirements.txt CHANGED
@@ -1,11 +1,19 @@
1
- spaces
2
- huggingface_hub
3
- accelerate
4
- git+https://github.com/huggingface/diffusers.git
5
- torch==2.4.0
6
- torchvision==0.19.0
7
- transformers==4.42.4
8
- xformers
9
- sentencepiece
10
- timm
11
- einops
 
 
 
 
 
 
 
 
 
1
+ torch>=2.1.2
2
+ torchvision>=0.16.2
3
+ einops>=0.7.0
4
+ jaxtyping>=0.2.31
5
+ omegaconf>=2.3.0
6
+ transformers>=4.43.3
7
+ slangtorch>=1.2.2
8
+ open_clip_torch>=2.24.0
9
+ trimesh>=4.4.1
10
+ numpy>=1.26.4
11
+ huggingface-hub>=0.23.4
12
+ rembg[gpu]>=2.0.57
13
+ gradio-litmodel3d>=0.0.1
14
+ accelerate
15
+ diffusers>=0.30.0
16
+ invisible_watermark
17
+ xformers
18
+ sentencepiece
19
+ peft
sf3d/box_uv_unwrap.py ADDED
@@ -0,0 +1,610 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from jaxtyping import Float, Integer
7
+ from torch import Tensor
8
+
9
+ from sf3d.models.utils import dot, triangle_intersection_2d
10
+
11
+
12
+ def _box_assign_vertex_to_cube_face(
13
+ vertex_positions: Float[Tensor, "Nv 3"],
14
+ vertex_normals: Float[Tensor, "Nv 3"],
15
+ triangle_idxs: Integer[Tensor, "Nf 3"],
16
+ bbox: Float[Tensor, "2 3"],
17
+ ) -> Tuple[Float[Tensor, "Nf 3 2"], Integer[Tensor, "Nf 3"]]:
18
+ # Test to not have a scaled model to fit the space better
19
+ # bbox_min = bbox[:1].mean(-1, keepdim=True)
20
+ # bbox_max = bbox[1:].mean(-1, keepdim=True)
21
+ # v_pos_normalized = (vertex_positions - bbox_min) / (bbox_max - bbox_min)
22
+
23
+ # Create a [0, 1] normalized vertex position
24
+ v_pos_normalized = (vertex_positions - bbox[:1]) / (bbox[1:] - bbox[:1])
25
+ # And to [-1, 1]
26
+ v_pos_normalized = 2.0 * v_pos_normalized - 1.0
27
+
28
+ # Get all vertex positions for each triangle
29
+ # Now how do we define to which face the triangle belongs? Mean face pos? Max vertex pos?
30
+ v0 = v_pos_normalized[triangle_idxs[:, 0]]
31
+ v1 = v_pos_normalized[triangle_idxs[:, 1]]
32
+ v2 = v_pos_normalized[triangle_idxs[:, 2]]
33
+ tri_stack = torch.stack([v0, v1, v2], dim=1)
34
+
35
+ vn0 = vertex_normals[triangle_idxs[:, 0]]
36
+ vn1 = vertex_normals[triangle_idxs[:, 1]]
37
+ vn2 = vertex_normals[triangle_idxs[:, 2]]
38
+ tri_stack_nrm = torch.stack([vn0, vn1, vn2], dim=1)
39
+
40
+ # Just average the normals per face
41
+ face_normal = F.normalize(torch.sum(tri_stack_nrm, 1), eps=1e-6, dim=-1)
42
+
43
+ # Now decide based on the face normal in which box map we project
44
+ # abs_x, abs_y, abs_z = tri_stack_nrm.abs().unbind(-1)
45
+ abs_x, abs_y, abs_z = tri_stack.abs().unbind(-1)
46
+
47
+ axis = torch.tensor(
48
+ [
49
+ [1, 0, 0], # 0
50
+ [-1, 0, 0], # 1
51
+ [0, 1, 0], # 2
52
+ [0, -1, 0], # 3
53
+ [0, 0, 1], # 4
54
+ [0, 0, -1], # 5
55
+ ],
56
+ device=face_normal.device,
57
+ dtype=face_normal.dtype,
58
+ )
59
+ face_normal_axis = (face_normal[:, None] * axis[None]).sum(-1)
60
+ index = face_normal_axis.argmax(-1)
61
+
62
+ max_axis, uc, vc = (
63
+ torch.ones_like(abs_x),
64
+ torch.zeros_like(tri_stack[..., :1]),
65
+ torch.zeros_like(tri_stack[..., :1]),
66
+ )
67
+ mask_pos_x = index == 0
68
+ max_axis[mask_pos_x] = abs_x[mask_pos_x]
69
+ uc[mask_pos_x] = tri_stack[mask_pos_x][..., 1:2]
70
+ vc[mask_pos_x] = -tri_stack[mask_pos_x][..., -1:]
71
+
72
+ mask_neg_x = index == 1
73
+ max_axis[mask_neg_x] = abs_x[mask_neg_x]
74
+ uc[mask_neg_x] = tri_stack[mask_neg_x][..., 1:2]
75
+ vc[mask_neg_x] = -tri_stack[mask_neg_x][..., -1:]
76
+
77
+ mask_pos_y = index == 2
78
+ max_axis[mask_pos_y] = abs_y[mask_pos_y]
79
+ uc[mask_pos_y] = tri_stack[mask_pos_y][..., 0:1]
80
+ vc[mask_pos_y] = -tri_stack[mask_pos_y][..., -1:]
81
+
82
+ mask_neg_y = index == 3
83
+ max_axis[mask_neg_y] = abs_y[mask_neg_y]
84
+ uc[mask_neg_y] = tri_stack[mask_neg_y][..., 0:1]
85
+ vc[mask_neg_y] = -tri_stack[mask_neg_y][..., -1:]
86
+
87
+ mask_pos_z = index == 4
88
+ max_axis[mask_pos_z] = abs_z[mask_pos_z]
89
+ uc[mask_pos_z] = tri_stack[mask_pos_z][..., 0:1]
90
+ vc[mask_pos_z] = tri_stack[mask_pos_z][..., 1:2]
91
+
92
+ mask_neg_z = index == 5
93
+ max_axis[mask_neg_z] = abs_z[mask_neg_z]
94
+ uc[mask_neg_z] = tri_stack[mask_neg_z][..., 0:1]
95
+ vc[mask_neg_z] = -tri_stack[mask_neg_z][..., 1:2]
96
+
97
+ # UC from [-1, 1] to [0, 1]
98
+ max_dim_div = max_axis.max(dim=0, keepdims=True).values
99
+ uc = ((uc[..., 0] / max_dim_div + 1.0) * 0.5).clip(0, 1)
100
+ vc = ((vc[..., 0] / max_dim_div + 1.0) * 0.5).clip(0, 1)
101
+
102
+ uv = torch.stack([uc, vc], dim=-1)
103
+
104
+ return uv, index
105
+
106
+
107
+ def _assign_faces_uv_to_atlas_index(
108
+ vertex_positions: Float[Tensor, "Nv 3"],
109
+ triangle_idxs: Integer[Tensor, "Nf 3"],
110
+ face_uv: Float[Tensor, "Nf 3 2"],
111
+ face_index: Integer[Tensor, "Nf 3"],
112
+ ) -> Integer[Tensor, "Nf"]: # noqa: F821
113
+ triangle_pos = vertex_positions[triangle_idxs]
114
+ # We need to do perform 3 overlap checks.
115
+ # The first set is placed in the upper two thirds of the UV atlas.
116
+ # Conceptually, this is the direct visible surfaces from the each cube side
117
+ # The second set is placed in the lower thirds and the left half of the UV atlas.
118
+ # This is the first set of occluded surfaces. They will also be saved in the projected fashion
119
+ # The third pass finds all non assigned faces. They will be placed in the bottom right half of
120
+ # the UV atlas in scattered fashion.
121
+ assign_idx = face_index.clone()
122
+ for overlap_step in range(3):
123
+ overlapping_indicator = torch.zeros_like(assign_idx, dtype=torch.bool)
124
+ for i in range(overlap_step * 6, (overlap_step + 1) * 6):
125
+ mask = assign_idx == i
126
+ if not mask.any():
127
+ continue
128
+ # Get all elements belonging to the projection face
129
+ uv_triangle = face_uv[mask]
130
+ cur_triangle_pos = triangle_pos[mask]
131
+ # Find the center of the uv coordinates
132
+ center_uv = uv_triangle.mean(dim=1, keepdim=True)
133
+ # And also the radius of the triangle
134
+ uv_triangle_radius = (uv_triangle - center_uv).norm(dim=-1).max(-1).values
135
+
136
+ potentially_overlapping_mask = (
137
+ # Find all close triangles
138
+ (center_uv[None, ...] - center_uv[:, None]).norm(dim=-1)
139
+ # Do not select the same element by offseting with an large valued identity matrix
140
+ + torch.eye(
141
+ uv_triangle.shape[0],
142
+ device=uv_triangle.device,
143
+ dtype=uv_triangle.dtype,
144
+ ).unsqueeze(-1)
145
+ * 1000
146
+ )
147
+ # Mark all potentially overlapping triangles to reduce the number of triangle intersection tests
148
+ potentially_overlapping_mask = (
149
+ potentially_overlapping_mask
150
+ <= (uv_triangle_radius.view(-1, 1, 1) * 3.0)
151
+ ).squeeze(-1)
152
+ overlap_coords = torch.stack(torch.where(potentially_overlapping_mask), -1)
153
+
154
+ # Only unique triangles (A|B and B|A should be the same)
155
+ f = torch.min(overlap_coords, dim=-1).values
156
+ s = torch.max(overlap_coords, dim=-1).values
157
+ overlap_coords = torch.unique(torch.stack([f, s], dim=1), dim=0)
158
+ first, second = overlap_coords.unbind(-1)
159
+
160
+ # Get the triangles
161
+ tri_1 = uv_triangle[first]
162
+ tri_2 = uv_triangle[second]
163
+
164
+ # Perform the actual set with the reduced number of potentially overlapping triangles
165
+ its = triangle_intersection_2d(tri_1, tri_2, eps=1e-6)
166
+
167
+ # So we now need to detect which triangles are the occluded ones.
168
+ # We always assume the first to be the visible one (the others should move)
169
+ # In the previous step we use a lexigraphical sort to get the unique pairs
170
+ # In this we use a sort based on the orthographic projection
171
+ ax = 0 if i < 2 else 1 if i < 4 else 2
172
+ use_max = i % 2 == 1
173
+
174
+ tri1_c = cur_triangle_pos[first].mean(dim=1)
175
+ tri2_c = cur_triangle_pos[second].mean(dim=1)
176
+
177
+ mark_first = (
178
+ (tri1_c[..., ax] > tri2_c[..., ax])
179
+ if use_max
180
+ else (tri1_c[..., ax] < tri2_c[..., ax])
181
+ )
182
+ first[mark_first] = second[mark_first]
183
+
184
+ # Lastly the same index can be tested multiple times.
185
+ # If one marks it as overlapping we keep it marked as such.
186
+ # We do this by testing if it has been marked at least once.
187
+ unique_idx, rev_idx = torch.unique(first, return_inverse=True)
188
+
189
+ add = torch.zeros_like(unique_idx, dtype=torch.float32)
190
+ add.index_add_(0, rev_idx, its.float())
191
+ its_mask = add > 0
192
+
193
+ # And fill it in the overlapping indicator
194
+ idx = torch.where(mask)[0][unique_idx]
195
+ overlapping_indicator[idx] = its_mask
196
+
197
+ # Move the index to the overlap regions (shift by 6)
198
+ assign_idx[overlapping_indicator] += 6
199
+
200
+ # We do not care about the correct face placement after the first 2 slices
201
+ max_idx = 6 * 2
202
+ return assign_idx.clamp(0, max_idx)
203
+
204
+
205
+ def _find_slice_offset_and_scale(
206
+ index: Integer[Tensor, "Nf"], # noqa: F821
207
+ ) -> Tuple[
208
+ Float[Tensor, "Nf"], Float[Tensor, "Nf"], Float[Tensor, "Nf"], Float[Tensor, "Nf"] # noqa: F821
209
+ ]: # noqa: F821
210
+ # 6 due to the 6 cube faces
211
+ off = 1 / 3
212
+ dupl_off = 1 / 6
213
+
214
+ # Here, we need to decide how to pack the textures in the case of overlap
215
+ def x_offset_calc(x, i):
216
+ offset_calc = i // 6
217
+ # Initial coordinates - just 3x2 grid
218
+ if offset_calc == 0:
219
+ return off * x
220
+ else:
221
+ # Smaller 3x2 grid plus eventual shift to right for
222
+ # second overlap
223
+ return dupl_off * x + min(offset_calc - 1, 1) * 0.5
224
+
225
+ def y_offset_calc(x, i):
226
+ offset_calc = i // 6
227
+ # Initial coordinates - just a 3x2 grid
228
+ if offset_calc == 0:
229
+ return off * x
230
+ else:
231
+ # Smaller coordinates in the lowest row
232
+ return dupl_off * x + off * 2
233
+
234
+ offset_x = torch.zeros_like(index, dtype=torch.float32)
235
+ offset_y = torch.zeros_like(index, dtype=torch.float32)
236
+ offset_x_vals = [0, 1, 2, 0, 1, 2]
237
+ offset_y_vals = [0, 0, 0, 1, 1, 1]
238
+ for i in range(index.max().item() + 1):
239
+ mask = index == i
240
+ if not mask.any():
241
+ continue
242
+ offset_x[mask] = x_offset_calc(offset_x_vals[i % 6], i)
243
+ offset_y[mask] = y_offset_calc(offset_y_vals[i % 6], i)
244
+
245
+ div_x = torch.full_like(index, 6 // 2, dtype=torch.float32)
246
+ # All overlap elements are saved in half scale
247
+ div_x[index >= 6] = 6
248
+ div_y = div_x.clone() # Same for y
249
+ # Except for the random overlaps
250
+ div_x[index >= 12] = 2
251
+ # But the random overlaps are saved in a large block in the lower thirds
252
+ div_y[index >= 12] = 3
253
+
254
+ return offset_x, offset_y, div_x, div_y
255
+
256
+
257
+ def rotation_flip_matrix_2d(
258
+ rad: float, flip_x: bool = False, flip_y: bool = False
259
+ ) -> Float[Tensor, "2 2"]:
260
+ cos = math.cos(rad)
261
+ sin = math.sin(rad)
262
+ rot_mat = torch.tensor([[cos, -sin], [sin, cos]], dtype=torch.float32)
263
+ flip_mat = torch.tensor(
264
+ [
265
+ [-1 if flip_x else 1, 0],
266
+ [0, -1 if flip_y else 1],
267
+ ],
268
+ dtype=torch.float32,
269
+ )
270
+
271
+ return flip_mat @ rot_mat
272
+
273
+
274
+ def calculate_tangents(
275
+ vertex_positions: Float[Tensor, "Nv 3"],
276
+ vertex_normals: Float[Tensor, "Nv 3"],
277
+ triangle_idxs: Integer[Tensor, "Nf 3"],
278
+ face_uv: Float[Tensor, "Nf 3 2"],
279
+ ) -> Float[Tensor, "Nf 3 4"]: # noqa: F821
280
+ vn_idx = [None] * 3
281
+ pos = [None] * 3
282
+ tex = face_uv.unbind(1)
283
+ for i in range(0, 3):
284
+ pos[i] = vertex_positions[triangle_idxs[:, i]]
285
+ # t_nrm_idx is always the same as t_pos_idx
286
+ vn_idx[i] = triangle_idxs[:, i]
287
+
288
+ tangents = torch.zeros_like(vertex_normals)
289
+ tansum = torch.zeros_like(vertex_normals)
290
+
291
+ # Compute tangent space for each triangle
292
+ duv1 = tex[1] - tex[0]
293
+ duv2 = tex[2] - tex[0]
294
+ dpos1 = pos[1] - pos[0]
295
+ dpos2 = pos[2] - pos[0]
296
+
297
+ tng_nom = dpos1 * duv2[..., 1:2] - dpos2 * duv1[..., 1:2]
298
+
299
+ denom = duv1[..., 0:1] * duv2[..., 1:2] - duv1[..., 1:2] * duv2[..., 0:1]
300
+
301
+ # Avoid division by zero for degenerated texture coordinates
302
+ denom_safe = denom.clip(1e-6)
303
+ tang = tng_nom / denom_safe
304
+
305
+ # Update all 3 vertices
306
+ for i in range(0, 3):
307
+ idx = vn_idx[i][:, None].repeat(1, 3)
308
+ tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang
309
+ tansum.scatter_add_(
310
+ 0, idx, torch.ones_like(tang)
311
+ ) # tansum[n_i] = tansum[n_i] + 1
312
+ # Also normalize it. Here we do not normalize the individual triangles first so larger area
313
+ # triangles influence the tangent space more
314
+ tangents = tangents / tansum
315
+
316
+ # Normalize and make sure tangent is perpendicular to normal
317
+ tangents = F.normalize(tangents, dim=1)
318
+ tangents = F.normalize(tangents - dot(tangents, vertex_normals) * vertex_normals)
319
+
320
+ return tangents
321
+
322
+
323
+ def _rotate_uv_slices_consistent_space(
324
+ vertex_positions: Float[Tensor, "Nv 3"],
325
+ vertex_normals: Float[Tensor, "Nv 3"],
326
+ triangle_idxs: Integer[Tensor, "Nf 3"],
327
+ uv: Float[Tensor, "Nf 3 2"],
328
+ index: Integer[Tensor, "Nf"], # noqa: F821
329
+ ):
330
+ tangents = calculate_tangents(vertex_positions, vertex_normals, triangle_idxs, uv)
331
+ pos_stack = torch.stack(
332
+ [
333
+ -vertex_positions[..., 1],
334
+ vertex_positions[..., 0],
335
+ torch.zeros_like(vertex_positions[..., 0]),
336
+ ],
337
+ dim=-1,
338
+ )
339
+ expected_tangents = F.normalize(
340
+ torch.linalg.cross(
341
+ vertex_normals, torch.linalg.cross(pos_stack, vertex_normals)
342
+ ),
343
+ -1,
344
+ )
345
+
346
+ actual_tangents = tangents[triangle_idxs]
347
+ expected_tangents = expected_tangents[triangle_idxs]
348
+
349
+ def rotation_matrix_2d(theta):
350
+ c, s = torch.cos(theta), torch.sin(theta)
351
+ return torch.tensor([[c, -s], [s, c]])
352
+
353
+ # Now find the rotation
354
+ index_mod = index % 6 # Shouldn't happen. Just for safety
355
+ for i in range(6):
356
+ mask = index_mod == i
357
+ if not mask.any():
358
+ continue
359
+
360
+ actual_mean_tangent = actual_tangents[mask].mean(dim=(0, 1))
361
+ expected_mean_tangent = expected_tangents[mask].mean(dim=(0, 1))
362
+
363
+ dot_product = torch.dot(actual_mean_tangent, expected_mean_tangent)
364
+ cross_product = (
365
+ actual_mean_tangent[0] * expected_mean_tangent[1]
366
+ - actual_mean_tangent[1] * expected_mean_tangent[0]
367
+ )
368
+ angle = torch.atan2(cross_product, dot_product)
369
+
370
+ rot_matrix = rotation_matrix_2d(angle).to(mask.device)
371
+ # Center the uv coordinate to be in the range of -1 to 1 and 0 centered
372
+ uv_cur = uv[mask] * 2 - 1 # Center it first
373
+ # Rotate it
374
+ uv[mask] = torch.einsum("ij,nfj->nfi", rot_matrix, uv_cur)
375
+
376
+ # Rescale uv[mask] to be within the 0-1 range
377
+ uv[mask] = (uv[mask] - uv[mask].min()) / (uv[mask].max() - uv[mask].min())
378
+
379
+ return uv
380
+
381
+
382
+ def _handle_slice_uvs(
383
+ uv: Float[Tensor, "Nf 3 2"],
384
+ index: Integer[Tensor, "Nf"], # noqa: F821
385
+ island_padding: float,
386
+ max_index: int = 6 * 2,
387
+ ) -> Float[Tensor, "Nf 3 2"]: # noqa: F821
388
+ uc, vc = uv.unbind(-1)
389
+
390
+ # Get the second slice (The first overlap)
391
+ index_filter = [index == i for i in range(6, max_index)]
392
+
393
+ # Normalize them to always fully fill the atlas patch
394
+ for i, fi in enumerate(index_filter):
395
+ if fi.sum() > 0:
396
+ # Scale the slice but only up to a factor of 2
397
+ # This keeps the texture resolution with the first slice in line (Half space in UV)
398
+ uc[fi] = (uc[fi] - uc[fi].min()) / (uc[fi].max() - uc[fi].min()).clip(0.5)
399
+ vc[fi] = (vc[fi] - vc[fi].min()) / (vc[fi].max() - vc[fi].min()).clip(0.5)
400
+
401
+ uc_padded = (uc * (1 - 2 * island_padding) + island_padding).clip(0, 1)
402
+ vc_padded = (vc * (1 - 2 * island_padding) + island_padding).clip(0, 1)
403
+
404
+ return torch.stack([uc_padded, vc_padded], dim=-1)
405
+
406
+
407
+ def _handle_remaining_uvs(
408
+ uv: Float[Tensor, "Nf 3 2"],
409
+ index: Integer[Tensor, "Nf"], # noqa: F821
410
+ island_padding: float,
411
+ ) -> Float[Tensor, "Nf 3 2"]:
412
+ uc, vc = uv.unbind(-1)
413
+ # Get all remaining elements
414
+ remaining_filter = index >= 6 * 2
415
+ squares_left = remaining_filter.sum()
416
+
417
+ if squares_left == 0:
418
+ return uv
419
+
420
+ uc = uc[remaining_filter]
421
+ vc = vc[remaining_filter]
422
+
423
+ # Or remaining triangles are distributed in a rectangle
424
+ # The rectangle takes 0.5 of the entire uv space in width and 1/3 in height
425
+ ratio = 0.5 * (1 / 3) # 1.5
426
+ # sqrt(744/(0.5*(1/3)))
427
+
428
+ mult = math.sqrt(squares_left / ratio)
429
+ num_square_width = int(math.ceil(0.5 * mult))
430
+ num_square_height = int(math.ceil(squares_left / num_square_width))
431
+
432
+ width = 1 / num_square_width
433
+ height = 1 / num_square_height
434
+
435
+ # The idea is again to keep the texture resolution consistent with the first slice
436
+ # This only occupys half the region in the texture chart but the scaling on the squares
437
+ # assumes full coverage.
438
+ clip_val = min(width, height) * 1.5
439
+ # Now normalize the UVs with taking into account the maximum scaling
440
+ uc = (uc - uc.min(dim=1, keepdim=True).values) / (
441
+ uc.amax(dim=1, keepdim=True) - uc.amin(dim=1, keepdim=True)
442
+ ).clip(clip_val)
443
+ vc = (vc - vc.min(dim=1, keepdim=True).values) / (
444
+ vc.amax(dim=1, keepdim=True) - vc.amin(dim=1, keepdim=True)
445
+ ).clip(clip_val)
446
+ # Add a small padding
447
+ uc = (
448
+ uc * (1 - island_padding * num_square_width * 0.5)
449
+ + island_padding * num_square_width * 0.25
450
+ ).clip(0, 1)
451
+ vc = (
452
+ vc * (1 - island_padding * num_square_height * 0.5)
453
+ + island_padding * num_square_height * 0.25
454
+ ).clip(0, 1)
455
+
456
+ uc = uc * width
457
+ vc = vc * height
458
+
459
+ # And calculate offsets for each element
460
+ idx = torch.arange(uc.shape[0], device=uc.device, dtype=torch.int32)
461
+ x_idx = idx % num_square_width
462
+ y_idx = idx // num_square_width
463
+ # And move each triangle to its own spot
464
+ uc = uc + x_idx[:, None] * width
465
+ vc = vc + y_idx[:, None] * height
466
+
467
+ uc = (uc * (1 - 2 * island_padding * 0.5) + island_padding * 0.5).clip(0, 1)
468
+ vc = (vc * (1 - 2 * island_padding * 0.5) + island_padding * 0.5).clip(0, 1)
469
+
470
+ uv[remaining_filter] = torch.stack([uc, vc], dim=-1)
471
+
472
+ return uv
473
+
474
+
475
+ def _distribute_individual_uvs_in_atlas(
476
+ face_uv: Float[Tensor, "Nf 3 2"],
477
+ assigned_faces: Integer[Tensor, "Nf"], # noqa: F821
478
+ offset_x: Float[Tensor, "Nf"], # noqa: F821
479
+ offset_y: Float[Tensor, "Nf"], # noqa: F821
480
+ div_x: Float[Tensor, "Nf"], # noqa: F821
481
+ div_y: Float[Tensor, "Nf"], # noqa: F821
482
+ island_padding: float,
483
+ ):
484
+ # Place the slice first
485
+ placed_uv = _handle_slice_uvs(face_uv, assigned_faces, island_padding)
486
+ # Then handle the remaining overlap elements
487
+ placed_uv = _handle_remaining_uvs(placed_uv, assigned_faces, island_padding)
488
+
489
+ uc, vc = placed_uv.unbind(-1)
490
+ uc = uc / div_x[:, None] + offset_x[:, None]
491
+ vc = vc / div_y[:, None] + offset_y[:, None]
492
+
493
+ uv = torch.stack([uc, vc], dim=-1).view(-1, 2)
494
+
495
+ return uv
496
+
497
+
498
+ def _get_unique_face_uv(
499
+ uv: Float[Tensor, "Nf 3 2"],
500
+ ) -> Tuple[Float[Tensor, "Utex 3"], Integer[Tensor, "Nf"]]: # noqa: F821
501
+ unique_uv, unique_idx = torch.unique(uv, return_inverse=True, dim=0)
502
+ # And add the face to uv index mapping
503
+ vtex_idx = unique_idx.view(-1, 3)
504
+
505
+ return unique_uv, vtex_idx
506
+
507
+
508
+ def _align_mesh_with_main_axis(
509
+ vertex_positions: Float[Tensor, "Nv 3"], vertex_normals: Float[Tensor, "Nv 3"]
510
+ ) -> Tuple[Float[Tensor, "Nv 3"], Float[Tensor, "Nv 3"]]:
511
+ # Use pca to find the 2 main axis (third is derived by cross product)
512
+ # Set the random seed so it's repeatable
513
+ torch.manual_seed(0)
514
+ _, _, v = torch.pca_lowrank(vertex_positions, q=2)
515
+ main_axis, seconday_axis = v[:, 0], v[:, 1]
516
+
517
+ main_axis: Float[Tensor, "3"] = F.normalize(main_axis, eps=1e-6, dim=-1)
518
+ # Orthogonalize the second axis
519
+ seconday_axis: Float[Tensor, "3"] = F.normalize(
520
+ seconday_axis - dot(seconday_axis, main_axis) * main_axis, eps=1e-6, dim=-1
521
+ )
522
+ # Create perpendicular third axis
523
+ third_axis: Float[Tensor, "3"] = F.normalize(
524
+ torch.cross(main_axis, seconday_axis), dim=-1, eps=1e-6
525
+ )
526
+
527
+ # Check to which canonical axis each aligns
528
+ main_axis_max_idx = main_axis.abs().argmax().item()
529
+ seconday_axis_max_idx = seconday_axis.abs().argmax().item()
530
+ third_axis_max_idx = third_axis.abs().argmax().item()
531
+
532
+ # Now sort the axes based on the argmax so they align with thecanonoical axes
533
+ # If two axes have the same argmax move one of them
534
+ all_possible_axis = {0, 1, 2}
535
+ cur_index = 1
536
+ while len(set([main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx])) != 3:
537
+ # Find missing axis
538
+ missing_axis = all_possible_axis - set(
539
+ [main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx]
540
+ )
541
+ missing_axis = missing_axis.pop()
542
+ # Just assign it to third axis as it had the smallest contribution to the
543
+ # overall shape
544
+ if cur_index == 1:
545
+ third_axis_max_idx = missing_axis
546
+ elif cur_index == 2:
547
+ seconday_axis_max_idx = missing_axis
548
+ else:
549
+ raise ValueError("Could not find 3 unique axis")
550
+ cur_index += 1
551
+
552
+ if len({main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx}) != 3:
553
+ raise ValueError("Could not find 3 unique axis")
554
+
555
+ axes = [None] * 3
556
+ axes[main_axis_max_idx] = main_axis
557
+ axes[seconday_axis_max_idx] = seconday_axis
558
+ axes[third_axis_max_idx] = third_axis
559
+ # Create rotation matrix from the individual axes
560
+ rot_mat = torch.stack(axes, dim=1).T
561
+
562
+ # Now rotate the vertex positions and vertex normals so the mesh aligns with the main axis
563
+ vertex_positions = torch.einsum("ij,nj->ni", rot_mat, vertex_positions)
564
+ vertex_normals = torch.einsum("ij,nj->ni", rot_mat, vertex_normals)
565
+
566
+ return vertex_positions, vertex_normals
567
+
568
+
569
+ def box_projection_uv_unwrap(
570
+ vertex_positions: Float[Tensor, "Nv 3"],
571
+ vertex_normals: Float[Tensor, "Nv 3"],
572
+ triangle_idxs: Integer[Tensor, "Nf 3"],
573
+ island_padding: float,
574
+ ) -> Tuple[Float[Tensor, "Utex 3"], Integer[Tensor, "Nf"]]: # noqa: F821
575
+ # Align the mesh with main axis directions first
576
+ vertex_positions, vertex_normals = _align_mesh_with_main_axis(
577
+ vertex_positions, vertex_normals
578
+ )
579
+
580
+ bbox: Float[Tensor, "2 3"] = torch.stack(
581
+ [vertex_positions.min(dim=0).values, vertex_positions.max(dim=0).values], dim=0
582
+ )
583
+ # First decide in which cube face the triangle is placed
584
+ face_uv, face_index = _box_assign_vertex_to_cube_face(
585
+ vertex_positions, vertex_normals, triangle_idxs, bbox
586
+ )
587
+
588
+ # Rotate the UV islands in a way that they align with the radial z tangent space
589
+ face_uv = _rotate_uv_slices_consistent_space(
590
+ vertex_positions, vertex_normals, triangle_idxs, face_uv, face_index
591
+ )
592
+
593
+ # Then find where where the face is placed in the atlas.
594
+ # This has to detect potential overlaps
595
+ assigned_atlas_index = _assign_faces_uv_to_atlas_index(
596
+ vertex_positions, triangle_idxs, face_uv, face_index
597
+ )
598
+
599
+ # Then figure out the final place in the atlas based on the assignment
600
+ offset_x, offset_y, div_x, div_y = _find_slice_offset_and_scale(
601
+ assigned_atlas_index
602
+ )
603
+
604
+ # Next distribute the faces in the uv atlas
605
+ placed_uv = _distribute_individual_uvs_in_atlas(
606
+ face_uv, assigned_atlas_index, offset_x, offset_y, div_x, div_y, island_padding
607
+ )
608
+
609
+ # And get the unique per-triangle UV coordinates
610
+ return _get_unique_face_uv(placed_uv)
sf3d/models/camera.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import List
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from sf3d.models.utils import BaseModule
8
+
9
+
10
+ class LinearCameraEmbedder(BaseModule):
11
+ @dataclass
12
+ class Config(BaseModule.Config):
13
+ in_channels: int = 25
14
+ out_channels: int = 768
15
+ conditions: List[str] = field(default_factory=list)
16
+
17
+ cfg: Config
18
+
19
+ def configure(self) -> None:
20
+ self.linear = nn.Linear(self.cfg.in_channels, self.cfg.out_channels)
21
+
22
+ def forward(self, **kwargs):
23
+ cond_tensors = []
24
+ for cond_name in self.cfg.conditions:
25
+ assert cond_name in kwargs
26
+ cond = kwargs[cond_name]
27
+ # cond in shape (B, Nv, ...)
28
+ cond_tensors.append(cond.view(*cond.shape[:2], -1))
29
+ cond_tensor = torch.cat(cond_tensors, dim=-1)
30
+ assert cond_tensor.shape[-1] == self.cfg.in_channels
31
+ embedding = self.linear(cond_tensor)
32
+ return embedding
sf3d/models/global_estimator/multi_head_estimator.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Any, List, Optional
3
+
4
+ import torch.nn as nn
5
+ from jaxtyping import Float
6
+ from torch import Tensor
7
+
8
+ from sf3d.models.network import get_activation
9
+ from sf3d.models.utils import BaseModule
10
+
11
+
12
+ @dataclass
13
+ class HeadSpec:
14
+ name: str
15
+ out_channels: int
16
+ n_hidden_layers: int
17
+ output_activation: Optional[str] = None
18
+ output_bias: float = 0.0
19
+ add_to_decoder_features: bool = False
20
+ shape: Optional[list[int]] = None
21
+
22
+
23
+ class MultiHeadEstimator(BaseModule):
24
+ @dataclass
25
+ class Config(BaseModule.Config):
26
+ triplane_features: int = 1024
27
+
28
+ n_layers: int = 2
29
+ hidden_features: int = 512
30
+ activation: str = "relu"
31
+
32
+ pool: str = "max"
33
+ # Literal["mean", "max"] = "mean" # noqa: F821
34
+
35
+ heads: List[HeadSpec] = field(default_factory=lambda: [])
36
+
37
+ cfg: Config
38
+
39
+ def configure(self):
40
+ layers = []
41
+ cur_features = self.cfg.triplane_features * 3
42
+ for _ in range(self.cfg.n_layers):
43
+ layers.append(
44
+ nn.Conv2d(
45
+ cur_features,
46
+ self.cfg.hidden_features,
47
+ kernel_size=3,
48
+ padding=0,
49
+ stride=2,
50
+ )
51
+ )
52
+ layers.append(self.make_activation(self.cfg.activation))
53
+
54
+ cur_features = self.cfg.hidden_features
55
+
56
+ self.layers = nn.Sequential(*layers)
57
+
58
+ assert len(self.cfg.heads) > 0
59
+ heads = {}
60
+ for head in self.cfg.heads:
61
+ head_layers = []
62
+ for i in range(head.n_hidden_layers):
63
+ head_layers += [
64
+ nn.Linear(
65
+ self.cfg.hidden_features,
66
+ self.cfg.hidden_features,
67
+ ),
68
+ self.make_activation(self.cfg.activation),
69
+ ]
70
+ head_layers += [
71
+ nn.Linear(
72
+ self.cfg.hidden_features,
73
+ head.out_channels,
74
+ ),
75
+ ]
76
+ heads[head.name] = nn.Sequential(*head_layers)
77
+ self.heads = nn.ModuleDict(heads)
78
+
79
+ def make_activation(self, activation):
80
+ if activation == "relu":
81
+ return nn.ReLU(inplace=True)
82
+ elif activation == "silu":
83
+ return nn.SiLU(inplace=True)
84
+ else:
85
+ raise NotImplementedError
86
+
87
+ def forward(
88
+ self,
89
+ triplane: Float[Tensor, "B 3 F Ht Wt"],
90
+ ) -> dict[str, Any]:
91
+ x = self.layers(
92
+ triplane.reshape(
93
+ triplane.shape[0], -1, triplane.shape[-2], triplane.shape[-1]
94
+ )
95
+ )
96
+
97
+ if self.cfg.pool == "max":
98
+ x = x.amax(dim=[-2, -1])
99
+ elif self.cfg.pool == "mean":
100
+ x = x.mean(dim=[-2, -1])
101
+ else:
102
+ raise NotImplementedError
103
+
104
+ out = {
105
+ ("decoder_" if head.add_to_decoder_features else "")
106
+ + head.name: get_activation(head.output_activation)(
107
+ self.heads[head.name](x) + head.output_bias
108
+ )
109
+ for head in self.cfg.heads
110
+ }
111
+ for head in self.cfg.heads:
112
+ if head.shape:
113
+ head_name = (
114
+ "decoder_" if head.add_to_decoder_features else ""
115
+ ) + head.name
116
+ out[head_name] = out[head_name].reshape(*head.shape)
117
+
118
+ return out
sf3d/models/image_estimator/clip_based_estimator.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Any, List, Optional
3
+
4
+ import open_clip
5
+ import torch
6
+ import torch.nn as nn
7
+ from jaxtyping import Float
8
+ from torch import Tensor
9
+ from torchvision.transforms import Normalize
10
+
11
+ from sf3d.models.network import get_activation
12
+ from sf3d.models.utils import BaseModule
13
+
14
+
15
+ @dataclass
16
+ class HeadSpec:
17
+ name: str
18
+ out_channels: int
19
+ n_hidden_layers: int
20
+ output_activation: Optional[str] = None
21
+ output_bias: float = 0.0
22
+ add_to_decoder_features: bool = False
23
+ shape: Optional[list[int]] = None
24
+
25
+
26
+ class ClipBasedHeadEstimator(BaseModule):
27
+ @dataclass
28
+ class Config(BaseModule.Config):
29
+ model: str = "ViT-B-32"
30
+ pretrain: str = "laion2b_s34b_b79k"
31
+
32
+ distribution: str = "beta"
33
+
34
+ # ["mean", "mode", "sample", "sample_mean"]
35
+ distribution_eval: str = "mode"
36
+
37
+ activation: str = "relu"
38
+ hidden_features: int = 512
39
+ heads: List[HeadSpec] = field(default_factory=lambda: [])
40
+
41
+ cfg: Config
42
+
43
+ def configure(self):
44
+ self.model, _, self.preprocess = open_clip.create_model_and_transforms(
45
+ self.cfg.model, pretrained=self.cfg.pretrain
46
+ )
47
+ self.model.eval()
48
+
49
+ # Do not add the weights in self.model to the optimizer
50
+ for param in self.model.parameters():
51
+ param.requires_grad = False
52
+
53
+ assert len(self.cfg.heads) > 0
54
+ heads = {}
55
+ for head in self.cfg.heads:
56
+ head_layers = []
57
+
58
+ for i in range(head.n_hidden_layers):
59
+ head_layers += [
60
+ nn.Linear(
61
+ self.cfg.hidden_features,
62
+ self.cfg.hidden_features,
63
+ ),
64
+ self.make_activation(self.cfg.activation),
65
+ ]
66
+
67
+ head_layers = [nn.Sequential(*head_layers)]
68
+ head_layers += [
69
+ nn.Sequential(
70
+ nn.Linear(
71
+ self.cfg.hidden_features,
72
+ self.cfg.hidden_features,
73
+ ),
74
+ self.make_activation(self.cfg.activation),
75
+ nn.Linear(self.cfg.hidden_features, 1),
76
+ )
77
+ for _ in range(2)
78
+ ]
79
+ heads[head.name] = nn.ModuleList(head_layers)
80
+ self.heads = nn.ModuleDict(heads)
81
+
82
+ def make_activation(self, activation):
83
+ if activation == "relu":
84
+ return nn.ReLU(inplace=True)
85
+ elif activation == "silu":
86
+ return nn.SiLU(inplace=True)
87
+ else:
88
+ raise NotImplementedError
89
+
90
+ def forward(
91
+ self,
92
+ cond_image: Float[Tensor, "B 1 H W 3"],
93
+ sample: bool = True,
94
+ ) -> dict[str, Any]:
95
+ # Run the model
96
+ # Resize cond_image to 224
97
+ cond_image = nn.functional.interpolate(
98
+ cond_image.flatten(0, 1).permute(0, 3, 1, 2),
99
+ size=(224, 224),
100
+ mode="bilinear",
101
+ align_corners=False,
102
+ )
103
+ cond_image = Normalize(
104
+ mean=open_clip.constants.OPENAI_DATASET_MEAN,
105
+ std=open_clip.constants.OPENAI_DATASET_STD,
106
+ )(cond_image)
107
+ image_features = self.model.encode_image(cond_image)
108
+
109
+ # Run the heads
110
+ outputs = {}
111
+
112
+ for head_dict in self.cfg.heads:
113
+ head_name = head_dict.name
114
+ shared_head, d1_h, d2_h = self.heads[head_name]
115
+ shared_features = shared_head(image_features)
116
+ d1, d2 = [head(shared_features).squeeze(-1) for head in [d1_h, d2_h]]
117
+ if self.cfg.distribution == "normal":
118
+ mean = d1
119
+ var = d2
120
+ if mean.shape[-1] == 1:
121
+ outputs[head_name] = torch.distributions.Normal(
122
+ mean + head_dict.output_bias,
123
+ torch.nn.functional.softplus(var),
124
+ )
125
+ else:
126
+ outputs[head_name] = torch.distributions.MultivariateNormal(
127
+ mean + head_dict.output_bias,
128
+ torch.nn.functional.softplus(var).diag_embed(),
129
+ )
130
+ elif self.cfg.distribution == "beta":
131
+ outputs[head_name] = torch.distributions.Beta(
132
+ torch.nn.functional.softplus(d1 + head_dict.output_bias),
133
+ torch.nn.functional.softplus(d2 + head_dict.output_bias),
134
+ )
135
+ else:
136
+ raise NotImplementedError
137
+
138
+ if sample:
139
+ for head_dict in self.cfg.heads:
140
+ head_name = head_dict.name
141
+ dist = outputs[head_name]
142
+
143
+ if self.cfg.distribution_eval == "mean":
144
+ out = dist.mean
145
+ elif self.cfg.distribution_eval == "mode":
146
+ out = dist.mode
147
+ elif self.cfg.distribution_eval == "sample_mean":
148
+ out = dist.sample([10]).mean(-1)
149
+ else:
150
+ # use rsample if gradient is needed
151
+ out = dist.rsample() if self.training else dist.sample()
152
+
153
+ outputs[head_name] = get_activation(head_dict.output_activation)(out)
154
+ outputs[f"{head_name}_dist"] = dist
155
+
156
+ for head in self.cfg.heads:
157
+ if head.shape:
158
+ if not sample:
159
+ raise ValueError(
160
+ "Cannot reshape non-sampled probabilisitic outputs"
161
+ )
162
+ outputs[head.name] = outputs[head.name].reshape(*head.shape)
163
+
164
+ if head.add_to_decoder_features:
165
+ outputs[f"decoder_{head.name}"] = outputs[head.name]
166
+ del outputs[head.name]
167
+
168
+ return outputs
sf3d/models/isosurface.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from jaxtyping import Float, Integer
7
+ from torch import Tensor
8
+
9
+ from .mesh import Mesh
10
+
11
+
12
+ class IsosurfaceHelper(nn.Module):
13
+ points_range: Tuple[float, float] = (0, 1)
14
+
15
+ @property
16
+ def grid_vertices(self) -> Float[Tensor, "N 3"]:
17
+ raise NotImplementedError
18
+
19
+ @property
20
+ def requires_instance_per_batch(self) -> bool:
21
+ return False
22
+
23
+
24
+ class MarchingTetrahedraHelper(IsosurfaceHelper):
25
+ def __init__(self, resolution: int, tets_path: str):
26
+ super().__init__()
27
+ self.resolution = resolution
28
+ self.tets_path = tets_path
29
+
30
+ self.triangle_table: Float[Tensor, "..."]
31
+ self.register_buffer(
32
+ "triangle_table",
33
+ torch.as_tensor(
34
+ [
35
+ [-1, -1, -1, -1, -1, -1],
36
+ [1, 0, 2, -1, -1, -1],
37
+ [4, 0, 3, -1, -1, -1],
38
+ [1, 4, 2, 1, 3, 4],
39
+ [3, 1, 5, -1, -1, -1],
40
+ [2, 3, 0, 2, 5, 3],
41
+ [1, 4, 0, 1, 5, 4],
42
+ [4, 2, 5, -1, -1, -1],
43
+ [4, 5, 2, -1, -1, -1],
44
+ [4, 1, 0, 4, 5, 1],
45
+ [3, 2, 0, 3, 5, 2],
46
+ [1, 3, 5, -1, -1, -1],
47
+ [4, 1, 2, 4, 3, 1],
48
+ [3, 0, 4, -1, -1, -1],
49
+ [2, 0, 1, -1, -1, -1],
50
+ [-1, -1, -1, -1, -1, -1],
51
+ ],
52
+ dtype=torch.long,
53
+ ),
54
+ persistent=False,
55
+ )
56
+ self.num_triangles_table: Integer[Tensor, "..."]
57
+ self.register_buffer(
58
+ "num_triangles_table",
59
+ torch.as_tensor(
60
+ [0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long
61
+ ),
62
+ persistent=False,
63
+ )
64
+ self.base_tet_edges: Integer[Tensor, "..."]
65
+ self.register_buffer(
66
+ "base_tet_edges",
67
+ torch.as_tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long),
68
+ persistent=False,
69
+ )
70
+
71
+ tets = np.load(self.tets_path)
72
+ self._grid_vertices: Float[Tensor, "..."]
73
+ self.register_buffer(
74
+ "_grid_vertices",
75
+ torch.from_numpy(tets["vertices"]).float(),
76
+ persistent=False,
77
+ )
78
+ self.indices: Integer[Tensor, "..."]
79
+ self.register_buffer(
80
+ "indices", torch.from_numpy(tets["indices"]).long(), persistent=False
81
+ )
82
+
83
+ self._all_edges: Optional[Integer[Tensor, "Ne 2"]] = None
84
+
85
+ center_indices, boundary_indices = self.get_center_boundary_index(
86
+ self._grid_vertices
87
+ )
88
+ self.center_indices: Integer[Tensor, "..."]
89
+ self.register_buffer("center_indices", center_indices, persistent=False)
90
+ self.boundary_indices: Integer[Tensor, "..."]
91
+ self.register_buffer("boundary_indices", boundary_indices, persistent=False)
92
+
93
+ def get_center_boundary_index(self, verts):
94
+ magn = torch.sum(verts**2, dim=-1)
95
+
96
+ center_idx = torch.argmin(magn)
97
+ boundary_neg = verts == verts.max()
98
+ boundary_pos = verts == verts.min()
99
+
100
+ boundary = torch.bitwise_or(boundary_pos, boundary_neg)
101
+ boundary = torch.sum(boundary.float(), dim=-1)
102
+
103
+ boundary_idx = torch.nonzero(boundary)
104
+ return center_idx, boundary_idx.squeeze(dim=-1)
105
+
106
+ def normalize_grid_deformation(
107
+ self, grid_vertex_offsets: Float[Tensor, "Nv 3"]
108
+ ) -> Float[Tensor, "Nv 3"]:
109
+ return (
110
+ (self.points_range[1] - self.points_range[0])
111
+ / self.resolution # half tet size is approximately 1 / self.resolution
112
+ * torch.tanh(grid_vertex_offsets)
113
+ ) # FIXME: hard-coded activation
114
+
115
+ @property
116
+ def grid_vertices(self) -> Float[Tensor, "Nv 3"]:
117
+ return self._grid_vertices
118
+
119
+ @property
120
+ def all_edges(self) -> Integer[Tensor, "Ne 2"]:
121
+ if self._all_edges is None:
122
+ # compute edges on GPU, or it would be VERY SLOW (basically due to the unique operation)
123
+ edges = torch.tensor(
124
+ [0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3],
125
+ dtype=torch.long,
126
+ device=self.indices.device,
127
+ )
128
+ _all_edges = self.indices[:, edges].reshape(-1, 2)
129
+ _all_edges_sorted = torch.sort(_all_edges, dim=1)[0]
130
+ _all_edges = torch.unique(_all_edges_sorted, dim=0)
131
+ self._all_edges = _all_edges
132
+ return self._all_edges
133
+
134
+ def sort_edges(self, edges_ex2):
135
+ with torch.no_grad():
136
+ order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long()
137
+ order = order.unsqueeze(dim=1)
138
+
139
+ a = torch.gather(input=edges_ex2, index=order, dim=1)
140
+ b = torch.gather(input=edges_ex2, index=1 - order, dim=1)
141
+
142
+ return torch.stack([a, b], -1)
143
+
144
+ def _forward(self, pos_nx3, sdf_n, tet_fx4):
145
+ with torch.no_grad():
146
+ occ_n = sdf_n > 0
147
+ occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
148
+ occ_sum = torch.sum(occ_fx4, -1)
149
+ valid_tets = (occ_sum > 0) & (occ_sum < 4)
150
+ occ_sum = occ_sum[valid_tets]
151
+
152
+ # find all vertices
153
+ all_edges = tet_fx4[valid_tets][:, self.base_tet_edges].reshape(-1, 2)
154
+ all_edges = self.sort_edges(all_edges)
155
+ unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
156
+
157
+ unique_edges = unique_edges.long()
158
+ mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
159
+ mapping = (
160
+ torch.ones(
161
+ (unique_edges.shape[0]), dtype=torch.long, device=pos_nx3.device
162
+ )
163
+ * -1
164
+ )
165
+ mapping[mask_edges] = torch.arange(
166
+ mask_edges.sum(), dtype=torch.long, device=pos_nx3.device
167
+ )
168
+ idx_map = mapping[idx_map] # map edges to verts
169
+
170
+ interp_v = unique_edges[mask_edges]
171
+ edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3)
172
+ edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1)
173
+ edges_to_interp_sdf[:, -1] *= -1
174
+
175
+ denominator = edges_to_interp_sdf.sum(1, keepdim=True)
176
+
177
+ edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator
178
+ verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
179
+
180
+ idx_map = idx_map.reshape(-1, 6)
181
+
182
+ v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=pos_nx3.device))
183
+ tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
184
+ num_triangles = self.num_triangles_table[tetindex]
185
+
186
+ # Generate triangle indices
187
+ faces = torch.cat(
188
+ (
189
+ torch.gather(
190
+ input=idx_map[num_triangles == 1],
191
+ dim=1,
192
+ index=self.triangle_table[tetindex[num_triangles == 1]][:, :3],
193
+ ).reshape(-1, 3),
194
+ torch.gather(
195
+ input=idx_map[num_triangles == 2],
196
+ dim=1,
197
+ index=self.triangle_table[tetindex[num_triangles == 2]][:, :6],
198
+ ).reshape(-1, 3),
199
+ ),
200
+ dim=0,
201
+ )
202
+
203
+ return verts, faces
204
+
205
+ def forward(
206
+ self,
207
+ level: Float[Tensor, "N3 1"],
208
+ deformation: Optional[Float[Tensor, "N3 3"]] = None,
209
+ ) -> Mesh:
210
+ if deformation is not None:
211
+ grid_vertices = self.grid_vertices + self.normalize_grid_deformation(
212
+ deformation
213
+ )
214
+ else:
215
+ grid_vertices = self.grid_vertices
216
+
217
+ v_pos, t_pos_idx = self._forward(grid_vertices, level, self.indices)
218
+
219
+ mesh = Mesh(
220
+ v_pos=v_pos,
221
+ t_pos_idx=t_pos_idx,
222
+ # extras
223
+ grid_vertices=grid_vertices,
224
+ tet_edges=self.all_edges,
225
+ grid_level=level,
226
+ grid_deformation=deformation,
227
+ )
228
+
229
+ return mesh
sf3d/models/mesh.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, Optional
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from jaxtyping import Float, Integer
8
+ from torch import Tensor
9
+
10
+ from sf3d.box_uv_unwrap import box_projection_uv_unwrap
11
+ from sf3d.models.utils import dot
12
+
13
+
14
+ class Mesh:
15
+ def __init__(
16
+ self, v_pos: Float[Tensor, "Nv 3"], t_pos_idx: Integer[Tensor, "Nf 3"], **kwargs
17
+ ) -> None:
18
+ self.v_pos: Float[Tensor, "Nv 3"] = v_pos
19
+ self.t_pos_idx: Integer[Tensor, "Nf 3"] = t_pos_idx
20
+ self._v_nrm: Optional[Float[Tensor, "Nv 3"]] = None
21
+ self._v_tng: Optional[Float[Tensor, "Nv 3"]] = None
22
+ self._v_tex: Optional[Float[Tensor, "Nt 3"]] = None
23
+ self._edges: Optional[Integer[Tensor, "Ne 2"]] = None
24
+ self.extras: Dict[str, Any] = {}
25
+ for k, v in kwargs.items():
26
+ self.add_extra(k, v)
27
+
28
+ def add_extra(self, k, v) -> None:
29
+ self.extras[k] = v
30
+
31
+ @property
32
+ def requires_grad(self):
33
+ return self.v_pos.requires_grad
34
+
35
+ @property
36
+ def v_nrm(self):
37
+ if self._v_nrm is None:
38
+ self._v_nrm = self._compute_vertex_normal()
39
+ return self._v_nrm
40
+
41
+ @property
42
+ def v_tng(self):
43
+ if self._v_tng is None:
44
+ self._v_tng = self._compute_vertex_tangent()
45
+ return self._v_tng
46
+
47
+ @property
48
+ def v_tex(self):
49
+ if self._v_tex is None:
50
+ self.unwrap_uv()
51
+ return self._v_tex
52
+
53
+ @property
54
+ def edges(self):
55
+ if self._edges is None:
56
+ self._edges = self._compute_edges()
57
+ return self._edges
58
+
59
+ def _compute_vertex_normal(self):
60
+ i0 = self.t_pos_idx[:, 0]
61
+ i1 = self.t_pos_idx[:, 1]
62
+ i2 = self.t_pos_idx[:, 2]
63
+
64
+ v0 = self.v_pos[i0, :]
65
+ v1 = self.v_pos[i1, :]
66
+ v2 = self.v_pos[i2, :]
67
+
68
+ face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
69
+
70
+ # Splat face normals to vertices
71
+ v_nrm = torch.zeros_like(self.v_pos)
72
+ v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
73
+ v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
74
+ v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
75
+
76
+ # Normalize, replace zero (degenerated) normals with some default value
77
+ v_nrm = torch.where(
78
+ dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm)
79
+ )
80
+ v_nrm = F.normalize(v_nrm, dim=1)
81
+
82
+ if torch.is_anomaly_enabled():
83
+ assert torch.all(torch.isfinite(v_nrm))
84
+
85
+ return v_nrm
86
+
87
+ def _compute_vertex_tangent(self):
88
+ vn_idx = [None] * 3
89
+ pos = [None] * 3
90
+ tex = [None] * 3
91
+ for i in range(0, 3):
92
+ pos[i] = self.v_pos[self.t_pos_idx[:, i]]
93
+ tex[i] = self.v_tex[self.t_pos_idx[:, i]]
94
+ # t_nrm_idx is always the same as t_pos_idx
95
+ vn_idx[i] = self.t_pos_idx[:, i]
96
+
97
+ tangents = torch.zeros_like(self.v_nrm)
98
+ tansum = torch.zeros_like(self.v_nrm)
99
+
100
+ # Compute tangent space for each triangle
101
+ duv1 = tex[1] - tex[0]
102
+ duv2 = tex[2] - tex[0]
103
+ dpos1 = pos[1] - pos[0]
104
+ dpos2 = pos[2] - pos[0]
105
+
106
+ tng_nom = dpos1 * duv2[..., 1:2] - dpos2 * duv1[..., 1:2]
107
+
108
+ denom = duv1[..., 0:1] * duv2[..., 1:2] - duv1[..., 1:2] * duv2[..., 0:1]
109
+
110
+ # Avoid division by zero for degenerated texture coordinates
111
+ denom_safe = denom.clip(1e-6)
112
+ tang = tng_nom / denom_safe
113
+
114
+ # Update all 3 vertices
115
+ for i in range(0, 3):
116
+ idx = vn_idx[i][:, None].repeat(1, 3)
117
+ tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang
118
+ tansum.scatter_add_(
119
+ 0, idx, torch.ones_like(tang)
120
+ ) # tansum[n_i] = tansum[n_i] + 1
121
+ # Also normalize it. Here we do not normalize the individual triangles first so larger area
122
+ # triangles influence the tangent space more
123
+ tangents = tangents / tansum
124
+
125
+ # Normalize and make sure tangent is perpendicular to normal
126
+ tangents = F.normalize(tangents, dim=1)
127
+ tangents = F.normalize(tangents - dot(tangents, self.v_nrm) * self.v_nrm)
128
+
129
+ if torch.is_anomaly_enabled():
130
+ assert torch.all(torch.isfinite(tangents))
131
+
132
+ return tangents
133
+
134
+ @torch.no_grad()
135
+ def unwrap_uv(
136
+ self,
137
+ island_padding: float = 0.02,
138
+ ) -> Mesh:
139
+ uv, indices = box_projection_uv_unwrap(
140
+ self.v_pos, self.v_nrm, self.t_pos_idx, island_padding
141
+ )
142
+
143
+ # Do store per vertex UVs.
144
+ # This means we need to duplicate some vertices at the seams
145
+ individual_vertices = self.v_pos[self.t_pos_idx].reshape(-1, 3)
146
+ individual_faces = torch.arange(
147
+ individual_vertices.shape[0],
148
+ device=individual_vertices.device,
149
+ dtype=self.t_pos_idx.dtype,
150
+ ).reshape(-1, 3)
151
+ uv_flat = uv[indices].reshape((-1, 2))
152
+ # uv_flat[:, 1] = 1 - uv_flat[:, 1]
153
+
154
+ self.v_pos = individual_vertices
155
+ self.t_pos_idx = individual_faces
156
+ self._v_tex = uv_flat
157
+ self._v_nrm = self._compute_vertex_normal()
158
+ self._v_tng = self._compute_vertex_tangent()
159
+
160
+ def _compute_edges(self):
161
+ # Compute edges
162
+ edges = torch.cat(
163
+ [
164
+ self.t_pos_idx[:, [0, 1]],
165
+ self.t_pos_idx[:, [1, 2]],
166
+ self.t_pos_idx[:, [2, 0]],
167
+ ],
168
+ dim=0,
169
+ )
170
+ edges = edges.sort()[0]
171
+ edges = torch.unique(edges, dim=0)
172
+ return edges
sf3d/models/network.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Callable, List, Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from jaxtyping import Float
9
+ from torch import Tensor
10
+ from torch.autograd import Function
11
+ from torch.cuda.amp import custom_bwd, custom_fwd
12
+
13
+ from sf3d.models.utils import BaseModule, normalize
14
+
15
+
16
+ class PixelShuffleUpsampleNetwork(BaseModule):
17
+ @dataclass
18
+ class Config(BaseModule.Config):
19
+ in_channels: int = 1024
20
+ out_channels: int = 40
21
+ scale_factor: int = 4
22
+
23
+ conv_layers: int = 4
24
+ conv_kernel_size: int = 3
25
+
26
+ cfg: Config
27
+
28
+ def configure(self) -> None:
29
+ layers = []
30
+ output_channels = self.cfg.out_channels * self.cfg.scale_factor**2
31
+
32
+ in_channels = self.cfg.in_channels
33
+ for i in range(self.cfg.conv_layers):
34
+ cur_out_channels = (
35
+ in_channels if i != self.cfg.conv_layers - 1 else output_channels
36
+ )
37
+ layers.append(
38
+ nn.Conv2d(
39
+ in_channels,
40
+ cur_out_channels,
41
+ self.cfg.conv_kernel_size,
42
+ padding=(self.cfg.conv_kernel_size - 1) // 2,
43
+ )
44
+ )
45
+ if i != self.cfg.conv_layers - 1:
46
+ layers.append(nn.ReLU(inplace=True))
47
+
48
+ layers.append(nn.PixelShuffle(self.cfg.scale_factor))
49
+
50
+ self.upsample = nn.Sequential(*layers)
51
+
52
+ def forward(
53
+ self, triplanes: Float[Tensor, "B 3 Ci Hp Wp"]
54
+ ) -> Float[Tensor, "B 3 Co Hp2 Wp2"]:
55
+ return rearrange(
56
+ self.upsample(
57
+ rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3)
58
+ ),
59
+ "(B Np) Co Hp Wp -> B Np Co Hp Wp",
60
+ Np=3,
61
+ )
62
+
63
+
64
+ class _TruncExp(Function): # pylint: disable=abstract-method
65
+ # Implementation from torch-ngp:
66
+ # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
67
+ @staticmethod
68
+ @custom_fwd(cast_inputs=torch.float32)
69
+ def forward(ctx, x): # pylint: disable=arguments-differ
70
+ ctx.save_for_backward(x)
71
+ return torch.exp(x)
72
+
73
+ @staticmethod
74
+ @custom_bwd
75
+ def backward(ctx, g): # pylint: disable=arguments-differ
76
+ x = ctx.saved_tensors[0]
77
+ return g * torch.exp(torch.clamp(x, max=15))
78
+
79
+
80
+ trunc_exp = _TruncExp.apply
81
+
82
+
83
+ def get_activation(name) -> Callable:
84
+ if name is None:
85
+ return lambda x: x
86
+ name = name.lower()
87
+ if name == "none" or name == "linear" or name == "identity":
88
+ return lambda x: x
89
+ elif name == "lin2srgb":
90
+ return lambda x: torch.where(
91
+ x > 0.0031308,
92
+ torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) * 1.055 - 0.055,
93
+ 12.92 * x,
94
+ ).clamp(0.0, 1.0)
95
+ elif name == "exp":
96
+ return lambda x: torch.exp(x)
97
+ elif name == "shifted_exp":
98
+ return lambda x: torch.exp(x - 1.0)
99
+ elif name == "trunc_exp":
100
+ return trunc_exp
101
+ elif name == "shifted_trunc_exp":
102
+ return lambda x: trunc_exp(x - 1.0)
103
+ elif name == "sigmoid":
104
+ return lambda x: torch.sigmoid(x)
105
+ elif name == "tanh":
106
+ return lambda x: torch.tanh(x)
107
+ elif name == "shifted_softplus":
108
+ return lambda x: F.softplus(x - 1.0)
109
+ elif name == "scale_-11_01":
110
+ return lambda x: x * 0.5 + 0.5
111
+ elif name == "negative":
112
+ return lambda x: -x
113
+ elif name == "normalize_channel_last":
114
+ return lambda x: normalize(x)
115
+ elif name == "normalize_channel_first":
116
+ return lambda x: normalize(x, dim=1)
117
+ else:
118
+ try:
119
+ return getattr(F, name)
120
+ except AttributeError:
121
+ raise ValueError(f"Unknown activation function: {name}")
122
+
123
+
124
+ @dataclass
125
+ class HeadSpec:
126
+ name: str
127
+ out_channels: int
128
+ n_hidden_layers: int
129
+ output_activation: Optional[str] = None
130
+ out_bias: float = 0.0
131
+
132
+
133
+ class MaterialMLP(BaseModule):
134
+ @dataclass
135
+ class Config(BaseModule.Config):
136
+ in_channels: int = 120
137
+ n_neurons: int = 64
138
+ activation: str = "silu"
139
+ heads: List[HeadSpec] = field(default_factory=lambda: [])
140
+
141
+ cfg: Config
142
+
143
+ def configure(self) -> None:
144
+ assert len(self.cfg.heads) > 0
145
+ heads = {}
146
+ for head in self.cfg.heads:
147
+ head_layers = []
148
+ for i in range(head.n_hidden_layers):
149
+ head_layers += [
150
+ nn.Linear(
151
+ self.cfg.in_channels if i == 0 else self.cfg.n_neurons,
152
+ self.cfg.n_neurons,
153
+ ),
154
+ self.make_activation(self.cfg.activation),
155
+ ]
156
+ head_layers += [
157
+ nn.Linear(
158
+ self.cfg.n_neurons,
159
+ head.out_channels,
160
+ ),
161
+ ]
162
+ heads[head.name] = nn.Sequential(*head_layers)
163
+ self.heads = nn.ModuleDict(heads)
164
+
165
+ def make_activation(self, activation):
166
+ if activation == "relu":
167
+ return nn.ReLU(inplace=True)
168
+ elif activation == "silu":
169
+ return nn.SiLU(inplace=True)
170
+ else:
171
+ raise NotImplementedError
172
+
173
+ def keys(self):
174
+ return self.heads.keys()
175
+
176
+ def forward(
177
+ self, x, include: Optional[List] = None, exclude: Optional[List] = None
178
+ ):
179
+ if include is not None and exclude is not None:
180
+ raise ValueError("Cannot specify both include and exclude.")
181
+ if include is not None:
182
+ heads = [h for h in self.cfg.heads if h.name in include]
183
+ elif exclude is not None:
184
+ heads = [h for h in self.cfg.heads if h.name not in exclude]
185
+ else:
186
+ heads = self.cfg.heads
187
+
188
+ out = {
189
+ head.name: get_activation(head.output_activation)(
190
+ self.heads[head.name](x) + head.out_bias
191
+ )
192
+ for head in heads
193
+ }
194
+
195
+ return out
sf3d/models/tokenizers/dinov2.py ADDED
@@ -0,0 +1,1196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch DINOv2 model."""
16
+
17
+ import collections.abc
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import Dict, List, Optional, Set, Tuple, Union
21
+
22
+ import torch
23
+ import torch.nn.functional as F
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
+ from transformers.activations import ACT2FN
28
+ from transformers.modeling_outputs import (
29
+ BackboneOutput,
30
+ BaseModelOutput,
31
+ BaseModelOutputWithPooling,
32
+ ImageClassifierOutput,
33
+ )
34
+ from transformers.modeling_utils import PreTrainedModel
35
+ from transformers.models.dinov2.configuration_dinov2 import Dinov2Config
36
+ from transformers.pytorch_utils import (
37
+ find_pruneable_heads_and_indices,
38
+ prune_linear_layer,
39
+ )
40
+ from transformers.utils import (
41
+ add_code_sample_docstrings,
42
+ add_start_docstrings,
43
+ add_start_docstrings_to_model_forward,
44
+ logging,
45
+ replace_return_docstrings,
46
+ )
47
+ from transformers.utils.backbone_utils import BackboneMixin
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+ # General docstring
52
+ _CONFIG_FOR_DOC = "Dinov2Config"
53
+
54
+ # Base docstring
55
+ _CHECKPOINT_FOR_DOC = "facebook/dinov2-base"
56
+ _EXPECTED_OUTPUT_SHAPE = [1, 257, 768]
57
+
58
+ # Image classification docstring
59
+ _IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-base"
60
+
61
+
62
+ DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST = [
63
+ "facebook/dinov2-base",
64
+ # See all DINOv2 models at https://huggingface.co/models?filter=dinov2
65
+ ]
66
+
67
+
68
+ class Dinov2Embeddings(nn.Module):
69
+ """
70
+ Construct the CLS token, mask token, position and patch embeddings.
71
+ """
72
+
73
+ def __init__(self, config: Dinov2Config) -> None:
74
+ super().__init__()
75
+
76
+ self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
77
+ # register as mask token as it's not used in optimization
78
+ # to avoid the use of find_unused_parameters_true
79
+ # self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size))
80
+ self.register_buffer("mask_token", torch.zeros(1, config.hidden_size))
81
+ self.patch_embeddings = Dinov2PatchEmbeddings(config)
82
+ num_patches = self.patch_embeddings.num_patches
83
+ self.position_embeddings = nn.Parameter(
84
+ torch.randn(1, num_patches + 1, config.hidden_size)
85
+ )
86
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
87
+ self.config = config
88
+
89
+ def interpolate_pos_encoding(
90
+ self, embeddings: torch.Tensor, height: int, width: int
91
+ ) -> torch.Tensor:
92
+ """
93
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
94
+ resolution images.
95
+
96
+ Source:
97
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
98
+ """
99
+
100
+ num_patches = embeddings.shape[1] - 1
101
+ num_positions = self.position_embeddings.shape[1] - 1
102
+ if num_patches == num_positions and height == width:
103
+ return self.position_embeddings
104
+ class_pos_embed = self.position_embeddings[:, 0]
105
+ patch_pos_embed = self.position_embeddings[:, 1:]
106
+ dim = embeddings.shape[-1]
107
+ height = height // self.config.patch_size
108
+ width = width // self.config.patch_size
109
+ # we add a small number to avoid floating point error in the interpolation
110
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
111
+ height, width = height + 0.1, width + 0.1
112
+ patch_pos_embed = patch_pos_embed.reshape(
113
+ 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
114
+ )
115
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
116
+ patch_pos_embed = nn.functional.interpolate(
117
+ patch_pos_embed,
118
+ scale_factor=(
119
+ height / math.sqrt(num_positions),
120
+ width / math.sqrt(num_positions),
121
+ ),
122
+ mode="bicubic",
123
+ align_corners=False,
124
+ )
125
+ if (
126
+ int(height) != patch_pos_embed.shape[-2]
127
+ or int(width) != patch_pos_embed.shape[-1]
128
+ ):
129
+ raise ValueError(
130
+ "Width or height does not match with the interpolated position embeddings"
131
+ )
132
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
133
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
134
+
135
+ def forward(
136
+ self,
137
+ pixel_values: torch.Tensor,
138
+ bool_masked_pos: Optional[torch.Tensor] = None,
139
+ ) -> torch.Tensor:
140
+ batch_size, _, height, width = pixel_values.shape
141
+ patch_embeddings = self.patch_embeddings(pixel_values)
142
+ embeddings = patch_embeddings
143
+
144
+ if bool_masked_pos is not None:
145
+ embeddings = torch.where(
146
+ bool_masked_pos.unsqueeze(-1),
147
+ self.mask_token.to(embeddings.dtype).unsqueeze(0),
148
+ embeddings,
149
+ )
150
+
151
+ # add the [CLS] token to the embedded patch tokens
152
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
153
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
154
+
155
+ # add positional encoding to each token
156
+ embeddings = embeddings + self.interpolate_pos_encoding(
157
+ embeddings, height, width
158
+ )
159
+
160
+ embeddings = self.dropout(embeddings)
161
+
162
+ return embeddings
163
+
164
+
165
+ class Dinov2PatchEmbeddings(nn.Module):
166
+ """
167
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
168
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
169
+ Transformer.
170
+ """
171
+
172
+ def __init__(self, config):
173
+ super().__init__()
174
+ image_size, patch_size = config.image_size, config.patch_size
175
+ num_channels, hidden_size = config.num_channels, config.hidden_size
176
+
177
+ image_size = (
178
+ image_size
179
+ if isinstance(image_size, collections.abc.Iterable)
180
+ else (image_size, image_size)
181
+ )
182
+ patch_size = (
183
+ patch_size
184
+ if isinstance(patch_size, collections.abc.Iterable)
185
+ else (patch_size, patch_size)
186
+ )
187
+ num_patches = (image_size[1] // patch_size[1]) * (
188
+ image_size[0] // patch_size[0]
189
+ )
190
+ self.image_size = image_size
191
+ self.patch_size = patch_size
192
+ self.num_channels = num_channels
193
+ self.num_patches = num_patches
194
+
195
+ self.projection = nn.Conv2d(
196
+ num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
197
+ )
198
+
199
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
200
+ """
201
+ num_channels = pixel_values.shape[1]
202
+ if num_channels != self.num_channels:
203
+ raise ValueError(
204
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
205
+ f" Expected {self.num_channels} but got {num_channels}."
206
+ )
207
+ """
208
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
209
+ return embeddings
210
+
211
+
212
+ # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Dinov2
213
+ class Dinov2SelfAttention(nn.Module):
214
+ def __init__(self, config: Dinov2Config) -> None:
215
+ super().__init__()
216
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
217
+ config, "embedding_size"
218
+ ):
219
+ raise ValueError(
220
+ f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
221
+ f"heads {config.num_attention_heads}."
222
+ )
223
+
224
+ self.num_attention_heads = config.num_attention_heads
225
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
226
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
227
+ self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
228
+
229
+ self.query = nn.Linear(
230
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
231
+ )
232
+ self.key = nn.Linear(
233
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
234
+ )
235
+ self.value = nn.Linear(
236
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
237
+ )
238
+
239
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
240
+
241
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
242
+ new_x_shape = x.size()[:-1] + (
243
+ self.num_attention_heads,
244
+ self.attention_head_size,
245
+ )
246
+ x = x.view(new_x_shape)
247
+ return x.permute(0, 2, 1, 3)
248
+
249
+ def forward(
250
+ self,
251
+ hidden_states,
252
+ head_mask: Optional[torch.Tensor] = None,
253
+ output_attentions: bool = False,
254
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
255
+ mixed_query_layer = self.query(hidden_states)
256
+
257
+ if hasattr(F, "scaled_dot_product_attention"):
258
+ assert head_mask is None and not output_attentions
259
+ new_size = hidden_states.size()[:-1] + (
260
+ self.num_attention_heads,
261
+ self.attention_head_size,
262
+ )
263
+ key_layer = self.key(hidden_states).reshape(new_size).transpose(1, 2)
264
+ value_layer = self.value(hidden_states).reshape(new_size).transpose(1, 2)
265
+ query_layer = mixed_query_layer.reshape(new_size).transpose(1, 2)
266
+ context_layer = F.scaled_dot_product_attention(
267
+ query_layer,
268
+ key_layer,
269
+ value_layer,
270
+ dropout_p=self.attention_probs_dropout_prob,
271
+ is_causal=False,
272
+ )
273
+ context_layer = context_layer.transpose(1, 2).reshape(
274
+ *hidden_states.size()[:-1], -1
275
+ )
276
+ else:
277
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
278
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
279
+ query_layer = self.transpose_for_scores(mixed_query_layer)
280
+
281
+ # Take the dot product between "query" and "key" to get the raw attention scores.
282
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
283
+
284
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
285
+
286
+ # Normalize the attention scores to probabilities.
287
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
288
+
289
+ # This is actually dropping out entire tokens to attend to, which might
290
+ # seem a bit unusual, but is taken from the original Transformer paper.
291
+ attention_probs = self.dropout(attention_probs)
292
+
293
+ # Mask heads if we want to
294
+ if head_mask is not None:
295
+ attention_probs = attention_probs * head_mask
296
+
297
+ context_layer = torch.matmul(attention_probs, value_layer)
298
+
299
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
300
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
301
+ context_layer = context_layer.view(new_context_layer_shape)
302
+
303
+ outputs = (
304
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
305
+ )
306
+
307
+ return outputs
308
+
309
+
310
+ # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2
311
+ class Dinov2SelfOutput(nn.Module):
312
+ """
313
+ The residual connection is defined in Dinov2Layer instead of here (as is the case with other models), due to the
314
+ layernorm applied before each block.
315
+ """
316
+
317
+ def __init__(self, config: Dinov2Config) -> None:
318
+ super().__init__()
319
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
320
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
321
+
322
+ def forward(
323
+ self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
324
+ ) -> torch.Tensor:
325
+ hidden_states = self.dense(hidden_states)
326
+ hidden_states = self.dropout(hidden_states)
327
+
328
+ return hidden_states
329
+
330
+
331
+ # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Dinov2
332
+ class Dinov2Attention(nn.Module):
333
+ def __init__(self, config: Dinov2Config) -> None:
334
+ super().__init__()
335
+ self.attention = Dinov2SelfAttention(config)
336
+ self.output = Dinov2SelfOutput(config)
337
+ self.pruned_heads = set()
338
+
339
+ def prune_heads(self, heads: Set[int]) -> None:
340
+ if len(heads) == 0:
341
+ return
342
+ heads, index = find_pruneable_heads_and_indices(
343
+ heads,
344
+ self.attention.num_attention_heads,
345
+ self.attention.attention_head_size,
346
+ self.pruned_heads,
347
+ )
348
+
349
+ # Prune linear layers
350
+ self.attention.query = prune_linear_layer(self.attention.query, index)
351
+ self.attention.key = prune_linear_layer(self.attention.key, index)
352
+ self.attention.value = prune_linear_layer(self.attention.value, index)
353
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
354
+
355
+ # Update hyper params and store pruned heads
356
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(
357
+ heads
358
+ )
359
+ self.attention.all_head_size = (
360
+ self.attention.attention_head_size * self.attention.num_attention_heads
361
+ )
362
+ self.pruned_heads = self.pruned_heads.union(heads)
363
+
364
+ def forward(
365
+ self,
366
+ hidden_states: torch.Tensor,
367
+ head_mask: Optional[torch.Tensor] = None,
368
+ output_attentions: bool = False,
369
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
370
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions)
371
+
372
+ attention_output = self.output(self_outputs[0], hidden_states)
373
+
374
+ outputs = (attention_output,) + self_outputs[
375
+ 1:
376
+ ] # add attentions if we output them
377
+ return outputs
378
+
379
+
380
+ class Dinov2LayerScale(nn.Module):
381
+ def __init__(self, config) -> None:
382
+ super().__init__()
383
+ self.lambda1 = nn.Parameter(
384
+ config.layerscale_value * torch.ones(config.hidden_size)
385
+ )
386
+
387
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
388
+ return hidden_state * self.lambda1
389
+
390
+
391
+ # Copied from transformers.models.beit.modeling_beit.drop_path
392
+ def drop_path(
393
+ input: torch.Tensor, drop_prob: float = 0.0, training: bool = False
394
+ ) -> torch.Tensor:
395
+ """
396
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
397
+
398
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
399
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
400
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
401
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
402
+ argument.
403
+ """
404
+ if drop_prob == 0.0 or not training:
405
+ return input
406
+ keep_prob = 1 - drop_prob
407
+ shape = (input.shape[0],) + (1,) * (
408
+ input.ndim - 1
409
+ ) # work with diff dim tensors, not just 2D ConvNets
410
+ random_tensor = keep_prob + torch.rand(
411
+ shape, dtype=input.dtype, device=input.device
412
+ )
413
+ random_tensor.floor_() # binarize
414
+ output = input.div(keep_prob) * random_tensor
415
+ return output
416
+
417
+
418
+ # Copied from transformers.models.beit.modeling_beit.BeitDropPath
419
+ class Dinov2DropPath(nn.Module):
420
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
421
+
422
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
423
+ super().__init__()
424
+ self.drop_prob = drop_prob
425
+
426
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
427
+ return drop_path(hidden_states, self.drop_prob, self.training)
428
+
429
+ def extra_repr(self) -> str:
430
+ return "p={}".format(self.drop_prob)
431
+
432
+
433
+ class Dinov2MLP(nn.Module):
434
+ def __init__(self, config) -> None:
435
+ super().__init__()
436
+ in_features = out_features = config.hidden_size
437
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
438
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
439
+ if isinstance(config.hidden_act, str):
440
+ self.activation = ACT2FN[config.hidden_act]
441
+ else:
442
+ self.activation = config.hidden_act
443
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
444
+
445
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
446
+ hidden_state = self.fc1(hidden_state)
447
+ hidden_state = self.activation(hidden_state)
448
+ hidden_state = self.fc2(hidden_state)
449
+ return hidden_state
450
+
451
+
452
+ class Dinov2SwiGLUFFN(nn.Module):
453
+ def __init__(self, config) -> None:
454
+ super().__init__()
455
+ in_features = out_features = config.hidden_size
456
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
457
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
458
+
459
+ self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True)
460
+ self.weights_out = nn.Linear(hidden_features, out_features, bias=True)
461
+
462
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
463
+ hidden_state = self.weights_in(hidden_state)
464
+ x1, x2 = hidden_state.chunk(2, dim=-1)
465
+ hidden = nn.functional.silu(x1) * x2
466
+ return self.weights_out(hidden)
467
+
468
+
469
+ class Dinov2Layer(nn.Module):
470
+ """This corresponds to the Block class in the original implementation."""
471
+
472
+ def __init__(self, config: Dinov2Config) -> None:
473
+ super().__init__()
474
+
475
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
476
+ self.norm1_modulation = None
477
+ self.attention = Dinov2Attention(config)
478
+ self.layer_scale1 = Dinov2LayerScale(config)
479
+ self.drop_path1 = (
480
+ Dinov2DropPath(config.drop_path_rate)
481
+ if config.drop_path_rate > 0.0
482
+ else nn.Identity()
483
+ )
484
+
485
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
486
+ self.norm2_modulation = None
487
+
488
+ if config.use_swiglu_ffn:
489
+ self.mlp = Dinov2SwiGLUFFN(config)
490
+ else:
491
+ self.mlp = Dinov2MLP(config)
492
+ self.layer_scale2 = Dinov2LayerScale(config)
493
+ self.drop_path2 = (
494
+ Dinov2DropPath(config.drop_path_rate)
495
+ if config.drop_path_rate > 0.0
496
+ else nn.Identity()
497
+ )
498
+
499
+ def forward(
500
+ self,
501
+ hidden_states: torch.Tensor,
502
+ head_mask: Optional[torch.Tensor] = None,
503
+ modulation_cond: Optional[torch.Tensor] = None,
504
+ output_attentions: bool = False,
505
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
506
+ hidden_states_norm = self.norm1(hidden_states)
507
+ if self.norm1_modulation is not None:
508
+ assert modulation_cond is not None
509
+ hidden_states_norm = self.norm1_modulation(
510
+ hidden_states_norm, modulation_cond
511
+ )
512
+ self_attention_outputs = self.attention(
513
+ hidden_states_norm, # in Dinov2, layernorm is applied before self-attention
514
+ head_mask,
515
+ output_attentions=output_attentions,
516
+ )
517
+ attention_output = self_attention_outputs[0]
518
+
519
+ attention_output = self.layer_scale1(attention_output)
520
+ outputs = self_attention_outputs[
521
+ 1:
522
+ ] # add self attentions if we output attention weights
523
+
524
+ # first residual connection
525
+ hidden_states = attention_output + hidden_states
526
+
527
+ # in Dinov2, layernorm is also applied after self-attention
528
+ layer_output = self.norm2(hidden_states)
529
+ if self.norm2_modulation is not None:
530
+ assert modulation_cond is not None
531
+ layer_output = self.norm2_modulation(layer_output, modulation_cond)
532
+ layer_output = self.mlp(layer_output)
533
+ layer_output = self.layer_scale2(layer_output)
534
+
535
+ # second residual connection
536
+ layer_output = layer_output + hidden_states
537
+
538
+ outputs = (layer_output,) + outputs
539
+
540
+ return outputs
541
+
542
+ def register_ada_norm_modulation(self, norm1_mod: nn.Module, norm2_mod: nn.Module):
543
+ self.norm1_modulation = norm1_mod
544
+ self.norm2_modulation = norm2_mod
545
+
546
+
547
+ # Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->Dinov2
548
+ class Dinov2Encoder(nn.Module):
549
+ def __init__(self, config: Dinov2Config) -> None:
550
+ super().__init__()
551
+ self.config = config
552
+ self.layer = nn.ModuleList(
553
+ [Dinov2Layer(config) for _ in range(config.num_hidden_layers)]
554
+ )
555
+ self.gradient_checkpointing = False
556
+
557
+ def forward(
558
+ self,
559
+ hidden_states: torch.Tensor,
560
+ head_mask: Optional[torch.Tensor] = None,
561
+ modulation_cond: Optional[torch.Tensor] = None,
562
+ output_attentions: bool = False,
563
+ output_hidden_states: bool = False,
564
+ return_dict: bool = True,
565
+ ) -> Union[tuple, BaseModelOutput]:
566
+ all_hidden_states = () if output_hidden_states else None
567
+ all_self_attentions = () if output_attentions else None
568
+
569
+ for i, layer_module in enumerate(self.layer):
570
+ if output_hidden_states:
571
+ all_hidden_states = all_hidden_states + (hidden_states,)
572
+
573
+ layer_head_mask = head_mask[i] if head_mask is not None else None
574
+
575
+ if self.gradient_checkpointing and self.training:
576
+
577
+ def create_custom_forward(module):
578
+ def custom_forward(*inputs):
579
+ return module(*inputs, output_attentions)
580
+
581
+ return custom_forward
582
+
583
+ layer_outputs = torch.utils.checkpoint.checkpoint(
584
+ create_custom_forward(layer_module),
585
+ hidden_states,
586
+ layer_head_mask,
587
+ modulation_cond,
588
+ use_reentrant=False,
589
+ )
590
+ else:
591
+ layer_outputs = layer_module(
592
+ hidden_states, layer_head_mask, modulation_cond, output_attentions
593
+ )
594
+
595
+ hidden_states = layer_outputs[0]
596
+
597
+ if output_attentions:
598
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
599
+
600
+ if output_hidden_states:
601
+ all_hidden_states = all_hidden_states + (hidden_states,)
602
+
603
+ if not return_dict:
604
+ return tuple(
605
+ v
606
+ for v in [hidden_states, all_hidden_states, all_self_attentions]
607
+ if v is not None
608
+ )
609
+ return BaseModelOutput(
610
+ last_hidden_state=hidden_states,
611
+ hidden_states=all_hidden_states,
612
+ attentions=all_self_attentions,
613
+ )
614
+
615
+
616
+ class Dinov2PreTrainedModel(PreTrainedModel):
617
+ """
618
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
619
+ models.
620
+ """
621
+
622
+ config_class = Dinov2Config
623
+ base_model_prefix = "dinov2"
624
+ main_input_name = "pixel_values"
625
+ supports_gradient_checkpointing = True
626
+
627
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
628
+ """Initialize the weights"""
629
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
630
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
631
+ # `trunc_normal_cpu` not implemented in `half` issues
632
+ module.weight.data = nn.init.trunc_normal_(
633
+ module.weight.data.to(torch.float32),
634
+ mean=0.0,
635
+ std=self.config.initializer_range,
636
+ ).to(module.weight.dtype)
637
+ if module.bias is not None:
638
+ module.bias.data.zero_()
639
+ elif isinstance(module, nn.LayerNorm):
640
+ module.bias.data.zero_()
641
+ module.weight.data.fill_(1.0)
642
+ elif isinstance(module, Dinov2Embeddings):
643
+ module.position_embeddings.data = nn.init.trunc_normal_(
644
+ module.position_embeddings.data.to(torch.float32),
645
+ mean=0.0,
646
+ std=self.config.initializer_range,
647
+ ).to(module.position_embeddings.dtype)
648
+
649
+ module.cls_token.data = nn.init.trunc_normal_(
650
+ module.cls_token.data.to(torch.float32),
651
+ mean=0.0,
652
+ std=self.config.initializer_range,
653
+ ).to(module.cls_token.dtype)
654
+
655
+ def _set_gradient_checkpointing(
656
+ self, module: Dinov2Encoder, value: bool = False
657
+ ) -> None:
658
+ if isinstance(module, Dinov2Encoder):
659
+ module.gradient_checkpointing = value
660
+
661
+
662
+ DINOV2_START_DOCSTRING = r"""
663
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
664
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
665
+ behavior.
666
+
667
+ Parameters:
668
+ config ([`Dinov2Config`]): Model configuration class with all the parameters of the model.
669
+ Initializing with a config file does not load the weights associated with the model, only the
670
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
671
+ """
672
+
673
+ DINOV2_BASE_INPUTS_DOCSTRING = r"""
674
+ Args:
675
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
676
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
677
+ [`BitImageProcessor.preprocess`] for details.
678
+
679
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
680
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for
681
+ pre-training.
682
+
683
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
684
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
685
+
686
+ - 1 indicates the head is **not masked**,
687
+ - 0 indicates the head is **masked**.
688
+
689
+ output_attentions (`bool`, *optional*):
690
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
691
+ tensors for more detail.
692
+ output_hidden_states (`bool`, *optional*):
693
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
694
+ more detail.
695
+ return_dict (`bool`, *optional*):
696
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
697
+ """
698
+
699
+ DINOV2_INPUTS_DOCSTRING = r"""
700
+ Args:
701
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
702
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
703
+ [`BitImageProcessor.preprocess`] for details.
704
+
705
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
706
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
707
+
708
+ - 1 indicates the head is **not masked**,
709
+ - 0 indicates the head is **masked**.
710
+
711
+ output_attentions (`bool`, *optional*):
712
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
713
+ tensors for more detail.
714
+ output_hidden_states (`bool`, *optional*):
715
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
716
+ more detail.
717
+ return_dict (`bool`, *optional*):
718
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
719
+ """
720
+
721
+
722
+ @dataclass
723
+ class CustomBaseModelOutputWithPooling(BaseModelOutputWithPooling):
724
+ patch_embeddings: Optional[torch.FloatTensor] = None
725
+
726
+
727
+ @add_start_docstrings(
728
+ "The bare DINOv2 Model transformer outputting raw hidden-states without any specific head on top.",
729
+ DINOV2_START_DOCSTRING,
730
+ )
731
+ class Dinov2Model(Dinov2PreTrainedModel):
732
+ def __init__(self, config: Dinov2Config):
733
+ super().__init__(config)
734
+ self.config = config
735
+
736
+ self.embeddings = Dinov2Embeddings(config)
737
+ self.encoder = Dinov2Encoder(config)
738
+
739
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
740
+
741
+ # Initialize weights and apply final processing
742
+ self.post_init()
743
+
744
+ def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
745
+ return self.embeddings.patch_embeddings
746
+
747
+ def expand_input_channels(self, extra_input_channels: int) -> None:
748
+ if extra_input_channels == 0:
749
+ return
750
+ conv_old = self.embeddings.patch_embeddings.projection
751
+ conv_new = nn.Conv2d(
752
+ self.config.num_channels + extra_input_channels,
753
+ self.config.hidden_size,
754
+ kernel_size=self.config.patch_size,
755
+ stride=self.config.patch_size,
756
+ ).to(self.device)
757
+ with torch.no_grad():
758
+ conv_new.weight[:, :3] = conv_old.weight
759
+ conv_new.bias = conv_old.bias
760
+ self.embeddings.patch_embeddings.projection = conv_new
761
+ del conv_old
762
+
763
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
764
+ """
765
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
766
+ class PreTrainedModel
767
+ """
768
+ for layer, heads in heads_to_prune.items():
769
+ self.encoder.layer[layer].attention.prune_heads(heads)
770
+
771
+ @add_start_docstrings_to_model_forward(DINOV2_BASE_INPUTS_DOCSTRING)
772
+ @add_code_sample_docstrings(
773
+ checkpoint=_CHECKPOINT_FOR_DOC,
774
+ output_type=BaseModelOutputWithPooling,
775
+ config_class=_CONFIG_FOR_DOC,
776
+ modality="vision",
777
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
778
+ )
779
+ def forward(
780
+ self,
781
+ pixel_values: Optional[torch.Tensor] = None,
782
+ bool_masked_pos: Optional[torch.Tensor] = None,
783
+ head_mask: Optional[torch.Tensor] = None,
784
+ modulation_cond: Optional[torch.Tensor] = None,
785
+ output_attentions: Optional[bool] = None,
786
+ output_hidden_states: Optional[bool] = None,
787
+ return_dict: Optional[bool] = None,
788
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
789
+ output_attentions = (
790
+ output_attentions
791
+ if output_attentions is not None
792
+ else self.config.output_attentions
793
+ )
794
+ output_hidden_states = (
795
+ output_hidden_states
796
+ if output_hidden_states is not None
797
+ else self.config.output_hidden_states
798
+ )
799
+ return_dict = (
800
+ return_dict if return_dict is not None else self.config.use_return_dict
801
+ )
802
+
803
+ if pixel_values is None:
804
+ raise ValueError("You have to specify pixel_values")
805
+
806
+ # Prepare head mask if needed
807
+ # 1.0 in head_mask indicate we keep the head
808
+ # attention_probs has shape bsz x n_heads x N x N
809
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
810
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
811
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
812
+
813
+ embedding_output = self.embeddings(
814
+ pixel_values, bool_masked_pos=bool_masked_pos
815
+ )
816
+
817
+ encoder_outputs = self.encoder(
818
+ embedding_output,
819
+ head_mask=head_mask,
820
+ modulation_cond=modulation_cond,
821
+ output_attentions=output_attentions,
822
+ output_hidden_states=output_hidden_states,
823
+ return_dict=return_dict,
824
+ )
825
+ sequence_output = encoder_outputs[0]
826
+ sequence_output = self.layernorm(sequence_output)
827
+ pooled_output = sequence_output[:, 0, :]
828
+
829
+ if not return_dict:
830
+ head_outputs = (sequence_output, pooled_output)
831
+ return head_outputs + encoder_outputs[1:]
832
+
833
+ return CustomBaseModelOutputWithPooling(
834
+ last_hidden_state=sequence_output,
835
+ pooler_output=pooled_output,
836
+ hidden_states=encoder_outputs.hidden_states,
837
+ attentions=encoder_outputs.attentions,
838
+ patch_embeddings=embedding_output,
839
+ )
840
+
841
+ def set_gradient_checkpointing(self, value: bool = False) -> None:
842
+ self._set_gradient_checkpointing(self.encoder, value)
843
+
844
+
845
+ @add_start_docstrings(
846
+ """
847
+ Dinov2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state
848
+ of the [CLS] token) e.g. for ImageNet.
849
+ """,
850
+ DINOV2_START_DOCSTRING,
851
+ )
852
+ class Dinov2ForImageClassification(Dinov2PreTrainedModel):
853
+ def __init__(self, config: Dinov2Config) -> None:
854
+ super().__init__(config)
855
+
856
+ self.num_labels = config.num_labels
857
+ self.dinov2 = Dinov2Model(config)
858
+
859
+ # Classifier head
860
+ self.classifier = (
861
+ nn.Linear(config.hidden_size * 2, config.num_labels)
862
+ if config.num_labels > 0
863
+ else nn.Identity()
864
+ )
865
+
866
+ # Initialize weights and apply final processing
867
+ self.post_init()
868
+
869
+ @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING)
870
+ @add_code_sample_docstrings(
871
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
872
+ output_type=ImageClassifierOutput,
873
+ config_class=_CONFIG_FOR_DOC,
874
+ )
875
+ def forward(
876
+ self,
877
+ pixel_values: Optional[torch.Tensor] = None,
878
+ head_mask: Optional[torch.Tensor] = None,
879
+ labels: Optional[torch.Tensor] = None,
880
+ output_attentions: Optional[bool] = None,
881
+ output_hidden_states: Optional[bool] = None,
882
+ return_dict: Optional[bool] = None,
883
+ ) -> Union[tuple, ImageClassifierOutput]:
884
+ r"""
885
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
886
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
887
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
888
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
889
+ """
890
+ return_dict = (
891
+ return_dict if return_dict is not None else self.config.use_return_dict
892
+ )
893
+
894
+ outputs = self.dinov2(
895
+ pixel_values,
896
+ head_mask=head_mask,
897
+ output_attentions=output_attentions,
898
+ output_hidden_states=output_hidden_states,
899
+ return_dict=return_dict,
900
+ )
901
+
902
+ sequence_output = outputs[0] # batch_size, sequence_length, hidden_size
903
+
904
+ cls_token = sequence_output[:, 0]
905
+ patch_tokens = sequence_output[:, 1:]
906
+
907
+ linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
908
+
909
+ logits = self.classifier(linear_input)
910
+
911
+ loss = None
912
+ if labels is not None:
913
+ # move labels to correct device to enable model parallelism
914
+ labels = labels.to(logits.device)
915
+ if self.config.problem_type is None:
916
+ if self.num_labels == 1:
917
+ self.config.problem_type = "regression"
918
+ elif self.num_labels > 1 and (
919
+ labels.dtype == torch.long or labels.dtype == torch.int
920
+ ):
921
+ self.config.problem_type = "single_label_classification"
922
+ else:
923
+ self.config.problem_type = "multi_label_classification"
924
+
925
+ if self.config.problem_type == "regression":
926
+ loss_fct = MSELoss()
927
+ if self.num_labels == 1:
928
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
929
+ else:
930
+ loss = loss_fct(logits, labels)
931
+ elif self.config.problem_type == "single_label_classification":
932
+ loss_fct = CrossEntropyLoss()
933
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
934
+ elif self.config.problem_type == "multi_label_classification":
935
+ loss_fct = BCEWithLogitsLoss()
936
+ loss = loss_fct(logits, labels)
937
+
938
+ if not return_dict:
939
+ output = (logits,) + outputs[2:]
940
+ return ((loss,) + output) if loss is not None else output
941
+
942
+ return ImageClassifierOutput(
943
+ loss=loss,
944
+ logits=logits,
945
+ hidden_states=outputs.hidden_states,
946
+ attentions=outputs.attentions,
947
+ )
948
+
949
+
950
+ @add_start_docstrings(
951
+ """
952
+ Dinov2 backbone, to be used with frameworks like DETR and MaskFormer.
953
+ """,
954
+ DINOV2_START_DOCSTRING,
955
+ )
956
+ class Dinov2Backbone(Dinov2PreTrainedModel, BackboneMixin):
957
+ def __init__(self, config):
958
+ super().__init__(config)
959
+ super()._init_backbone(config)
960
+
961
+ self.num_features = [
962
+ config.hidden_size for _ in range(config.num_hidden_layers + 1)
963
+ ]
964
+ self.embeddings = Dinov2Embeddings(config)
965
+ self.encoder = Dinov2Encoder(config)
966
+
967
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
968
+
969
+ # Initialize weights and apply final processing
970
+ self.post_init()
971
+
972
+ def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
973
+ return self.embeddings.patch_embeddings
974
+
975
+ @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING)
976
+ @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
977
+ def forward(
978
+ self,
979
+ pixel_values: torch.Tensor,
980
+ output_hidden_states: Optional[bool] = None,
981
+ output_attentions: Optional[bool] = None,
982
+ return_dict: Optional[bool] = None,
983
+ ) -> BackboneOutput:
984
+ """
985
+ Returns:
986
+
987
+ Examples:
988
+
989
+ ```python
990
+ >>> from transformers import AutoImageProcessor, AutoBackbone
991
+ >>> import torch
992
+ >>> from PIL import Image
993
+ >>> import requests
994
+
995
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
996
+ >>> image = Image.open(requests.get(url, stream=True).raw)
997
+
998
+ >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
999
+ >>> model = AutoBackbone.from_pretrained(
1000
+ ... "facebook/dinov2-base", out_features=["stage2", "stage5", "stage8", "stage11"]
1001
+ ... )
1002
+
1003
+ >>> inputs = processor(image, return_tensors="pt")
1004
+
1005
+ >>> outputs = model(**inputs)
1006
+ >>> feature_maps = outputs.feature_maps
1007
+ >>> list(feature_maps[-1].shape)
1008
+ [1, 768, 16, 16]
1009
+ ```"""
1010
+ return_dict = (
1011
+ return_dict if return_dict is not None else self.config.use_return_dict
1012
+ )
1013
+ output_hidden_states = (
1014
+ output_hidden_states
1015
+ if output_hidden_states is not None
1016
+ else self.config.output_hidden_states
1017
+ )
1018
+ output_attentions = (
1019
+ output_attentions
1020
+ if output_attentions is not None
1021
+ else self.config.output_attentions
1022
+ )
1023
+
1024
+ embedding_output = self.embeddings(pixel_values)
1025
+
1026
+ outputs = self.encoder(
1027
+ embedding_output,
1028
+ output_hidden_states=True,
1029
+ output_attentions=output_attentions,
1030
+ return_dict=return_dict,
1031
+ )
1032
+
1033
+ hidden_states = outputs.hidden_states if return_dict else outputs[1]
1034
+
1035
+ feature_maps = ()
1036
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
1037
+ if stage in self.out_features:
1038
+ if self.config.apply_layernorm:
1039
+ hidden_state = self.layernorm(hidden_state)
1040
+ if self.config.reshape_hidden_states:
1041
+ batch_size, _, height, width = pixel_values.shape
1042
+ patch_size = self.config.patch_size
1043
+ hidden_state = hidden_state[:, 1:, :].reshape(
1044
+ batch_size, width // patch_size, height // patch_size, -1
1045
+ )
1046
+ hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
1047
+ feature_maps += (hidden_state,)
1048
+
1049
+ if not return_dict:
1050
+ if output_hidden_states:
1051
+ output = (feature_maps,) + outputs[1:]
1052
+ else:
1053
+ output = (feature_maps,) + outputs[2:]
1054
+ return output
1055
+
1056
+ return BackboneOutput(
1057
+ feature_maps=feature_maps,
1058
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
1059
+ attentions=outputs.attentions if output_attentions else None,
1060
+ )
1061
+
1062
+
1063
+ class CustomPatchEmbeddings(nn.Module):
1064
+ """
1065
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
1066
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
1067
+ Transformer.
1068
+ """
1069
+
1070
+ def __init__(
1071
+ self, image_size: int, patch_size: int, num_channels: int, hidden_size: int
1072
+ ):
1073
+ super().__init__()
1074
+
1075
+ image_size = (
1076
+ image_size
1077
+ if isinstance(image_size, collections.abc.Iterable)
1078
+ else (image_size, image_size)
1079
+ )
1080
+ patch_size = (
1081
+ patch_size
1082
+ if isinstance(patch_size, collections.abc.Iterable)
1083
+ else (patch_size, patch_size)
1084
+ )
1085
+ num_patches = (image_size[1] // patch_size[1]) * (
1086
+ image_size[0] // patch_size[0]
1087
+ )
1088
+ self.image_size = image_size
1089
+ self.patch_size = patch_size
1090
+ self.num_channels = num_channels
1091
+ self.num_patches = num_patches
1092
+
1093
+ self.projection = nn.Conv2d(
1094
+ num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
1095
+ )
1096
+
1097
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
1098
+ num_channels = pixel_values.shape[1]
1099
+ if num_channels != self.num_channels:
1100
+ raise ValueError(
1101
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
1102
+ f" Expected {self.num_channels} but got {num_channels}."
1103
+ )
1104
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
1105
+ return embeddings
1106
+
1107
+
1108
+ class CustomEmbeddings(nn.Module):
1109
+ """
1110
+ Construct the CLS token, mask token, position and patch embeddings.
1111
+ """
1112
+
1113
+ def __init__(
1114
+ self, image_size: int, patch_size: int, num_channels: int, hidden_size: int
1115
+ ) -> None:
1116
+ super().__init__()
1117
+
1118
+ self.image_size = image_size
1119
+ self.patch_size = patch_size
1120
+ self.num_channels = num_channels
1121
+ self.hidden_size = hidden_size
1122
+
1123
+ self.cls_token = nn.Parameter(torch.randn(1, 1, self.hidden_size))
1124
+
1125
+ self.patch_embeddings = CustomPatchEmbeddings(
1126
+ image_size, patch_size, num_channels, hidden_size
1127
+ )
1128
+ num_patches = self.patch_embeddings.num_patches
1129
+ self.position_embeddings = nn.Parameter(
1130
+ torch.randn(1, num_patches + 1, self.hidden_size)
1131
+ )
1132
+
1133
+ def interpolate_pos_encoding(
1134
+ self, embeddings: torch.Tensor, height: int, width: int
1135
+ ) -> torch.Tensor:
1136
+ """
1137
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
1138
+ resolution images.
1139
+
1140
+ Source:
1141
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
1142
+ """
1143
+
1144
+ num_patches = embeddings.shape[1] - 1
1145
+ num_positions = self.position_embeddings.shape[1] - 1
1146
+ if num_patches == num_positions and height == width:
1147
+ return self.position_embeddings
1148
+ class_pos_embed = self.position_embeddings[:, 0]
1149
+ patch_pos_embed = self.position_embeddings[:, 1:]
1150
+ dim = embeddings.shape[-1]
1151
+ height = height // self.patch_size
1152
+ width = width // self.patch_size
1153
+ # we add a small number to avoid floating point error in the interpolation
1154
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
1155
+ height, width = height + 0.1, width + 0.1
1156
+ patch_pos_embed = patch_pos_embed.reshape(
1157
+ 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
1158
+ )
1159
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
1160
+ patch_pos_embed = nn.functional.interpolate(
1161
+ patch_pos_embed,
1162
+ scale_factor=(
1163
+ height / math.sqrt(num_positions),
1164
+ width / math.sqrt(num_positions),
1165
+ ),
1166
+ mode="bicubic",
1167
+ align_corners=False,
1168
+ )
1169
+ if (
1170
+ int(height) != patch_pos_embed.shape[-2]
1171
+ or int(width) != patch_pos_embed.shape[-1]
1172
+ ):
1173
+ raise ValueError(
1174
+ "Width or height does not match with the interpolated position embeddings"
1175
+ )
1176
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
1177
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
1178
+
1179
+ def forward(
1180
+ self,
1181
+ pixel_values: torch.Tensor,
1182
+ ) -> torch.Tensor:
1183
+ batch_size, _, height, width = pixel_values.shape
1184
+ patch_embeddings = self.patch_embeddings(pixel_values)
1185
+ embeddings = patch_embeddings
1186
+
1187
+ # add the [CLS] token to the embedded patch tokens
1188
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
1189
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
1190
+
1191
+ # add positional encoding to each token
1192
+ embeddings = embeddings + self.interpolate_pos_encoding(
1193
+ embeddings, height, width
1194
+ )
1195
+
1196
+ return embeddings
sf3d/models/tokenizers/image.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange
7
+ from jaxtyping import Float
8
+ from torch import Tensor
9
+
10
+ from sf3d.models.tokenizers.dinov2 import Dinov2Model
11
+ from sf3d.models.transformers.attention import Modulation
12
+ from sf3d.models.utils import BaseModule
13
+
14
+
15
+ class DINOV2SingleImageTokenizer(BaseModule):
16
+ @dataclass
17
+ class Config(BaseModule.Config):
18
+ pretrained_model_name_or_path: str = "facebook/dinov2-large"
19
+ width: int = 512
20
+ height: int = 512
21
+ modulation_cond_dim: int = 768
22
+
23
+ cfg: Config
24
+
25
+ def configure(self) -> None:
26
+ self.model = Dinov2Model.from_pretrained(self.cfg.pretrained_model_name_or_path)
27
+
28
+ for p in self.model.parameters():
29
+ p.requires_grad_(False)
30
+ self.model.eval()
31
+
32
+ self.model.set_gradient_checkpointing(False)
33
+
34
+ # add modulation
35
+ modulations = []
36
+ for layer in self.model.encoder.layer:
37
+ norm1_modulation = Modulation(
38
+ self.model.config.hidden_size,
39
+ self.cfg.modulation_cond_dim,
40
+ zero_init=True,
41
+ single_layer=True,
42
+ )
43
+ norm2_modulation = Modulation(
44
+ self.model.config.hidden_size,
45
+ self.cfg.modulation_cond_dim,
46
+ zero_init=True,
47
+ single_layer=True,
48
+ )
49
+ layer.register_ada_norm_modulation(norm1_modulation, norm2_modulation)
50
+ modulations += [norm1_modulation, norm2_modulation]
51
+ self.modulations = nn.ModuleList(modulations)
52
+
53
+ self.register_buffer(
54
+ "image_mean",
55
+ torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1),
56
+ persistent=False,
57
+ )
58
+ self.register_buffer(
59
+ "image_std",
60
+ torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1),
61
+ persistent=False,
62
+ )
63
+
64
+ def forward(
65
+ self,
66
+ images: Float[Tensor, "B *N C H W"],
67
+ modulation_cond: Optional[Float[Tensor, "B *N Cc"]],
68
+ **kwargs,
69
+ ) -> Float[Tensor, "B *N Ct Nt"]:
70
+ model = self.model
71
+
72
+ packed = False
73
+ if images.ndim == 4:
74
+ packed = True
75
+ images = images.unsqueeze(1)
76
+ if modulation_cond is not None:
77
+ assert modulation_cond.ndim == 2
78
+ modulation_cond = modulation_cond.unsqueeze(1)
79
+
80
+ batch_size, n_input_views = images.shape[:2]
81
+ images = (images - self.image_mean) / self.image_std
82
+ out = model(
83
+ rearrange(images, "B N C H W -> (B N) C H W"),
84
+ modulation_cond=rearrange(modulation_cond, "B N Cc -> (B N) Cc")
85
+ if modulation_cond is not None
86
+ else None,
87
+ )
88
+ local_features = out.last_hidden_state
89
+ local_features = local_features.permute(0, 2, 1)
90
+ local_features = rearrange(
91
+ local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size
92
+ )
93
+ if packed:
94
+ local_features = local_features.squeeze(1)
95
+
96
+ return local_features
97
+
98
+ def detokenize(self, *args, **kwargs):
99
+ raise NotImplementedError
sf3d/models/tokenizers/triplane.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange, repeat
7
+ from jaxtyping import Float
8
+ from torch import Tensor
9
+
10
+ from sf3d.models.utils import BaseModule
11
+
12
+
13
+ class TriplaneLearnablePositionalEmbedding(BaseModule):
14
+ @dataclass
15
+ class Config(BaseModule.Config):
16
+ plane_size: int = 96
17
+ num_channels: int = 1024
18
+
19
+ cfg: Config
20
+
21
+ def configure(self) -> None:
22
+ self.embeddings = nn.Parameter(
23
+ torch.randn(
24
+ (3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size),
25
+ dtype=torch.float32,
26
+ )
27
+ * 1
28
+ / math.sqrt(self.cfg.num_channels)
29
+ )
30
+
31
+ def forward(self, batch_size: int) -> Float[Tensor, "B Ct Nt"]:
32
+ return rearrange(
33
+ repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size),
34
+ "B Np Ct Hp Wp -> B Ct (Np Hp Wp)",
35
+ )
36
+
37
+ def detokenize(
38
+ self, tokens: Float[Tensor, "B Ct Nt"]
39
+ ) -> Float[Tensor, "B 3 Ct Hp Wp"]:
40
+ batch_size, Ct, Nt = tokens.shape
41
+ assert Nt == self.cfg.plane_size**2 * 3
42
+ assert Ct == self.cfg.num_channels
43
+ return rearrange(
44
+ tokens,
45
+ "B Ct (Np Hp Wp) -> B Np Ct Hp Wp",
46
+ Np=3,
47
+ Hp=self.cfg.plane_size,
48
+ Wp=self.cfg.plane_size,
49
+ )
sf3d/models/transformers/attention.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class Modulation(nn.Module):
6
+ def __init__(
7
+ self,
8
+ embedding_dim: int,
9
+ condition_dim: int,
10
+ zero_init: bool = False,
11
+ single_layer: bool = False,
12
+ ):
13
+ super().__init__()
14
+ self.silu = nn.SiLU()
15
+ if single_layer:
16
+ self.linear1 = nn.Identity()
17
+ else:
18
+ self.linear1 = nn.Linear(condition_dim, condition_dim)
19
+
20
+ self.linear2 = nn.Linear(condition_dim, embedding_dim * 2)
21
+
22
+ # Only zero init the last linear layer
23
+ if zero_init:
24
+ nn.init.zeros_(self.linear2.weight)
25
+ nn.init.zeros_(self.linear2.bias)
26
+
27
+ def forward(self, x: torch.Tensor, condition: torch.Tensor) -> torch.Tensor:
28
+ emb = self.linear2(self.silu(self.linear1(condition)))
29
+ scale, shift = torch.chunk(emb, 2, dim=1)
30
+ x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
31
+ return x
sf3d/models/transformers/backbone.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+
8
+ from sf3d.models.utils import BaseModule
9
+
10
+
11
+ class GEGLU(nn.Module):
12
+ r"""
13
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
14
+
15
+ Parameters:
16
+ dim_in (`int`): The number of channels in the input.
17
+ dim_out (`int`): The number of channels in the output.
18
+ """
19
+
20
+ def __init__(self, dim_in: int, dim_out: int):
21
+ super().__init__()
22
+ self.proj = nn.Linear(dim_in, dim_out * 2)
23
+
24
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
25
+ if gate.device.type != "mps":
26
+ return F.gelu(gate)
27
+ # mps: gelu is not implemented for float16
28
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
29
+
30
+ def forward(self, hidden_states, scale: float = 1.0):
31
+ args = ()
32
+ hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
33
+ return hidden_states * self.gelu(gate)
34
+
35
+
36
+ class CrossAttention(nn.Module):
37
+ def __init__(
38
+ self,
39
+ dim,
40
+ kv_dim=None,
41
+ num_heads=16,
42
+ qkv_bias=False,
43
+ attn_drop=0.0,
44
+ proj_drop=0.0,
45
+ ):
46
+ super().__init__()
47
+ self.num_heads = num_heads
48
+ head_dim = dim // num_heads
49
+ self.scale = head_dim**-0.5
50
+ kv_dim = dim if not kv_dim else kv_dim
51
+ self.wq = nn.Linear(dim, dim, bias=qkv_bias)
52
+ self.wk = nn.Linear(kv_dim, dim, bias=qkv_bias)
53
+ self.wv = nn.Linear(kv_dim, dim, bias=qkv_bias)
54
+ self.attn_drop = attn_drop
55
+ self.proj = nn.Linear(dim, dim)
56
+ self.proj_drop = nn.Dropout(proj_drop)
57
+
58
+ def forward(self, x_q, x_kv):
59
+ B, N_q, C = x_q.shape
60
+ B, N_kv, _ = x_kv.shape
61
+ # [B, N_q, C] -> [B, N_q, H, C/H]
62
+ q = self.wq(x_q).reshape(B, N_q, self.num_heads, C // self.num_heads)
63
+ # [B, N_kv, C] -> [B, N_kv, H, C/H]
64
+ k = self.wk(x_kv).reshape(B, N_kv, self.num_heads, C // self.num_heads)
65
+ v = self.wv(x_kv).reshape(B, N_kv, self.num_heads, C // self.num_heads)
66
+
67
+ # attention
68
+ x = torch.nn.functional.scaled_dot_product_attention(
69
+ q.permute(0, 2, 1, 3),
70
+ k.permute(0, 2, 1, 3),
71
+ v.permute(0, 2, 1, 3),
72
+ attn_mask=None,
73
+ dropout_p=self.attn_drop,
74
+ scale=self.scale,
75
+ ).permute(0, 2, 1, 3)
76
+
77
+ # [B, N_q, H, C/H] -> [B, N_q, C]
78
+ x = x.reshape(B, N_q, C)
79
+ x = self.proj(x)
80
+ x = self.proj_drop(x)
81
+ return x
82
+
83
+
84
+ class FeedForward(nn.Module):
85
+ def __init__(
86
+ self,
87
+ dim: int,
88
+ dim_out: Optional[int] = None,
89
+ mult: int = 4,
90
+ dropout: float = 0.0,
91
+ ):
92
+ super().__init__()
93
+ inner_dim = int(dim * mult)
94
+ dim_out = dim_out if dim_out is not None else dim
95
+ act_fn = GEGLU(dim, inner_dim)
96
+ self.net = nn.ModuleList([])
97
+ self.net.append(act_fn)
98
+ self.net.append(nn.Dropout(dropout))
99
+ self.net.append(nn.Linear(inner_dim, dim_out))
100
+
101
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
102
+ for module in self.net:
103
+ x = module(x)
104
+ return x
105
+
106
+
107
+ class BasicBlock(nn.Module):
108
+ def __init__(
109
+ self,
110
+ dim: int,
111
+ kv_dim: Optional[int] = None,
112
+ num_heads: int = 16,
113
+ qkv_bias: bool = False,
114
+ attn_drop: float = 0.0,
115
+ proj_drop: float = 0.0,
116
+ ff_drop: float = 0.0,
117
+ ):
118
+ super().__init__()
119
+ self.norm1 = nn.LayerNorm(dim)
120
+ self.attn1 = CrossAttention(
121
+ dim,
122
+ kv_dim=dim,
123
+ num_heads=num_heads,
124
+ qkv_bias=qkv_bias,
125
+ attn_drop=attn_drop,
126
+ proj_drop=proj_drop,
127
+ )
128
+ self.norm2 = nn.LayerNorm(dim)
129
+ self.attn2 = CrossAttention(
130
+ dim,
131
+ kv_dim=kv_dim,
132
+ num_heads=num_heads,
133
+ qkv_bias=qkv_bias,
134
+ attn_drop=attn_drop,
135
+ proj_drop=proj_drop,
136
+ )
137
+ self.norm3 = nn.LayerNorm(dim)
138
+ self.ff = FeedForward(dim, dropout=ff_drop)
139
+
140
+ def forward(self, z, x):
141
+ z_norm = self.norm1(z)
142
+ z = z + self.attn1(z_norm, z_norm)
143
+ # TODO: do we need to have the second attention when x is None?
144
+ z_norm = self.norm2(z)
145
+ z = z + self.attn2(z_norm, x if x is not None else z_norm)
146
+ z_norm = self.norm3(z)
147
+ z = z + self.ff(z_norm)
148
+ return z
149
+
150
+
151
+ class SingleStreamTransformer(BaseModule):
152
+ @dataclass
153
+ class Config(BaseModule.Config):
154
+ num_attention_heads: int = 16
155
+ attention_head_dim: int = 88
156
+ in_channels: Optional[int] = None
157
+ out_channels: Optional[int] = None
158
+ num_layers: int = 16
159
+ dropout: float = 0.0
160
+ norm_num_groups: int = 32
161
+ cross_attention_dim: Optional[int] = None
162
+ attention_bias: bool = False
163
+
164
+ cfg: Config
165
+
166
+ def configure(self) -> None:
167
+ self.num_attention_heads = self.cfg.num_attention_heads
168
+ self.attention_head_dim = self.cfg.attention_head_dim
169
+ inner_dim = self.num_attention_heads * self.attention_head_dim
170
+
171
+ # Define input layers
172
+ self.norm = torch.nn.GroupNorm(
173
+ num_groups=self.cfg.norm_num_groups,
174
+ num_channels=self.cfg.in_channels,
175
+ eps=1e-6,
176
+ affine=True,
177
+ )
178
+ self.proj_in = nn.Linear(self.cfg.in_channels, inner_dim)
179
+
180
+ # Define transformers blocks
181
+ self.transformer_blocks = nn.ModuleList(
182
+ [
183
+ BasicBlock(
184
+ inner_dim,
185
+ kv_dim=self.cfg.cross_attention_dim,
186
+ num_heads=self.num_attention_heads,
187
+ qkv_bias=self.cfg.attention_bias,
188
+ proj_drop=self.cfg.dropout,
189
+ ff_drop=self.cfg.dropout,
190
+ )
191
+ for d in range(self.cfg.num_layers)
192
+ ]
193
+ )
194
+
195
+ # 4. Define output layers
196
+ self.proj_out = nn.Linear(inner_dim, self.cfg.in_channels)
197
+
198
+ def forward(self, hidden_states, encoder_hidden_states=None, **kwargs):
199
+ residual = hidden_states
200
+ hidden_states = self.norm(hidden_states)
201
+ hidden_states = hidden_states.permute(0, 2, 1)
202
+ hidden_states = self.proj_in(hidden_states)
203
+ for block in self.transformer_blocks:
204
+ hidden_states = block(hidden_states, encoder_hidden_states)
205
+ hidden_states = self.proj_out(hidden_states).permute(0, 2, 1).contiguous()
206
+ # TODO: do we really need to add the residual?
207
+ hidden_states = hidden_states + residual
208
+ return hidden_states
209
+
210
+
211
+ class FuseBlock(nn.Module):
212
+ """
213
+ Fuse X in to Z with cross attention
214
+ """
215
+
216
+ def __init__(
217
+ self,
218
+ dim_z: int,
219
+ dim_x: int,
220
+ num_heads: int = 16,
221
+ qkv_bias: bool = False,
222
+ attn_drop: float = 0.0,
223
+ proj_drop: float = 0.0,
224
+ ff_drop: float = 0.0,
225
+ norm_x_input: bool = True,
226
+ ):
227
+ super().__init__()
228
+ self.norm_x_input = norm_x_input
229
+ if self.norm_x_input:
230
+ self.norm_x = nn.LayerNorm(dim_x)
231
+ self.attn = CrossAttention(
232
+ dim_z,
233
+ kv_dim=dim_x,
234
+ num_heads=num_heads,
235
+ qkv_bias=qkv_bias,
236
+ attn_drop=attn_drop,
237
+ proj_drop=proj_drop,
238
+ )
239
+ self.norm_z1 = nn.LayerNorm(dim_z)
240
+ self.norm_z2 = nn.LayerNorm(dim_z)
241
+ self.ff = FeedForward(dim_z, dropout=ff_drop)
242
+
243
+ def forward(self, z, x):
244
+ # TODO: do we need to normalize x?
245
+ z = z + self.attn(self.norm_z1(z), self.norm_x(x) if self.norm_x_input else x)
246
+ z = z + self.ff(self.norm_z2(z))
247
+ return z
248
+
249
+
250
+ @torch.no_grad()
251
+ def get_triplane_attention_mask(res):
252
+ N = 3 * res * res
253
+ attn_mask = torch.zeros(3, res, res, 3, res, res)
254
+
255
+ i, j = torch.meshgrid(torch.arange(res), torch.arange(res))
256
+
257
+ attn_mask[0, i, j, 1, i, :] = 1.0
258
+ attn_mask[0, i, j, 2, j, :] = 1.0
259
+ attn_mask[1, i, j, 0, i, :] = 1.0
260
+ attn_mask[1, i, j, 2, :, j] = 1.0
261
+ attn_mask[2, i, j, 0, :, i] = 1.0
262
+ attn_mask[2, i, j, 1, :, j] = 1.0
263
+ attn_mask = attn_mask.bool()
264
+
265
+ attn_bias = torch.empty_like(attn_mask, dtype=torch.float)
266
+ attn_bias.masked_fill_(attn_mask, 0.0)
267
+ attn_bias.masked_fill_(~attn_mask, float("-inf"))
268
+
269
+ return attn_bias.reshape(N, N)
270
+
271
+
272
+ class TriplaneAttention(nn.Module):
273
+ def __init__(
274
+ self,
275
+ dim: int,
276
+ resolution: int,
277
+ num_heads: int = 16,
278
+ qkv_bias: bool = False,
279
+ attn_drop: float = 0.0,
280
+ proj_drop: float = 0.0,
281
+ full_attention: bool = False,
282
+ ):
283
+ super().__init__()
284
+ self.num_heads = num_heads
285
+ head_dim = dim // num_heads
286
+ self.scale = head_dim**-0.5
287
+ self.wq = nn.Linear(dim, dim, bias=qkv_bias)
288
+ self.wk = nn.Linear(dim, dim, bias=qkv_bias)
289
+ self.wv = nn.Linear(dim, dim, bias=qkv_bias)
290
+ self.attn_drop = attn_drop
291
+ self.proj = nn.Linear(dim, dim)
292
+ self.proj_drop = nn.Dropout(proj_drop)
293
+
294
+ self.resolution = resolution
295
+ self.full_attention = full_attention
296
+ self.attn_mask = (
297
+ get_triplane_attention_mask(resolution) if not full_attention else None
298
+ )
299
+
300
+ def forward(self, x):
301
+ B, N, C = x.shape
302
+ # [B, N, C] -> [B, N, H, C/H]
303
+ q = self.wq(x).reshape(B, N, self.num_heads, C // self.num_heads)
304
+ k = self.wk(x).reshape(B, N, self.num_heads, C // self.num_heads)
305
+ v = self.wv(x).reshape(B, N, self.num_heads, C // self.num_heads)
306
+
307
+ # detokenize the planes
308
+ assert N == self.resolution**2 * 3
309
+ attn_bias = (
310
+ self.attn_mask.to(q)
311
+ .unsqueeze(0)
312
+ .unsqueeze(0)
313
+ .expand(B, self.num_heads, -1, -1)
314
+ if not self.full_attention
315
+ else None
316
+ )
317
+
318
+ # full attention
319
+ x = torch.nn.functional.scaled_dot_product_attention(
320
+ q.permute(0, 2, 1, 3),
321
+ k.permute(0, 2, 1, 3),
322
+ v.permute(0, 2, 1, 3),
323
+ attn_mask=attn_bias,
324
+ dropout_p=self.attn_drop,
325
+ scale=self.scale,
326
+ ).permute(0, 2, 1, 3)
327
+
328
+ # [B, N_q, H, C/H] -> [B, N_q, C]
329
+ x = x.reshape(B, N, C)
330
+ x = self.proj(x)
331
+ x = self.proj_drop(x)
332
+ return x
333
+
334
+
335
+ class TwoStreamBlock(nn.Module):
336
+ def __init__(
337
+ self,
338
+ dim_latent: int,
339
+ dim_input: int,
340
+ num_basic_blocks: int = 4,
341
+ num_heads: int = 16,
342
+ qkv_bias: bool = False,
343
+ attn_drop: float = 0.0,
344
+ proj_drop: float = 0.0,
345
+ ff_drop: float = 0.0,
346
+ norm_x_input: bool = True,
347
+ dim_cross: Optional[int] = None,
348
+ ):
349
+ super().__init__()
350
+
351
+ # Define the fuse block that fuse the input into the latent
352
+ self.fuse_block_in = FuseBlock(
353
+ dim_latent,
354
+ dim_input,
355
+ num_heads=num_heads,
356
+ qkv_bias=qkv_bias,
357
+ attn_drop=attn_drop,
358
+ proj_drop=proj_drop,
359
+ ff_drop=ff_drop,
360
+ norm_x_input=norm_x_input,
361
+ )
362
+
363
+ # Define the transformer block that process the latent
364
+ self.transformer_block = nn.ModuleList(
365
+ [
366
+ BasicBlock(
367
+ dim_latent,
368
+ kv_dim=dim_cross,
369
+ num_heads=num_heads,
370
+ qkv_bias=qkv_bias,
371
+ proj_drop=proj_drop,
372
+ ff_drop=ff_drop,
373
+ )
374
+ for _ in range(num_basic_blocks)
375
+ ]
376
+ )
377
+
378
+ # Define the fuse block that fuse the latent into the input
379
+ self.fuse_block_out = FuseBlock(
380
+ dim_input,
381
+ dim_latent,
382
+ num_heads=num_heads,
383
+ qkv_bias=qkv_bias,
384
+ attn_drop=attn_drop,
385
+ proj_drop=proj_drop,
386
+ ff_drop=ff_drop,
387
+ norm_x_input=norm_x_input,
388
+ )
389
+
390
+ def forward(self, latent, input, cross_input):
391
+ latent = self.fuse_block_in(latent, input)
392
+ for block in self.transformer_block:
393
+ latent = block(latent, cross_input)
394
+ input = self.fuse_block_out(input, latent)
395
+ return latent, input
396
+
397
+
398
+ class TwoStreamInterleaveTransformer(BaseModule):
399
+ @dataclass
400
+ class Config(BaseModule.Config):
401
+ num_attention_heads: int = 16
402
+ attention_head_dim: int = 64
403
+ raw_triplane_channels: int = 1024
404
+ triplane_channels: int = 1024
405
+ raw_image_channels: int = 1024
406
+ num_latents: int = 1792
407
+ num_blocks: int = 4
408
+ num_basic_blocks: int = 3
409
+ dropout: float = 0.0
410
+ latent_init_std: float = 0.02
411
+ norm_num_groups: int = 32
412
+ attention_bias: bool = False
413
+ norm_x_input: bool = False
414
+ cross_attention_dim: int = 1024
415
+ mix_latent: bool = True
416
+
417
+ cfg: Config
418
+
419
+ def configure(self) -> None:
420
+ self.mix_latent = self.cfg.mix_latent
421
+
422
+ # Define the dimensions
423
+ self.num_attention_heads = self.cfg.num_attention_heads
424
+ self.attention_head_dim = self.cfg.attention_head_dim
425
+ self.num_latents = self.cfg.num_latents
426
+ self.latent_dim = self.num_attention_heads * self.attention_head_dim
427
+
428
+ # Define input layers
429
+ if self.cfg.norm_num_groups > 0:
430
+ self.norm_triplane = torch.nn.GroupNorm(
431
+ num_groups=self.cfg.norm_num_groups,
432
+ num_channels=self.cfg.raw_triplane_channels,
433
+ eps=1e-6,
434
+ affine=True,
435
+ )
436
+ else:
437
+ self.norm_triplane = nn.LayerNorm(self.cfg.raw_triplane_channels)
438
+ self.proj_triplane = nn.Linear(
439
+ self.cfg.raw_triplane_channels, self.cfg.triplane_channels
440
+ )
441
+ if self.mix_latent:
442
+ self.norm_image = nn.LayerNorm(self.cfg.raw_image_channels)
443
+ self.proj_image = nn.Linear(self.cfg.raw_image_channels, self.latent_dim)
444
+ self.norm_latent = nn.LayerNorm(self.latent_dim)
445
+ self.proj_latent = nn.Linear(self.latent_dim, self.latent_dim)
446
+
447
+ # Define the latents
448
+ self.latent_init = nn.Parameter(
449
+ torch.zeros(1, self.num_latents, self.latent_dim)
450
+ )
451
+ nn.init.normal_(self.latent_init, std=self.cfg.latent_init_std)
452
+
453
+ # Define the transformer blocks
454
+ self.main_blocks = nn.ModuleList(
455
+ [
456
+ TwoStreamBlock(
457
+ self.latent_dim,
458
+ self.cfg.triplane_channels,
459
+ num_basic_blocks=self.cfg.num_basic_blocks,
460
+ num_heads=self.num_attention_heads,
461
+ qkv_bias=self.cfg.attention_bias,
462
+ proj_drop=self.cfg.dropout,
463
+ ff_drop=self.cfg.dropout,
464
+ norm_x_input=self.cfg.norm_x_input,
465
+ dim_cross=self.cfg.cross_attention_dim,
466
+ )
467
+ for _ in range(self.cfg.num_blocks)
468
+ ]
469
+ )
470
+
471
+ # 4. Define output layers
472
+ self.proj_out = nn.Linear(
473
+ self.cfg.triplane_channels, self.cfg.raw_triplane_channels
474
+ )
475
+
476
+ def forward(self, hidden_states, encoder_hidden_states, **kwargs):
477
+ # hidden_states: [B, triplane_dim, N_triplane] is triplane tokens
478
+ # encoder_hidden_states: [B, N_image, image_dim] is the image tokens
479
+ if isinstance(self.norm_triplane, nn.GroupNorm):
480
+ triplane_tokens = self.norm_triplane(hidden_states)
481
+ triplane_tokens = triplane_tokens.permute(
482
+ 0, 2, 1
483
+ ) # [B, N_triplane, triplane_dim]
484
+ elif isinstance(self.norm_triplane, nn.LayerNorm):
485
+ triplane_tokens = self.norm_triplane(hidden_states.permute(0, 2, 1))
486
+ else:
487
+ raise ValueError("Unknown normalization layer")
488
+ triplane_tokens = self.proj_triplane(triplane_tokens)
489
+ if self.mix_latent:
490
+ image_tokens = self.norm_image(
491
+ encoder_hidden_states
492
+ ) # [B, N_image, image_dim]
493
+ image_tokens = self.proj_image(image_tokens)
494
+ init_latents = self.latent_init.expand(
495
+ hidden_states.shape[0], -1, -1
496
+ ) # [B, N_latent_init, latent_dim]
497
+ init_latents = self.norm_latent(init_latents)
498
+ init_latents = self.proj_latent(init_latents)
499
+ if self.mix_latent:
500
+ latent_tokens = torch.cat(
501
+ [image_tokens, init_latents], dim=1
502
+ ) # [B, N_latent, latent_dim]
503
+ else:
504
+ latent_tokens = init_latents
505
+
506
+ # forward the main blocks
507
+ for block in self.main_blocks:
508
+ latent_tokens, triplane_tokens = block(
509
+ latent_tokens, triplane_tokens, encoder_hidden_states
510
+ )
511
+
512
+ # project the triplane tokens back to the original dimension
513
+ triplane_tokens = self.proj_out(triplane_tokens).permute(0, 2, 1).contiguous()
514
+ triplane_tokens = triplane_tokens + hidden_states
515
+ return triplane_tokens
sf3d/models/utils.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import importlib
3
+ import math
4
+ from dataclasses import dataclass
5
+ from typing import Any, List, Optional, Tuple, Union
6
+
7
+ import numpy as np
8
+ import PIL
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from jaxtyping import Bool, Float, Int, Num
13
+ from omegaconf import DictConfig, OmegaConf
14
+ from torch import Tensor
15
+
16
+
17
+ class BaseModule(nn.Module):
18
+ @dataclass
19
+ class Config:
20
+ pass
21
+
22
+ cfg: Config # add this to every subclass of BaseModule to enable static type checking
23
+
24
+ def __init__(
25
+ self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs
26
+ ) -> None:
27
+ super().__init__()
28
+ self.cfg = parse_structured(self.Config, cfg)
29
+ self.configure(*args, **kwargs)
30
+
31
+ def configure(self, *args, **kwargs) -> None:
32
+ raise NotImplementedError
33
+
34
+
35
+ def find_class(cls_string):
36
+ module_string = ".".join(cls_string.split(".")[:-1])
37
+ cls_name = cls_string.split(".")[-1]
38
+ module = importlib.import_module(module_string, package=None)
39
+ cls = getattr(module, cls_name)
40
+ return cls
41
+
42
+
43
+ def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
44
+ # Check if cfg.keys are in fields
45
+ cfg_ = cfg.copy()
46
+ keys = list(cfg_.keys())
47
+
48
+ field_names = {f.name for f in dataclasses.fields(fields)}
49
+ for key in keys:
50
+ # This is helpful when swapping out modules from CLI
51
+ if key not in field_names:
52
+ print(f"Ignoring {key} as it's not supported by {fields}")
53
+ cfg_.pop(key)
54
+ scfg = OmegaConf.merge(OmegaConf.structured(fields), cfg_)
55
+ return scfg
56
+
57
+
58
+ EPS_DTYPE = {
59
+ torch.float16: 1e-4,
60
+ torch.bfloat16: 1e-4,
61
+ torch.float32: 1e-7,
62
+ torch.float64: 1e-8,
63
+ }
64
+
65
+
66
+ def dot(x, y, dim=-1):
67
+ return torch.sum(x * y, dim, keepdim=True)
68
+
69
+
70
+ def reflect(x, n):
71
+ return x - 2 * dot(x, n) * n
72
+
73
+
74
+ def normalize(x, dim=-1, eps=None):
75
+ if eps is None:
76
+ eps = EPS_DTYPE[x.dtype]
77
+ return F.normalize(x, dim=dim, p=2, eps=eps)
78
+
79
+
80
+ def tri_winding(tri: Float[Tensor, "*B 3 2"]) -> Float[Tensor, "*B 3 3"]:
81
+ # One pad for determinant
82
+ tri_sq = F.pad(tri, (0, 1), "constant", 1.0)
83
+ det_tri = torch.det(tri_sq)
84
+ tri_rev = torch.cat(
85
+ (tri_sq[..., 0:1, :], tri_sq[..., 2:3, :], tri_sq[..., 1:2, :]), -2
86
+ )
87
+ tri_sq[det_tri < 0] = tri_rev[det_tri < 0]
88
+ return tri_sq
89
+
90
+
91
+ def triangle_intersection_2d(
92
+ t1: Float[Tensor, "*B 3 2"],
93
+ t2: Float[Tensor, "*B 3 2"],
94
+ eps=1e-12,
95
+ ) -> Float[Tensor, "*B"]: # noqa: F821
96
+ """Returns True if triangles collide, False otherwise"""
97
+
98
+ def chk_edge(x: Float[Tensor, "*B 3 3"]) -> Bool[Tensor, "*B"]: # noqa: F821
99
+ logdetx = torch.logdet(x.double())
100
+ if eps is None:
101
+ return ~torch.isfinite(logdetx)
102
+ return ~(torch.isfinite(logdetx) & (logdetx > math.log(eps)))
103
+
104
+ t1s = tri_winding(t1)
105
+ t2s = tri_winding(t2)
106
+
107
+ # Assume the triangles do not collide in the begging
108
+ ret = torch.zeros(t1.shape[0], dtype=torch.bool, device=t1.device)
109
+ for i in range(3):
110
+ edge = torch.roll(t1s, i, dims=1)[:, :2, :]
111
+ # Check if all points of triangle 2 lay on the external side of edge E.
112
+ # If this is the case the triangle do not collide
113
+ upd = (
114
+ chk_edge(torch.cat((edge, t2s[:, 0:1]), 1))
115
+ & chk_edge(torch.cat((edge, t2s[:, 1:2]), 1))
116
+ & chk_edge(torch.cat((edge, t2s[:, 2:3]), 1))
117
+ )
118
+ # Here no collision is still True due to inversion
119
+ ret = ret | upd
120
+
121
+ for i in range(3):
122
+ edge = torch.roll(t2s, i, dims=1)[:, :2, :]
123
+
124
+ upd = (
125
+ chk_edge(torch.cat((edge, t1s[:, 0:1]), 1))
126
+ & chk_edge(torch.cat((edge, t1s[:, 1:2]), 1))
127
+ & chk_edge(torch.cat((edge, t1s[:, 2:3]), 1))
128
+ )
129
+ # Here no collision is still True due to inversion
130
+ ret = ret | upd
131
+
132
+ return ~ret # Do the inversion
133
+
134
+
135
+ ValidScale = Union[Tuple[float, float], Num[Tensor, "2 D"]]
136
+
137
+
138
+ def scale_tensor(
139
+ dat: Num[Tensor, "... D"], inp_scale: ValidScale, tgt_scale: ValidScale
140
+ ):
141
+ if inp_scale is None:
142
+ inp_scale = (0, 1)
143
+ if tgt_scale is None:
144
+ tgt_scale = (0, 1)
145
+ if isinstance(tgt_scale, Tensor):
146
+ assert dat.shape[-1] == tgt_scale.shape[-1]
147
+ dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0])
148
+ dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0]
149
+ return dat
150
+
151
+
152
+ def dilate_fill(img, mask, iterations=10):
153
+ oldMask = mask.float()
154
+ oldImg = img
155
+
156
+ mask_kernel = torch.ones(
157
+ (1, 1, 3, 3),
158
+ dtype=oldMask.dtype,
159
+ device=oldMask.device,
160
+ )
161
+
162
+ for i in range(iterations):
163
+ newMask = torch.nn.functional.max_pool2d(oldMask, 3, 1, 1)
164
+
165
+ # Fill the extension with mean color of old valid regions
166
+ img_unfold = F.unfold(oldImg, (3, 3)).view(1, 3, 3 * 3, -1)
167
+ mask_unfold = F.unfold(oldMask, (3, 3)).view(1, 1, 3 * 3, -1)
168
+ new_mask_unfold = F.unfold(newMask, (3, 3)).view(1, 1, 3 * 3, -1)
169
+
170
+ # Average color of the valid region
171
+ mean_color = (img_unfold.sum(dim=2) / mask_unfold.sum(dim=2).clip(1)).unsqueeze(
172
+ 2
173
+ )
174
+ # Extend it to the new region
175
+ fill_color = (mean_color * new_mask_unfold).view(1, 3 * 3 * 3, -1)
176
+
177
+ mask_conv = F.conv2d(
178
+ newMask, mask_kernel, padding=1
179
+ ) # Get the sum for each kernel patch
180
+ newImg = F.fold(
181
+ fill_color, (img.shape[-2], img.shape[-1]), (3, 3)
182
+ ) / mask_conv.clamp(1)
183
+
184
+ diffMask = newMask - oldMask
185
+
186
+ oldMask = newMask
187
+ oldImg = torch.lerp(oldImg, newImg, diffMask)
188
+
189
+ return oldImg
190
+
191
+
192
+ def float32_to_uint8_np(
193
+ x: Float[np.ndarray, "*B H W C"],
194
+ dither: bool = True,
195
+ dither_mask: Optional[Float[np.ndarray, "*B H W C"]] = None,
196
+ dither_strength: float = 1.0,
197
+ ) -> Int[np.ndarray, "*B H W C"]:
198
+ if dither:
199
+ dither = (
200
+ dither_strength * np.random.rand(*x[..., :1].shape).astype(np.float32) - 0.5
201
+ )
202
+ if dither_mask is not None:
203
+ dither = dither * dither_mask
204
+ return np.clip(np.floor((256.0 * x + dither)), 0, 255).astype(np.uint8)
205
+ return np.clip(np.floor((256.0 * x)), 0, 255).astype(torch.uint8)
206
+
207
+
208
+ def convert_data(data):
209
+ if data is None:
210
+ return None
211
+ elif isinstance(data, np.ndarray):
212
+ return data
213
+ elif isinstance(data, torch.Tensor):
214
+ if data.dtype in [torch.float16, torch.bfloat16]:
215
+ data = data.float()
216
+ return data.detach().cpu().numpy()
217
+ elif isinstance(data, list):
218
+ return [convert_data(d) for d in data]
219
+ elif isinstance(data, dict):
220
+ return {k: convert_data(v) for k, v in data.items()}
221
+ else:
222
+ raise TypeError(
223
+ "Data must be in type numpy.ndarray, torch.Tensor, list or dict, getting",
224
+ type(data),
225
+ )
226
+
227
+
228
+ class ImageProcessor:
229
+ def convert_and_resize(
230
+ self,
231
+ image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
232
+ size: int,
233
+ ):
234
+ if isinstance(image, PIL.Image.Image):
235
+ image = torch.from_numpy(np.array(image).astype(np.float32) / 255.0)
236
+ elif isinstance(image, np.ndarray):
237
+ if image.dtype == np.uint8:
238
+ image = torch.from_numpy(image.astype(np.float32) / 255.0)
239
+ else:
240
+ image = torch.from_numpy(image)
241
+ elif isinstance(image, torch.Tensor):
242
+ pass
243
+
244
+ batched = image.ndim == 4
245
+
246
+ if not batched:
247
+ image = image[None, ...]
248
+ image = F.interpolate(
249
+ image.permute(0, 3, 1, 2),
250
+ (size, size),
251
+ mode="bilinear",
252
+ align_corners=False,
253
+ antialias=True,
254
+ ).permute(0, 2, 3, 1)
255
+ if not batched:
256
+ image = image[0]
257
+ return image
258
+
259
+ def __call__(
260
+ self,
261
+ image: Union[
262
+ PIL.Image.Image,
263
+ np.ndarray,
264
+ torch.FloatTensor,
265
+ List[PIL.Image.Image],
266
+ List[np.ndarray],
267
+ List[torch.FloatTensor],
268
+ ],
269
+ size: int,
270
+ ) -> Any:
271
+ if isinstance(image, (np.ndarray, torch.FloatTensor)) and image.ndim == 4:
272
+ image = self.convert_and_resize(image, size)
273
+ else:
274
+ if not isinstance(image, list):
275
+ image = [image]
276
+ image = [self.convert_and_resize(im, size) for im in image]
277
+ image = torch.stack(image, dim=0)
278
+ return image
279
+
280
+
281
+ def get_intrinsic_from_fov(fov, H, W, bs=-1):
282
+ focal_length = 0.5 * H / np.tan(0.5 * fov)
283
+ intrinsic = np.identity(3, dtype=np.float32)
284
+ intrinsic[0, 0] = focal_length
285
+ intrinsic[1, 1] = focal_length
286
+ intrinsic[0, 2] = W / 2.0
287
+ intrinsic[1, 2] = H / 2.0
288
+
289
+ if bs > 0:
290
+ intrinsic = intrinsic[None].repeat(bs, axis=0)
291
+
292
+ return torch.from_numpy(intrinsic)
sf3d/system.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass, field
3
+ from typing import Any, List, Optional, Tuple
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import trimesh
9
+ from einops import rearrange
10
+ from huggingface_hub import hf_hub_download
11
+ from jaxtyping import Float
12
+ from omegaconf import OmegaConf
13
+ from PIL import Image
14
+ from safetensors.torch import load_model
15
+ from torch import Tensor
16
+
17
+ from sf3d.models.isosurface import MarchingTetrahedraHelper
18
+ from sf3d.models.mesh import Mesh
19
+ from sf3d.models.utils import (
20
+ BaseModule,
21
+ ImageProcessor,
22
+ convert_data,
23
+ dilate_fill,
24
+ dot,
25
+ find_class,
26
+ float32_to_uint8_np,
27
+ normalize,
28
+ scale_tensor,
29
+ )
30
+ from sf3d.utils import create_intrinsic_from_fov_deg, default_cond_c2w
31
+
32
+ from .texture_baker import TextureBaker
33
+
34
+
35
+ class SF3D(BaseModule):
36
+ @dataclass
37
+ class Config(BaseModule.Config):
38
+ cond_image_size: int
39
+ isosurface_resolution: int
40
+ isosurface_threshold: float = 10.0
41
+ radius: float = 1.0
42
+ background_color: list[float] = field(default_factory=lambda: [0.5, 0.5, 0.5])
43
+ default_fovy_deg: float = 40.0
44
+ default_distance: float = 1.6
45
+
46
+ camera_embedder_cls: str = ""
47
+ camera_embedder: dict = field(default_factory=dict)
48
+
49
+ image_tokenizer_cls: str = ""
50
+ image_tokenizer: dict = field(default_factory=dict)
51
+
52
+ tokenizer_cls: str = ""
53
+ tokenizer: dict = field(default_factory=dict)
54
+
55
+ backbone_cls: str = ""
56
+ backbone: dict = field(default_factory=dict)
57
+
58
+ post_processor_cls: str = ""
59
+ post_processor: dict = field(default_factory=dict)
60
+
61
+ decoder_cls: str = ""
62
+ decoder: dict = field(default_factory=dict)
63
+
64
+ image_estimator_cls: str = ""
65
+ image_estimator: dict = field(default_factory=dict)
66
+
67
+ global_estimator_cls: str = ""
68
+ global_estimator: dict = field(default_factory=dict)
69
+
70
+ cfg: Config
71
+
72
+ @classmethod
73
+ def from_pretrained(
74
+ cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str
75
+ ):
76
+ if os.path.isdir(pretrained_model_name_or_path):
77
+ config_path = os.path.join(pretrained_model_name_or_path, config_name)
78
+ weight_path = os.path.join(pretrained_model_name_or_path, weight_name)
79
+ else:
80
+ config_path = hf_hub_download(
81
+ repo_id=pretrained_model_name_or_path, filename=config_name
82
+ )
83
+ weight_path = hf_hub_download(
84
+ repo_id=pretrained_model_name_or_path, filename=weight_name
85
+ )
86
+
87
+ cfg = OmegaConf.load(config_path)
88
+ OmegaConf.resolve(cfg)
89
+ model = cls(cfg)
90
+ load_model(model, weight_path)
91
+ return model
92
+
93
+ @property
94
+ def device(self):
95
+ return next(self.parameters()).device
96
+
97
+ def configure(self):
98
+ self.image_tokenizer = find_class(self.cfg.image_tokenizer_cls)(
99
+ self.cfg.image_tokenizer
100
+ )
101
+ self.tokenizer = find_class(self.cfg.tokenizer_cls)(self.cfg.tokenizer)
102
+ self.camera_embedder = find_class(self.cfg.camera_embedder_cls)(
103
+ self.cfg.camera_embedder
104
+ )
105
+ self.backbone = find_class(self.cfg.backbone_cls)(self.cfg.backbone)
106
+ self.post_processor = find_class(self.cfg.post_processor_cls)(
107
+ self.cfg.post_processor
108
+ )
109
+ self.decoder = find_class(self.cfg.decoder_cls)(self.cfg.decoder)
110
+ self.image_estimator = find_class(self.cfg.image_estimator_cls)(
111
+ self.cfg.image_estimator
112
+ )
113
+ self.global_estimator = find_class(self.cfg.global_estimator_cls)(
114
+ self.cfg.global_estimator
115
+ )
116
+
117
+ self.bbox: Float[Tensor, "2 3"]
118
+ self.register_buffer(
119
+ "bbox",
120
+ torch.as_tensor(
121
+ [
122
+ [-self.cfg.radius, -self.cfg.radius, -self.cfg.radius],
123
+ [self.cfg.radius, self.cfg.radius, self.cfg.radius],
124
+ ],
125
+ dtype=torch.float32,
126
+ ),
127
+ )
128
+ self.isosurface_helper = MarchingTetrahedraHelper(
129
+ self.cfg.isosurface_resolution,
130
+ os.path.join(
131
+ os.path.dirname(__file__),
132
+ "..",
133
+ "load",
134
+ "tets",
135
+ f"{self.cfg.isosurface_resolution}_tets.npz",
136
+ ),
137
+ )
138
+
139
+ self.baker = TextureBaker()
140
+ self.image_processor = ImageProcessor()
141
+
142
+ def triplane_to_meshes(
143
+ self, triplanes: Float[Tensor, "B 3 Cp Hp Wp"]
144
+ ) -> list[Mesh]:
145
+ meshes = []
146
+ for i in range(triplanes.shape[0]):
147
+ triplane = triplanes[i]
148
+ grid_vertices = scale_tensor(
149
+ self.isosurface_helper.grid_vertices.to(triplanes.device),
150
+ self.isosurface_helper.points_range,
151
+ self.bbox,
152
+ )
153
+
154
+ values = self.query_triplane(grid_vertices, triplane)
155
+ decoded = self.decoder(values, include=["vertex_offset", "density"])
156
+ sdf = decoded["density"] - self.cfg.isosurface_threshold
157
+
158
+ deform = decoded["vertex_offset"].squeeze(0)
159
+
160
+ mesh: Mesh = self.isosurface_helper(
161
+ sdf.view(-1, 1), deform.view(-1, 3) if deform is not None else None
162
+ )
163
+ mesh.v_pos = scale_tensor(
164
+ mesh.v_pos, self.isosurface_helper.points_range, self.bbox
165
+ )
166
+
167
+ meshes.append(mesh)
168
+
169
+ return meshes
170
+
171
+ def query_triplane(
172
+ self,
173
+ positions: Float[Tensor, "*B N 3"],
174
+ triplanes: Float[Tensor, "*B 3 Cp Hp Wp"],
175
+ ) -> Float[Tensor, "*B N F"]:
176
+ batched = positions.ndim == 3
177
+ if not batched:
178
+ # no batch dimension
179
+ triplanes = triplanes[None, ...]
180
+ positions = positions[None, ...]
181
+ assert triplanes.ndim == 5 and positions.ndim == 3
182
+
183
+ positions = scale_tensor(
184
+ positions, (-self.cfg.radius, self.cfg.radius), (-1, 1)
185
+ )
186
+
187
+ indices2D: Float[Tensor, "B 3 N 2"] = torch.stack(
188
+ (positions[..., [0, 1]], positions[..., [0, 2]], positions[..., [1, 2]]),
189
+ dim=-3,
190
+ ).to(triplanes.dtype)
191
+ out: Float[Tensor, "B3 Cp 1 N"] = F.grid_sample(
192
+ rearrange(triplanes, "B Np Cp Hp Wp -> (B Np) Cp Hp Wp", Np=3).float(),
193
+ rearrange(indices2D, "B Np N Nd -> (B Np) () N Nd", Np=3).float(),
194
+ align_corners=True,
195
+ mode="bilinear",
196
+ )
197
+ out = rearrange(out, "(B Np) Cp () N -> B N (Np Cp)", Np=3)
198
+
199
+ return out
200
+
201
+ def get_scene_codes(self, batch) -> Float[Tensor, "B 3 C H W"]:
202
+ # if batch[rgb_cond] is only one view, add a view dimension
203
+ if len(batch["rgb_cond"].shape) == 4:
204
+ batch["rgb_cond"] = batch["rgb_cond"].unsqueeze(1)
205
+ batch["mask_cond"] = batch["mask_cond"].unsqueeze(1)
206
+ batch["c2w_cond"] = batch["c2w_cond"].unsqueeze(1)
207
+ batch["intrinsic_cond"] = batch["intrinsic_cond"].unsqueeze(1)
208
+ batch["intrinsic_normed_cond"] = batch["intrinsic_normed_cond"].unsqueeze(1)
209
+ batch_size, n_input_views = batch["rgb_cond"].shape[:2]
210
+
211
+ camera_embeds: Optional[Float[Tensor, "B Nv Cc"]]
212
+ camera_embeds = self.camera_embedder(**batch)
213
+
214
+ input_image_tokens: Float[Tensor, "B Nv Cit Nit"] = self.image_tokenizer(
215
+ rearrange(batch["rgb_cond"], "B Nv H W C -> B Nv C H W"),
216
+ modulation_cond=camera_embeds,
217
+ )
218
+
219
+ input_image_tokens = rearrange(
220
+ input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=n_input_views
221
+ )
222
+
223
+ tokens: Float[Tensor, "B Ct Nt"] = self.tokenizer(batch_size)
224
+
225
+ tokens = self.backbone(
226
+ tokens,
227
+ encoder_hidden_states=input_image_tokens,
228
+ modulation_cond=None,
229
+ )
230
+
231
+ direct_codes = self.tokenizer.detokenize(tokens)
232
+ scene_codes = self.post_processor(direct_codes)
233
+ return scene_codes, direct_codes
234
+
235
+ def run_image(
236
+ self,
237
+ image: Image,
238
+ bake_resolution: int,
239
+ estimate_illumination: bool = False,
240
+ ) -> Tuple[trimesh.Trimesh, dict[str, Any]]:
241
+ if image.mode != "RGBA":
242
+ raise ValueError("Image must be in RGBA mode")
243
+ img_cond = (
244
+ torch.from_numpy(
245
+ np.asarray(
246
+ image.resize((self.cfg.cond_image_size, self.cfg.cond_image_size))
247
+ ).astype(np.float32)
248
+ / 255.0
249
+ )
250
+ .float()
251
+ .clip(0, 1)
252
+ .to(self.device)
253
+ )
254
+ mask_cond = img_cond[:, :, -1:]
255
+ rgb_cond = torch.lerp(
256
+ torch.tensor(self.cfg.background_color, device=self.device)[None, None, :],
257
+ img_cond[:, :, :3],
258
+ mask_cond,
259
+ )
260
+
261
+ c2w_cond = default_cond_c2w(self.cfg.default_distance).to(self.device)
262
+ intrinsic, intrinsic_normed_cond = create_intrinsic_from_fov_deg(
263
+ self.cfg.default_fovy_deg,
264
+ self.cfg.cond_image_size,
265
+ self.cfg.cond_image_size,
266
+ )
267
+
268
+ batch = {
269
+ "rgb_cond": rgb_cond,
270
+ "mask_cond": mask_cond,
271
+ "c2w_cond": c2w_cond.unsqueeze(0),
272
+ "intrinsic_cond": intrinsic.to(self.device).unsqueeze(0),
273
+ "intrinsic_normed_cond": intrinsic_normed_cond.to(self.device).unsqueeze(0),
274
+ }
275
+
276
+ meshes, global_dict = self.generate_mesh(
277
+ batch, bake_resolution, estimate_illumination
278
+ )
279
+ return meshes[0], global_dict
280
+
281
+ def generate_mesh(
282
+ self,
283
+ batch,
284
+ bake_resolution: int,
285
+ estimate_illumination: bool = False,
286
+ ) -> Tuple[List[trimesh.Trimesh], dict[str, Any]]:
287
+ batch["rgb_cond"] = self.image_processor(
288
+ batch["rgb_cond"], self.cfg.cond_image_size
289
+ )
290
+ batch["mask_cond"] = self.image_processor(
291
+ batch["mask_cond"], self.cfg.cond_image_size
292
+ )
293
+ scene_codes, non_postprocessed_codes = self.get_scene_codes(batch)
294
+
295
+ global_dict = {}
296
+ if self.image_estimator is not None:
297
+ global_dict.update(
298
+ self.image_estimator(batch["rgb_cond"] * batch["mask_cond"])
299
+ )
300
+ if self.global_estimator is not None and estimate_illumination:
301
+ global_dict.update(self.global_estimator(non_postprocessed_codes))
302
+
303
+ with torch.no_grad():
304
+ with torch.autocast(device_type="cuda", enabled=False):
305
+ meshes = self.triplane_to_meshes(scene_codes)
306
+
307
+ rets = []
308
+ for i, mesh in enumerate(meshes):
309
+ # Check for empty mesh
310
+ if mesh.v_pos.shape[0] == 0:
311
+ rets.append(trimesh.Trimesh())
312
+ continue
313
+
314
+ mesh.unwrap_uv()
315
+
316
+ # Build textures
317
+ rast = self.baker.rasterize(
318
+ mesh.v_tex, mesh.t_pos_idx, bake_resolution
319
+ )
320
+ bake_mask = self.baker.get_mask(rast)
321
+
322
+ pos_bake = self.baker.interpolate(
323
+ mesh.v_pos,
324
+ rast,
325
+ mesh.t_pos_idx,
326
+ mesh.v_tex,
327
+ )
328
+ gb_pos = pos_bake[bake_mask]
329
+
330
+ tri_query = self.query_triplane(gb_pos, scene_codes[i])[0]
331
+ decoded = self.decoder(
332
+ tri_query, exclude=["density", "vertex_offset"]
333
+ )
334
+
335
+ nrm = self.baker.interpolate(
336
+ mesh.v_nrm,
337
+ rast,
338
+ mesh.t_pos_idx,
339
+ mesh.v_tex,
340
+ )
341
+ gb_nrm = F.normalize(nrm[bake_mask], dim=-1)
342
+ decoded["normal"] = gb_nrm
343
+
344
+ # Check if any keys in global_dict start with decoded_
345
+ for k, v in global_dict.items():
346
+ if k.startswith("decoder_"):
347
+ decoded[k.replace("decoder_", "")] = v[i]
348
+
349
+ mat_out = {
350
+ "albedo": decoded["features"],
351
+ "roughness": decoded["roughness"],
352
+ "metallic": decoded["metallic"],
353
+ "normal": normalize(decoded["perturb_normal"]),
354
+ "bump": None,
355
+ }
356
+
357
+ for k, v in mat_out.items():
358
+ if v is None:
359
+ continue
360
+ if v.shape[0] == 1:
361
+ # Skip and directly add a single value
362
+ mat_out[k] = v[0]
363
+ else:
364
+ f = torch.zeros(
365
+ bake_resolution,
366
+ bake_resolution,
367
+ v.shape[-1],
368
+ dtype=v.dtype,
369
+ device=v.device,
370
+ )
371
+ if v.shape == f.shape:
372
+ continue
373
+ if k == "normal":
374
+ # Use un-normalized tangents here so that larger smaller tris
375
+ # Don't effect the tangents that much
376
+ tng = self.baker.interpolate(
377
+ mesh.v_tng,
378
+ rast,
379
+ mesh.t_pos_idx,
380
+ mesh.v_tex,
381
+ )
382
+ gb_tng = tng[bake_mask]
383
+ gb_tng = F.normalize(gb_tng, dim=-1)
384
+ gb_btng = F.normalize(
385
+ torch.cross(gb_tng, gb_nrm, dim=-1), dim=-1
386
+ )
387
+ normal = F.normalize(mat_out["normal"], dim=-1)
388
+
389
+ bump = torch.cat(
390
+ # Check if we have to flip some things
391
+ (
392
+ dot(normal, gb_tng),
393
+ dot(normal, gb_btng),
394
+ dot(normal, gb_nrm).clip(
395
+ 0.3, 1
396
+ ), # Never go below 0.3. This would indicate a flipped (or close to one) normal
397
+ ),
398
+ -1,
399
+ )
400
+ bump = (bump * 0.5 + 0.5).clamp(0, 1)
401
+
402
+ f[bake_mask] = bump.view(-1, 3)
403
+ mat_out["bump"] = f
404
+ else:
405
+ f[bake_mask] = v.view(-1, v.shape[-1])
406
+ mat_out[k] = f
407
+
408
+ def uv_padding(arr):
409
+ if arr.ndim == 1:
410
+ return arr
411
+ return (
412
+ dilate_fill(
413
+ arr.permute(2, 0, 1)[None, ...],
414
+ bake_mask.unsqueeze(0).unsqueeze(0),
415
+ iterations=bake_resolution // 150,
416
+ )
417
+ .squeeze(0)
418
+ .permute(1, 2, 0)
419
+ )
420
+
421
+ verts_np = convert_data(mesh.v_pos)
422
+ faces = convert_data(mesh.t_pos_idx)
423
+ uvs = convert_data(mesh.v_tex)
424
+
425
+ basecolor_tex = Image.fromarray(
426
+ float32_to_uint8_np(convert_data(uv_padding(mat_out["albedo"])))
427
+ ).convert("RGB")
428
+ basecolor_tex.format = "JPEG"
429
+
430
+ metallic = mat_out["metallic"].squeeze().cpu().item()
431
+ roughness = mat_out["roughness"].squeeze().cpu().item()
432
+
433
+ if "bump" in mat_out and mat_out["bump"] is not None:
434
+ bump_np = convert_data(uv_padding(mat_out["bump"]))
435
+ bump_up = np.ones_like(bump_np)
436
+ bump_up[..., :2] = 0.5
437
+ bump_up[..., 2:] = 1
438
+ bump_tex = Image.fromarray(
439
+ float32_to_uint8_np(
440
+ bump_np,
441
+ dither=True,
442
+ # Do not dither if something is perfectly flat
443
+ dither_mask=np.all(
444
+ bump_np == bump_up, axis=-1, keepdims=True
445
+ ).astype(np.float32),
446
+ )
447
+ ).convert("RGB")
448
+ bump_tex.format = (
449
+ "JPEG" # PNG would be better but the assets are larger
450
+ )
451
+ else:
452
+ bump_tex = None
453
+
454
+ material = trimesh.visual.material.PBRMaterial(
455
+ baseColorTexture=basecolor_tex,
456
+ roughnessFactor=roughness,
457
+ metallicFactor=metallic,
458
+ normalTexture=bump_tex,
459
+ )
460
+
461
+ tmesh = trimesh.Trimesh(
462
+ vertices=verts_np,
463
+ faces=faces,
464
+ visual=trimesh.visual.texture.TextureVisuals(
465
+ uv=uvs, material=material
466
+ ),
467
+ )
468
+ rot = trimesh.transformations.rotation_matrix(
469
+ np.radians(-90), [1, 0, 0]
470
+ )
471
+ tmesh.apply_transform(rot)
472
+ tmesh.apply_transform(
473
+ trimesh.transformations.rotation_matrix(
474
+ np.radians(90), [0, 1, 0]
475
+ )
476
+ )
477
+
478
+ tmesh.invert()
479
+
480
+ rets.append(tmesh)
481
+
482
+ return rets, global_dict
sf3d/texture_baker.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import slangtorch
4
+ import torch
5
+ import torch.nn as nn
6
+ from jaxtyping import Bool, Float
7
+ from torch import Tensor
8
+
9
+
10
+ class TextureBaker(nn.Module):
11
+ def __init__(self):
12
+ super().__init__()
13
+ self.baker = slangtorch.loadModule(
14
+ os.path.join(os.path.dirname(__file__), "texture_baker.slang")
15
+ )
16
+
17
+ def rasterize(
18
+ self,
19
+ uv: Float[Tensor, "Nv 2"],
20
+ face_indices: Float[Tensor, "Nf 3"],
21
+ bake_resolution: int,
22
+ ) -> Float[Tensor, "bake_resolution bake_resolution 4"]:
23
+ if not face_indices.is_cuda or not uv.is_cuda:
24
+ raise ValueError("All input tensors must be on cuda")
25
+
26
+ face_indices = face_indices.to(torch.int32)
27
+ uv = uv.to(torch.float32)
28
+
29
+ rast_result = torch.empty(
30
+ bake_resolution, bake_resolution, 4, device=uv.device, dtype=torch.float32
31
+ )
32
+
33
+ block_size = 16
34
+ grid_size = bake_resolution // block_size
35
+ self.baker.bake_uv(uv=uv, indices=face_indices, output=rast_result).launchRaw(
36
+ blockSize=(block_size, block_size, 1), gridSize=(grid_size, grid_size, 1)
37
+ )
38
+
39
+ return rast_result
40
+
41
+ def get_mask(
42
+ self, rast: Float[Tensor, "bake_resolution bake_resolution 4"]
43
+ ) -> Bool[Tensor, "bake_resolution bake_resolution"]:
44
+ return rast[..., -1] >= 0
45
+
46
+ def interpolate(
47
+ self,
48
+ attr: Float[Tensor, "Nv 3"],
49
+ rast: Float[Tensor, "bake_resolution bake_resolution 4"],
50
+ face_indices: Float[Tensor, "Nf 3"],
51
+ uv: Float[Tensor, "Nv 2"],
52
+ ) -> Float[Tensor, "bake_resolution bake_resolution 3"]:
53
+ # Make sure all input tensors are on torch
54
+ if not attr.is_cuda or not face_indices.is_cuda or not rast.is_cuda:
55
+ raise ValueError("All input tensors must be on cuda")
56
+
57
+ attr = attr.to(torch.float32)
58
+ face_indices = face_indices.to(torch.int32)
59
+ uv = uv.to(torch.float32)
60
+
61
+ pos_bake = torch.zeros(
62
+ rast.shape[0],
63
+ rast.shape[1],
64
+ 3,
65
+ device=attr.device,
66
+ dtype=attr.dtype,
67
+ )
68
+
69
+ block_size = 16
70
+ grid_size = rast.shape[0] // block_size
71
+ self.baker.interpolate(
72
+ attr=attr, indices=face_indices, rast=rast, output=pos_bake
73
+ ).launchRaw(
74
+ blockSize=(block_size, block_size, 1), gridSize=(grid_size, grid_size, 1)
75
+ )
76
+
77
+ return pos_bake
78
+
79
+ def forward(
80
+ self,
81
+ attr: Float[Tensor, "Nv 3"],
82
+ uv: Float[Tensor, "Nv 2"],
83
+ face_indices: Float[Tensor, "Nf 3"],
84
+ bake_resolution: int,
85
+ ) -> Float[Tensor, "bake_resolution bake_resolution 3"]:
86
+ rast = self.rasterize(uv, face_indices, bake_resolution)
87
+ return self.interpolate(attr, rast, face_indices, uv)
sf3d/texture_baker.slang ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // xy: 2D test position
2
+ // v1: vertex position 1
3
+ // v2: vertex position 2
4
+ // v3: vertex position 3
5
+ //
6
+ bool barycentric_coordinates(float2 xy, float2 v1, float2 v2, float2 v3, out float u, out float v, out float w)
7
+ {
8
+ // Return true if the point (x,y) is inside the triangle defined by the vertices v1, v2, v3.
9
+ // If the point is inside the triangle, the barycentric coordinates are stored in u, v, and w.
10
+ float2 v1v2 = v2 - v1;
11
+ float2 v1v3 = v3 - v1;
12
+ float2 xyv1 = xy - v1;
13
+
14
+ float d00 = dot(v1v2, v1v2);
15
+ float d01 = dot(v1v2, v1v3);
16
+ float d11 = dot(v1v3, v1v3);
17
+ float d20 = dot(xyv1, v1v2);
18
+ float d21 = dot(xyv1, v1v3);
19
+
20
+ float denom = d00 * d11 - d01 * d01;
21
+ v = (d11 * d20 - d01 * d21) / denom;
22
+ w = (d00 * d21 - d01 * d20) / denom;
23
+ u = 1.0 - v - w;
24
+
25
+ return (v >= 0.0) && (w >= 0.0) && (v + w <= 1.0);
26
+ }
27
+
28
+ [AutoPyBindCUDA]
29
+ [CUDAKernel]
30
+ void interpolate(
31
+ TensorView<float3> attr,
32
+ TensorView<int3> indices,
33
+ TensorView<float4> rast,
34
+ TensorView<float3> output)
35
+ {
36
+ // Interpolate the attr into output based on the rast result (barycentric coordinates, + triangle idx)
37
+
38
+ uint3 dispatch_id = cudaBlockIdx() * cudaBlockDim() + cudaThreadIdx();
39
+
40
+ if (dispatch_id.x > output.size(0) || dispatch_id.y > output.size(1))
41
+ return;
42
+
43
+ float4 barycentric = rast[dispatch_id.x, dispatch_id.y];
44
+ int triangle_idx = int(barycentric.w);
45
+
46
+ if (triangle_idx < 0) {
47
+ output[dispatch_id.x, dispatch_id.y] = float3(0.0, 0.0, 0.0);
48
+ return;
49
+ }
50
+
51
+ float3 v1 = attr[indices[triangle_idx].x];
52
+ float3 v2 = attr[indices[triangle_idx].y];
53
+ float3 v3 = attr[indices[triangle_idx].z];
54
+
55
+ output[dispatch_id.x, dispatch_id.y] = v1 * barycentric.x + v2 * barycentric.y + v3 * barycentric.z;
56
+ }
57
+
58
+ [AutoPyBindCUDA]
59
+ [CUDAKernel]
60
+ void bake_uv(
61
+ TensorView<float2> uv,
62
+ TensorView<int3> indices,
63
+ TensorView<float4> output)
64
+ {
65
+ uint3 dispatch_id = cudaBlockIdx() * cudaBlockDim() + cudaThreadIdx();
66
+
67
+ if (dispatch_id.y > output.size(0) || dispatch_id.x > output.size(1))
68
+ return;
69
+
70
+ // We index x,y but the orginal coords are HW. So swap them
71
+ float2 pixel_coord = float2(dispatch_id.y, dispatch_id.x);
72
+ // Normalize to [0, 1]
73
+ pixel_coord /= float2(output.size(1), output.size(0));
74
+ pixel_coord = clamp(pixel_coord, 0.0, 1.0);
75
+ // Flip x-axis
76
+ pixel_coord.y = 1 - pixel_coord.y;
77
+
78
+ for (int i = 0; i < indices.size(0); i++) {
79
+ float2 v1 = float2(uv[indices[i].x].x, uv[indices[i].x].y);
80
+ float2 v2 = float2(uv[indices[i].y].x, uv[indices[i].y].y);
81
+ float2 v3 = float2(uv[indices[i].z].x, uv[indices[i].z].y);
82
+
83
+ float u, v, w;
84
+ bool hit = barycentric_coordinates(pixel_coord, v1, v2, v3, u, v, w);
85
+
86
+ if (hit){
87
+ output[dispatch_id.x, dispatch_id.y] = float4(u, v, w, i);
88
+ return;
89
+ }
90
+ }
91
+
92
+ output[dispatch_id.x, dispatch_id.y] = float4(0.0, 0.0, 0.0, -1);
93
+ }
sf3d/utils.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ import numpy as np
4
+ import rembg
5
+ import torch
6
+ from PIL import Image
7
+
8
+ import sf3d.models.utils as sf3d_utils
9
+
10
+
11
+ def create_intrinsic_from_fov_deg(fov_deg: float, cond_height: int, cond_width: int):
12
+ intrinsic = sf3d_utils.get_intrinsic_from_fov(
13
+ np.deg2rad(fov_deg),
14
+ H=cond_height,
15
+ W=cond_width,
16
+ )
17
+ intrinsic_normed_cond = intrinsic.clone()
18
+ intrinsic_normed_cond[..., 0, 2] /= cond_width
19
+ intrinsic_normed_cond[..., 1, 2] /= cond_height
20
+ intrinsic_normed_cond[..., 0, 0] /= cond_width
21
+ intrinsic_normed_cond[..., 1, 1] /= cond_height
22
+
23
+ return intrinsic, intrinsic_normed_cond
24
+
25
+
26
+ def default_cond_c2w(distance: float):
27
+ c2w_cond = torch.as_tensor(
28
+ [
29
+ [0, 0, 1, distance],
30
+ [1, 0, 0, 0],
31
+ [0, 1, 0, 0],
32
+ [0, 0, 0, 1],
33
+ ]
34
+ ).float()
35
+ return c2w_cond
36
+
37
+
38
+ def remove_background(
39
+ image: Image,
40
+ rembg_session: Any = None,
41
+ force: bool = False,
42
+ **rembg_kwargs,
43
+ ) -> Image:
44
+ do_remove = True
45
+ if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
46
+ do_remove = False
47
+ do_remove = do_remove or force
48
+ if do_remove:
49
+ image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
50
+ return image
51
+
52
+
53
+ def resize_foreground(
54
+ image: Image,
55
+ ratio: float,
56
+ ) -> Image:
57
+ image = np.array(image)
58
+ assert image.shape[-1] == 4
59
+ alpha = np.where(image[..., 3] > 0)
60
+ y1, y2, x1, x2 = (
61
+ alpha[0].min(),
62
+ alpha[0].max(),
63
+ alpha[1].min(),
64
+ alpha[1].max(),
65
+ )
66
+ # crop the foreground
67
+ fg = image[y1:y2, x1:x2]
68
+ # pad to square
69
+ size = max(fg.shape[0], fg.shape[1])
70
+ ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
71
+ ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
72
+ new_image = np.pad(
73
+ fg,
74
+ ((ph0, ph1), (pw0, pw1), (0, 0)),
75
+ mode="constant",
76
+ constant_values=((0, 0), (0, 0), (0, 0)),
77
+ )
78
+
79
+ # compute padding according to the ratio
80
+ new_size = int(new_image.shape[0] / ratio)
81
+ # pad to size, double side
82
+ ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
83
+ ph1, pw1 = new_size - size - ph0, new_size - size - pw0
84
+ new_image = np.pad(
85
+ new_image,
86
+ ((ph0, ph1), (pw0, pw1), (0, 0)),
87
+ mode="constant",
88
+ constant_values=((0, 0), (0, 0), (0, 0)),
89
+ )
90
+ new_image = Image.fromarray(new_image, mode="RGBA")
91
+ return new_image
stable_fast.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import time
4
+ from functools import lru_cache
5
+ from typing import Any
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ import rembg
10
+ import torch
11
+ from gradio_litmodel3d import LitModel3D
12
+ from PIL import Image
13
+
14
+ import sf3d.utils as sf3d_utils
15
+ from sf3d.system import SF3D
16
+
17
+ rembg_session = rembg.new_session()
18
+
19
+ COND_WIDTH = 512
20
+ COND_HEIGHT = 512
21
+ COND_DISTANCE = 1.6
22
+ COND_FOVY_DEG = 40
23
+ BACKGROUND_COLOR = [0.5, 0.5, 0.5]
24
+
25
+ # Cached. Doesn't change
26
+ c2w_cond = sf3d_utils.default_cond_c2w(COND_DISTANCE)
27
+ intrinsic, intrinsic_normed_cond = sf3d_utils.create_intrinsic_from_fov_deg(
28
+ COND_FOVY_DEG, COND_HEIGHT, COND_WIDTH
29
+ )
30
+
31
+
32
+ model = SF3D.from_pretrained(
33
+ "stabilityai/stable-fast-3d",
34
+ config_name="config.yaml",
35
+ weight_name="model.safetensors",
36
+ )
37
+ model.eval().cuda()
38
+
39
+ example_files = [
40
+ os.path.join("demo_files/examples", f) for f in os.listdir("demo_files/examples")
41
+ ]
42
+
43
+
44
+ def run_model(input_image):
45
+ start = time.time()
46
+ with torch.no_grad():
47
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
48
+ model_batch = create_batch(input_image)
49
+ model_batch = {k: v.cuda() for k, v in model_batch.items()}
50
+ trimesh_mesh, _glob_dict = model.generate_mesh(model_batch, 1024)
51
+ trimesh_mesh = trimesh_mesh[0]
52
+
53
+ # Create new tmp file
54
+ tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".glb")
55
+
56
+ trimesh_mesh.export(tmp_file.name, file_type="glb", include_normals=True)
57
+
58
+ print("Generation took:", time.time() - start, "s")
59
+
60
+ return tmp_file.name
61
+
62
+
63
+ def create_batch(input_image: Image) -> dict[str, Any]:
64
+ img_cond = (
65
+ torch.from_numpy(
66
+ np.asarray(input_image.resize((COND_WIDTH, COND_HEIGHT))).astype(np.float32)
67
+ / 255.0
68
+ )
69
+ .float()
70
+ .clip(0, 1)
71
+ )
72
+ mask_cond = img_cond[:, :, -1:]
73
+ rgb_cond = torch.lerp(
74
+ torch.tensor(BACKGROUND_COLOR)[None, None, :], img_cond[:, :, :3], mask_cond
75
+ )
76
+
77
+ batch_elem = {
78
+ "rgb_cond": rgb_cond,
79
+ "mask_cond": mask_cond,
80
+ "c2w_cond": c2w_cond.unsqueeze(0),
81
+ "intrinsic_cond": intrinsic.unsqueeze(0),
82
+ "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
83
+ }
84
+ # Add batch dim
85
+ batched = {k: v.unsqueeze(0) for k, v in batch_elem.items()}
86
+ return batched
87
+
88
+
89
+ @lru_cache
90
+ def checkerboard(squares: int, size: int, min_value: float = 0.5):
91
+ base = np.zeros((squares, squares)) + min_value
92
+ base[1::2, ::2] = 1
93
+ base[::2, 1::2] = 1
94
+
95
+ repeat_mult = size // squares
96
+ return (
97
+ base.repeat(repeat_mult, axis=0)
98
+ .repeat(repeat_mult, axis=1)[:, :, None]
99
+ .repeat(3, axis=-1)
100
+ )
101
+
102
+
103
+ def remove_background(input_image: Image) -> Image:
104
+ return rembg.remove(input_image, session=rembg_session)
105
+
106
+
107
+ def resize_foreground(
108
+ image: Image,
109
+ ratio: float,
110
+ ) -> Image:
111
+ image = np.array(image)
112
+ assert image.shape[-1] == 4
113
+ alpha = np.where(image[..., 3] > 0)
114
+ y1, y2, x1, x2 = (
115
+ alpha[0].min(),
116
+ alpha[0].max(),
117
+ alpha[1].min(),
118
+ alpha[1].max(),
119
+ )
120
+ # crop the foreground
121
+ fg = image[y1:y2, x1:x2]
122
+ # pad to square
123
+ size = max(fg.shape[0], fg.shape[1])
124
+ ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
125
+ ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
126
+ new_image = np.pad(
127
+ fg,
128
+ ((ph0, ph1), (pw0, pw1), (0, 0)),
129
+ mode="constant",
130
+ constant_values=((0, 0), (0, 0), (0, 0)),
131
+ )
132
+
133
+ # compute padding according to the ratio
134
+ new_size = int(new_image.shape[0] / ratio)
135
+ # pad to size, double side
136
+ ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
137
+ ph1, pw1 = new_size - size - ph0, new_size - size - pw0
138
+ new_image = np.pad(
139
+ new_image,
140
+ ((ph0, ph1), (pw0, pw1), (0, 0)),
141
+ mode="constant",
142
+ constant_values=((0, 0), (0, 0), (0, 0)),
143
+ )
144
+ new_image = Image.fromarray(new_image, mode="RGBA").resize(
145
+ (COND_WIDTH, COND_HEIGHT)
146
+ )
147
+ return new_image
148
+
149
+
150
+ def square_crop(input_image: Image) -> Image:
151
+ # Perform a center square crop
152
+ min_size = min(input_image.size)
153
+ left = (input_image.size[0] - min_size) // 2
154
+ top = (input_image.size[1] - min_size) // 2
155
+ right = (input_image.size[0] + min_size) // 2
156
+ bottom = (input_image.size[1] + min_size) // 2
157
+ return input_image.crop((left, top, right, bottom)).resize(
158
+ (COND_WIDTH, COND_HEIGHT)
159
+ )
160
+
161
+
162
+ def show_mask_img(input_image: Image) -> Image:
163
+ img_numpy = np.array(input_image)
164
+ alpha = img_numpy[:, :, 3] / 255.0
165
+ chkb = checkerboard(32, 512) * 255
166
+ new_img = img_numpy[..., :3] * alpha[:, :, None] + chkb * (1 - alpha[:, :, None])
167
+ return Image.fromarray(new_img.astype(np.uint8), mode="RGB")
168
+
169
+
170
+ def run_button(run_btn, input_image, background_state, foreground_ratio):
171
+ if run_btn == "Run":
172
+ glb_file: str = run_model(background_state)
173
+
174
+ return (
175
+ gr.update(),
176
+ gr.update(),
177
+ gr.update(),
178
+ gr.update(),
179
+ gr.update(value=glb_file, visible=True),
180
+ gr.update(visible=True),
181
+ )
182
+ elif run_btn == "Remove Background":
183
+ rem_removed = remove_background(input_image)
184
+
185
+ sqr_crop = square_crop(rem_removed)
186
+ fr_res = resize_foreground(sqr_crop, foreground_ratio)
187
+
188
+ return (
189
+ gr.update(value="Run", visible=True),
190
+ sqr_crop,
191
+ fr_res,
192
+ gr.update(value=show_mask_img(fr_res), visible=True),
193
+ gr.update(value=None, visible=False),
194
+ gr.update(visible=False),
195
+ )
196
+
197
+
198
+ def requires_bg_remove(image, fr):
199
+ if image is None:
200
+ return (
201
+ gr.update(visible=False, value="Run"),
202
+ None,
203
+ None,
204
+ gr.update(value=None, visible=False),
205
+ gr.update(visible=False),
206
+ gr.update(visible=False),
207
+ )
208
+ alpha_channel = np.array(image.getchannel("A"))
209
+ min_alpha = alpha_channel.min()
210
+
211
+ if min_alpha == 0:
212
+ print("Already has alpha")
213
+ sqr_crop = square_crop(image)
214
+ fr_res = resize_foreground(sqr_crop, fr)
215
+ return (
216
+ gr.update(value="Run", visible=True),
217
+ sqr_crop,
218
+ fr_res,
219
+ gr.update(value=show_mask_img(fr_res), visible=True),
220
+ gr.update(visible=False),
221
+ gr.update(visible=False),
222
+ )
223
+ return (
224
+ gr.update(value="Remove Background", visible=True),
225
+ None,
226
+ None,
227
+ gr.update(value=None, visible=False),
228
+ gr.update(visible=False),
229
+ gr.update(visible=False),
230
+ )
231
+
232
+
233
+ def update_foreground_ratio(img_proc, fr):
234
+ foreground_res = resize_foreground(img_proc, fr)
235
+ return (
236
+ foreground_res,
237
+ gr.update(value=show_mask_img(foreground_res)),
238
+ )
239
+
240
+
241
+ with gr.Blocks() as demo:
242
+ img_proc_state = gr.State()
243
+ background_remove_state = gr.State()
244
+ gr.Markdown("""
245
+ # SF3D: Stable Fast 3D Mesh Reconstruction with UV-unwrapping and Illumination Disentanglement
246
+
247
+ **SF3D** is a state-of-the-art method for 3D mesh reconstruction from a single image.
248
+ This demo allows you to upload an image and generate a 3D mesh model from it.
249
+
250
+ **Tips**
251
+ 1. If the image already has an alpha channel, you can skip the background removal step.
252
+ 2. You can adjust the foreground ratio to control the size of the foreground object. This can influence the shape
253
+ 3. You can upload your own HDR environment map to light the 3D model.
254
+ """)
255
+ with gr.Row(variant="panel"):
256
+ with gr.Column():
257
+ with gr.Row():
258
+ input_img = gr.Image(
259
+ type="pil", label="Input Image", sources="upload", image_mode="RGBA"
260
+ )
261
+ preview_removal = gr.Image(
262
+ label="Preview Background Removal",
263
+ type="pil",
264
+ image_mode="RGB",
265
+ interactive=False,
266
+ visible=False,
267
+ )
268
+
269
+ foreground_ratio = gr.Slider(
270
+ label="Foreground Ratio",
271
+ minimum=0.5,
272
+ maximum=1.0,
273
+ value=0.85,
274
+ step=0.05,
275
+ )
276
+
277
+ foreground_ratio.change(
278
+ update_foreground_ratio,
279
+ inputs=[img_proc_state, foreground_ratio],
280
+ outputs=[background_remove_state, preview_removal],
281
+ )
282
+
283
+ run_btn = gr.Button("Run", variant="primary", visible=False)
284
+
285
+ with gr.Column():
286
+ output_3d = LitModel3D(
287
+ label="3D Model",
288
+ visible=False,
289
+ clear_color=[0.0, 0.0, 0.0, 0.0],
290
+ tonemapping="aces",
291
+ contrast=1.0,
292
+ scale=1.0,
293
+ )
294
+ with gr.Column(visible=False, scale=1.0) as hdr_row:
295
+ gr.Markdown("""## HDR Environment Map
296
+
297
+ Select an HDR environment map to light the 3D model. You can also upload your own HDR environment maps.
298
+ """)
299
+
300
+ with gr.Row():
301
+ hdr_illumination_file = gr.File(
302
+ label="HDR Env Map", file_types=[".hdr"], file_count="single"
303
+ )
304
+ example_hdris = [
305
+ os.path.join("demo_files/hdri", f)
306
+ for f in os.listdir("demo_files/hdri")
307
+ ]
308
+ hdr_illumination_example = gr.Examples(
309
+ examples=example_hdris,
310
+ inputs=hdr_illumination_file,
311
+ )
312
+
313
+ hdr_illumination_file.change(
314
+ lambda x: gr.update(env_map=x.name if x is not None else None),
315
+ inputs=hdr_illumination_file,
316
+ outputs=[output_3d],
317
+ )
318
+
319
+ examples = gr.Examples(
320
+ examples=example_files,
321
+ inputs=input_img,
322
+ )
323
+
324
+ input_img.change(
325
+ requires_bg_remove,
326
+ inputs=[input_img, foreground_ratio],
327
+ outputs=[
328
+ run_btn,
329
+ img_proc_state,
330
+ background_remove_state,
331
+ preview_removal,
332
+ output_3d,
333
+ hdr_row,
334
+ ],
335
+ )
336
+
337
+ run_btn.click(
338
+ run_button,
339
+ inputs=[
340
+ run_btn,
341
+ input_img,
342
+ background_remove_state,
343
+ foreground_ratio,
344
+ ],
345
+ outputs=[
346
+ run_btn,
347
+ img_proc_state,
348
+ background_remove_state,
349
+ preview_removal,
350
+ output_3d,
351
+ hdr_row,
352
+ ],
353
+ )
354
+
355
+ demo.launch()