Spaces:
Sleeping
Sleeping
import re | |
import os | |
from transformers import (BartTokenizerFast, | |
TFAutoModelForSeq2SeqLM) | |
import tensorflow as tf | |
from scraper import scrape_text | |
from fastapi import FastAPI, Response | |
from typing import List | |
from pydantic import BaseModel | |
import uvicorn | |
import json | |
import logging | |
import multiprocessing | |
os.environ['TF_USE_LEGACY_KERAS'] = "1" | |
SUMM_CHECKPOINT = "facebook/bart-base" | |
SUMM_INPUT_N_TOKENS = 400 | |
SUMM_TARGET_N_TOKENS = 300 | |
def load_summarizer_models(): | |
summ_tokenizer = BartTokenizerFast.from_pretrained(SUMM_CHECKPOINT) | |
summ_model = TFAutoModelForSeq2SeqLM.from_pretrained(SUMM_CHECKPOINT) | |
summ_model.load_weights(os.path.join("models", "bart_en_summarizer.h5"), by_name=True) | |
logging.warning('Loaded summarizer models') | |
return summ_tokenizer, summ_model | |
def summ_preprocess(txt): | |
txt = re.sub(r'^By \. [\w\s]+ \. ', ' ', txt) # By . Ellie Zolfagharifard . | |
txt = re.sub(r'\d{1,2}\:\d\d [a-zA-Z]{3}', ' ', txt) # 10:30 EST | |
txt = re.sub(r'\d{1,2} [a-zA-Z]+ \d{4}', ' ', txt) # 10 November 1990 | |
txt = txt.replace('PUBLISHED:', ' ') | |
txt = txt.replace('UPDATED', ' ') | |
txt = re.sub(r' [\,\.\:\'\;\|] ', ' ', txt) # remove puncts with spaces before and after | |
txt = txt.replace(' : ', ' ') | |
txt = txt.replace('(CNN)', ' ') | |
txt = txt.replace('--', ' ') | |
txt = re.sub(r'^\s*[\,\.\:\'\;\|]', ' ', txt) # remove puncts at beginning of sent | |
txt = re.sub(r' [\,\.\:\'\;\|] ', ' ', txt) # remove puncts with spaces before and after | |
txt = re.sub(r'\n+',' ', txt) | |
txt = " ".join(txt.split()) | |
return txt | |
def summ_inference_tokenize(input_: list, n_tokens: int): | |
tokenized_data = summ_tokenizer(text=input_, max_length=SUMM_TARGET_N_TOKENS, truncation=True, padding="max_length", return_tensors="tf") | |
return summ_tokenizer, tokenized_data | |
def summ_inference(txts: str): | |
txts = [*map(summ_preprocess, txts)] | |
inference_tokenizer, tokenized_data = summ_inference_tokenize(input_=txts, n_tokens=SUMM_INPUT_N_TOKENS) | |
pred = summ_model.generate(**tokenized_data, max_new_tokens=SUMM_TARGET_N_TOKENS) | |
result = ["" if t=="" else inference_tokenizer.decode(p, skip_special_tokens=True).strip() for t, p in zip(txts, pred)] | |
return result | |
# def scrape_multi_process(urls): | |
# logging.warning('Entering get_news_multi_process() to extract new news articles') | |
# ''' | |
# Get the data shape by parallely calculating lenght of each chunk and | |
# aggregating them to get lenght of complete training dataset | |
# ''' | |
# pool = multiprocessing.Pool(processes=multiprocessing.cpu_count()) | |
# results = [] | |
# for url in urls: | |
# f = pool.apply_async(scrape_text, [url]) # asynchronously applying function to chunk. Each worker parallely begins to work on the job | |
# results.append(f) # appending result to results | |
# scraped_texts = [] | |
# for f in results: | |
# scraped_texts.append(f.get(timeout=120)) | |
# pool.close() | |
# pool.join() | |
# logging.warning('Exiting scrape_multi_process()') | |
# return scraped_texts | |
def scrape_urls(urls): | |
scraped_texts = [] | |
scrape_errors = [] | |
for url in urls: | |
text, err = scrape_text(url) | |
scraped_texts.append(text) | |
scrape_errors.append(err) | |
return scraped_texts, scrape_errors | |
##### API ##### | |
app = FastAPI() | |
summ_tokenizer, summ_model = load_summarizer_models() | |
class URLList(BaseModel): | |
urls: List[str] | |
key: str | |
class NewsSummarizerAPIAuthenticationError(Exception): | |
pass | |
def authenticate_key(api_key: str): | |
if api_key != os.getenv('API_KEY'): | |
raise NewsSummarizerAPIAuthenticationError("Authentication error: Invalid API key.") | |
async def read_items(q: URLList): | |
try: | |
urls = "" | |
scraped_texts = "" | |
scrape_errors = "" | |
summaries = "" | |
request_json = q.json() | |
request_json = json.loads(request_json) | |
urls = request_json['urls'] | |
api_key = request_json['key'] | |
_ = authenticate_key(api_key) | |
scraped_texts, scrape_errors = scrape_urls(urls) | |
summaries = summ_inference(scraped_texts) | |
status_code = 200 | |
response_json = {'urls': urls, 'scraped_texts': scraped_texts, 'scrape_errors': scrape_errors, 'summaries': summaries, 'summarizer_error': ''} | |
except Exception as e: | |
status_code = 500 | |
if e.__class__.__name__ == "NewsSummarizerAPIAuthenticationError": | |
status_code = 401 | |
response_json = {'urls': urls, 'scraped_texts': scraped_texts, 'scrape_errors': scrape_errors, 'summaries': "", 'summarizer_error': f'error: {e}'} | |
json_str = json.dumps(response_json, indent=5) # convert dict to JSON str | |
return Response(content=json_str, media_type='application/json', status_code=status_code) | |
if __name__ == '__main__': | |
uvicorn.run(app=app, host='0.0.0.0', port=7860) |