Spaces:
Running
Running
sashavor
commited on
Commit
•
0d69242
1
Parent(s):
c45624f
adding new index and changing app a bit
Browse files- app.py +8 -8
- index_768.pickle +3 -0
- requirements.txt +2 -1
app.py
CHANGED
@@ -8,23 +8,23 @@ from transformers import AutoModel, AutoFeatureExtractor
|
|
8 |
seed = 42
|
9 |
|
10 |
# Only runs once when the script is first run.
|
11 |
-
with open("
|
12 |
index = pickle.load(handle)
|
13 |
|
14 |
# Load model for computing embeddings.
|
15 |
-
feature_extractor = AutoFeatureExtractor.from_pretrained("
|
16 |
-
model = AutoModel.from_pretrained("
|
17 |
|
18 |
# Candidate images.
|
19 |
-
dataset = load_dataset("sasha/
|
20 |
ds = dataset["train"]
|
21 |
|
22 |
|
23 |
-
def query(image, top_k):
|
24 |
inputs = feature_extractor(image, return_tensors="pt")
|
25 |
model_output = model(**inputs)
|
26 |
embedding = model_output.pooler_output.detach()
|
27 |
-
results = index.query(embedding)
|
28 |
inx = results[0][0].tolist()
|
29 |
images = ds.select(inx)["image"]
|
30 |
return images
|
@@ -37,8 +37,8 @@ description = "This Space demos an image similarity system. You can refer to [th
|
|
37 |
# Not sure what the best for this demo is.
|
38 |
gr.Interface(
|
39 |
query,
|
40 |
-
inputs=[gr.Image(type="pil")
|
41 |
-
outputs=gr.Gallery().style(grid=[
|
42 |
# Filenames denote the integer labels. Know here: https://hf.co/datasets/beans
|
43 |
title=title,
|
44 |
description=description,
|
|
|
8 |
seed = 42
|
9 |
|
10 |
# Only runs once when the script is first run.
|
11 |
+
with open("index_768.pickle", "rb") as handle:
|
12 |
index = pickle.load(handle)
|
13 |
|
14 |
# Load model for computing embeddings.
|
15 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained("sasha/autotrain-butterfly-similarity-2490576840")
|
16 |
+
model = AutoModel.from_pretrained("sasha/autotrain-butterfly-similarity-2490576840")
|
17 |
|
18 |
# Candidate images.
|
19 |
+
dataset = load_dataset("sasha/butterflies_10k_names_multiple")
|
20 |
ds = dataset["train"]
|
21 |
|
22 |
|
23 |
+
def query(image, top_k=4):
|
24 |
inputs = feature_extractor(image, return_tensors="pt")
|
25 |
model_output = model(**inputs)
|
26 |
embedding = model_output.pooler_output.detach()
|
27 |
+
results = index.query(embedding, k=top_k)
|
28 |
inx = results[0][0].tolist()
|
29 |
images = ds.select(inx)["image"]
|
30 |
return images
|
|
|
37 |
# Not sure what the best for this demo is.
|
38 |
gr.Interface(
|
39 |
query,
|
40 |
+
inputs=[gr.Image(type="pil")],
|
41 |
+
outputs=gr.Gallery().style(grid=[2], height="auto"),
|
42 |
# Filenames denote the integer labels. Know here: https://hf.co/datasets/beans
|
43 |
title=title,
|
44 |
description=description,
|
index_768.pickle
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:eccd83bb743f6de5eaf05886f948d95daceeede60805ad01cdba0baddd1a60cc
|
3 |
+
size 53317256
|
requirements.txt
CHANGED
@@ -2,4 +2,5 @@ transformers==4.25.1
|
|
2 |
datasets==2.7.1
|
3 |
numpy==1.21.6
|
4 |
torch==1.12.1
|
5 |
-
torchvision
|
|
|
|
2 |
datasets==2.7.1
|
3 |
numpy==1.21.6
|
4 |
torch==1.12.1
|
5 |
+
torchvision
|
6 |
+
pynndescent
|