shripadbhat commited on
Commit
c2e898c
1 Parent(s): 2a2e764

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -1
app.py CHANGED
@@ -1,7 +1,9 @@
1
  import gradio as gr
 
2
  from transformers import pipeline
3
  from sentence_transformers import CrossEncoder
4
 
 
5
  passage_retreival_model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
6
  qa_model = pipeline("question-answering",'a-ware/bart-squadv2')
7
 
@@ -17,9 +19,17 @@ def fetch_answers(question, clincal_note ):
17
  top_5_query_paragraph_answer_list = ""
18
  count = 1
19
  for query, passage in top_5_query_paragraph_list:
 
20
  answer = qa_model(question = query, context = passage)['answer']
 
 
 
 
 
 
 
21
  result_str = "# RESULT "+str(count)+"\n"
22
- result_str = result_str + passage.replace(answer,"**"+answer.replace('.','')+"**") + "\n\n"
23
  top_5_query_paragraph_answer_list += result_str
24
  count+=1
25
 
 
1
  import gradio as gr
2
+ import pysbd
3
  from transformers import pipeline
4
  from sentence_transformers import CrossEncoder
5
 
6
+ sentence_segmenter = pysbd.Segmenter(language='en',clean=False)
7
  passage_retreival_model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
8
  qa_model = pipeline("question-answering",'a-ware/bart-squadv2')
9
 
 
19
  top_5_query_paragraph_answer_list = ""
20
  count = 1
21
  for query, passage in top_5_query_paragraph_list:
22
+ passage_sentences = sentence_segmenter.segment(passage)
23
  answer = qa_model(question = query, context = passage)['answer']
24
+
25
+ for i in range(len(passage_sentences)):
26
+ if answer.startswith('.') or answer.startswith(':'):
27
+ answer = answer[1:].strip()
28
+ if answer in passage_sentences[i]:
29
+ passage_sentences[i] = "**"+passage_sentences[i].strip()+"**"
30
+
31
  result_str = "# RESULT "+str(count)+"\n"
32
+ result_str = result_str + " ".join(passage_sentences) + "\n\n"
33
  top_5_query_paragraph_answer_list += result_str
34
  count+=1
35