joselobenitezg commited on
Commit
46a60b0
1 Parent(s): fa22dae

add inference script

Browse files
Files changed (6) hide show
  1. app.py +19 -71
  2. inference/depth.py +0 -0
  3. inference/normal.py +0 -0
  4. inference/pose.py +0 -0
  5. inference/seg.py +67 -0
  6. load_and_test.ipynb +1018 -18
app.py CHANGED
@@ -1,95 +1,44 @@
1
- # Part of the code is from: fashn-ai/sapiens-body-part-segmentation
2
  import os
3
-
4
  import gradio as gr
5
  import numpy as np
6
- import spaces
7
- import torch
8
- from gradio.themes.utils import sizes
9
  from PIL import Image
10
- from torchvision import transforms
11
- from utils.vis_utils import get_palette, visualize_mask_with_overlay
12
- from config import SAPIENS_LITE_MODELS_PATH
13
-
14
- if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
15
- torch.backends.cuda.matmul.allow_tf32 = True
16
- torch.backends.cudnn.allow_tf32 = True
17
-
18
- CHECKPOINTS_DIR = "checkpoints"
19
-
20
- def load_model(checkpoint_name: str):
21
- checkpoint_path = os.path.join(CHECKPOINTS_DIR, CHECKPOINTS[checkpoint_name])
22
- model = torch.jit.load(checkpoint_path)
23
- model.eval()
24
- model.to("cuda")
25
- return model
26
-
27
-
28
- #MODELS = {name: load_model(name) for name in CHECKPOINTS.keys()}
29
-
30
- @torch.inference_mode()
31
- def run_model(model, input_tensor, height, width):
32
- output = model(input_tensor)
33
- output = torch.nn.functional.interpolate(output, size=(height, width), mode="bilinear", align_corners=False)
34
- _, preds = torch.max(output, 1)
35
- return preds
36
-
37
-
38
- transform_fn = transforms.Compose(
39
- [
40
- transforms.Resize((1024, 768)),
41
- transforms.ToTensor(),
42
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
43
- ]
44
- )
45
-
46
- @spaces.GPU
47
- def segment(image: Image.Image, model_name: str) -> Image.Image:
48
- input_tensor = transform_fn(image).unsqueeze(0).to("cuda")
49
- model = MODELS[model_name]
50
- preds = run_model(model, input_tensor, height=image.height, width=image.width)
51
- mask = preds.squeeze(0).cpu().numpy()
52
- mask_image = Image.fromarray(mask.astype("uint8"))
53
- blended_image = visualize_mask_with_overlay(image, mask_image, LABELS_TO_IDS, alpha=0.5)
54
- return blended_image
55
 
 
 
56
 
57
  def update_model_choices(task):
58
  model_choices = list(SAPIENS_LITE_MODELS_PATH[task.lower()].keys())
59
  return gr.Dropdown(choices=model_choices, value=model_choices[0] if model_choices else None)
60
 
 
 
 
 
 
 
 
 
61
  with gr.Blocks() as demo:
62
  gr.Markdown("# Sapiens Arena 🤸🏽‍♂️ - WIP devmode- Not yet available")
63
  with gr.Tabs():
64
  with gr.TabItem('Image'):
65
  with gr.Row():
66
  with gr.Column():
67
- input_image = gr.Image(label="Input Image", type="pil", format="png")
68
  select_task = gr.Radio(
69
- ["Seg", "Pose", "Depth", "Normal"],
70
  label="Task",
71
- info="Choose the task to perfom",
72
- choices=list(SAPIENS_LITE_MODELS_PATH.keys())
73
  )
74
  model_name = gr.Dropdown(
75
  label="Model Version",
76
  choices=list(SAPIENS_LITE_MODELS_PATH["seg"].keys()),
77
- value="0.3B",
78
  )
79
-
80
- # example_model = gr.Examples(
81
- # inputs=input_image,
82
- # examples_per_page=10,
83
- # examples=[
84
- # os.path.join(ASSETS_DIR, "examples", img)
85
- # for img in os.listdir(os.path.join(ASSETS_DIR, "examples"))
86
- # ],
87
- # )
88
  with gr.Column():
89
- result_image = gr.Image(label="Segmentation Result", format="png")
90
  run_button = gr.Button("Run")
91
-
92
- #gr.Image(os.path.join(ASSETS_DIR, "legend.png"), label="Legend", type="filepath")
93
 
94
  with gr.TabItem('Video'):
95
  gr.Markdown("In construction")
@@ -97,11 +46,10 @@ with gr.Blocks() as demo:
97
  select_task.change(fn=update_model_choices, inputs=select_task, outputs=model_name)
98
 
99
  run_button.click(
100
- fn=segment,
101
- inputs=[input_image, model_name],
102
  outputs=[result_image],
103
  )
104
 
105
-
106
  if __name__ == "__main__":
107
- demo.launch(share=False)
 
 
1
  import os
 
2
  import gradio as gr
3
  import numpy as np
 
 
 
4
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ from inference.seg import process_image_or_video
7
+ from config import SAPIENS_LITE_MODELS_PATH
8
 
9
  def update_model_choices(task):
10
  model_choices = list(SAPIENS_LITE_MODELS_PATH[task.lower()].keys())
11
  return gr.Dropdown(choices=model_choices, value=model_choices[0] if model_choices else None)
12
 
13
+ def gradio_wrapper(input_image, task, version):
14
+ if isinstance(input_image, np.ndarray):
15
+ input_image = Image.fromarray(input_image)
16
+
17
+ result = process_image_or_video(input_image, task=task.lower(), version=version)
18
+
19
+ return result
20
+
21
  with gr.Blocks() as demo:
22
  gr.Markdown("# Sapiens Arena 🤸🏽‍♂️ - WIP devmode- Not yet available")
23
  with gr.Tabs():
24
  with gr.TabItem('Image'):
25
  with gr.Row():
26
  with gr.Column():
27
+ input_image = gr.Image(label="Input Image", type="pil")
28
  select_task = gr.Radio(
29
+ ["seg", "pose", "depth", "normal"],
30
  label="Task",
31
+ info="Choose the task to perform",
32
+ value="seg"
33
  )
34
  model_name = gr.Dropdown(
35
  label="Model Version",
36
  choices=list(SAPIENS_LITE_MODELS_PATH["seg"].keys()),
37
+ value="sapiens_0.3b",
38
  )
 
 
 
 
 
 
 
 
 
39
  with gr.Column():
40
+ result_image = gr.Image(label="Result")
41
  run_button = gr.Button("Run")
 
 
42
 
43
  with gr.TabItem('Video'):
44
  gr.Markdown("In construction")
 
46
  select_task.change(fn=update_model_choices, inputs=select_task, outputs=model_name)
47
 
48
  run_button.click(
49
+ fn=gradio_wrapper,
50
+ inputs=[input_image, select_task, model_name],
51
  outputs=[result_image],
52
  )
53
 
 
54
  if __name__ == "__main__":
55
+ demo.launch(share=True)
inference/depth.py ADDED
File without changes
inference/normal.py ADDED
File without changes
inference/pose.py ADDED
File without changes
inference/seg.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ from torchvision import transforms
5
+ from config import LABELS_TO_IDS
6
+ from utils.vis_utils import visualize_mask_with_overlay
7
+
8
+ def load_model(task, version):
9
+ from config import SAPIENS_LITE_MODELS_PATH
10
+ import os
11
+
12
+ try:
13
+ model_path = SAPIENS_LITE_MODELS_PATH[task][version]
14
+ if not os.path.exists(model_path):
15
+ print(f"Advertencia: El archivo del modelo no existe en {model_path}")
16
+ return None, None
17
+
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ model = torch.jit.load(model_path)
20
+ model.eval()
21
+ model.to(device)
22
+ return model, device
23
+ except KeyError as e:
24
+ print(f"Error: Tarea o versión inválida. {e}")
25
+ return None, None
26
+
27
+ def process_image_or_video(input_data, task='seg', version='sapiens_0.3b'):
28
+ # Configurar el modelo
29
+ model, device = load_model(task, version)
30
+ if model is None or device is None:
31
+ return None
32
+
33
+ # Configurar la transformación de entrada
34
+ transform_fn = transforms.Compose([
35
+ transforms.Resize((1024, 768)),
36
+ transforms.ToTensor(),
37
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
38
+ ])
39
+
40
+ # Función para procesar un solo frame
41
+ def process_frame(frame):
42
+ if isinstance(frame, np.ndarray):
43
+ frame = Image.fromarray(frame)
44
+
45
+ if frame.mode == 'RGBA':
46
+ frame = frame.convert('RGB')
47
+
48
+ input_tensor = transform_fn(frame).unsqueeze(0).to(device)
49
+
50
+ with torch.inference_mode():
51
+ output = model(input_tensor)
52
+ output = torch.nn.functional.interpolate(output, size=(frame.height, frame.width), mode="bilinear", align_corners=False)
53
+ _, preds = torch.max(output, 1)
54
+
55
+ mask = preds.squeeze(0).cpu().numpy()
56
+ mask_image = Image.fromarray(mask.astype("uint8"))
57
+ blended_image = visualize_mask_with_overlay(frame, mask_image, LABELS_TO_IDS, alpha=0.5)
58
+ return blended_image
59
+
60
+ # Procesar imagen o video
61
+ if isinstance(input_data, np.ndarray): # Video frame
62
+ return process_frame(input_data)
63
+ elif isinstance(input_data, Image.Image): # Imagen
64
+ return process_frame(input_data)
65
+ else:
66
+ print("Tipo de entrada no soportado. Por favor, proporcione una imagen PIL o un frame de video numpy.")
67
+ return None
load_and_test.ipynb CHANGED
@@ -3146,7 +3146,7 @@
3146
  },
