Sujit Pal commited on
Commit
f9d31ee
1 Parent(s): 6c0a88f

fix: removed commented code

Browse files
Files changed (3) hide show
  1. dashboard_image2image.py +4 -26
  2. dashboard_text2image.py +3 -26
  3. utils.py +2 -1
dashboard_image2image.py CHANGED
@@ -7,6 +7,7 @@ import streamlit as st
7
  from PIL import Image
8
  from transformers import CLIPProcessor, FlaxCLIPModel
9
 
 
10
 
11
  BASELINE_MODEL = "openai/clip-vit-base-patch32"
12
  # MODEL_PATH = "/home/shared/models/clip-rsicd/bs128x8-lr5e-6-adam/ckpt-1"
@@ -20,30 +21,6 @@ IMAGE_VECTOR_FILE = "./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"
20
  IMAGES_DIR = "./images"
21
 
22
 
23
- @st.cache(allow_output_mutation=True)
24
- def load_index():
25
- filenames, image_vecs = [], []
26
- fvec = open(IMAGE_VECTOR_FILE, "r")
27
- for line in fvec:
28
- cols = line.strip().split('\t')
29
- filename = cols[0]
30
- image_vec = np.array([float(x) for x in cols[1].split(',')])
31
- filenames.append(filename)
32
- image_vecs.append(image_vec)
33
- V = np.array(image_vecs)
34
- index = nmslib.init(method='hnsw', space='cosinesimil')
35
- index.addDataPointBatch(V)
36
- index.createIndex({'post': 2}, print_progress=True)
37
- return filenames, index
38
-
39
-
40
- @st.cache(allow_output_mutation=True)
41
- def load_model():
42
- model = FlaxCLIPModel.from_pretrained(MODEL_PATH)
43
- processor = CLIPProcessor.from_pretrained(BASELINE_MODEL)
44
- return model, processor
45
-
46
-
47
  @st.cache(allow_output_mutation=True)
48
  def load_example_images():
49
  example_images = {}
@@ -60,8 +37,9 @@ def load_example_images():
60
 
61
 
62
  def app():
63
- filenames, index = load_index()
64
- model, processor = load_model()
 
65
  example_images = load_example_images()
66
  example_image_list = sorted([v[np.random.randint(0, len(v))]
67
  for k, v in example_images.items()][0:10])
 
7
  from PIL import Image
8
  from transformers import CLIPProcessor, FlaxCLIPModel
9
 
10
+ import utils
11
 
12
  BASELINE_MODEL = "openai/clip-vit-base-patch32"
13
  # MODEL_PATH = "/home/shared/models/clip-rsicd/bs128x8-lr5e-6-adam/ckpt-1"
 
21
  IMAGES_DIR = "./images"
22
 
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  @st.cache(allow_output_mutation=True)
25
  def load_example_images():
26
  example_images = {}
 
37
 
38
 
39
  def app():
40
+ filenames, index = utils.load_index(IMAGE_VECTOR_FILE)
41
+ model, processor = utils.load_model(MODEL_PATH, BASELINE_MODEL)
42
+
43
  example_images = load_example_images()
44
  example_image_list = sorted([v[np.random.randint(0, len(v))]
45
  for k, v in example_images.items()][0:10])
dashboard_text2image.py CHANGED
@@ -6,6 +6,7 @@ import streamlit as st
6
 
7
  from transformers import CLIPProcessor, FlaxCLIPModel
8
 
 
9
 
10
  BASELINE_MODEL = "openai/clip-vit-base-patch32"
11
  # MODEL_PATH = "/home/shared/models/clip-rsicd/bs128x8-lr5e-6-adam/ckpt-1"
@@ -19,33 +20,9 @@ IMAGE_VECTOR_FILE = "./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"
19
  IMAGES_DIR = "./images"
20
 
21
 
22
- @st.cache(allow_output_mutation=True)
23
- def load_index():
24
- filenames, image_vecs = [], []
25
- fvec = open(IMAGE_VECTOR_FILE, "r")
26
- for line in fvec:
27
- cols = line.strip().split('\t')
28
- filename = cols[0]
29
- image_vec = np.array([float(x) for x in cols[1].split(',')])
30
- filenames.append(filename)
31
- image_vecs.append(image_vec)
32
- V = np.array(image_vecs)
33
- index = nmslib.init(method='hnsw', space='cosinesimil')
34
- index.addDataPointBatch(V)
35
- index.createIndex({'post': 2}, print_progress=True)
36
- return filenames, index
37
-
38
-
39
- @st.cache(allow_output_mutation=True)
40
- def load_model():
41
- model = FlaxCLIPModel.from_pretrained(MODEL_PATH)
42
- processor = CLIPProcessor.from_pretrained(BASELINE_MODEL)
43
- return model, processor
44
-
45
-
46
  def app():
47
- filenames, index = load_index()
48
- model, processor = load_model()
49
 
50
  st.title("Text to Image Retrieval")
51
  st.markdown("""
 
6
 
7
  from transformers import CLIPProcessor, FlaxCLIPModel
8
 
9
+ import utils
10
 
11
  BASELINE_MODEL = "openai/clip-vit-base-patch32"
12
  # MODEL_PATH = "/home/shared/models/clip-rsicd/bs128x8-lr5e-6-adam/ckpt-1"
 
20
  IMAGES_DIR = "./images"
21
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def app():
24
+ filenames, index = utils.load_index(IMAGE_VECTOR_FILE)
25
+ model, processor = utils.load_model(MODEL_PATH, BASELINE_MODEL)
26
 
27
  st.title("Text to Image Retrieval")
28
  st.markdown("""
utils.py CHANGED
@@ -28,5 +28,6 @@ def load_index(image_vector_file):
28
  @st.cache(allow_output_mutation=True)
29
  def load_model(model_path, baseline_model):
30
  model = FlaxCLIPModel.from_pretrained(model_path)
31
- processor = CLIPProcessor.from_pretrained(baseline_model)
 
32
  return model, processor
 
28
  @st.cache(allow_output_mutation=True)
29
  def load_model(model_path, baseline_model):
30
  model = FlaxCLIPModel.from_pretrained(model_path)
31
+ # processor = CLIPProcessor.from_pretrained(baseline_model)
32
+ processor = CLIPProcessor.from_pretrained(model_path)
33
  return model, processor