Spaces:
Running
Running
move back to streamlit 1.2.0
Browse files- image2text.py +8 -14
- utils.py +2 -3
image2text.py
CHANGED
@@ -3,6 +3,7 @@ from text2image import get_model, get_tokenizer, get_image_transform
|
|
3 |
from utils import text_encoder, image_encoder
|
4 |
from PIL import Image
|
5 |
from jax import numpy as jnp
|
|
|
6 |
import pandas as pd
|
7 |
import requests
|
8 |
import jax
|
@@ -29,7 +30,7 @@ def app():
|
|
29 |
|
30 |
image_url = st.text_input(
|
31 |
"You can input the URL of an image",
|
32 |
-
value="https://upload.wikimedia.org/wikipedia/commons/
|
33 |
)
|
34 |
|
35 |
MAX_CAP = 4
|
@@ -59,17 +60,13 @@ def app():
|
|
59 |
|
60 |
text_embeds = list()
|
61 |
for i, c in enumerate(captions):
|
62 |
-
text_embeds.extend(text_encoder(c, model, tokenizer))
|
63 |
|
64 |
text_embeds = jnp.array(text_embeds)
|
65 |
-
|
66 |
-
|
67 |
-
stream=True,
|
68 |
-
).raw
|
69 |
-
|
70 |
-
image = Image.open(image_raw).convert("RGB")
|
71 |
transform = get_image_transform(model.config.vision_config.image_size)
|
72 |
-
image_embed = image_encoder(transform(image), model)
|
73 |
|
74 |
# we could have a softmax here
|
75 |
cos_similarities = jax.nn.softmax(
|
@@ -87,9 +84,6 @@ def app():
|
|
87 |
gc.collect()
|
88 |
|
89 |
elif image_url:
|
90 |
-
|
91 |
-
|
92 |
-
stream=True,
|
93 |
-
).raw
|
94 |
-
image = Image.open(image_raw).convert("RGB")
|
95 |
st.image(image)
|
|
|
3 |
from utils import text_encoder, image_encoder
|
4 |
from PIL import Image
|
5 |
from jax import numpy as jnp
|
6 |
+
from io import BytesIO
|
7 |
import pandas as pd
|
8 |
import requests
|
9 |
import jax
|
|
|
30 |
|
31 |
image_url = st.text_input(
|
32 |
"You can input the URL of an image",
|
33 |
+
value="https://upload.wikimedia.org/wikipedia/commons/thumb/8/88/Ragdoll%2C_blue_mitted.JPG/1280px-Ragdoll%2C_blue_mitted.JPG",
|
34 |
)
|
35 |
|
36 |
MAX_CAP = 4
|
|
|
60 |
|
61 |
text_embeds = list()
|
62 |
for i, c in enumerate(captions):
|
63 |
+
text_embeds.extend(text_encoder(c, model, tokenizer)[0])
|
64 |
|
65 |
text_embeds = jnp.array(text_embeds)
|
66 |
+
response = requests.get(image_url)
|
67 |
+
image = Image.open(BytesIO(response.content)).convert("RGB")
|
|
|
|
|
|
|
|
|
68 |
transform = get_image_transform(model.config.vision_config.image_size)
|
69 |
+
image_embed, _ = image_encoder(transform(image), model)
|
70 |
|
71 |
# we could have a softmax here
|
72 |
cos_similarities = jax.nn.softmax(
|
|
|
84 |
gc.collect()
|
85 |
|
86 |
elif image_url:
|
87 |
+
response = requests.get(image_url)
|
88 |
+
image = Image.open(BytesIO(response.content)).convert("RGB")
|
|
|
|
|
|
|
89 |
st.image(image)
|
utils.py
CHANGED
@@ -48,7 +48,7 @@ def image_encoder(image, model):
|
|
48 |
features = model.get_image_features(image,)
|
49 |
norms = jnp.linalg.norm(features, axis=-1, keepdims=True)
|
50 |
features = features / norms
|
51 |
-
return features
|
52 |
|
53 |
|
54 |
def precompute_image_features(model, loader):
|
@@ -62,8 +62,7 @@ def precompute_image_features(model, loader):
|
|
62 |
|
63 |
|
64 |
def find_image(text_query, model, dataset, tokenizer, image_features, n, dataset_name):
|
65 |
-
zeroshot_weights = text_encoder(text_query, model, tokenizer)
|
66 |
-
zeroshot_weights /= jnp.linalg.norm(zeroshot_weights)
|
67 |
distances = jnp.dot(image_features, zeroshot_weights.reshape(-1, 1))
|
68 |
file_paths = []
|
69 |
for i in range(1, n + 1):
|
|
|
48 |
features = model.get_image_features(image,)
|
49 |
norms = jnp.linalg.norm(features, axis=-1, keepdims=True)
|
50 |
features = features / norms
|
51 |
+
return features, norms
|
52 |
|
53 |
|
54 |
def precompute_image_features(model, loader):
|
|
|
62 |
|
63 |
|
64 |
def find_image(text_query, model, dataset, tokenizer, image_features, n, dataset_name):
|
65 |
+
zeroshot_weights, _ = text_encoder(text_query, model, tokenizer)
|
|
|
66 |
distances = jnp.dot(image_features, zeroshot_weights.reshape(-1, 1))
|
67 |
file_paths = []
|
68 |
for i in range(1, n + 1):
|