File size: 7,488 Bytes
74e4bcd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1945055
 
 
74e4bcd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1945055
74e4bcd
 
 
 
1945055
74e4bcd
 
 
 
 
 
 
1945055
74e4bcd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1945055
 
 
74e4bcd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1945055
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import os
import urllib.request
from collections import OrderedDict
from html import escape

import pandas as pd
import numpy as np

import torch
import torchvision.transforms as transforms

from transformers import CLIPProcessor, CLIPModel
import tokenizers
import regex

import streamlit as st

import models
from tokenizer import SimpleTokenizer

cuda_available = torch.cuda.is_available()

model_url = "https://dl.fbaipublicfiles.com/slip/slip_large_100ep.pt"
model_filename = "slip_large_100ep.pt"


def get_model(model):
    if isinstance(model, torch.nn.DataParallel) or isinstance(
        model, torch.nn.parallel.DistributedDataParallel
    ):
        return model.module
    else:
        return model


@st.cache(
    show_spinner=False,
    hash_funcs={
        CLIPModel: lambda _: None,
        CLIPProcessor: lambda _: None,
        dict: lambda _: None,
    },
)
def load():
    # Load SLIP model from Facebook AI Research
    if model_filename not in os.listdir():
        urllib.request.urlretrieve(model_url, model_filename)
    ckpt = torch.load("slip_large_100ep.pt", map_location="cpu")
    state_dict = OrderedDict()
    for k, v in ckpt["state_dict"].items():
        state_dict[k.replace("module.", "")] = v
    old_args = ckpt["args"]
    slip_model = getattr(models, "SLIP_VITL16")(
        rand_embed=False,
        ssl_mlp_dim=old_args.ssl_mlp_dim,
        ssl_emb_dim=old_args.ssl_emb_dim,
    )
    if cuda_available:
        slip_model.cuda()
    slip_model.load_state_dict(state_dict, strict=True)
    slip_model = get_model(slip_model)
    tokenizer = SimpleTokenizer()
    del ckpt
    del state_dict
    # Load CLIP model from HuggingFace
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    # Load images' descriptions and embeddings
    df = {0: pd.read_csv("data.csv"), 1: pd.read_csv("data2.csv")}
    embeddings = {0: np.load("embeddings.npy"), 1: np.load("embeddings2.npy")}
    slip_embeddings = {
        0: np.load("embeddings_slip_large.npy"),
        1: np.load("embeddings2_slip_large.npy"),
    }
    for k in [0, 1]:
        embeddings[k] = np.divide(
            embeddings[k], np.sqrt(np.sum(embeddings[k] ** 2, axis=1, keepdims=True))
        )
    return model, processor, slip_model, tokenizer, df, embeddings, slip_embeddings


model, processor, slip_model, tokenizer, df, embeddings, slip_embeddings = load()

source = {0: "\nSource: Unsplash", 1: "\nSource: The Movie Database (TMDB)"}


def get_html(url_list, url_list_slip, height=150):
    html = (
        "<div style='display: flex; flex-wrap: wrap; justify-content: space-evenly;'>"
    )
    html += "<span style='margin-top: 20px; max-width: 1200px; display: flex; align-content: flex-start; flex-wrap: wrap; justify-content: space-evenly; width: 50%'>"
    html += "<div style='width: 100%; text-align: center;'><b>CLIP</b> (<a href='https://arxiv.org/abs/2103.00020'>Arxiv</a>, <a href='https://github.com/openai/CLIP'>GitHub</a>) from OpenAI</div>"
    for url, title, link in url_list:
        html2 = f"<img title='{escape(title)}' style='height: {height}px; margin: 5px' src='{escape(url)}'>"
        if len(link) > 0:
            html2 = f"<a href='{escape(link)}' target='_blank'>" + html2 + "</a>"
        html = html + html2
    html += "</span>"
    html += "<span style='margin-top: 20px; max-width: 1200px; display: flex; align-content: flex-start; flex-wrap: wrap; justify-content: space-evenly; width: 50%; border-left: solid; border-color: #ffc423; border-width: thin;'>"
    html += "<div style='width: 100%; text-align: center;'><b>SLIP</b> (<a href='https://arxiv.org/abs/2112.12750'>Arxiv</a>, <a href='https://github.com/facebookresearch/SLIP'>GitHub</a>) from Meta AI</div>"
    for url, title, link in url_list_slip:
        html2 = f"<img title='{escape(title)}' style='height: {height}px; margin: 5px' src='{escape(url)}'>"
        if len(link) > 0:
            html2 = f"<a href='{escape(link)}' target='_blank'>" + html2 + "</a>"
        html = html + html2
    html += "</span></div>"
    return html


