g8a9 commited on
Commit
6668c84
1 Parent(s): 01c6c72

Add gc.collect

Browse files
Files changed (2) hide show
  1. image2text.py +9 -3
  2. text2image.py +6 -3
image2text.py CHANGED
@@ -6,6 +6,8 @@ from jax import numpy as jnp
6
  import pandas as pd
7
  import requests
8
  import jax
 
 
9
 
10
  def app():
11
  st.title("From Image to Text")
@@ -27,8 +29,8 @@ def app():
27
 
28
  image_url = st.text_input(
29
  "You can input the URL of an image",
30
- value="https://www.petdetective.it/wp-content/uploads/2016/04/gatto-toilette.jpg")
31
-
32
 
33
  MAX_CAP = 4
34
 
@@ -67,7 +69,9 @@ def app():
67
  image_embed = image_encoder(transform(image), model)
68
 
69
  # we could have a softmax here
70
- cos_similarities = jax.nn.softmax(jnp.matmul(image_embed, text_embeds.T))
 
 
71
 
72
  chart_data = pd.Series(cos_similarities[0], index=captions)
73
 
@@ -77,3 +81,5 @@ def app():
77
 
78
  with col2:
79
  st.image(image)
 
 
 
6
  import pandas as pd
7
  import requests
8
  import jax
9
+ import gc
10
+
11
 
12
  def app():
13
  st.title("From Image to Text")
 
29
 
30
  image_url = st.text_input(
31
  "You can input the URL of an image",
32
+ value="https://www.petdetective.it/wp-content/uploads/2016/04/gatto-toilette.jpg",
33
+ )
34
 
35
  MAX_CAP = 4
36
 
 
69
  image_embed = image_encoder(transform(image), model)
70
 
71
  # we could have a softmax here
72
+ cos_similarities = jax.nn.softmax(
73
+ jnp.matmul(image_embed, text_embeds.T)
74
+ )
75
 
76
  chart_data = pd.Series(cos_similarities[0], index=captions)
77
 
 
81
 
82
  with col2:
83
  st.image(image)
84
+
85
+ gc.collect()
text2image.py CHANGED
@@ -3,6 +3,7 @@ import os
3
  import requests
4
  import zipfile
5
  import natsort
 
6
 
7
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
8
  from stqdm import stqdm
@@ -144,13 +145,11 @@ def app():
144
  with col2:
145
  dataset_name = st.selectbox("IR dataset", ["Unsplash", "CC"])
146
 
147
- query = query if query else suggestions[sugg_idx] if sugg_idx > -1 else ""
148
 
149
  if query:
150
  with st.spinner("Computing..."):
151
 
152
- model = get_model()
153
-
154
  if dataset_name == "Unsplash":
155
  download_images()
156
 
@@ -173,3 +172,7 @@ def app():
173
  )
174
 
175
  st.image(image_paths)
 
 
 
 
 
3
  import requests
4
  import zipfile
5
  import natsort
6
+ import gc
7
 
8
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
9
  from stqdm import stqdm
 
145
  with col2:
146
  dataset_name = st.selectbox("IR dataset", ["Unsplash", "CC"])
147
 
148
+ query = suggestions[sugg_idx] if sugg_idx > -1 else query if query else ""
149
 
150
  if query:
151
  with st.spinner("Computing..."):
152
 
 
 
153
  if dataset_name == "Unsplash":
154
  download_images()
155
 
 
172
  )
173
 
174
  st.image(image_paths)
175
+
176
+ gc.collect()
177
+
178
+ sugg_idx = -1