from fastapi import FastAPI from pydantic import BaseModel from fastapi.responses import JSONResponse from transformers import AutoTokenizer,AutoModelForSeq2SeqLM import nltk from nltk.tokenize import sent_tokenize nltk.download('punkt_tab') import torch app = FastAPI() class Sentence(BaseModel): sentence: str model_name = 'Vamsi/T5_Paraphrase_Paws' # model_name = 'google/flan-t5-large' tokenizer = AutoTokenizer.from_pretrained(model_name,clean_up_tokenization_spaces=True) model = AutoModelForSeq2SeqLM.from_pretrained(model_name) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("--"*10,device,"--"*10) model.to(device) def my_rephrase(sentence): sentence = "paraphrase: " + sentence encoding = tokenizer.encode_plus(sentence, padding=True, return_tensors="pt") input_ids = encoding['input_ids'] outputs = model.generate( input_ids=input_ids, max_length=256, num_beams=2, do_sample=True, top_k=120, top_p=0.95, early_stopping=True, num_return_sequences=1 ) output = tokenizer.decode(outputs[0], skip_special_tokens=True) return output @app.get("/") def index(): return "OK" @app.post("/rephrase") def rephrase(sentence: Sentence): try: rephrase_text = " ".join([my_rephrase(sent) for sent in sent_tokenize(sentence.sentence) ]) return JSONResponse(status_code=200, content={"rephrased_sentence": rephrase_text}) except Exception as e: print(e) return JSONResponse(status_code=422,content="Unable to rephrase")