def compute_text_embeddings(list_of_strings):
    inputs = processor(text=list_of_strings, return_tensors="pt", padding=True)
    return model.get_text_features(**inputs)


def compute_text_embeddings_slip(list_of_strings):
    texts = tokenizer(list_of_strings)
    if cuda_available:
        texts = texts.cuda(non_blocking=True)
    texts = texts.view(-1, 77).contiguous()
    return slip_model.encode_text(texts)


def image_search(query, corpus, n_results=24):
    text_embeddings = compute_text_embeddings([query]).detach().numpy()
    text_embeddings_slip = compute_text_embeddings_slip([query]).detach().numpy()
    k = 0 if corpus == "Unsplash" else 1
    results = np.argsort((embeddings[k] @ text_embeddings.T)[:, 0])[
        -1 : -n_results - 1 : -1
    ]
    results_slip = np.argsort((slip_embeddings[k] @ text_embeddings_slip.T)[:, 0])[
        -1 : -n_results - 1 : -1
    ]
    return (
        [
            (
                df[k].iloc[i]["path"],
                df[k].iloc[i]["tooltip"] + source[k],
                df[k].iloc[i]["link"],
            )
            for i in results
        ],
        [
            (
                df[k].iloc[i]["path"],
                df[k].iloc[i]["tooltip"] + source[k],
                df[k].iloc[i]["link"],
            )
            for i in results_slip
        ],
    )


description = """
# Comparing CLIP and SLIP side by side

**Enter your query and hit enter**

CLIP and SLIP are ML models that encode images and texts as vectors so that the vectors of an image and its caption are similar. They can notably be used for zero-shot image classification, text-based image retrieval or image generation.

Cf. this Twitter [thread](https://twitter.com/vivien000000/status/1475829936443334660) with some suprising differences between CLIP and SLIP.

*Built with OpenAI's [CLIP](https://openai.com/blog/clip/) model, Meta AI's [SLIP](https://github.com/facebookresearch/SLIP) model, 🤗 Hugging Face's [transformers library](https://huggingface.co/transformers/), [Streamlit](https://streamlit.io/), 25k images from [Unsplash](https://unsplash.com/) and 8k images from [The Movie Database (TMDB)](https://www.themoviedb.org/)*
"""


st.markdown(
    """
          <style>
          .block-container{
            max-width: 1200px;
          }
          div.row-widget.stRadio > div{
            flex-direction:row;
            display: flex;
            justify-content: center;
          }
          div.row-widget.stRadio > div > label{
            margin-left: 5px;
            margin-right: 5px;
          }
          section.main>div:first-child {
            padding-top: 0px;
          }
          section:not(.main)>div:first-child {
            padding-top: 30px;
          }
          div.reportview-container > section:first-child{
            max-width: 320px;
          }
          #MainMenu {
            visibility: hidden;
          }
          footer {
            visibility: hidden;
          }
          </style>""",
    unsafe_allow_html=True,
)
st.sidebar.markdown(description)
_, c, _ = st.columns((1, 3, 1))
query = c.text_input("", value="clouds at sunset")
corpus = st.radio("", ["Unsplash", "Movies"])
if len(query) > 0:
    results, results_slip = image_search(query, corpus)
    st.markdown(get_html(results, results_slip), unsafe_allow_html=True)