ksvmuralidhar commited on
Commit
bc12604
1 Parent(s): 7d6b7ab

Upload files

Browse files
Files changed (5) hide show
  1. Dockerfile +30 -0
  2. api.py +134 -0
  3. models/bart_en_summarizer.h5 +3 -0
  4. requirements.txt +8 -0
  5. scraper.py +58 -0
Dockerfile ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+ WORKDIR /code
3
+ COPY ./requirements.txt /code/requirements.txt
4
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
5
+ RUN apt update && apt install -y ffmpeg
6
+ RUN apt -y install wget
7
+ RUN apt -y install firefox-esr
8
+
9
+ RUN useradd -m -u 1000 user
10
+ USER user
11
+ ENV HOME=/home/user \
12
+ PATH=/home/user/.local/bin:$PATH \
13
+ GECKODRIVERURL=https://github.com/mozilla/geckodriver/releases/download/v0.34.0/geckodriver-v0.34.0-linux64.tar.gz \
14
+ GECKODRIVERFILENAME=geckodriver-v0.34.0-linux64.tar.gz
15
+
16
+
17
+ WORKDIR $HOME/app
18
+
19
+ COPY --chown=user . $HOME/app
20
+
21
+ RUN wget -P $HOME/app $GECKODRIVERURL
22
+ RUN tar --warning=no-file-changed -xzf $HOME/app/$GECKODRIVERFILENAME
23
+ RUN rm $HOME/app/$GECKODRIVERFILENAME
24
+
25
+ RUN chmod -x geckodriver
26
+
27
+ RUN ls -ltr
28
+
29
+ EXPOSE 7860
30
+ ENTRYPOINT ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "7860"]
api.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ from transformers import (BartTokenizerFast,
4
+ TFAutoModelForSeq2SeqLM)
5
+ import tensorflow as tf
6
+ from scraper import scrape_text
7
+ from fastapi import FastAPI, Response
8
+ from typing import List
9
+ from pydantic import BaseModel
10
+ import uvicorn
11
+ import json
12
+ import logging
13
+ import multiprocessing
14
+
15
+
16
+ os.environ['TF_USE_LEGACY_KERAS'] = "1"
17
+
18
+ SUMM_CHECKPOINT = "facebook/bart-base"
19
+ SUMM_INPUT_N_TOKENS = 400
20
+ SUMM_TARGET_N_TOKENS = 300
21
+
22
+
23
+ def load_summarizer_models():
24
+ summ_tokenizer = BartTokenizerFast.from_pretrained(SUMM_CHECKPOINT)
25
+ summ_model = TFAutoModelForSeq2SeqLM.from_pretrained(SUMM_CHECKPOINT)
26
+ summ_model.load_weights(os.path.join("models", "bart_en_summarizer.h5"), by_name=True)
27
+ logging.warning('Loaded summarizer models')
28
+ return summ_tokenizer, summ_model
29
+
30
+
31
+ def summ_preprocess(txt):
32
+ txt = re.sub(r'^By \. [\w\s]+ \. ', ' ', txt) # By . Ellie Zolfagharifard .
33
+ txt = re.sub(r'\d{1,2}\:\d\d [a-zA-Z]{3}', ' ', txt) # 10:30 EST
34
+ txt = re.sub(r'\d{1,2} [a-zA-Z]+ \d{4}', ' ', txt) # 10 November 1990
35
+ txt = txt.replace('PUBLISHED:', ' ')
36
+ txt = txt.replace('UPDATED', ' ')
37
+ txt = re.sub(r' [\,\.\:\'\;\|] ', ' ', txt) # remove puncts with spaces before and after
38
+ txt = txt.replace(' : ', ' ')
39
+ txt = txt.replace('(CNN)', ' ')
40
+ txt = txt.replace('--', ' ')
41
+ txt = re.sub(r'^\s*[\,\.\:\'\;\|]', ' ', txt) # remove puncts at beginning of sent
42
+ txt = re.sub(r' [\,\.\:\'\;\|] ', ' ', txt) # remove puncts with spaces before and after
43
+ txt = re.sub(r'\n+',' ', txt)
44
+ txt = " ".join(txt.split())
45
+ return txt
46
+
47
+
48
+ def summ_inference_tokenize(input_: list, n_tokens: int):
49
+ tokenized_data = summ_tokenizer(text=input_, max_length=SUMM_TARGET_N_TOKENS, truncation=True, padding="max_length", return_tensors="tf")
50
+ return summ_tokenizer, tokenized_data
51
+
52
+
53
+ def summ_inference(txts: str):
54
+ txts = [*map(summ_preprocess, txts)]
55
+ inference_tokenizer, tokenized_data = summ_inference_tokenize(input_=txts, n_tokens=SUMM_INPUT_N_TOKENS)
56
+ pred = summ_model.generate(**tokenized_data, max_new_tokens=SUMM_TARGET_N_TOKENS)
57
+ result = ["" if t=="" else inference_tokenizer.decode(p, skip_special_tokens=True).strip() for t, p in zip(txts, pred)]
58
+ return result
59
+
60
+ # def scrape_multi_process(urls):
61
+ # logging.warning('Entering get_news_multi_process() to extract new news articles')
62
+ # '''
63
+ # Get the data shape by parallely calculating lenght of each chunk and
64
+ # aggregating them to get lenght of complete training dataset
65
+ # '''
66
+ # pool = multiprocessing.Pool(processes=multiprocessing.cpu_count())
67
+
68
+ # results = []
69
+ # for url in urls:
70
+ # f = pool.apply_async(scrape_text, [url]) # asynchronously applying function to chunk. Each worker parallely begins to work on the job
71
+ # results.append(f) # appending result to results
72
+
73
+ # scraped_texts = []
74
+ # for f in results:
75
+ # scraped_texts.append(f.get(timeout=120))
76
+ # pool.close()
77
+ # pool.join()
78
+ # logging.warning('Exiting scrape_multi_process()')
79
+ # return scraped_texts
80
+
81
+ def scrape_urls(urls):
82
+ scraped_texts = []
83
+ scrape_errors = []
84
+ for url in urls:
85
+ text, err = scrape_text(url)
86
+ scraped_texts.append(text)
87
+ scrape_errors.append(err)
88
+ return scraped_texts, scrape_errors
89
+
90
+ ##### API #####
91
+ app = FastAPI()
92
+ summ_tokenizer, summ_model = load_summarizer_models()
93
+
94
+ class URLList(BaseModel):
95
+ urls: List[str]
96
+ key: str
97
+
98
+
99
+ class NewsSummarizerAPIAuthenticationError(Exception):
100
+ pass
101
+
102
+
103
+ def authenticate_key(api_key: str):
104
+ if api_key != os.getenv('API_KEY'):
105
+ raise NewsSummarizerAPIAuthenticationError("Authentication error: Invalid API key.")
106
+
107
+ @app.post("/generate_summary/")
108
+ async def read_items(q: URLList):
109
+ try:
110
+ urls = ""
111
+ scraped_texts = ""
112
+ scrape_errors = ""
113
+ summaries = ""
114
+ request_json = q.json()
115
+ request_json = json.loads(request_json)
116
+ urls = request_json['urls']
117
+ api_key = request_json['key']
118
+ _ = authenticate_key(api_key)
119
+ scraped_texts, scrape_errors = scrape_urls(urls)
120
+ summaries = summ_inference(scraped_texts)
121
+ status_code = 200
122
+ response_json = {'urls': urls, 'scraped_texts': scraped_texts, 'scrape_errors': scrape_errors, 'summaries': summaries, 'summarizer_error': ''}
123
+ except Exception as e:
124
+ status_code = 500
125
+ if e.__class__.__name__ == "NewsSummarizerAPIAuthenticationError":
126
+ status_code = 401
127
+ response_json = {'urls': urls, 'scraped_texts': scraped_texts, 'scrape_errors': scrape_errors, 'summaries': "", 'summarizer_error': f'error: {e}'}
128
+
129
+ json_str = json.dumps(response_json, indent=5) # convert dict to JSON str
130
+ return Response(content=json_str, media_type='application/json', status_code=status_code)
131
+
132
+
133
+ if __name__ == '__main__':
134
+ uvicorn.run(app=app, host='0.0.0.0', port=7860)
models/bart_en_summarizer.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e6381d18af41ddc2b674cde800281c5eb65ece6f0c964ab5a0e5f20b362d801
3
+ size 558172300
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ transformers==4.39.3
2
+ tensorflow==2.15.0
3
+ unidecode
4
+ tf-keras==2.15.0
5
+ selenium==4.19.0
6
+ fastapi
7
+ pydantic
8
+ uvicorn
scraper.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from selenium import webdriver
2
+ from selenium.webdriver.common.by import By
3
+ from selenium.webdriver import FirefoxOptions
4
+ import re
5
+ import logging
6
+ import os
7
+
8
+
9
+ def get_text(url, n_words=15):
10
+ try:
11
+ # geckodriver_path ='/home/user/app/geckodriver'
12
+ # os.environ['PATH'] += ':' + geckodriver_path
13
+ # os.environ['SELENIUM_DRIVER_CAPABILITIES'] = '{"alwaysLoadVcDrivers": true}'
14
+ driver = None
15
+ logging.warning(f"Initiated Scraping {url}")
16
+ user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"
17
+ opts = FirefoxOptions()
18
+ opts.add_argument("--headless")
19
+ opts.add_argument(f"user-agent={user_agent}")
20
+ # opts.binary = geckodriver_path
21
+ # webdriver.firefox.driver = geckodriver_path
22
+ driver = webdriver.Firefox(options=opts)
23
+ driver.set_page_load_timeout(50)
24
+ driver.get(url)
25
+ elem = driver.find_element(By.TAG_NAME, "body").text
26
+ sents = elem.split("\n")
27
+ sentence_list = []
28
+ for sent in sents:
29
+ sent = sent.strip()
30
+ if (len(sent.split()) >= n_words) and (len(re.findall(r"^\w.+[^\w\)\s]$", sent))>0):
31
+ sentence_list.append(sent)
32
+ driver.quit()
33
+ logging.warning("Closed Webdriver")
34
+ logging.warning("Successfully scraped text")
35
+ if len(sentence_list) < 3:
36
+ raise Exception("Found nothing to scrape.")
37
+ return "\n".join(sentence_list), ""
38
+ except Exception as e:
39
+ logging.warning(str(e))
40
+ if driver:
41
+ driver.close()
42
+ logging.warning("Closed Webdriver")
43
+ err_msg = str(e).split('\n')[0]
44
+ return "", err_msg
45
+
46
+
47
+ def scrape_text(url, n_words=15,max_retries=2):
48
+ scraped_text = ""
49
+ scrape_error = ""
50
+ try:
51
+ n_tries = 1
52
+ while (n_tries <= max_retries) and (scraped_text == ""):
53
+ scraped_text, scrape_error = get_text(url=url, n_words=n_words)
54
+ n_tries += 1
55
+ return scraped_text, scrape_error
56
+ except Exception as e:
57
+ err_msg = str(e).split('\n')[0]
58
+ return "", err_msg