Spaces:
Build error
Build error
import argparse | |
import jax | |
import jax.numpy as jnp | |
import json | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import requests | |
import os | |
from PIL import Image | |
from transformers import CLIPProcessor, FlaxCLIPModel | |
def encode_image(image_file, model, processor): | |
image = Image.fromarray(plt.imread(os.path.join(IMAGES_DIR, image_file))) | |
inputs = processor(images=image, return_tensors="jax") | |
image_vec = model.get_image_features(**inputs) | |
return np.array(image_vec).reshape(-1) | |
DATA_DIR = "/home/shared/data" | |
IMAGES_DIR = os.path.join(DATA_DIR, "rsicd_images") | |
CAPTIONS_FILE = os.path.join(DATA_DIR, "dataset_rsicd.json") | |
VECTORS_DIR = os.path.join(DATA_DIR, "vectors") | |
BASELINE_MODEL = "openai/clip-vit-base-patch32" | |
parser = argparse.ArgumentParser() | |
parser.add_argument("model_dir", help="Path to model to use for encoding") | |
args = parser.parse_args() | |
print("Loading image list...", end="") | |
image2captions = {} | |
with open(CAPTIONS_FILE, "r") as fcap: | |
data = json.loads(fcap.read()) | |
for image in data["images"]: | |
if image["split"] == "test": | |
filename = image["filename"] | |
sentences = [] | |
for sentence in image["sentences"]: | |
sentences.append(sentence["raw"]) | |
image2captions[filename] = sentences | |
print("{:d} images".format(len(image2captions))) | |
print("Loading model...") | |
if args.model_dir == "baseline": | |
model = FlaxCLIPModel.from_pretrained(BASELINE_MODEL) | |
else: | |
model = FlaxCLIPModel.from_pretrained(args.model_dir) | |
processor = CLIPProcessor.from_pretrained(BASELINE_MODEL) | |
model_basename = "-".join(args.model_dir.split("/")[-2:]) | |
vector_file = os.path.join(VECTORS_DIR, "test-{:s}.tsv".format(model_basename)) | |
print("Vectors written to {:s}".format(vector_file)) | |
num_written = 0 | |
fvec = open(vector_file, "w") | |
for image_file in image2captions.keys(): | |
if num_written % 100 == 0: | |
print("{:d} images processed".format(num_written)) | |
image_vec = encode_image(image_file, model, processor) | |
image_vec_s = ",".join(["{:.7e}".format(x) for x in image_vec]) | |
fvec.write("{:s}\t{:s}\n".format(image_file, image_vec_s)) | |
num_written += 1 | |
print("{:d} images processed, COMPLETE".format(num_written)) | |
fvec.close() | |