Spaces:
Build error
Build error
Sujit Pal
commited on
Commit
•
f9d31ee
1
Parent(s):
6c0a88f
fix: removed commented code
Browse files- dashboard_image2image.py +4 -26
- dashboard_text2image.py +3 -26
- utils.py +2 -1
dashboard_image2image.py
CHANGED
@@ -7,6 +7,7 @@ import streamlit as st
|
|
7 |
from PIL import Image
|
8 |
from transformers import CLIPProcessor, FlaxCLIPModel
|
9 |
|
|
|
10 |
|
11 |
BASELINE_MODEL = "openai/clip-vit-base-patch32"
|
12 |
# MODEL_PATH = "/home/shared/models/clip-rsicd/bs128x8-lr5e-6-adam/ckpt-1"
|
@@ -20,30 +21,6 @@ IMAGE_VECTOR_FILE = "./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"
|
|
20 |
IMAGES_DIR = "./images"
|
21 |
|
22 |
|
23 |
-
@st.cache(allow_output_mutation=True)
|
24 |
-
def load_index():
|
25 |
-
filenames, image_vecs = [], []
|
26 |
-
fvec = open(IMAGE_VECTOR_FILE, "r")
|
27 |
-
for line in fvec:
|
28 |
-
cols = line.strip().split('\t')
|
29 |
-
filename = cols[0]
|
30 |
-
image_vec = np.array([float(x) for x in cols[1].split(',')])
|
31 |
-
filenames.append(filename)
|
32 |
-
image_vecs.append(image_vec)
|
33 |
-
V = np.array(image_vecs)
|
34 |
-
index = nmslib.init(method='hnsw', space='cosinesimil')
|
35 |
-
index.addDataPointBatch(V)
|
36 |
-
index.createIndex({'post': 2}, print_progress=True)
|
37 |
-
return filenames, index
|
38 |
-
|
39 |
-
|
40 |
-
@st.cache(allow_output_mutation=True)
|
41 |
-
def load_model():
|
42 |
-
model = FlaxCLIPModel.from_pretrained(MODEL_PATH)
|
43 |
-
processor = CLIPProcessor.from_pretrained(BASELINE_MODEL)
|
44 |
-
return model, processor
|
45 |
-
|
46 |
-
|
47 |
@st.cache(allow_output_mutation=True)
|
48 |
def load_example_images():
|
49 |
example_images = {}
|
@@ -60,8 +37,9 @@ def load_example_images():
|
|
60 |
|
61 |
|
62 |
def app():
|
63 |
-
filenames, index = load_index()
|
64 |
-
model, processor = load_model()
|
|
|
65 |
example_images = load_example_images()
|
66 |
example_image_list = sorted([v[np.random.randint(0, len(v))]
|
67 |
for k, v in example_images.items()][0:10])
|
|
|
7 |
from PIL import Image
|
8 |
from transformers import CLIPProcessor, FlaxCLIPModel
|
9 |
|
10 |
+
import utils
|
11 |
|
12 |
BASELINE_MODEL = "openai/clip-vit-base-patch32"
|
13 |
# MODEL_PATH = "/home/shared/models/clip-rsicd/bs128x8-lr5e-6-adam/ckpt-1"
|
|
|
21 |
IMAGES_DIR = "./images"
|
22 |
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
@st.cache(allow_output_mutation=True)
|
25 |
def load_example_images():
|
26 |
example_images = {}
|
|
|
37 |
|
38 |
|
39 |
def app():
|
40 |
+
filenames, index = utils.load_index(IMAGE_VECTOR_FILE)
|
41 |
+
model, processor = utils.load_model(MODEL_PATH, BASELINE_MODEL)
|
42 |
+
|
43 |
example_images = load_example_images()
|
44 |
example_image_list = sorted([v[np.random.randint(0, len(v))]
|
45 |
for k, v in example_images.items()][0:10])
|
dashboard_text2image.py
CHANGED
@@ -6,6 +6,7 @@ import streamlit as st
|
|
6 |
|
7 |
from transformers import CLIPProcessor, FlaxCLIPModel
|
8 |
|
|
|
9 |
|
10 |
BASELINE_MODEL = "openai/clip-vit-base-patch32"
|
11 |
# MODEL_PATH = "/home/shared/models/clip-rsicd/bs128x8-lr5e-6-adam/ckpt-1"
|
@@ -19,33 +20,9 @@ IMAGE_VECTOR_FILE = "./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"
|
|
19 |
IMAGES_DIR = "./images"
|
20 |
|
21 |
|
22 |
-
@st.cache(allow_output_mutation=True)
|
23 |
-
def load_index():
|
24 |
-
filenames, image_vecs = [], []
|
25 |
-
fvec = open(IMAGE_VECTOR_FILE, "r")
|
26 |
-
for line in fvec:
|
27 |
-
cols = line.strip().split('\t')
|
28 |
-
filename = cols[0]
|
29 |
-
image_vec = np.array([float(x) for x in cols[1].split(',')])
|
30 |
-
filenames.append(filename)
|
31 |
-
image_vecs.append(image_vec)
|
32 |
-
V = np.array(image_vecs)
|
33 |
-
index = nmslib.init(method='hnsw', space='cosinesimil')
|
34 |
-
index.addDataPointBatch(V)
|
35 |
-
index.createIndex({'post': 2}, print_progress=True)
|
36 |
-
return filenames, index
|
37 |
-
|
38 |
-
|
39 |
-
@st.cache(allow_output_mutation=True)
|
40 |
-
def load_model():
|
41 |
-
model = FlaxCLIPModel.from_pretrained(MODEL_PATH)
|
42 |
-
processor = CLIPProcessor.from_pretrained(BASELINE_MODEL)
|
43 |
-
return model, processor
|
44 |
-
|
45 |
-
|
46 |
def app():
|
47 |
-
filenames, index = load_index()
|
48 |
-
model, processor = load_model()
|
49 |
|
50 |
st.title("Text to Image Retrieval")
|
51 |
st.markdown("""
|
|
|
6 |
|
7 |
from transformers import CLIPProcessor, FlaxCLIPModel
|
8 |
|
9 |
+
import utils
|
10 |
|
11 |
BASELINE_MODEL = "openai/clip-vit-base-patch32"
|
12 |
# MODEL_PATH = "/home/shared/models/clip-rsicd/bs128x8-lr5e-6-adam/ckpt-1"
|
|
|
20 |
IMAGES_DIR = "./images"
|
21 |
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
def app():
|
24 |
+
filenames, index = utils.load_index(IMAGE_VECTOR_FILE)
|
25 |
+
model, processor = utils.load_model(MODEL_PATH, BASELINE_MODEL)
|
26 |
|
27 |
st.title("Text to Image Retrieval")
|
28 |
st.markdown("""
|
utils.py
CHANGED
@@ -28,5 +28,6 @@ def load_index(image_vector_file):
|
|
28 |
@st.cache(allow_output_mutation=True)
|
29 |
def load_model(model_path, baseline_model):
|
30 |
model = FlaxCLIPModel.from_pretrained(model_path)
|
31 |
-
processor = CLIPProcessor.from_pretrained(baseline_model)
|
|
|
32 |
return model, processor
|
|
|
28 |
@st.cache(allow_output_mutation=True)
|
29 |
def load_model(model_path, baseline_model):
|
30 |
model = FlaxCLIPModel.from_pretrained(model_path)
|
31 |
+
# processor = CLIPProcessor.from_pretrained(baseline_model)
|
32 |
+
processor = CLIPProcessor.from_pretrained(model_path)
|
33 |
return model, processor
|