4rtemi5 commited on
Commit
90de990
1 Parent(s): 753d26f

move back to streamlit 1.2.0

Browse files
Files changed (2) hide show
  1. image2text.py +8 -14
  2. utils.py +2 -3
image2text.py CHANGED
@@ -3,6 +3,7 @@ from text2image import get_model, get_tokenizer, get_image_transform
3
  from utils import text_encoder, image_encoder
4
  from PIL import Image
5
  from jax import numpy as jnp
 
6
  import pandas as pd
7
  import requests
8
  import jax
@@ -29,7 +30,7 @@ def app():
29
 
30
  image_url = st.text_input(
31
  "You can input the URL of an image",
32
- value="https://upload.wikimedia.org/wikipedia/commons/b/bc/Juvenile_Ragdoll.jpg",
33
  )
34
 
35
  MAX_CAP = 4
@@ -59,17 +60,13 @@ def app():
59
 
60
  text_embeds = list()
61
  for i, c in enumerate(captions):
62
- text_embeds.extend(text_encoder(c, model, tokenizer))
63
 
64
  text_embeds = jnp.array(text_embeds)
65
- image_raw = requests.get(
66
- image_url,
67
- stream=True,
68
- ).raw
69
-
70
- image = Image.open(image_raw).convert("RGB")
71
  transform = get_image_transform(model.config.vision_config.image_size)
72
- image_embed = image_encoder(transform(image), model)
73
 
74
  # we could have a softmax here
75
  cos_similarities = jax.nn.softmax(
@@ -87,9 +84,6 @@ def app():
87
  gc.collect()
88
 
89
  elif image_url:
90
- image_raw = requests.get(
91
- image_url,
92
- stream=True,
93
- ).raw
94
- image = Image.open(image_raw).convert("RGB")
95
  st.image(image)
 
3
  from utils import text_encoder, image_encoder
4
  from PIL import Image
5
  from jax import numpy as jnp
6
+ from io import BytesIO
7
  import pandas as pd
8
  import requests
9
  import jax
 
30
 
31
  image_url = st.text_input(
32
  "You can input the URL of an image",
33
+ value="https://upload.wikimedia.org/wikipedia/commons/thumb/8/88/Ragdoll%2C_blue_mitted.JPG/1280px-Ragdoll%2C_blue_mitted.JPG",
34
  )
35
 
36
  MAX_CAP = 4
 
60
 
61
  text_embeds = list()
62
  for i, c in enumerate(captions):
63
+ text_embeds.extend(text_encoder(c, model, tokenizer)[0])
64
 
65
  text_embeds = jnp.array(text_embeds)
66
+ response = requests.get(image_url)
67
+ image = Image.open(BytesIO(response.content)).convert("RGB")
 
 
 
 
68
  transform = get_image_transform(model.config.vision_config.image_size)
69
+ image_embed, _ = image_encoder(transform(image), model)
70
 
71
  # we could have a softmax here
72
  cos_similarities = jax.nn.softmax(
 
84
  gc.collect()
85
 
86
  elif image_url:
87
+ response = requests.get(image_url)
88
+ image = Image.open(BytesIO(response.content)).convert("RGB")
 
 
 
89
  st.image(image)
utils.py CHANGED
@@ -48,7 +48,7 @@ def image_encoder(image, model):
48
  features = model.get_image_features(image,)
49
  norms = jnp.linalg.norm(features, axis=-1, keepdims=True)
50
  features = features / norms
51
- return features
52
 
53
 
54
  def precompute_image_features(model, loader):
@@ -62,8 +62,7 @@ def precompute_image_features(model, loader):
62
 
63
 
64
  def find_image(text_query, model, dataset, tokenizer, image_features, n, dataset_name):
65
- zeroshot_weights = text_encoder(text_query, model, tokenizer)
66
- zeroshot_weights /= jnp.linalg.norm(zeroshot_weights)
67
  distances = jnp.dot(image_features, zeroshot_weights.reshape(-1, 1))
68
  file_paths = []
69
  for i in range(1, n + 1):
 
48
  features = model.get_image_features(image,)
49
  norms = jnp.linalg.norm(features, axis=-1, keepdims=True)
50
  features = features / norms
51
+ return features, norms
52
 
53
 
54
  def precompute_image_features(model, loader):
 
62
 
63
 
64
  def find_image(text_query, model, dataset, tokenizer, image_features, n, dataset_name):
65
+ zeroshot_weights, _ = text_encoder(text_query, model, tokenizer)
 
66
  distances = jnp.dot(image_features, zeroshot_weights.reshape(-1, 1))
67
  file_paths = []
68
  for i in range(1, n + 1):