File size: 532 Bytes
d137e33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch
import numpy as np
from pyserini.search import QueryEncoder
from sentence_transformers import SentenceTransformer


class SentenceTransformerEncoder(QueryEncoder):
    def __init__(self, model_name: str, device: str = 'cpu'):
        self.device = torch.device(device)
        self.model = SentenceTransformer(model_name, device=self.device)

    def encode(self, query: str):
        emb = self.model.encode(query)
        emb = emb / np.linalg.norm(emb)
        # emb = np.expand_dims(emb, axis=0)
        return emb