3147
  {
3148
  "cell_type": "code",
3149
- "execution_count": 83,
3150
  "metadata": {},
3151
  "outputs": [],
3152
  "source": [
@@ -3155,6 +3155,104 @@
3155
  "import numpy as np\n",
3156
  "import cv2\n",
3157
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3158
  "def get_depth(image, depth_model, input_shape=(3, 1024, 768), device=\"cuda\"):\n",
3159
  " # Preprocess the image\n",
3160
  " img = preprocess_image(image, input_shape)\n",
@@ -3202,18 +3300,38 @@
3202
  "def visualize_depth(depth_map):\n",
3203
  " # Normalize the depth map\n",
3204
  " min_val, max_val = np.nanmin(depth_map), np.nanmax(depth_map)\n",
3205
- " depth_normalized = (depth_map - min_val) / (max_val - min_val)\n",
3206
- " depth_normalized = (depth_normalized * 255.0).astype(np.uint8)\n",
3207
  " \n",
3208
- " # Apply color map\n",
 
 
 
3209
  " depth_colored = cv2.applyColorMap(depth_normalized, cv2.COLORMAP_INFERNO)\n",
3210
  " \n",
3211
- " return depth_colored"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3212
  ]
3213
  },
3214
  {
3215
  "cell_type": "code",
3216
- "execution_count": 84,
3217
  "metadata": {},
3218
  "outputs": [],
3219
  "source": [
@@ -3222,31 +3340,913 @@
3222
  "pil_image = Image.open('/home/user/app/assets/image.webp')\n",
3223
  "\n",
3224
  "# Load and process an image\n",
3225
- "image = cv2.imread('/home/user/app/assets/image.webp')\n",
3226
  "depth_image, depth_map = get_depth(image, model)\n",
3227
  "\n",
 
 
3228
  "# Save the results\n",
3229
- "output_im = cv2.imwrite(\"output_depth_image.jpg\", depth_image)"
 
 
 
 
 
 
 
3230
  ]
3231
  },
3232
  {
3233
  "cell_type": "code",
3234
- "execution_count": 85,
3235
  "metadata": {},
3236
  "outputs": [
3237
  {
3238
- "data": {
3239
- "text/plain": [
3240
- "True"
3241
- ]
3242
- },
3243
- "execution_count": 85,
3244
- "metadata": {},
3245
- "output_type": "execute_result"
3246
  }
3247
  ],
3248
  "source": [
3249
- "output_im"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3250
  ]
3251
  },
3252
  {
 
3146
  },
3147
  {
3148
  "cell_type": "code",
3149
+ "execution_count": 90,
3150
  "metadata": {},
3151
  "outputs": [],
3152
  "source": [
 
3155
  "import numpy as np\n",
3156
  "import cv2\n",
3157
  "\n",
3158
+ "def get_depth(image, depth_model, input_shape=(3, 1024, 768), device=\"cuda\"):\n",
3159
+ " # Preprocess the image\n",
3160
+ " img = preprocess_image(image, input_shape)\n",
3161
+ " \n",
3162
+ " # Run the model\n",
3163
+ " with torch.no_grad():\n",
3164
+ " result = depth_model(img.to(device))\n",
3165
+ " \n",
3166
+ " # Post-process the output\n",
3167
+ " depth_map = post_process_depth(result, (image.shape[0], image.shape[1]))\n",
3168
+ " \n",
3169
+ " # Visualize the depth map\n",
3170
+ " depth_image = visualize_depth(depth_map)\n",
3171
+ " \n",
3172
+ " return depth_image, depth_map\n",
3173
+ "\n",
3174
+ "def preprocess_image(image, input_shape):\n",
3175
+ " img = cv2.resize(image, (input_shape[2], input_shape[1]), interpolation=cv2.INTER_LINEAR).transpose(2, 0, 1)\n",
3176
+ " img = torch.from_numpy(img)\n",
3177
+ " img = img[[2, 1, 0], ...].float()\n",
3178
+ " mean = torch.tensor([123.5, 116.5, 103.5]).view(-1, 1, 1)\n",
3179
+ " std = torch.tensor([58.5, 57.0, 57.5]).view(-1, 1, 1)\n",
3180
+ " img = (img - mean) / std\n",
3181
+ " return img.unsqueeze(0)\n",
3182
+ "\n",
3183
+ "def post_process_depth(result, original_shape):\n",
3184
+ " # Check the dimensionality of the result\n",
3185
+ " if result.dim() == 3:\n",
3186
+ " result = result.unsqueeze(0)\n",
3187
+ " elif result.dim() == 4:\n",
3188
+ " pass\n",
3189
+ " else:\n",
3190
+ " raise ValueError(f\"Unexpected result dimension: {result.dim()}\")\n",
3191
+ " \n",
3192
+ " # Ensure we're interpolating to the correct dimensions\n",
3193
+ " seg_logits = F.interpolate(result, size=original_shape, mode=\"bilinear\", align_corners=False).squeeze(0)\n",
3194
+ " depth_map = seg_logits.data.float().cpu().numpy()\n",
3195
+ " \n",
3196
+ " # If depth_map has an extra dimension, squeeze it\n",
3197
+ " if depth_map.ndim == 3 and depth_map.shape[0] == 1:\n",
3198
+ " depth_map = depth_map.squeeze(0)\n",
3199
+ " \n",
3200
+ " return depth_map\n",
3201
+ "\n",
3202
+ "# def visualize_depth(depth_map):\n",
3203
+ "# # Normalize the depth map\n",
3204
+ "# min_val, max_val = np.nanmin(depth_map), np.nanmax(depth_map)\n",
3205
+ "# depth_normalized = (depth_map - min_val) / (max_val - min_val)\n",
3206
+ "# depth_normalized = (depth_normalized * 255.0).astype(np.uint8)\n",
3207
+ " \n",
3208
+ "# # Apply color map\n",
3209
+ "# depth_colored = cv2.applyColorMap(depth_normalized, cv2.COLORMAP_INFERNO)\n",
3210
+ " \n",
3211
+ "# return depth_colored\n",
3212
+ "\n",
3213
+ "# def post_process_depth(result, original_shape):\n",
3214
+ "# seg_logits = F.interpolate(result.unsqueeze(0), size=original_shape, mode=\"bilinear\").squeeze(0)\n",
3215
+ "# depth_map = seg_logits.data.float().cpu().numpy()[0] # H x W\n",
3216
+ "# return depth_map\n",
3217
+ "\n",
3218
+ "def visualize_depth(depth_map):\n",
3219
+ " # Normalize the depth map\n",
3220
+ " min_val, max_val = np.nanmin(depth_map), np.nanmax(depth_map)\n",
3221
+ " depth_normalized = 1 - ((depth_map - min_val) / (max_val - min_val))\n",
3222
+ " \n",
3223
+ " # Convert to uint8\n",
3224
+ " depth_normalized = (depth_normalized * 255).astype(np.uint8)\n",
3225
+ " \n",
3226
+ " # Apply colormap\n",
3227
+ " depth_colored = cv2.applyColorMap(depth_normalized, cv2.COLORMAP_INFERNO)\n",
3228
+ " \n",
3229
+ " return depth_colored\n",
3230
+ "\n",
3231
+ "# You can add the surface normal calculation if needed\n",
3232
+ "def calculate_surface_normal(depth_map):\n",
3233
+ " kernel_size = 7\n",
3234
+ " grad_x = cv2.Sobel(depth_map.astype(np.float32), cv2.CV_32F, 1, 0, ksize=kernel_size)\n",
3235
+ " grad_y = cv2.Sobel(depth_map.astype(np.float32), cv2.CV_32F, 0, 1, ksize=kernel_size)\n",
3236
+ " z = np.full(grad_x.shape, -1)\n",
3237
+ " normals = np.dstack((-grad_x, -grad_y, z))\n",
3238
+ "\n",
3239
+ " normals_mag = np.linalg.norm(normals, axis=2, keepdims=True)\n",
3240
+ " with np.errstate(divide=\"ignore\", invalid=\"ignore\"):\n",
3241
+ " normals_normalized = normals / (normals_mag + 1e-5)\n",
3242
+ "\n",
3243
+ " normals_normalized = np.nan_to_num(normals_normalized, nan=-1, posinf=-1, neginf=-1)\n",
3244
+ " normal_from_depth = ((normals_normalized + 1) / 2 * 255).astype(np.uint8)\n",
3245
+ " normal_from_depth = normal_from_depth[:, :, ::-1] # RGB to BGR for cv2\n",
3246
+ "\n",
3247
+ " return normal_from_depth"
3248
+ ]
3249
+ },
3250
+ {
3251
+ "cell_type": "code",
3252
+ "execution_count": 94,
3253
+ "metadata": {},
3254
+ "outputs": [],
3255
+ "source": [
3256
  "def get_depth(image, depth_model, input_shape=(3, 1024, 768), device=\"cuda\"):\n",
3257
  " # Preprocess the image\n",
3258
  " img = preprocess_image(image, input_shape)\n",
 
3300
  "def visualize_depth(depth_map):\n",
3301
  " # Normalize the depth map\n",
3302
  " min_val, max_val = np.nanmin(depth_map), np.nanmax(depth_map)\n",
3303
+ " depth_normalized = 1 - ((depth_map - min_val) / (max_val - min_val))\n",
 
3304
  " \n",
3305
+ " # Convert to uint8\n",
3306
+ " depth_normalized = (depth_normalized * 255).astype(np.uint8)\n",
3307
+ " \n",
3308
+ " # Apply colormap\n",
3309
  " depth_colored = cv2.applyColorMap(depth_normalized, cv2.COLORMAP_INFERNO)\n",
3310
  " \n",
3311
+ " return depth_colored\n",
3312
+ "\n",
3313
+ "# You can add the surface normal calculation if needed\n",
3314
+ "def calculate_surface_normal(depth_map):\n",
3315
+ " kernel_size = 7\n",
3316
+ " grad_x = cv2.Sobel(depth_map.astype(np.float32), cv2.CV_32F, 1, 0, ksize=kernel_size)\n",
3317
+ " grad_y = cv2.Sobel(depth_map.astype(np.float32), cv2.CV_32F, 0, 1, ksize=kernel_size)\n",
3318
+ " z = np.full(grad_x.shape, -1)\n",
3319
+ " normals = np.dstack((-grad_x, -grad_y, z))\n",
3320
+ "\n",
3321
+ " normals_mag = np.linalg.norm(normals, axis=2, keepdims=True)\n",
3322
+ " with np.errstate(divide=\"ignore\", invalid=\"ignore\"):\n",
3323
+ " normals_normalized = normals / (normals_mag + 1e-5)\n",
3324
+ "\n",
3325
+ " normals_normalized = np.nan_to_num(normals_normalized, nan=-1, posinf=-1, neginf=-1)\n",
3326
+ " normal_from_depth = ((normals_normalized + 1) / 2 * 255).astype(np.uint8)\n",
3327
+ " normal_from_depth = normal_from_depth[:, :, ::-1] # RGB to BGR for cv2\n",
3328
+ "\n",
3329
+ " return normal_from_depth"
3330
  ]
3331
  },
