Spaces:
Running
Running
Add gc.collect
Browse files- image2text.py +9 -3
- text2image.py +6 -3
image2text.py
CHANGED
@@ -6,6 +6,8 @@ from jax import numpy as jnp
|
|
6 |
import pandas as pd
|
7 |
import requests
|
8 |
import jax
|
|
|
|
|
9 |
|
10 |
def app():
|
11 |
st.title("From Image to Text")
|
@@ -27,8 +29,8 @@ def app():
|
|
27 |
|
28 |
image_url = st.text_input(
|
29 |
"You can input the URL of an image",
|
30 |
-
value="https://www.petdetective.it/wp-content/uploads/2016/04/gatto-toilette.jpg"
|
31 |
-
|
32 |
|
33 |
MAX_CAP = 4
|
34 |
|
@@ -67,7 +69,9 @@ def app():
|
|
67 |
image_embed = image_encoder(transform(image), model)
|
68 |
|
69 |
# we could have a softmax here
|
70 |
-
cos_similarities = jax.nn.softmax(
|
|
|
|
|
71 |
|
72 |
chart_data = pd.Series(cos_similarities[0], index=captions)
|
73 |
|
@@ -77,3 +81,5 @@ def app():
|
|
77 |
|
78 |
with col2:
|
79 |
st.image(image)
|
|
|
|
|
|
6 |
import pandas as pd
|
7 |
import requests
|
8 |
import jax
|
9 |
+
import gc
|
10 |
+
|
11 |
|
12 |
def app():
|
13 |
st.title("From Image to Text")
|
|
|
29 |
|
30 |
image_url = st.text_input(
|
31 |
"You can input the URL of an image",
|
32 |
+
value="https://www.petdetective.it/wp-content/uploads/2016/04/gatto-toilette.jpg",
|
33 |
+
)
|
34 |
|
35 |
MAX_CAP = 4
|
36 |
|
|
|
69 |
image_embed = image_encoder(transform(image), model)
|
70 |
|
71 |
# we could have a softmax here
|
72 |
+
cos_similarities = jax.nn.softmax(
|
73 |
+
jnp.matmul(image_embed, text_embeds.T)
|
74 |
+
)
|
75 |
|
76 |
chart_data = pd.Series(cos_similarities[0], index=captions)
|
77 |
|
|
|
81 |
|
82 |
with col2:
|
83 |
st.image(image)
|
84 |
+
|
85 |
+
gc.collect()
|
text2image.py
CHANGED
@@ -3,6 +3,7 @@ import os
|
|
3 |
import requests
|
4 |
import zipfile
|
5 |
import natsort
|
|
|
6 |
|
7 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
8 |
from stqdm import stqdm
|
@@ -144,13 +145,11 @@ def app():
|
|
144 |
with col2:
|
145 |
dataset_name = st.selectbox("IR dataset", ["Unsplash", "CC"])
|
146 |
|
147 |
-
query =
|
148 |
|
149 |
if query:
|
150 |
with st.spinner("Computing..."):
|
151 |
|
152 |
-
model = get_model()
|
153 |
-
|
154 |
if dataset_name == "Unsplash":
|
155 |
download_images()
|
156 |
|
@@ -173,3 +172,7 @@ def app():
|
|
173 |
)
|
174 |
|
175 |
st.image(image_paths)
|
|
|
|
|
|
|
|
|
|
3 |
import requests
|
4 |
import zipfile
|
5 |
import natsort
|
6 |
+
import gc
|
7 |
|
8 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
9 |
from stqdm import stqdm
|
|
|
145 |
with col2:
|
146 |
dataset_name = st.selectbox("IR dataset", ["Unsplash", "CC"])
|
147 |
|
148 |
+
query = suggestions[sugg_idx] if sugg_idx > -1 else query if query else ""
|
149 |
|
150 |
if query:
|
151 |
with st.spinner("Computing..."):
|
152 |
|
|
|
|
|
153 |
if dataset_name == "Unsplash":
|
154 |
download_images()
|
155 |
|
|
|
172 |
)
|
173 |
|
174 |
st.image(image_paths)
|
175 |
+
|
176 |
+
gc.collect()
|
177 |
+
|
178 |
+
sugg_idx = -1
|