Spaces:
No application file
No application file
File size: 532 Bytes
d137e33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
import torch
import numpy as np
from pyserini.search import QueryEncoder
from sentence_transformers import SentenceTransformer
class SentenceTransformerEncoder(QueryEncoder):
def __init__(self, model_name: str, device: str = 'cpu'):
self.device = torch.device(device)
self.model = SentenceTransformer(model_name, device=self.device)
def encode(self, query: str):
emb = self.model.encode(query)
emb = emb / np.linalg.norm(emb)
# emb = np.expand_dims(emb, axis=0)
return emb
|