3332
  {
3333
  "cell_type": "code",
3334
+ "execution_count": 99,
3335
  "metadata": {},
3336
  "outputs": [],
3337
  "source": [
 
3340
  "pil_image = Image.open('/home/user/app/assets/image.webp')\n",
3341
  "\n",
3342
  "# Load and process an image\n",
3343
+ "image = cv2.imread('/home/user/app/assets/frame.png')\n",
3344
  "depth_image, depth_map = get_depth(image, model)\n",
3345
  "\n",
3346
+ "surface_normal = calculate_surface_normal(depth_map)\n",
3347
+ "cv2.imwrite(\"output_surface_normal.jpg\", surface_normal)\n",
3348
  "# Save the results\n",
3349
+ "output_im = cv2.imwrite(\"output_depth_image2.jpg\", depth_image)"
3350
+ ]
3351
+ },
3352
+ {
3353
+ "cell_type": "markdown",
3354
+ "metadata": {},
3355
+ "source": [
3356
+ "# Normal"
3357
  ]
3358
  },
3359
  {
3360
  "cell_type": "code",
3361
+ "execution_count": 100,
3362
  "metadata": {},
3363
  "outputs": [
3364
  {
3365
+ "name": "stdout",
3366
+ "output_type": "stream",
3367
+ "text": [
3368
+ "checkpoints/normal/sapiens_0.3b_torchscript.pt2\n"
3369
+ ]
 
 
 
3370
  }
3371
  ],
3372
  "source": [
3373
+ "# Example usage\n",
3374
+ "TASK = 'normal'\n",
3375
+ "VERSION = 'sapiens_0.3b'\n",
3376
+ "\n",
3377
+ "model_path = get_model_path(TASK, VERSION)\n",
3378
+ "print(model_path)"
3379
+ ]
3380
+ },
3381
+ {
3382
+ "cell_type": "code",
3383
+ "execution_count": 101,
3384
+ "metadata": {},
3385
+ "outputs": [
3386
+ {
3387
+ "data": {
3388
+ "text/plain": [
3389
+ "RecursiveScriptModule(\n",
3390
+ " original_name=DepthEstimator\n",
3391
+ " (data_preprocessor): RecursiveScriptModule(original_name=SegDataPreProcessor)\n",
3392
+ " (backbone): RecursiveScriptModule(\n",
3393
+ " original_name=VisionTransformer\n",
3394
+ " (patch_embed): RecursiveScriptModule(\n",
3395
+ " original_name=PatchEmbed\n",
3396
+ " (projection): RecursiveScriptModule(original_name=Conv2d)\n",
3397
+ " )\n",
3398
+ " (drop_after_pos): RecursiveScriptModule(original_name=Dropout)\n",
3399
+ " (layers): RecursiveScriptModule(\n",
3400
+ " original_name=ModuleList\n",
3401
+ " (0): RecursiveScriptModule(\n",
3402
+ " original_name=TransformerEncoderLayer\n",
3403
+ " (ln1): RecursiveScriptModule(original_name=LayerNorm)\n",
3404
+ " (attn): RecursiveScriptModule(\n",
3405
+ " original_name=MultiheadAttention\n",
3406
+ " (qkv): RecursiveScriptModule(original_name=Linear)\n",
3407
+ " (proj): RecursiveScriptModule(original_name=Linear)\n",
3408
+ " (proj_drop): RecursiveScriptModule(original_name=Dropout)\n",
3409
+ " (out_drop): RecursiveScriptModule(original_name=DropPath)\n",
3410
+ " (gamma1): RecursiveScriptModule(original_name=Identity)\n",
3411
+ " )\n",
3412
+ " (ln2): RecursiveScriptModule(original_name=LayerNorm)\n",
3413
+ " (ffn): RecursiveScriptModule(\n",
3414
+ " original_name=FFN\n",
3415
+ " (layers): RecursiveScriptModule(\n",
3416
+ " original_name=Sequential\n",
3417
+ " (0): RecursiveScriptModule(\n",
3418
+ " original_name=Sequential\n",
3419
+ " (0): RecursiveScriptModule(original_name=Linear)\n",
3420
+ " (1): RecursiveScriptModule(original_name=GELU)\n",
3421
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3422
+ " )\n",
3423
+ " (1): RecursiveScriptModule(original_name=Linear)\n",
3424
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3425
+ " )\n",
3426
+ " (dropout_layer): RecursiveScriptModule(original_name=DropPath)\n",
3427
+ " (gamma2): RecursiveScriptModule(original_name=Identity)\n",
3428
+ " )\n",
3429
+ " )\n",
3430
+ " (1): RecursiveScriptModule(\n",
3431
+ " original_name=TransformerEncoderLayer\n",
3432
+ " (ln1): RecursiveScriptModule(original_name=LayerNorm)\n",
3433
+ " (attn): RecursiveScriptModule(\n",
3434
+ " original_name=MultiheadAttention\n",
3435
+ " (qkv): RecursiveScriptModule(original_name=Linear)\n",
3436
+ " (proj): RecursiveScriptModule(original_name=Linear)\n",
3437
+ " (proj_drop): RecursiveScriptModule(original_name=Dropout)\n",
3438
+ " (out_drop): RecursiveScriptModule(original_name=DropPath)\n",
3439
+ " (gamma1): RecursiveScriptModule(original_name=Identity)\n",
3440
+ " )\n",
3441
+ " (ln2): RecursiveScriptModule(original_name=LayerNorm)\n",
3442
+ " (ffn): RecursiveScriptModule(\n",
3443
+ " original_name=FFN\n",
3444
+ " (layers): RecursiveScriptModule(\n",
3445
+ " original_name=Sequential\n",
3446
+ " (0): RecursiveScriptModule(\n",
3447
+ " original_name=Sequential\n",
3448
+ " (0): RecursiveScriptModule(original_name=Linear)\n",
3449
+ " (1): RecursiveScriptModule(original_name=GELU)\n",
3450
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3451
+ " )\n",
3452
+ " (1): RecursiveScriptModule(original_name=Linear)\n",
3453
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3454
+ " )\n",
3455
+ " (dropout_layer): RecursiveScriptModule(original_name=DropPath)\n",
3456
+ " (gamma2): RecursiveScriptModule(original_name=Identity)\n",
3457
+ " )\n",
3458
+ " )\n",
3459
+ " (2): RecursiveScriptModule(\n",
3460
+ " original_name=TransformerEncoderLayer\n",
3461
+ " (ln1): RecursiveScriptModule(original_name=LayerNorm)\n",
3462
+ " (attn): RecursiveScriptModule(\n",
3463
+ " original_name=MultiheadAttention\n",
3464
+ " (qkv): RecursiveScriptModule(original_name=Linear)\n",
3465
+ " (proj): RecursiveScriptModule(original_name=Linear)\n",
3466
+ " (proj_drop): RecursiveScriptModule(original_name=Dropout)\n",
3467
+ " (out_drop): RecursiveScriptModule(original_name=DropPath)\n",
3468
+ " (gamma1): RecursiveScriptModule(original_name=Identity)\n",
3469
+ " )\n",
3470
+ " (ln2): RecursiveScriptModule(original_name=LayerNorm)\n",
3471
+ " (ffn): RecursiveScriptModule(\n",
3472
+ " original_name=FFN\n",
3473
+ " (layers): RecursiveScriptModule(\n",
3474
+ " original_name=Sequential\n",
3475
+ " (0): RecursiveScriptModule(\n",
3476
+ " original_name=Sequential\n",
3477
+ " (0): RecursiveScriptModule(original_name=Linear)\n",
3478
+ " (1): RecursiveScriptModule(original_name=GELU)\n",
3479
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3480
+ " )\n",
3481
+ " (1): RecursiveScriptModule(original_name=Linear)\n",
3482
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3483
+ " )\n",
3484
+ " (dropout_layer): RecursiveScriptModule(original_name=DropPath)\n",
3485
+ " (gamma2): RecursiveScriptModule(original_name=Identity)\n",
3486
+ " )\n",
3487
+ " )\n",
3488
+ " (3): RecursiveScriptModule(\n",
3489
+ " original_name=TransformerEncoderLayer\n",
3490
+ " (ln1): RecursiveScriptModule(original_name=LayerNorm)\n",
3491
+ " (attn): RecursiveScriptModule(\n",
3492
+ " original_name=MultiheadAttention\n",
3493
+ " (qkv): RecursiveScriptModule(original_name=Linear)\n",
3494
+ " (proj): RecursiveScriptModule(original_name=Linear)\n",
3495
+ " (proj_drop): RecursiveScriptModule(original_name=Dropout)\n",
3496
+ " (out_drop): RecursiveScriptModule(original_name=DropPath)\n",
3497
+ " (gamma1): RecursiveScriptModule(original_name=Identity)\n",
3498
+ " )\n",
3499
+ " (ln2): RecursiveScriptModule(original_name=LayerNorm)\n",
3500
+ " (ffn): RecursiveScriptModule(\n",
3501
+ " original_name=FFN\n",
3502
+ " (layers): RecursiveScriptModule(\n",
3503
+ " original_name=Sequential\n",
3504
+ " (0): RecursiveScriptModule(\n",
3505
+ " original_name=Sequential\n",
3506
+ " (0): RecursiveScriptModule(original_name=Linear)\n",
3507
+ " (1): RecursiveScriptModule(original_name=GELU)\n",
3508
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3509
+ " )\n",
3510
+ " (1): RecursiveScriptModule(original_name=Linear)\n",
3511
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3512
+ " )\n",
3513
+ " (dropout_layer): RecursiveScriptModule(original_name=DropPath)\n",
3514
+ " (gamma2): RecursiveScriptModule(original_name=Identity)\n",
3515
+ " )\n",
3516
+ " )\n",
3517
+ " (4): RecursiveScriptModule(\n",
3518
+ " original_name=TransformerEncoderLayer\n",
3519
+ " (ln1): RecursiveScriptModule(original_name=LayerNorm)\n",
3520
+ " (attn): RecursiveScriptModule(\n",
3521
+ " original_name=MultiheadAttention\n",
3522
+ " (qkv): RecursiveScriptModule(original_name=Linear)\n",
3523
+ " (proj): RecursiveScriptModule(original_name=Linear)\n",
3524
+ " (proj_drop): RecursiveScriptModule(original_name=Dropout)\n",
3525
+ " (out_drop): RecursiveScriptModule(original_name=DropPath)\n",
3526
+ " (gamma1): RecursiveScriptModule(original_name=Identity)\n",
3527
+ " )\n",
3528
+ " (ln2): RecursiveScriptModule(original_name=LayerNorm)\n",
3529
+ " (ffn): RecursiveScriptModule(\n",
3530
+ " original_name=FFN\n",
3531
+ " (layers): RecursiveScriptModule(\n",
3532
+ " original_name=Sequential\n",
3533
+ " (0): RecursiveScriptModule(\n",
3534
+ " original_name=Sequential\n",
3535
+ " (0): RecursiveScriptModule(original_name=Linear)\n",
3536
+ " (1): RecursiveScriptModule(original_name=GELU)\n",
3537
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3538
+ " )\n",
3539
+ " (1): RecursiveScriptModule(original_name=Linear)\n",
3540
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3541
+ " )\n",
3542
+ " (dropout_layer): RecursiveScriptModule(original_name=DropPath)\n",
3543
+ " (gamma2): RecursiveScriptModule(original_name=Identity)\n",
3544
+ " )\n",
3545
+ " )\n",
3546
+ " (5): RecursiveScriptModule(\n",
3547
+ " original_name=TransformerEncoderLayer\n",
3548
+ " (ln1): RecursiveScriptModule(original_name=LayerNorm)\n",
3549
+ " (attn): RecursiveScriptModule(\n",
3550
+ " original_name=MultiheadAttention\n",
3551
+ " (qkv): RecursiveScriptModule(original_name=Linear)\n",
3552
+ " (proj): RecursiveScriptModule(original_name=Linear)\n",
3553
+ " (proj_drop): RecursiveScriptModule(original_name=Dropout)\n",
3554
+ " (out_drop): RecursiveScriptModule(original_name=DropPath)\n",
3555
+ " (gamma1): RecursiveScriptModule(original_name=Identity)\n",
3556
+ " )\n",
3557
+ " (ln2): RecursiveScriptModule(original_name=LayerNorm)\n",
3558
+ " (ffn): RecursiveScriptModule(\n",
3559
+ " original_name=FFN\n",
3560
+ " (layers): RecursiveScriptModule(\n",
3561
+ " original_name=Sequential\n",
3562
+ " (0): RecursiveScriptModule(\n",
3563
+ " original_name=Sequential\n",
3564
+ " (0): RecursiveScriptModule(original_name=Linear)\n",
3565
+ " (1): RecursiveScriptModule(original_name=GELU)\n",
3566
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3567
+ " )\n",
3568
+ " (1): RecursiveScriptModule(original_name=Linear)\n",
3569
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3570
+ " )\n",
3571
+ " (dropout_layer): RecursiveScriptModule(original_name=DropPath)\n",
3572
+ " (gamma2): RecursiveScriptModule(original_name=Identity)\n",
3573
+ " )\n",
3574
+ " )\n",
3575
+ " (6): RecursiveScriptModule(\n",
3576
+ " original_name=TransformerEncoderLayer\n",
3577
+ " (ln1): RecursiveScriptModule(original_name=LayerNorm)\n",
3578
+ " (attn): RecursiveScriptModule(\n",
3579
+ " original_name=MultiheadAttention\n",
3580
+ " (qkv): RecursiveScriptModule(original_name=Linear)\n",
3581
+ " (proj): RecursiveScriptModule(original_name=Linear)\n",
3582
+ " (proj_drop): RecursiveScriptModule(original_name=Dropout)\n",
3583
+ " (out_drop): RecursiveScriptModule(original_name=DropPath)\n",
3584
+ " (gamma1): RecursiveScriptModule(original_name=Identity)\n",
3585
+ " )\n",
3586
+ " (ln2): RecursiveScriptModule(original_name=LayerNorm)\n",
3587
+ " (ffn): RecursiveScriptModule(\n",
3588
+ " original_name=FFN\n",
3589
+ " (layers): RecursiveScriptModule(\n",
3590
+ " original_name=Sequential\n",
3591
+ " (0): RecursiveScriptModule(\n",
3592
+ " original_name=Sequential\n",
3593
+ " (0): RecursiveScriptModule(original_name=Linear)\n",
3594
+ " (1): RecursiveScriptModule(original_name=GELU)\n",
3595
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3596
+ " )\n",
3597
+ " (1): RecursiveScriptModule(original_name=Linear)\n",
3598
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3599
+ " )\n",
3600
+ " (dropout_layer): RecursiveScriptModule(original_name=DropPath)\n",
3601
+ " (gamma2): RecursiveScriptModule(original_name=Identity)\n",
3602
+ " )\n",
3603
+ " )\n",
3604
+ " (7): RecursiveScriptModule(\n",
3605
+ " original_name=TransformerEncoderLayer\n",
3606
+ " (ln1): RecursiveScriptModule(original_name=LayerNorm)\n",
3607
+ " (attn): RecursiveScriptModule(\n",
3608
+ " original_name=MultiheadAttention\n",
3609
+ " (qkv): RecursiveScriptModule(original_name=Linear)\n",
3610
+ " (proj): RecursiveScriptModule(original_name=Linear)\n",
3611
+ " (proj_drop): RecursiveScriptModule(original_name=Dropout)\n",
3612
+ " (out_drop): RecursiveScriptModule(original_name=DropPath)\n",
3613
+ " (gamma1): RecursiveScriptModule(original_name=Identity)\n",
3614
+ " )\n",
3615
+ " (ln2): RecursiveScriptModule(original_name=LayerNorm)\n",
3616
+ " (ffn): RecursiveScriptModule(\n",
3617
+ " original_name=FFN\n",
3618
+ " (layers): RecursiveScriptModule(\n",
3619
+ " original_name=Sequential\n",
3620
+ " (0): RecursiveScriptModule(\n",
3621
+ " original_name=Sequential\n",
3622
+ " (0): RecursiveScriptModule(original_name=Linear)\n",
3623
+ " (1): RecursiveScriptModule(original_name=GELU)\n",
3624
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3625
+ " )\n",
3626
+ " (1): RecursiveScriptModule(original_name=Linear)\n",
3627
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3628
+ " )\n",
3629
+ " (dropout_layer): RecursiveScriptModule(original_name=DropPath)\n",
3630
+ " (gamma2): RecursiveScriptModule(original_name=Identity)\n",
3631
+ " )\n",
3632
+ " )\n",
3633
+ " (8): RecursiveScriptModule(\n",
3634
+ " original_name=TransformerEncoderLayer\n",
3635
+ " (ln1): RecursiveScriptModule(original_name=LayerNorm)\n",
3636
+ " (attn): RecursiveScriptModule(\n",
3637
+ " original_name=MultiheadAttention\n",
3638
+ " (qkv): RecursiveScriptModule(original_name=Linear)\n",
3639
+ " (proj): RecursiveScriptModule(original_name=Linear)\n",
3640
+ " (proj_drop): RecursiveScriptModule(original_name=Dropout)\n",
3641
+ " (out_drop): RecursiveScriptModule(original_name=DropPath)\n",
3642
+ " (gamma1): RecursiveScriptModule(original_name=Identity)\n",
3643
+ " )\n",
3644
+ " (ln2): RecursiveScriptModule(original_name=LayerNorm)\n",
3645
+ " (ffn): RecursiveScriptModule(\n",
3646
+ " original_name=FFN\n",
3647
+ " (layers): RecursiveScriptModule(\n",
3648
+ " original_name=Sequential\n",
3649
+ " (0): RecursiveScriptModule(\n",
3650
+ " original_name=Sequential\n",
3651
+ " (0): RecursiveScriptModule(original_name=Linear)\n",
3652
+ " (1): RecursiveScriptModule(original_name=GELU)\n",
3653
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3654
+ " )\n",
3655
+ " (1): RecursiveScriptModule(original_name=Linear)\n",
3656
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3657
+ " )\n",
3658
+ " (dropout_layer): RecursiveScriptModule(original_name=DropPath)\n",
3659
+ " (gamma2): RecursiveScriptModule(original_name=Identity)\n",
3660
+ " )\n",
3661
+ " )\n",
3662
+ " (9): RecursiveScriptModule(\n",
3663
+ " original_name=TransformerEncoderLayer\n",
3664
+ " (ln1): RecursiveScriptModule(original_name=LayerNorm)\n",
3665
+ " (attn): RecursiveScriptModule(\n",
3666
+ " original_name=MultiheadAttention\n",
3667
+ " (qkv): RecursiveScriptModule(original_name=Linear)\n",
3668
+ " (proj): RecursiveScriptModule(original_name=Linear)\n",
3669
+ " (proj_drop): RecursiveScriptModule(original_name=Dropout)\n",
3670
+ " (out_drop): RecursiveScriptModule(original_name=DropPath)\n",
3671
+ " (gamma1): RecursiveScriptModule(original_name=Identity)\n",
3672
+ " )\n",
3673
+ " (ln2): RecursiveScriptModule(original_name=LayerNorm)\n",
3674
+ " (ffn): RecursiveScriptModule(\n",
3675
+ " original_name=FFN\n",
3676
+ " (layers): RecursiveScriptModule(\n",
3677
+ " original_name=Sequential\n",
3678
+ " (0): RecursiveScriptModule(\n",
3679
+ " original_name=Sequential\n",
3680
+ " (0): RecursiveScriptModule(original_name=Linear)\n",
3681
+ " (1): RecursiveScriptModule(original_name=GELU)\n",
3682
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3683
+ " )\n",
3684
+ " (1): RecursiveScriptModule(original_name=Linear)\n",
3685
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3686
+ " )\n",
3687
+ " (dropout_layer): RecursiveScriptModule(original_name=DropPath)\n",
3688
+ " (gamma2): RecursiveScriptModule(original_name=Identity)\n",
3689
+ " )\n",
3690
+ " )\n",
3691
+ " (10): RecursiveScriptModule(\n",
3692
+ " original_name=TransformerEncoderLayer\n",
3693
+ " (ln1): RecursiveScriptModule(original_name=LayerNorm)\n",
3694
+ " (attn): RecursiveScriptModule(\n",
3695
+ " original_name=MultiheadAttention\n",
3696
+ " (qkv): RecursiveScriptModule(original_name=Linear)\n",
3697
+ " (proj): RecursiveScriptModule(original_name=Linear)\n",
3698
+ " (proj_drop): RecursiveScriptModule(original_name=Dropout)\n",
3699
+ " (out_drop): RecursiveScriptModule(original_name=DropPath)\n",
3700
+ " (gamma1): RecursiveScriptModule(original_name=Identity)\n",
3701
+ " )\n",
3702
+ " (ln2): RecursiveScriptModule(original_name=LayerNorm)\n",
3703
+ " (ffn): RecursiveScriptModule(\n",
3704
+ " original_name=FFN\n",
3705
+ " (layers): RecursiveScriptModule(\n",
3706
+ " original_name=Sequential\n",
3707
+ " (0): RecursiveScriptModule(\n",
3708
+ " original_name=Sequential\n",
3709
+ " (0): RecursiveScriptModule(original_name=Linear)\n",
3710
+ " (1): RecursiveScriptModule(original_name=GELU)\n",
3711
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3712
+ " )\n",
3713
+ " (1): RecursiveScriptModule(original_name=Linear)\n",
3714
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3715
+ " )\n",
3716
+ " (dropout_layer): RecursiveScriptModule(original_name=DropPath)\n",
3717
+ " (gamma2): RecursiveScriptModule(original_name=Identity)\n",
3718
+ " )\n",
3719
+ " )\n",
3720
+ " (11): RecursiveScriptModule(\n",
3721
+ " original_name=TransformerEncoderLayer\n",
3722
+ " (ln1): RecursiveScriptModule(original_name=LayerNorm)\n",
3723
+ " (attn): RecursiveScriptModule(\n",
3724
+ " original_name=MultiheadAttention\n",
3725
+ " (qkv): RecursiveScriptModule(original_name=Linear)\n",
3726
+ " (proj): RecursiveScriptModule(original_name=Linear)\n",
3727
+ " (proj_drop): RecursiveScriptModule(original_name=Dropout)\n",
3728
+ " (out_drop): RecursiveScriptModule(original_name=DropPath)\n",
3729
+ " (gamma1): RecursiveScriptModule(original_name=Identity)\n",
3730
+ " )\n",
3731
+ " (ln2): RecursiveScriptModule(original_name=LayerNorm)\n",
3732
+ " (ffn): RecursiveScriptModule(\n",
3733
+ " original_name=FFN\n",
3734
+ " (layers): RecursiveScriptModule(\n",
3735
+ " original_name=Sequential\n",
3736
+ " (0): RecursiveScriptModule(\n",
3737
+ " original_name=Sequential\n",
3738
+ " (0): RecursiveScriptModule(original_name=Linear)\n",
3739
+ " (1): RecursiveScriptModule(original_name=GELU)\n",
3740
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3741
+ " )\n",
3742
+ " (1): RecursiveScriptModule(original_name=Linear)\n",
3743
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3744
+ " )\n",
3745
+ " (dropout_layer): RecursiveScriptModule(original_name=DropPath)\n",
3746
+ " (gamma2): RecursiveScriptModule(original_name=Identity)\n",
3747
+ " )\n",
3748
+ " )\n",
3749
+ " (12): RecursiveScriptModule(\n",
3750
+ " original_name=TransformerEncoderLayer\n",
3751
+ " (ln1): RecursiveScriptModule(original_name=LayerNorm)\n",
3752
+ " (attn): RecursiveScriptModule(\n",
3753
+ " original_name=MultiheadAttention\n",
3754
+ " (qkv): RecursiveScriptModule(original_name=Linear)\n",
3755
+ " (proj): RecursiveScriptModule(original_name=Linear)\n",
3756
+ " (proj_drop): RecursiveScriptModule(original_name=Dropout)\n",
3757
+ " (out_drop): RecursiveScriptModule(original_name=DropPath)\n",
3758
+ " (gamma1): RecursiveScriptModule(original_name=Identity)\n",
3759
+ " )\n",
3760
+ " (ln2): RecursiveScriptModule(original_name=LayerNorm)\n",
3761
+ " (ffn): RecursiveScriptModule(\n",
3762
+ " original_name=FFN\n",
3763
+ " (layers): RecursiveScriptModule(\n",
3764
+ " original_name=Sequential\n",
3765
+ " (0): RecursiveScriptModule(\n",
3766
+ " original_name=Sequential\n",
3767
+ " (0): RecursiveScriptModule(original_name=Linear)\n",
3768
+ " (1): RecursiveScriptModule(original_name=GELU)\n",
3769
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3770
+ " )\n",
3771
+ " (1): RecursiveScriptModule(original_name=Linear)\n",
3772
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3773
+ " )\n",
3774
+ " (dropout_layer): RecursiveScriptModule(original_name=DropPath)\n",
3775
+ " (gamma2): RecursiveScriptModule(original_name=Identity)\n",
3776
+ " )\n",
3777
+ " )\n",
3778
+ " (13): RecursiveScriptModule(\n",
3779
+ " original_name=TransformerEncoderLayer\n",
3780
+ " (ln1): RecursiveScriptModule(original_name=LayerNorm)\n",
3781
+ " (attn): RecursiveScriptModule(\n",
3782
+ " original_name=MultiheadAttention\n",
3783
+ " (qkv): RecursiveScriptModule(original_name=Linear)\n",
3784
+ " (proj): RecursiveScriptModule(original_name=Linear)\n",
3785
+ " (proj_drop): RecursiveScriptModule(original_name=Dropout)\n",
3786
+ " (out_drop): RecursiveScriptModule(original_name=DropPath)\n",
3787
+ " (gamma1): RecursiveScriptModule(original_name=Identity)\n",
3788
+ " )\n",
3789
+ " (ln2): RecursiveScriptModule(original_name=LayerNorm)\n",
3790
+ " (ffn): RecursiveScriptModule(\n",
3791
+ " original_name=FFN\n",
3792
+ " (layers): RecursiveScriptModule(\n",
3793
+ " original_name=Sequential\n",
3794
+ " (0): RecursiveScriptModule(\n",
3795
+ " original_name=Sequential\n",
3796
+ " (0): RecursiveScriptModule(original_name=Linear)\n",
3797
+ " (1): RecursiveScriptModule(original_name=GELU)\n",
3798
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3799
+ " )\n",
3800
+ " (1): RecursiveScriptModule(original_name=Linear)\n",
3801
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3802
+ " )\n",
3803
+ " (dropout_layer): RecursiveScriptModule(original_name=DropPath)\n",
3804
+ " (gamma2): RecursiveScriptModule(original_name=Identity)\n",
3805
+ " )\n",
3806
+ " )\n",
3807
+ " (14): RecursiveScriptModule(\n",
3808
+ " original_name=TransformerEncoderLayer\n",
3809
+ " (ln1): RecursiveScriptModule(original_name=LayerNorm)\n",
3810
+ " (attn): RecursiveScriptModule(\n",
3811
+ " original_name=MultiheadAttention\n",
3812
+ " (qkv): RecursiveScriptModule(original_name=Linear)\n",
3813
+ " (proj): RecursiveScriptModule(original_name=Linear)\n",
3814
+ " (proj_drop): RecursiveScriptModule(original_name=Dropout)\n",
3815
+ " (out_drop): RecursiveScriptModule(original_name=DropPath)\n",
3816
+ " (gamma1): RecursiveScriptModule(original_name=Identity)\n",
3817
+ " )\n",
3818
+ " (ln2): RecursiveScriptModule(original_name=LayerNorm)\n",
3819
+ " (ffn): RecursiveScriptModule(\n",
3820
+ " original_name=FFN\n",
3821
+ " (layers): RecursiveScriptModule(\n",
3822
+ " original_name=Sequential\n",
3823
+ " (0): RecursiveScriptModule(\n",
3824
+ " original_name=Sequential\n",
3825
+ " (0): RecursiveScriptModule(original_name=Linear)\n",
3826
+ " (1): RecursiveScriptModule(original_name=GELU)\n",
3827
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3828
+ " )\n",
3829
+ " (1): RecursiveScriptModule(original_name=Linear)\n",
3830
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3831
+ " )\n",
3832
+ " (dropout_layer): RecursiveScriptModule(original_name=DropPath)\n",
3833
+ " (gamma2): RecursiveScriptModule(original_name=Identity)\n",
3834
+ " )\n",
3835
+ " )\n",
3836
+ " (15): RecursiveScriptModule(\n",
3837
+ " original_name=TransformerEncoderLayer\n",
3838
+ " (ln1): RecursiveScriptModule(original_name=LayerNorm)\n",
3839
+ " (attn): RecursiveScriptModule(\n",
3840
+ " original_name=MultiheadAttention\n",
3841
+ " (qkv): RecursiveScriptModule(original_name=Linear)\n",
3842
+ " (proj): RecursiveScriptModule(original_name=Linear)\n",
3843
+ " (proj_drop): RecursiveScriptModule(original_name=Dropout)\n",
3844
+ " (out_drop): RecursiveScriptModule(original_name=DropPath)\n",
3845
+ " (gamma1): RecursiveScriptModule(original_name=Identity)\n",
3846
+ " )\n",
3847
+ " (ln2): RecursiveScriptModule(original_name=LayerNorm)\n",
3848
+ " (ffn): RecursiveScriptModule(\n",
3849
+ " original_name=FFN\n",
3850
+ " (layers): RecursiveScriptModule(\n",
3851
+ " original_name=Sequential\n",
3852
+ " (0): RecursiveScriptModule(\n",
3853
+ " original_name=Sequential\n",
3854
+ " (0): RecursiveScriptModule(original_name=Linear)\n",
3855
+ " (1): RecursiveScriptModule(original_name=GELU)\n",
3856
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3857
+ " )\n",
3858
+ " (1): RecursiveScriptModule(original_name=Linear)\n",
3859
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3860
+ " )\n",
3861
+ " (dropout_layer): RecursiveScriptModule(original_name=DropPath)\n",
3862
+ " (gamma2): RecursiveScriptModule(original_name=Identity)\n",
3863
+ " )\n",
3864
+ " )\n",
3865
+ " (16): RecursiveScriptModule(\n",
3866
+ " original_name=TransformerEncoderLayer\n",
3867
+ " (ln1): RecursiveScriptModule(original_name=LayerNorm)\n",
3868
+ " (attn): RecursiveScriptModule(\n",
3869
+ " original_name=MultiheadAttention\n",
3870
+ " (qkv): RecursiveScriptModule(original_name=Linear)\n",
3871
+ " (proj): RecursiveScriptModule(original_name=Linear)\n",
3872
+ " (proj_drop): RecursiveScriptModule(original_name=Dropout)\n",
3873
+ " (out_drop): RecursiveScriptModule(original_name=DropPath)\n",
3874
+ " (gamma1): RecursiveScriptModule(original_name=Identity)\n",
3875
+ " )\n",
3876
+ " (ln2): RecursiveScriptModule(original_name=LayerNorm)\n",
3877
+ " (ffn): RecursiveScriptModule(\n",
3878
+ " original_name=FFN\n",
3879
+ " (layers): RecursiveScriptModule(\n",
3880
+ " original_name=Sequential\n",
3881
+ " (0): RecursiveScriptModule(\n",
3882
+ " original_name=Sequential\n",
3883
+ " (0): RecursiveScriptModule(original_name=Linear)\n",
3884
+ " (1): RecursiveScriptModule(original_name=GELU)\n",
3885
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3886
+ " )\n",
3887
+ " (1): RecursiveScriptModule(original_name=Linear)\n",
3888
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3889
+ " )\n",
3890
+ " (dropout_layer): RecursiveScriptModule(original_name=DropPath)\n",
3891
+ " (gamma2): RecursiveScriptModule(original_name=Identity)\n",
3892
+ " )\n",
3893
+ " )\n",
3894
+ " (17): RecursiveScriptModule(\n",
3895
+ " original_name=TransformerEncoderLayer\n",
3896
+ " (ln1): RecursiveScriptModule(original_name=LayerNorm)\n",
3897
+ " (attn): RecursiveScriptModule(\n",
3898
+ " original_name=MultiheadAttention\n",
3899
+ " (qkv): RecursiveScriptModule(original_name=Linear)\n",
3900
+ " (proj): RecursiveScriptModule(original_name=Linear)\n",
3901
+ " (proj_drop): RecursiveScriptModule(original_name=Dropout)\n",
3902
+ " (out_drop): RecursiveScriptModule(original_name=DropPath)\n",
3903
+ " (gamma1): RecursiveScriptModule(original_name=Identity)\n",
3904
+ " )\n",
3905
+ " (ln2): RecursiveScriptModule(original_name=LayerNorm)\n",
3906
+ " (ffn): RecursiveScriptModule(\n",
3907
+ " original_name=FFN\n",
3908
+ " (layers): RecursiveScriptModule(\n",
3909
+ " original_name=Sequential\n",
3910
+ " (0): RecursiveScriptModule(\n",
3911
+ " original_name=Sequential\n",
3912
+ " (0): RecursiveScriptModule(original_name=Linear)\n",
3913
+ " (1): RecursiveScriptModule(original_name=GELU)\n",
3914
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3915
+ " )\n",
3916
+ " (1): RecursiveScriptModule(original_name=Linear)\n",
3917
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3918
+ " )\n",
3919
+ " (dropout_layer): RecursiveScriptModule(original_name=DropPath)\n",
3920
+ " (gamma2): RecursiveScriptModule(original_name=Identity)\n",
3921
+ " )\n",
3922
+ " )\n",
3923
+ " (18): RecursiveScriptModule(\n",
3924
+ " original_name=TransformerEncoderLayer\n",
3925
+ " (ln1): RecursiveScriptModule(original_name=LayerNorm)\n",
3926
+ " (attn): RecursiveScriptModule(\n",
3927
+ " original_name=MultiheadAttention\n",
3928
+ " (qkv): RecursiveScriptModule(original_name=Linear)\n",
3929
+ " (proj): RecursiveScriptModule(original_name=Linear)\n",
3930
+ " (proj_drop): RecursiveScriptModule(original_name=Dropout)\n",
3931
+ " (out_drop): RecursiveScriptModule(original_name=DropPath)\n",
3932
+ " (gamma1): RecursiveScriptModule(original_name=Identity)\n",
3933
+ " )\n",
3934
+ " (ln2): RecursiveScriptModule(original_name=LayerNorm)\n",
3935
+ " (ffn): RecursiveScriptModule(\n",
3936
+ " original_name=FFN\n",
3937
+ " (layers): RecursiveScriptModule(\n",
3938
+ " original_name=Sequential\n",
3939
+ " (0): RecursiveScriptModule(\n",
3940
+ " original_name=Sequential\n",
3941
+ " (0): RecursiveScriptModule(original_name=Linear)\n",
3942
+ " (1): RecursiveScriptModule(original_name=GELU)\n",
3943
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3944
+ " )\n",
3945
+ " (1): RecursiveScriptModule(original_name=Linear)\n",
3946
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3947
+ " )\n",
3948
+ " (dropout_layer): RecursiveScriptModule(original_name=DropPath)\n",
3949
+ " (gamma2): RecursiveScriptModule(original_name=Identity)\n",
3950
+ " )\n",
3951
+ " )\n",
3952
+ " (19): RecursiveScriptModule(\n",
3953
+ " original_name=TransformerEncoderLayer\n",
3954
+ " (ln1): RecursiveScriptModule(original_name=LayerNorm)\n",
3955
+ " (attn): RecursiveScriptModule(\n",
3956
+ " original_name=MultiheadAttention\n",
3957
+ " (qkv): RecursiveScriptModule(original_name=Linear)\n",
3958
+ " (proj): RecursiveScriptModule(original_name=Linear)\n",
3959
+ " (proj_drop): RecursiveScriptModule(original_name=Dropout)\n",
3960
+ " (out_drop): RecursiveScriptModule(original_name=DropPath)\n",
3961
+ " (gamma1): RecursiveScriptModule(original_name=Identity)\n",
3962
+ " )\n",
3963
+ " (ln2): RecursiveScriptModule(original_name=LayerNorm)\n",
3964
+ " (ffn): RecursiveScriptModule(\n",
3965
+ " original_name=FFN\n",
3966
+ " (layers): RecursiveScriptModule(\n",
3967
+ " original_name=Sequential\n",
3968
+ " (0): RecursiveScriptModule(\n",
3969
+ " original_name=Sequential\n",
3970
+ " (0): RecursiveScriptModule(original_name=Linear)\n",
3971
+ " (1): RecursiveScriptModule(original_name=GELU)\n",
3972
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3973
+ " )\n",
3974
+ " (1): RecursiveScriptModule(original_name=Linear)\n",
3975
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
3976
+ " )\n",
3977
+ " (dropout_layer): RecursiveScriptModule(original_name=DropPath)\n",
3978
+ " (gamma2): RecursiveScriptModule(original_name=Identity)\n",
3979
+ " )\n",
3980
+ " )\n",
3981
+ " (20): RecursiveScriptModule(\n",
3982
+ " original_name=TransformerEncoderLayer\n",
3983
+ " (ln1): RecursiveScriptModule(original_name=LayerNorm)\n",
3984
+ " (attn): RecursiveScriptModule(\n",
3985
+ " original_name=MultiheadAttention\n",
3986
+ " (qkv): RecursiveScriptModule(original_name=Linear)\n",
3987
+ " (proj): RecursiveScriptModule(original_name=Linear)\n",
3988
+ " (proj_drop): RecursiveScriptModule(original_name=Dropout)\n",
3989
+ " (out_drop): RecursiveScriptModule(original_name=DropPath)\n",
3990
+ " (gamma1): RecursiveScriptModule(original_name=Identity)\n",
3991
+ " )\n",
3992
+ " (ln2): RecursiveScriptModule(original_name=LayerNorm)\n",
3993
+ " (ffn): RecursiveScriptModule(\n",
3994
+ " original_name=FFN\n",
3995
+ " (layers): RecursiveScriptModule(\n",
3996
+ " original_name=Sequential\n",
3997
+ " (0): RecursiveScriptModule(\n",
3998
+ " original_name=Sequential\n",
3999
+ " (0): RecursiveScriptModule(original_name=Linear)\n",
4000
+ " (1): RecursiveScriptModule(original_name=GELU)\n",
4001
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
4002
+ " )\n",
4003
+ " (1): RecursiveScriptModule(original_name=Linear)\n",
4004
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
4005
+ " )\n",
4006
+ " (dropout_layer): RecursiveScriptModule(original_name=DropPath)\n",
4007
+ " (gamma2): RecursiveScriptModule(original_name=Identity)\n",
4008
+ " )\n",
4009
+ " )\n",
4010
+ " (21): RecursiveScriptModule(\n",
4011
+ " original_name=TransformerEncoderLayer\n",
4012
+ " (ln1): RecursiveScriptModule(original_name=LayerNorm)\n",
4013
+ " (attn): RecursiveScriptModule(\n",
4014
+ " original_name=MultiheadAttention\n",
4015
+ " (qkv): RecursiveScriptModule(original_name=Linear)\n",
4016
+ " (proj): RecursiveScriptModule(original_name=Linear)\n",
4017
+ " (proj_drop): RecursiveScriptModule(original_name=Dropout)\n",
4018
+ " (out_drop): RecursiveScriptModule(original_name=DropPath)\n",
4019
+ " (gamma1): RecursiveScriptModule(original_name=Identity)\n",
4020
+ " )\n",
4021
+ " (ln2): RecursiveScriptModule(original_name=LayerNorm)\n",
4022
+ " (ffn): RecursiveScriptModule(\n",
4023
+ " original_name=FFN\n",
4024
+ " (layers): RecursiveScriptModule(\n",
4025
+ " original_name=Sequential\n",
4026
+ " (0): RecursiveScriptModule(\n",
4027
+ " original_name=Sequential\n",
4028
+ " (0): RecursiveScriptModule(original_name=Linear)\n",
4029
+ " (1): RecursiveScriptModule(original_name=GELU)\n",
4030
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
4031
+ " )\n",
4032
+ " (1): RecursiveScriptModule(original_name=Linear)\n",
4033
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
4034
+ " )\n",
4035
+ " (dropout_layer): RecursiveScriptModule(original_name=DropPath)\n",
4036
+ " (gamma2): RecursiveScriptModule(original_name=Identity)\n",
4037
+ " )\n",
4038
+ " )\n",
4039
+ " (22): RecursiveScriptModule(\n",
4040
+ " original_name=TransformerEncoderLayer\n",
4041
+ " (ln1): RecursiveScriptModule(original_name=LayerNorm)\n",
4042
+ " (attn): RecursiveScriptModule(\n",
4043
+ " original_name=MultiheadAttention\n",
4044
+ " (qkv): RecursiveScriptModule(original_name=Linear)\n",
4045
+ " (proj): RecursiveScriptModule(original_name=Linear)\n",
4046
+ " (proj_drop): RecursiveScriptModule(original_name=Dropout)\n",
4047
+ " (out_drop): RecursiveScriptModule(original_name=DropPath)\n",
4048
+ " (gamma1): RecursiveScriptModule(original_name=Identity)\n",
4049
+ " )\n",
4050
+ " (ln2): RecursiveScriptModule(original_name=LayerNorm)\n",
4051
+ " (ffn): RecursiveScriptModule(\n",
4052
+ " original_name=FFN\n",
4053
+ " (layers): RecursiveScriptModule(\n",
4054
+ " original_name=Sequential\n",
4055
+ " (0): RecursiveScriptModule(\n",
4056
+ " original_name=Sequential\n",
4057
+ " (0): RecursiveScriptModule(original_name=Linear)\n",
4058
+ " (1): RecursiveScriptModule(original_name=GELU)\n",
4059
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
4060
+ " )\n",
4061
+ " (1): RecursiveScriptModule(original_name=Linear)\n",
4062
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
4063
+ " )\n",
4064
+ " (dropout_layer): RecursiveScriptModule(original_name=DropPath)\n",
4065
+ " (gamma2): RecursiveScriptModule(original_name=Identity)\n",
4066
+ " )\n",
4067
+ " )\n",
4068
+ " (23): RecursiveScriptModule(\n",
4069
+ " original_name=TransformerEncoderLayer\n",
4070
+ " (ln1): RecursiveScriptModule(original_name=LayerNorm)\n",
4071
+ " (attn): RecursiveScriptModule(\n",
4072
+ " original_name=MultiheadAttention\n",
4073
+ " (qkv): RecursiveScriptModule(original_name=Linear)\n",
4074
+ " (proj): RecursiveScriptModule(original_name=Linear)\n",
4075
+ " (proj_drop): RecursiveScriptModule(original_name=Dropout)\n",
4076
+ " (out_drop): RecursiveScriptModule(original_name=DropPath)\n",
4077
+ " (gamma1): RecursiveScriptModule(original_name=Identity)\n",
4078
+ " )\n",
4079
+ " (ln2): RecursiveScriptModule(original_name=LayerNorm)\n",
4080
+ " (ffn): RecursiveScriptModule(\n",
4081
+ " original_name=FFN\n",
4082
+ " (layers): RecursiveScriptModule(\n",
4083
+ " original_name=Sequential\n",
4084
+ " (0): RecursiveScriptModule(\n",
4085
+ " original_name=Sequential\n",
4086
+ " (0): RecursiveScriptModule(original_name=Linear)\n",
4087
+ " (1): RecursiveScriptModule(original_name=GELU)\n",
4088
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
4089
+ " )\n",
4090
+ " (1): RecursiveScriptModule(original_name=Linear)\n",
4091
+ " (2): RecursiveScriptModule(original_name=Dropout)\n",
4092
+ " )\n",
4093
+ " (dropout_layer): RecursiveScriptModule(original_name=DropPath)\n",
4094
+ " (gamma2): RecursiveScriptModule(original_name=Identity)\n",
4095
+ " )\n",
4096
+ " )\n",
4097
+ " )\n",
4098
+ " (pre_norm): RecursiveScriptModule(original_name=Identity)\n",
4099
+ " (ln1): RecursiveScriptModule(original_name=LayerNorm)\n",
4100
+ " )\n",
4101
+ " (decode_head): RecursiveScriptModule(\n",
4102
+ " original_name=VitNormalHead\n",
4103
+ " (loss_decode): RecursiveScriptModule(\n",
4104
+ " original_name=ModuleList\n",
4105
+ " (0): RecursiveScriptModule(original_name=CosineSimilarityLoss)\n",
4106
+ " (1): RecursiveScriptModule(original_name=L1Loss)\n",
4107
+ " )\n",
4108
+ " (conv_seg): RecursiveScriptModule(original_name=Conv2d)\n",
4109
+ " (dropout): RecursiveScriptModule(original_name=Dropout2d)\n",
4110
+ " (deconv_layers): RecursiveScriptModule(\n",
4111
+ " original_name=Sequential\n",
4112
+ " (0): RecursiveScriptModule(original_name=ConvTranspose2d)\n",
4113
+ " (1): RecursiveScriptModule(original_name=InstanceNorm2d)\n",
4114
+ " (2): RecursiveScriptModule(original_name=SiLU)\n",
4115
+ " (3): RecursiveScriptModule(original_name=ConvTranspose2d)\n",
4116
+ " (4): RecursiveScriptModule(original_name=InstanceNorm2d)\n",
4117
+ " (5): RecursiveScriptModule(original_name=SiLU)\n",
4118
+ " (6): RecursiveScriptModule(original_name=ConvTranspose2d)\n",
4119
+ " (7): RecursiveScriptModule(original_name=InstanceNorm2d)\n",
4120
+ " (8): RecursiveScriptModule(original_name=SiLU)\n",
4121
+ " )\n",
4122
+ " (conv_layers): RecursiveScriptModule(\n",
4123
+ " original_name=Sequential\n",
4124
+ " (0): RecursiveScriptModule(original_name=Conv2d)\n",
4125
+ " (1): RecursiveScriptModule(original_name=InstanceNorm2d)\n",
4126
+ " (2): RecursiveScriptModule(original_name=SiLU)\n",
4127
+ " (3): RecursiveScriptModule(original_name=Conv2d)\n",
4128
+ " (4): RecursiveScriptModule(original_name=InstanceNorm2d)\n",
4129
+ " (5): RecursiveScriptModule(original_name=SiLU)\n",
4130
+ " (6): RecursiveScriptModule(original_name=Conv2d)\n",
4131
+ " (7): RecursiveScriptModule(original_name=InstanceNorm2d)\n",
4132
+ " (8): RecursiveScriptModule(original_name=SiLU)\n",
4133
+ " )\n",
4134
+ " )\n",
4135
+ ")"
4136
+ ]
4137
+ },
4138
+ "execution_count": 101,
4139
+ "metadata": {},
4140
+ "output_type": "execute_result"
4141
+ }
4142
+ ],
4143
+ "source": [
4144
+ "model = torch.jit.load(model_path)\n",
4145
+ "model.eval()\n",
4146
+ "model.to(\"cuda\")"
4147
+ ]
4148
+ },
4149
+ {
4150
+ "cell_type": "code",
4151
+ "execution_count": 105,
4152
+ "metadata": {},
4153
+ "outputs": [],
4154
+ "source": [
4155
+ "import torch\n",
4156
+ "import torch.nn.functional as F\n",
4157
+ "import numpy as np\n",
4158
+ "import cv2\n",
4159
+ "\n",
4160
+ "def get_normal(image, normal_model, input_shape=(3, 1024, 768), device=\"cuda\"):\n",
4161
+ " # Preprocess the image\n",
4162
+ " img = preprocess_image(image, input_shape)\n",
4163
+ " \n",
4164
+ " # Run the model\n",
4165
+ " with torch.no_grad():\n",
4166
+ " result = normal_model(img.to(device))\n",
4167
+ " \n",
4168
+ " # Post-process the output\n",
4169
+ " normal_map = post_process_normal(result, (image.shape[0], image.shape[1]))\n",
4170
+ " \n",
4171
+ " # Visualize the normal map\n",
4172
+ " normal_image = visualize_normal(normal_map)\n",
4173
+ " \n",
4174
+ " return normal_image, normal_map\n",
4175
+ "\n",
4176
+ "def preprocess_image(image, input_shape):\n",
4177
+ " img = cv2.resize(image, (input_shape[2], input_shape[1]), interpolation=cv2.INTER_LINEAR).transpose(2, 0, 1)\n",
4178
+ " img = torch.from_numpy(img)\n",
4179
+ " img = img[[2, 1, 0], ...].float()\n",
4180
+ " mean = torch.tensor([123.5, 116.5, 103.5]).view(-1, 1, 1)\n",
4181
+ " std = torch.tensor([58.5, 57.0, 57.5]).view(-1, 1, 1)\n",
4182
+ " img = (img - mean) / std\n",
4183
+ " return img.unsqueeze(0)\n",
4184
+ "\n",
4185
+ "def post_process_normal(result, original_shape):\n",
4186
+ " # Check the dimensionality of the result\n",
4187
+ " if result.dim() == 3:\n",
4188
+ " result = result.unsqueeze(0)\n",
4189
+ " elif result.dim() == 4:\n",
4190
+ " pass\n",
4191
+ " else:\n",
4192
+ " raise ValueError(f\"Unexpected result dimension: {result.dim()}\")\n",
4193
+ " \n",
4194
+ " # Ensure we're interpolating to the correct dimensions\n",
4195
+ " seg_logits = F.interpolate(result, size=original_shape, mode=\"bilinear\", align_corners=False).squeeze(0)\n",
4196
+ " normal_map = seg_logits.float().cpu().numpy().transpose(1, 2, 0) # H x W x 3\n",
4197
+ " return normal_map\n",
4198
+ "\n",
4199
+ "def visualize_normal(normal_map):\n",
4200
+ " normal_map_norm = np.linalg.norm(normal_map, axis=-1, keepdims=True)\n",
4201
+ " normal_map_normalized = normal_map / (normal_map_norm + 1e-5) # Add a small epsilon to avoid division by zero\n",
4202
+ " \n",
4203
+ " # Convert to 0-255 range and BGR format for visualization\n",
4204
+ " normal_map_vis = ((normal_map_normalized + 1) / 2 * 255).astype(np.uint8)\n",
4205
+ " normal_map_vis = normal_map_vis[:, :, ::-1] # RGB to BGR\n",
4206
+ " \n",
4207
+ " return normal_map_vis\n",
4208
+ "\n",
4209
+ "def load_normal_model(checkpoint, use_torchscript=False):\n",
4210
+ " if use_torchscript:\n",
4211
+ " return torch.jit.load(checkpoint)\n",
4212
+ " else:\n",
4213
+ " model = torch.export.load(checkpoint).module()\n",
4214
+ " model = model.to(\"cuda\")\n",
4215
+ " model = torch.compile(model, mode=\"max-autotune\", fullgraph=True)\n",
4216
+ " return model"
4217
+ ]
4218
+ },
4219
+ {
4220
+ "cell_type": "code",
4221
+ "execution_count": 107,
4222
+ "metadata": {},
4223
+ "outputs": [
4224
+ {
4225
+ "data": {
4226
+ "text/plain": [
4227
+ "True"
4228
+ ]
4229
+ },
4230
+ "execution_count": 107,
4231
+ "metadata": {},
4232
+ "output_type": "execute_result"
4233
+ }
4234
+ ],
4235
+ "source": [
4236
+ "import cv2\n",
4237
+ "import numpy as np\n",
4238
+ "\n",
4239
+ "# Load the model\n",
4240
+ "normal_model = load_normal_model(model_path, use_torchscript='_torchscript')\n",
4241
+ "\n",
4242
+ "# Load the image\n",
4243
+ "image = cv2.imread(\"/home/user/app/assets/image.webp\")\n",
4244
+ "\n",
4245
+ "# Get the normal map and visualization\n",
4246
+ "normal_image, normal_map = get_normal(image, normal_model)\n",
4247
+ "\n",
4248
+ "# Save the results\n",
4249
+ "cv2.imwrite(\"output_normal_image.png\", normal_image)"
4250
  ]
4251
  },
4252
  {