|
from transformers import Pipeline |
|
from sentence_transformers import SentenceTransformer |
|
import torch |
|
|
|
|
|
class SentimentModelPipe(Pipeline): |
|
def __init__(self, **kwargs): |
|
Pipeline.__init__(self, **kwargs) |
|
self.smodel = SentenceTransformer( |
|
kwargs.get("embedding_model", "sentence-transformers/all-MiniLM-L6-v2") |
|
) |
|
self.class_map = kwargs.get( |
|
"class_map", |
|
{0: "sad", 1: "joy", 2: "love", 3: "anger", 4: "fear", 5: "surprise"}, |
|
) |
|
|
|
def _sanitize_parameters(self, **kw): |
|
return {}, {}, {} |
|
|
|
def preprocess(self, inputs): |
|
return self.smodel.encode(inputs, convert_to_tensor=True) |
|
|
|
def postprocess(self, outputs): |
|
results = [] |
|
for i, l in enumerate(outputs): |
|
results.append({"label": self.class_map[i], "score": l.item()}) |
|
return results |
|
|
|
def _forward(self, tensor): |
|
with torch.no_grad(): |
|
out = self.model(tensor) |
|
return out |
|
|