Spaces:
Runtime error
Runtime error
from dataclasses import dataclass | |
from typing import Any | |
import japanese_clip as ja_clip | |
from s3_session import Bucket | |
from PIL import Image | |
import uuid | |
from db_session import get_db | |
class MLModel: | |
tokenizer: Any = None | |
model: Any = None | |
preprocess: Any = None | |
bucket: Any = None | |
def __post_init__(self): | |
tokenizer = ja_clip.load_tokenizer() | |
model, preprocess = ja_clip.load( | |
"rinna/japanese-clip-vit-b-16", cache_dir="/tmp/japanese_clip", device="cpu" | |
) | |
self.tokenizer = tokenizer | |
self.model = model | |
self.preprocess = preprocess | |
self.bucket = Bucket() | |
def save(self, image_path: str): | |
pillow_iamge = Image.open(image_path) | |
image = self.preprocess(pillow_iamge).unsqueeze(0).to("cpu") | |
image_features = self.model.get_image_features(image) | |
image_uuid = str(uuid.uuid4()) | |
# media upload | |
self.bucket.upload_file(pillow_iamge, image_uuid) | |
# db insert | |
db = get_db() | |
result = db["embedding"].insert_one( | |
{"uuid": image_uuid, "vectorField": image_features[0].tolist()} | |
) | |
return result.inserted_id | |
def search(self, prompt: str): | |
db = get_db() | |
encodings = ja_clip.tokenize( | |
texts=[prompt], max_seq_len=77, device="cpu", tokenizer=self.tokenizer | |
) | |
text_features = self.model.get_text_features(**encodings) | |
pipeline = [ | |
{ | |
"$vectorSearch": { | |
"index": "vector_index", | |
"path": "vectorField", | |
"queryVector": text_features[0].tolist(), | |
"numCandidates": 150, | |
"limit": 10, | |
} | |
}, | |
{ | |
"$project": { | |
"_id": {"$toString": "$_id"}, | |
"uuid": 1, | |
"score": {"$meta": "vectorSearchScore"}, | |
} | |
}, | |
] | |
result = db["embedding"].aggregate(pipeline) | |
urls = [self.bucket.get_presigned_url(x["uuid"]) for x in result] | |
return urls | |