Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,526 Bytes
c1c0440 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Tuple, Literal
from functools import partial
import itertools
# LRM
from .embedder import CameraEmbedder
from .transformer import TransformerDecoder
# from accelerate.logging import get_logger
# logger = get_logger(__name__)
class LRM_VSD_Mesh_Net(nn.Module):
"""
predict VSD using transformer
"""
def __init__(self, camera_embed_dim: int,
transformer_dim: int, transformer_layers: int, transformer_heads: int,
triplane_low_res: int, triplane_high_res: int, triplane_dim: int,
encoder_freeze: bool = True, encoder_type: str = 'dino',
encoder_model_name: str = 'facebook/dino-vitb16', encoder_feat_dim: int = 768, app_dim = 27, density_dim = 8, app_n_comp=24,
density_n_comp=8):
super().__init__()
# attributes
self.encoder_feat_dim = encoder_feat_dim
self.camera_embed_dim = camera_embed_dim
self.triplane_low_res = triplane_low_res
self.triplane_high_res = triplane_high_res
self.triplane_dim = triplane_dim
self.transformer_dim=transformer_dim
# modules
self.encoder = self._encoder_fn(encoder_type)(
model_name=encoder_model_name,
modulation_dim=self.camera_embed_dim, #mod camera vector
freeze=encoder_freeze,
)
self.camera_embedder = CameraEmbedder(
raw_dim=12+4, embed_dim=camera_embed_dim,
)
self.n_comp=app_n_comp+density_n_comp
self.app_dim=app_dim
self.density_dim=density_dim
self.app_n_comp=app_n_comp
self.density_n_comp=density_n_comp
self.pos_embed = nn.Parameter(torch.randn(1, 3*(triplane_low_res**2)+3*triplane_low_res, transformer_dim) * (1. / transformer_dim) ** 0.5)
self.transformer = TransformerDecoder(
block_type='cond',
num_layers=transformer_layers, num_heads=transformer_heads,
inner_dim=transformer_dim, cond_dim=encoder_feat_dim, mod_dim=None,
)
# for plane
self.upsampler = nn.ConvTranspose2d(transformer_dim, self.n_comp, kernel_size=2, stride=2, padding=0)
self.dim_map = nn.Linear(transformer_dim,self.n_comp)
self.up_line = nn.Linear(triplane_low_res,triplane_low_res*2)
@staticmethod
def _encoder_fn(encoder_type: str):
encoder_type = encoder_type.lower()
assert encoder_type in ['dino', 'dinov2'], "Unsupported encoder type"
if encoder_type == 'dino':
from .encoders.dino_wrapper import DinoWrapper
#logger.info("Using DINO as the encoder")
return DinoWrapper
elif encoder_type == 'dinov2':
from .encoders.dinov2_wrapper import Dinov2Wrapper
#logger.info("Using DINOv2 as the encoder")
return Dinov2Wrapper
def forward_transformer(self, image_feats, camera_embeddings=None):
N = image_feats.shape[0]
x = self.pos_embed.repeat(N, 1, 1) # [N, L, D]
x = self.transformer(
x,
cond=image_feats,
mod=camera_embeddings,
)
return x
def reshape_upsample(self, tokens):
#B,_,3*ncomp
N = tokens.shape[0]
H = W = self.triplane_low_res
P=self.n_comp
offset=3*H*W
# planes
plane_tokens= tokens[:,:3*H*W,:].view(N,H,W,3,self.transformer_dim)
plane_tokens = torch.einsum('nhwip->inphw', plane_tokens) # [3, N, P, H, W]
plane_tokens = plane_tokens.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W]
plane_tokens = self.upsampler(plane_tokens) # [3*N, P, H', W']
plane_tokens = plane_tokens.view(3, N, *plane_tokens.shape[-3:]) # [3, N, P, H', W']
plane_tokens = torch.einsum('inphw->niphw', plane_tokens) # [N, 3, P, H', W']
plane_tokens = plane_tokens.reshape(N, 3*P, *plane_tokens.shape[-2:]) # # [N, 3*P, H', W']
plane_tokens = plane_tokens.contiguous()
#lines
line_tokens= tokens[:,3*H*W:3*H*W+3*H,:].view(N,H,3,self.transformer_dim)
line_tokens= self.dim_map(line_tokens)
line_tokens = torch.einsum('nhip->npih', line_tokens) # [ N, P, 3, H]
line_tokens=self.up_line(line_tokens)
line_tokens = torch.einsum('npih->niph', line_tokens) # [ N, 3, P, H]
line_tokens=line_tokens.reshape(N,3*P,line_tokens.shape[-1],1)
line_tokens = line_tokens.contiguous()
mat_tokens=None
d_mat_tokens=None
return plane_tokens[:,:self.app_n_comp*3,:,:],line_tokens[:,:self.app_n_comp*3,:,:],mat_tokens,d_mat_tokens,plane_tokens[:,self.app_n_comp*3:,:,:],line_tokens[:,self.app_n_comp*3:,:,:]
def forward_planes(self, image, camera):
# image: [N, V, C_img, H_img, W_img]
# camera: [N,V, D_cam_raw]
N,V,_,H,W = image.shape
image=image.reshape(N*V,3,H,W)
camera=camera.reshape(N*V,-1)
# embed camera
camera_embeddings = self.camera_embedder(camera)
assert camera_embeddings.shape[-1] == self.camera_embed_dim, \
f"Feature dimension mismatch: {camera_embeddings.shape[-1]} vs {self.camera_embed_dim}"
# encode image
image_feats = self.encoder(image, camera_embeddings)
assert image_feats.shape[-1] == self.encoder_feat_dim, \
f"Feature dimension mismatch: {image_feats.shape[-1]} vs {self.encoder_feat_dim}"
image_feats=image_feats.reshape(N,V*image_feats.shape[-2],image_feats.shape[-1])
# transformer generating planes
tokens = self.forward_transformer(image_feats)
app_planes,app_lines,basis_mat,d_basis_mat,density_planes,density_lines = self.reshape_upsample(tokens)
return app_planes,app_lines,basis_mat,d_basis_mat,density_planes,density_lines
def forward(self, image,source_camera):
# image: [N,V, C_img, H_img, W_img]
# source_camera: [N, V, D_cam_raw]
assert image.shape[0] == source_camera.shape[0], "Batch size mismatch for image and source_camera"
planes = self.forward_planes(image, source_camera)
#B,3,dim,H,W
return planes |