File size: 7,112 Bytes
f981a9d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Any, List
import torch
from torch import nn
from torch.nn import functional as F
from detectron2.config import CfgNode
from detectron2.structures import Instances
from densepose.data.meshes.catalog import MeshCatalog
from densepose.modeling.cse.utils import normalize_embeddings, squared_euclidean_distance_matrix
from .embed_utils import PackedCseAnnotations
from .mask import extract_data_for_mask_loss_from_matches
def _create_pixel_dist_matrix(grid_size: int) -> torch.Tensor:
rows = torch.arange(grid_size)
cols = torch.arange(grid_size)
# at index `i` contains [row, col], where
# row = i // grid_size
# col = i % grid_size
pix_coords = (
torch.stack(torch.meshgrid(rows, cols), -1).reshape((grid_size * grid_size, 2)).float()
)
return squared_euclidean_distance_matrix(pix_coords, pix_coords)
def _sample_fg_pixels_randperm(fg_mask: torch.Tensor, sample_size: int) -> torch.Tensor:
fg_mask_flattened = fg_mask.reshape((-1,))
num_pixels = int(fg_mask_flattened.sum().item())
fg_pixel_indices = fg_mask_flattened.nonzero(as_tuple=True)[0]
if (sample_size <= 0) or (num_pixels <= sample_size):
return fg_pixel_indices
sample_indices = torch.randperm(num_pixels, device=fg_mask.device)[:sample_size]
return fg_pixel_indices[sample_indices]
def _sample_fg_pixels_multinomial(fg_mask: torch.Tensor, sample_size: int) -> torch.Tensor:
fg_mask_flattened = fg_mask.reshape((-1,))
num_pixels = int(fg_mask_flattened.sum().item())
if (sample_size <= 0) or (num_pixels <= sample_size):
return fg_mask_flattened.nonzero(as_tuple=True)[0]
return fg_mask_flattened.float().multinomial(sample_size, replacement=False)
class PixToShapeCycleLoss(nn.Module):
"""
Cycle loss for pixel-vertex correspondence
"""
def __init__(self, cfg: CfgNode):
super().__init__()
self.shape_names = list(cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS.keys())
self.embed_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE
self.norm_p = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.NORM_P
self.use_all_meshes_not_gt_only = (
cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.USE_ALL_MESHES_NOT_GT_ONLY
)
self.num_pixels_to_sample = (
cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.NUM_PIXELS_TO_SAMPLE
)
self.pix_sigma = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.PIXEL_SIGMA
self.temperature_pix_to_vertex = (
cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.TEMPERATURE_PIXEL_TO_VERTEX
)
self.temperature_vertex_to_pix = (
cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.TEMPERATURE_VERTEX_TO_PIXEL
)
self.pixel_dists = _create_pixel_dist_matrix(cfg.MODEL.ROI_DENSEPOSE_HEAD.HEATMAP_SIZE)
def forward(
self,
proposals_with_gt: List[Instances],
densepose_predictor_outputs: Any,
packed_annotations: PackedCseAnnotations,
embedder: nn.Module,
):
"""
Args:
proposals_with_gt (list of Instances): detections with associated
ground truth data; each item corresponds to instances detected
on 1 image; the number of items corresponds to the number of
images in a batch
densepose_predictor_outputs: an object of a dataclass that contains predictor
outputs with estimated values; assumed to have the following attributes:
* embedding - embedding estimates, tensor of shape [N, D, S, S], where
N = number of instances (= sum N_i, where N_i is the number of
instances on image i)
D = embedding space dimensionality (MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE)
S = output size (width and height)
packed_annotations (PackedCseAnnotations): contains various data useful
for loss computation, each data is packed into a single tensor
embedder (nn.Module): module that computes vertex embeddings for different meshes
"""
pix_embeds = densepose_predictor_outputs.embedding
if self.pixel_dists.device != pix_embeds.device:
# should normally be done only once
self.pixel_dists = self.pixel_dists.to(device=pix_embeds.device)
with torch.no_grad():
mask_loss_data = extract_data_for_mask_loss_from_matches(
proposals_with_gt, densepose_predictor_outputs.coarse_segm
)
# GT masks - tensor of shape [N, S, S] of int64
masks_gt = mask_loss_data.masks_gt.long() # pyre-ignore[16]
assert len(pix_embeds) == len(masks_gt), (
f"Number of instances with embeddings {len(pix_embeds)} != "
f"number of instances with GT masks {len(masks_gt)}"
)
losses = []
mesh_names = (
self.shape_names
if self.use_all_meshes_not_gt_only
else [
MeshCatalog.get_mesh_name(mesh_id.item())
for mesh_id in packed_annotations.vertex_mesh_ids_gt.unique()
]
)
for pixel_embeddings, mask_gt in zip(pix_embeds, masks_gt):
# pixel_embeddings [D, S, S]
# mask_gt [S, S]
for mesh_name in mesh_names:
mesh_vertex_embeddings = embedder(mesh_name)
# pixel indices [M]
pixel_indices_flattened = _sample_fg_pixels_randperm(
mask_gt, self.num_pixels_to_sample
)
# pixel distances [M, M]
pixel_dists = self.pixel_dists.to(pixel_embeddings.device)[
torch.meshgrid(pixel_indices_flattened, pixel_indices_flattened)
]
# pixel embeddings [M, D]
pixel_embeddings_sampled = normalize_embeddings(
pixel_embeddings.reshape((self.embed_size, -1))[:, pixel_indices_flattened].T
)
# pixel-vertex similarity [M, K]
sim_matrix = pixel_embeddings_sampled.mm(mesh_vertex_embeddings.T)
c_pix_vertex = F.softmax(sim_matrix / self.temperature_pix_to_vertex, dim=1)
c_vertex_pix = F.softmax(sim_matrix.T / self.temperature_vertex_to_pix, dim=1)
c_cycle = c_pix_vertex.mm(c_vertex_pix)
loss_cycle = torch.norm(pixel_dists * c_cycle, p=self.norm_p)
losses.append(loss_cycle)
if len(losses) == 0:
return pix_embeds.sum() * 0
return torch.stack(losses, dim=0).mean()
def fake_value(self, densepose_predictor_outputs: Any, embedder: nn.Module):
losses = [embedder(mesh_name).sum() * 0 for mesh_name in embedder.mesh_names]
losses.append(densepose_predictor_outputs.embedding.sum() * 0)
return torch.mean(torch.stack(losses))
|