CDM-Demo / modules /visual_matcher.py
wufan's picture
Upload 47 files
685cc58 verified
raw
history blame
6.74 kB
import time
import numpy as np
from PIL import Image
from scipy.spatial.distance import cdist
from scipy.optimize import linear_sum_assignment
class SimpleAffineTransform:
"""
simple affine transform, only translation and scale.
"""
def __init__(self, translation=(0, 0), scale=1.0):
self.translation = np.array(translation)
self.scale = scale
def estimate(self, src, dst):
src_center = np.mean(src, axis=0)
dst_center = np.mean(dst, axis=0)
self.translation = dst_center - src_center
src_dists = np.linalg.norm(src - src_center, axis=1)
dst_dists = np.linalg.norm(dst - dst_center, axis=1)
self.scale = np.mean(dst_dists) / (np.mean(src_dists) + 1e-10)
def inverse(self):
inverse_transform = AffineTransform(-self.translation, 1.0/self.scale)
return inverse_transform
def __call__(self, coords):
return self.scale * (coords - np.mean(coords, axis=0)) + np.mean(coords, axis=0) + self.translation
def residuals(self, src, dst):
return np.sqrt(np.sum((self(src) - dst) ** 2, axis=1))
def norm_coords(x, left, right):
if x < left:
return left
if x > right:
return right
return x
def norm_same_token(token):
special_map = {
"\\cdot": ".",
"\\mid": "|",
"\\to": "\\rightarrow",
"\\top": "T",
"\\Tilde": "\\tilde",
"\\cdots": "\\dots",
"\\prime": "'",
"\\ast": "*",
"\\left<": "\\langle",
"\\right>": "\\rangle"
}
if token in special_map.keys():
token = special_map[token]
if token.startswith('\\left') or token.startswith('\\right'):
token = token.replace("\\left", "").replace("\\right", "")
if token.startswith('\\big') or token.startswith('\\Big'):
if "\\" in token[4:]:
token = "\\"+token[4:].split("\\")[-1]
else:
token = token[-1]
if token in ['\\leq', '\\geq']:
return token[0:-1]
if token in ['\\lVert', '\\rVert', '\\Vert']:
return '\\|'
if token in ['\\lvert', '\\rvert', '\\vert']:
return '|'
if token.endswith("rightarrow"):
return "\\rightarrow"
if token.endswith("leftarrow"):
return "\\leftarrow"
if token.startswith('\\wide'):
return token.replace("wide", "")
if token.startswith('\\var'):
return token.replace("\\var", "")
return token
class HungarianMatcher:
def __init__(
self,
cost_token: float = 1,
cost_position: float = 0.05,
cost_order: float = 0.15,
):
self.cost_token = cost_token
self.cost_position = cost_position
self.cost_order = cost_order
self.cost = {}
def calculate_token_cost_old(self, box_gt, box_pred):
token_cost = np.ones((len(box_gt), len(box_pred)))
for i in range(token_cost.shape[0]):
box1 = box_gt[i]
for j in range(token_cost.shape[1]):
box2 = box_pred[j]
if box1['token'] == box2['token']:
token_cost[i, j] = 0
elif norm_same_token(box1['token']) == norm_same_token(box2['token']):
token_cost[i, j] = 0.05
return np.array(token_cost)
def calculate_token_cost(self, box_gt, box_pred):
token2id = {}
for data in box_gt+box_pred:
if data['token'] not in token2id:
token2id[data['token']] = len(token2id)
num_classes = len(token2id)
token2id_norm = {}
for data in box_gt+box_pred:
if norm_same_token(data['token']) not in token2id_norm:
token2id_norm[norm_same_token(data['token'])] = len(token2id_norm)
num_classes_norm = len(token2id_norm)
gt_token_array = []
norm_gt_token_array = []
for data in box_gt:
gt_token_array.append(token2id[data['token']])
norm_gt_token_array.append(token2id_norm[norm_same_token(data['token'])])
pred_token_logits = []
norm_pred_token_logits = []
for data in box_pred:
logits = [0] * num_classes
logits[token2id[data['token']]] = 1
pred_token_logits.append(logits)
logits_norm = [0] * num_classes_norm
logits_norm[token2id_norm[norm_same_token(data['token'])]] = 1
norm_pred_token_logits.append(logits_norm)
gt_token_array = np.array(gt_token_array)
pred_token_logits = np.array(pred_token_logits)
norm_gt_token_array = np.array(norm_gt_token_array)
norm_pred_token_logits = np.array(norm_pred_token_logits)
token_cost = 1.0 - pred_token_logits[:, gt_token_array]
norm_token_cost = 1.0 - norm_pred_token_logits[:, norm_gt_token_array]
token_cost[np.logical_and(token_cost==1, norm_token_cost==0)] = 0.05
return token_cost.T
def box2array(self, box_list, size):
W, H = size
box_array = []
for box in box_list:
x_min, y_min, x_max, y_max = box['bbox']
box_array.append([x_min/W, y_min/H, x_max/W, y_max/H])
return np.array(box_array)
def order2array(self, box_list):
order_array = []
for idx, box in enumerate(box_list):
order_array.append([idx / len(box_list)])
return np.array(order_array)
def calculate_l1_cost(self, gt_array, pred_array):
scale = gt_array.shape[-1]
l1_cost = cdist(gt_array, pred_array, 'minkowski', p=1)
return l1_cost / scale
def __call__(self, box_gt, box_pred, gt_size, pred_size):
aa = time.time()
gt_box_array = self.box2array(box_gt, gt_size)
pred_box_array = self.box2array(box_pred, pred_size)
gt_order_array = self.order2array(box_gt)
pred_order_array = self.order2array(box_pred)
token_cost = self.calculate_token_cost(box_gt, box_pred)
position_cost = self.calculate_l1_cost(gt_box_array, pred_box_array)
order_cost = self.calculate_l1_cost(gt_order_array, pred_order_array)
self.cost["token"] = token_cost
self.cost["position"] = position_cost
self.cost["order"] = order_cost
cost = self.cost_token * token_cost + self.cost_position * position_cost + self.cost_order * order_cost
cost[np.isnan(cost) | np.isinf(cost)] = 100
indexes = linear_sum_assignment(cost)
matched_idxes = []
for a, b in zip(*indexes):
matched_idxes.append((a, b))
return matched_idxes