Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import os | |
import requests | |
import spaces | |
import timm | |
import torch | |
import torchvision.transforms as T | |
import types | |
import albumentations as A | |
from PIL import Image | |
from tqdm import tqdm | |
from sklearn.decomposition import PCA | |
from torch_kmeans import KMeans, CosineSimilarity | |
cmap = plt.get_cmap("tab20") | |
MEAN = np.array([123.675, 116.280, 103.530]) / 255 | |
STD = np.array([58.395, 57.120, 57.375]) / 255 | |
transforms = A.Compose([ | |
A.Normalize(mean=list(MEAN), std=list(STD)), | |
]) | |
def get_intermediate_layers( | |
self, | |
x: torch.Tensor, | |
n=1, | |
reshape: bool = False, | |
return_prefix_tokens: bool = False, | |
return_class_token: bool = False, | |
norm: bool = True, | |
): | |
outputs = self._intermediate_layers(x, n) | |
if norm: | |
outputs = [self.norm(out) for out in outputs] | |
if return_class_token: | |
prefix_tokens = [out[:, 0] for out in outputs] | |
else: | |
prefix_tokens = [out[:, 0 : self.num_prefix_tokens] for out in outputs] | |
outputs = [out[:, self.num_prefix_tokens :] for out in outputs] | |
if reshape: | |
B, C, H, W = x.shape | |
grid_size = ( | |
(H - self.patch_embed.patch_size[0]) | |
// self.patch_embed.proj.stride[0] | |
+ 1, | |
(W - self.patch_embed.patch_size[1]) | |
// self.patch_embed.proj.stride[1] | |
+ 1, | |
) | |
outputs = [ | |
out.reshape(x.shape[0], grid_size[0], grid_size[1], -1) | |
.permute(0, 3, 1, 2) | |
.contiguous() | |
for out in outputs | |
] | |
if return_prefix_tokens or return_class_token: | |
return tuple(zip(outputs, prefix_tokens)) | |
return tuple(outputs) | |
def viz_feat(feat): | |
_,_,h,w = feat.shape | |
feat = feat.squeeze(0).permute((1,2,0)) | |
projected_featmap = feat.reshape(-1, feat.shape[-1]).cpu() | |
pca = PCA(n_components=3) | |
pca.fit(projected_featmap) | |
pca_features = pca.transform(projected_featmap) | |
pca_features = (pca_features - pca_features.min()) / (pca_features.max() - pca_features.min()) | |
pca_features = pca_features * 255 | |
res_pred = Image.fromarray(pca_features.reshape(h, w, 3).astype(np.uint8)) | |
return res_pred | |
def plot_feats(model_option, ori_feats, fine_feats, ori_labels=None, fine_labels=None): | |
ori_feats_map = viz_feat(ori_feats) | |
fine_feats_map = viz_feat(fine_feats) | |
fig, ax = plt.subplots(2, 2, figsize=(6, 5)) | |
ax[0][0].imshow(ori_feats_map) | |
ax[0][0].set_title("Original " + model_option, fontsize=15) | |
ax[0][1].imshow(fine_feats_map) | |
ax[0][1].set_title("Fine-tuned", fontsize=15) | |
ax[1][0].imshow(ori_labels) | |
ax[1][1].imshow(fine_labels) | |
for xx in ax: | |
for x in xx: | |
x.xaxis.set_major_formatter(plt.NullFormatter()) | |
x.yaxis.set_major_formatter(plt.NullFormatter()) | |
x.set_xticks([]) | |
x.set_yticks([]) | |
x.axis('off') | |
plt.tight_layout() | |
plt.close(fig) | |
return fig | |
def download_image(url, save_path): | |
response = requests.get(url) | |
with open(save_path, 'wb') as file: | |
file.write(response.content) | |
def process_image(image, stride, transforms): | |
transformed = transforms(image=np.array(image)) | |
image_tensor = torch.tensor(transformed['image']) | |
image_tensor = image_tensor.permute(2,0,1) | |
image_tensor = image_tensor.unsqueeze(0).to(device) | |
h, w = image_tensor.shape[2:] | |
height_int = (h // stride)*stride | |
width_int = (w // stride)*stride | |
image_resized = torch.nn.functional.interpolate(image_tensor, size=(height_int, width_int), mode='bilinear') | |
return image_resized | |
def kmeans_clustering(feats_map, n_clusters=20): | |
if n_clusters == None: | |
n_clusters = 20 | |
print('num clusters: ', n_clusters) | |
B, D, h, w = feats_map.shape | |
feats_map_flattened = feats_map.permute((0, 2, 3, 1)).reshape(B, -1, D) | |
kmeans_engine = KMeans(n_clusters=n_clusters, distance=CosineSimilarity) | |
kmeans_engine.fit(feats_map_flattened) | |
labels = kmeans_engine.predict( | |
feats_map_flattened | |
) | |
labels = labels.reshape( | |
B, h, w | |
).float() | |
labels = labels[0].cpu().numpy() | |
label_map = cmap(labels / n_clusters)[..., :3] | |
label_map = np.uint8(label_map * 255) | |
label_map = Image.fromarray(label_map) | |
return label_map | |
def load_model(options): | |
original_models = {} | |
fine_models = {} | |
for option in tqdm(options): | |
print('Please wait ...') | |
print('loading weights of ', option) | |
original_models[option] = timm.create_model( | |
timm_model_card[option], | |
pretrained=True, | |
num_classes=0, | |
dynamic_img_size=True, | |
dynamic_img_pad=False, | |
).to(device) | |
original_models[option].get_intermediate_layers = types.MethodType( | |
get_intermediate_layers, | |
original_models[option] | |
) | |
fine_models[option] = torch.hub.load("ywyue/FiT3D", our_model_card[option]).to(device) | |
fine_models[option].get_intermediate_layers = types.MethodType( | |
get_intermediate_layers, | |
fine_models[option] | |
) | |
print('Done! Now play the demo :)') | |
return original_models, fine_models | |
if __name__ == "__main__": | |
if torch.cuda.is_available(): | |
device = torch.device('cuda') | |
else: | |
device = torch.device('cpu') | |
print("device: ") | |
print(device) | |
example_urls = { | |
"library.jpg": "https://n.ethz.ch/~yuayue/assets/fit3d/demo_images/library.jpg", | |
"livingroom.jpg": "https://n.ethz.ch/~yuayue/assets/fit3d/demo_images/livingroom.jpg", | |
"airplane.jpg": "https://n.ethz.ch/~yuayue/assets/fit3d/demo_images/airplane.jpg", | |
"ship.jpg": "https://n.ethz.ch/~yuayue/assets/fit3d/demo_images/ship.jpg", | |
"chair.jpg": "https://n.ethz.ch/~yuayue/assets/fit3d/demo_images/chair.jpg", | |
} | |
example_dir = "/tmp/examples" | |
os.makedirs(example_dir, exist_ok=True) | |
for name, url in example_urls.items(): | |
save_path = os.path.join(example_dir, name) | |
if not os.path.exists(save_path): | |
print(f"Downloading to {save_path}...") | |
download_image(url, save_path) | |
else: | |
print(f"{save_path} already exists.") | |
image_input = gr.Image(label="Choose an image:", | |
height=500, | |
type="pil", | |
image_mode='RGB', | |
sources=['upload', 'webcam', 'clipboard'] | |
) | |
options = ['DINOv2', 'DINOv2-reg', 'CLIP', 'MAE', 'DeiT-III'] | |
model_option = gr.Radio(options, value="DINOv2", label='Choose a 2D foundation model') | |
kmeans_num = gr.Number( | |
label="Number of K-Means clusters", value=20 | |
) | |
timm_model_card = { | |
"DINOv2": "vit_small_patch14_dinov2.lvd142m", | |
"DINOv2-reg": "vit_small_patch14_reg4_dinov2.lvd142m", | |
"CLIP": "vit_base_patch16_clip_384.laion2b_ft_in12k_in1k", | |
"MAE": "vit_base_patch16_224.mae", | |
"DeiT-III": "deit3_base_patch16_224.fb_in1k" | |
} | |
our_model_card = { | |
"DINOv2": "dinov2_small_fine", | |
"DINOv2-reg": "dinov2_reg_small_fine", | |
"CLIP": "clip_base_fine", | |
"MAE": "mae_base_fine", | |
"DeiT-III": "deit3_base_fine" | |
} | |
os.environ['TORCH_HOME'] = '/tmp/.cache' | |
# os.environ['GRADIO_EXAMPLES_CACHE'] = '/tmp/gradio_cache' | |
# Pre-load all models | |
original_models, fine_models = load_model(options) | |
def fit3d(image, model_option, kmeans_num): | |
# Select model | |
original_model = original_models[model_option] | |
fine_model = fine_models[model_option] | |
# Data preprocessing | |
p = original_model.patch_embed.patch_size | |
stride = p if isinstance(p, int) else p[0] | |
image_resized = process_image(image, stride, transforms) | |
with torch.no_grad(): | |
ori_feats = original_model.get_intermediate_layers(image_resized, n=[8,9,10,11], reshape=True, return_prefix_tokens=False, | |
return_class_token=False, norm=True) | |
fine_feats = fine_model.get_intermediate_layers(image_resized, n=[8,9,10,11], reshape=True, return_prefix_tokens=False, | |
return_class_token=False, norm=True) | |
ori_feats = ori_feats[-1] | |
fine_feats = fine_feats[-1] | |
ori_labels = kmeans_clustering(ori_feats, kmeans_num) | |
fine_labels = kmeans_clustering(fine_feats, kmeans_num) | |
return plot_feats(model_option, ori_feats, fine_feats, ori_labels, fine_labels) | |
demo = gr.Interface( | |
title="<div> \ | |
<h1>FiT3D</h1> \ | |
<h2>Improving 2D Feature Representations by 3D-Aware Fine-Tuning</h2> \ | |
<h2>ECCV 2024</h2> \ | |
</div>", | |
description="<div style='display: flex; justify-content: center; align-items: center; text-align: center;'> \ | |
<a href='https://arxiv.org/abs/2407.20229'><img src='https://img.shields.io/badge/arXiv-2407.20229-red'></a> \ | |
\ | |
<a href='https://ywyue.github.io/FiT3D'><img src='https://img.shields.io/badge/Project_Page-FiT3D-green' alt='Project Page'></a> \ | |
\ | |
<a href='https://github.com/ywyue/FiT3D'><img src='https://img.shields.io/badge/Github-Code-blue'></a> \ | |
</div>", | |
fn=fit3d, | |
inputs=[image_input, model_option, kmeans_num], | |
outputs="plot", | |
examples=[ | |
["/tmp/examples/library.jpg", "DINOv2", 20], | |
["/tmp/examples/livingroom.jpg", "DINOv2", 20], | |
["/tmp/examples/airplane.jpg", "DINOv2", 20], | |
["/tmp/examples/ship.jpg", "DINOv2", 20], | |
["/tmp/examples/chair.jpg", "DINOv2", 20], | |
], | |
cache_examples=True) | |
demo.launch() | |