import re import os from transformers import (BartTokenizerFast, TFAutoModelForSeq2SeqLM) import tensorflow as tf from scraper import scrape_text from fastapi import FastAPI, Response, Request from typing import List from pydantic import BaseModel, Field from fastapi.exceptions import RequestValidationError import uvicorn import json import logging import multiprocessing os.environ['TF_USE_LEGACY_KERAS'] = "1" SUMM_CHECKPOINT = "facebook/bart-base" SUMM_INPUT_N_TOKENS = 400 SUMM_TARGET_N_TOKENS = 300 def load_summarizer_models(): summ_tokenizer = BartTokenizerFast.from_pretrained(SUMM_CHECKPOINT) summ_model = TFAutoModelForSeq2SeqLM.from_pretrained(SUMM_CHECKPOINT) summ_model.load_weights(os.path.join("models", "bart_en_summarizer.h5"), by_name=True) logging.warning('Loaded summarizer models') return summ_tokenizer, summ_model def summ_preprocess(txt): txt = re.sub(r'^By \. [\w\s]+ \. ', ' ', txt) # By . Ellie Zolfagharifard . txt = re.sub(r'\d{1,2}\:\d\d [a-zA-Z]{3}', ' ', txt) # 10:30 EST txt = re.sub(r'\d{1,2} [a-zA-Z]+ \d{4}', ' ', txt) # 10 November 1990 txt = txt.replace('PUBLISHED:', ' ') txt = txt.replace('UPDATED', ' ') txt = re.sub(r' [\,\.\:\'\;\|] ', ' ', txt) # remove puncts with spaces before and after txt = txt.replace(' : ', ' ') txt = txt.replace('(CNN)', ' ') txt = txt.replace('--', ' ') txt = re.sub(r'^\s*[\,\.\:\'\;\|]', ' ', txt) # remove puncts at beginning of sent txt = re.sub(r' [\,\.\:\'\;\|] ', ' ', txt) # remove puncts with spaces before and after txt = re.sub(r'\n+',' ', txt) txt = " ".join(txt.split()) return txt async def summ_inference_tokenize(input_: list, n_tokens: int): tokenized_data = summ_tokenizer(text=input_, max_length=SUMM_TARGET_N_TOKENS, truncation=True, padding="max_length", return_tensors="tf") return summ_tokenizer, tokenized_data async def summ_inference(txts: str): logging.warning("Entering summ_inference()") txts = [*map(summ_preprocess, txts)] inference_tokenizer, tokenized_data = await summ_inference_tokenize(input_=txts, n_tokens=SUMM_INPUT_N_TOKENS) pred = summ_model.generate(**tokenized_data, max_new_tokens=SUMM_TARGET_N_TOKENS) result = ["" if t=="" else inference_tokenizer.decode(p, skip_special_tokens=True).strip() for t, p in zip(txts, pred)] return result async def scrape_urls(urls): logging.warning('Entering scrape_urls()') pool = multiprocessing.Pool(processes=multiprocessing.cpu_count()) results = [] for url in urls: f = pool.apply_async(scrape_text, [url]) # asynchronously applying function to chunk. Each worker parallely begins to work on the job results.append(f) # appending result to results scraped_texts = [] scrape_errors = [] for f in results: t, e = f.get(timeout=120) scraped_texts.append(t) scrape_errors.append(e) pool.close() pool.join() logging.warning('Exiting scrape_urls()') return scraped_texts, scrape_errors description = "API to generate summaries of news articles from their URLs." app = FastAPI(title='News Summarizer API', description=description, version="0.0.1", contact={ "name": "Author: KSV Muralidhar", "url": "https://ksvmuralidhar.in" }, license_info={ "name": "License: MIT", "identifier": "MIT" }, swagger_ui_parameters={"defaultModelsExpandDepth": -1}) summ_tokenizer, summ_model = load_summarizer_models() class URLList(BaseModel): urls: List[str] = Field(..., description="List of URLs of news articles to generate summaries") key: str = Field(..., description="Authentication Key") class SuccessfulResponse(BaseModel): urls: List[str] = Field(..., description="List of URLs of news articles inputted by the user") scraped_texts: List[str] = Field(..., description="List of scraped text from input URLs") scrape_errors: List[str] = Field(..., description="List of errors raised during scraping. One item for corresponding URL") summaries: List[str] = Field(..., description="List of generated summaries of news articles") summarizer_error: str = Field("", description="Empty string as the response code is 200") class AuthenticationError(BaseModel): urls: List[str] = Field(..., description="List of URLs of news articles inputted by the user") scraped_texts: str = Field("", description="Empty string as authentication failed") scrape_errors: str = Field("", description="Empty string as authentication failed") summaries: str = Field("", description="Empty string as authentication failed") summarizer_error: str = Field("Error: Authentication error: Invalid API key.") class SummaryError(BaseModel): urls: List[str] = Field(..., description="List of URLs of news articles inputted by the user") scraped_texts: List[str] = Field(..., description="List of scraped text from input URLs") scrape_errors: List[str] = Field(..., description="List of errors raised during scraping. One item for corresponding URL") summaries: str = Field("", description="Empty string as summarizer encountered an error") summarizer_error: str = Field("Error: Summarizer Error with a message describing the error") class InputValidationError(BaseModel): urls: List[str] = Field(..., description="List of URLs of news articles inputted by the user") scraped_texts: str = Field("", description="Empty string as validation failed") scrape_errors: str = Field("", description="Empty string as validation failed") summaries: str = Field("", description="Empty string as validation failed") summarizer_error: str = Field("Validation Error with a message describing the error") class NewsSummarizerAPIAuthenticationError(Exception): pass class NewsSummarizerAPIScrapingError(Exception): pass def authenticate_key(api_key: str): if api_key != os.getenv('API_KEY'): raise NewsSummarizerAPIAuthenticationError("Authentication error: Invalid API key.") @app.exception_handler(RequestValidationError) async def validation_exception_handler(request: Request, exc: RequestValidationError): urls = request.query_params.getlist("urls") error_details = exc.errors() error_messages = [] for error in error_details: loc = [*map(str, error['loc'])][-1] msg = error['msg'] error_messages.append(f"{loc}: {msg}") error_message = "; ".join(error_messages) if error_messages else "" response_json = {'urls': urls, 'scraped_texts': '', 'scrape_errors': '', 'summaries': "", 'summarizer_error': f'Validation Error: {error_message}'} json_str = json.dumps(response_json, indent=5) # convert dict to JSON str return Response(content=json_str, media_type='application/json', status_code=422) @app.post("/generate_summary/", tags=["Generate Summary"], response_model=List[SuccessfulResponse], responses={ 401: {"model": AuthenticationError, "description": "Authentication Error: Returned when the entered API key is incorrect"}, 500: {"model": SummaryError, "description": "Summarizer Error: Returned when the API couldn't generate the summary of even a single article"}, 422: {"model": InputValidationError, "description": "Validation Error: Returned when the payload data doesn't match the data type requirements"} }) async def generate_summary(q: URLList): """ Get summaries of news articles by passing the list of URLs as input. - **urls**: List of URLs (required) - **key**: Authentication key (required) """ try: logging.warning("Entering generate_summary()") urls = "" scraped_texts = "" scrape_errors = "" summaries = "" request_json = q.json() request_json = json.loads(request_json) urls = request_json['urls'] api_key = request_json['key'] _ = authenticate_key(api_key) scraped_texts, scrape_errors = await scrape_urls(urls) unique_scraped_texts = [*set(scraped_texts)] if (unique_scraped_texts[0] == "") and (len(unique_scraped_texts) == 1): raise NewsSummarizerAPIScrapingError("Scrape Error: Couldn't scrape text from any of the URLs") summaries = await summ_inference(scraped_texts) status_code = 200 response_json = {'urls': urls, 'scraped_texts': scraped_texts, 'scrape_errors': scrape_errors, 'summaries': summaries, 'summarizer_error': ''} except Exception as e: status_code = 500 if e.__class__.__name__ == "NewsSummarizerAPIAuthenticationError": status_code = 401 response_json = {'urls': urls, 'scraped_texts': scraped_texts, 'scrape_errors': scrape_errors, 'summaries': "", 'summarizer_error': f'Error: {e}'} json_str = json.dumps(response_json, indent=5) # convert dict to JSON str return Response(content=json_str, media_type='application/json', status_code=status_code) if __name__ == '__main__': uvicorn.run(app=app, host='0.0.0.0', port=7860, workers=3)