Miaoran000 commited on
Commit
298d825
1 Parent(s): d4bf693

minor update for new models; postprocessing for md format

Browse files
requirements.txt CHANGED
@@ -21,4 +21,6 @@ anthropic
21
  openai
22
  cohere
23
  mistralai
24
- peft
 
 
 
21
  openai
22
  cohere
23
  mistralai
24
+ peft
25
+ mdit_plain
26
+ markdown_it
src/backend/model_operations.py CHANGED
@@ -160,7 +160,7 @@ class SummaryGenerator:
160
  using_replicate_api = False
161
  replicate_api_models = ['snowflake', 'llama-3.1-405b']
162
  using_pipeline = False
163
- pipeline_models = ['llama-3.1', 'phi-3-mini','falcon-7b', 'phi-3.5']
164
 
165
  for replicate_api_model in replicate_api_models:
166
  if replicate_api_model in self.model_id.lower():
@@ -325,7 +325,20 @@ class SummaryGenerator:
325
  result = message.content[0].text
326
  print(result)
327
  return result
328
-
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  elif 'mistral-large' in self.model_id.lower():
330
  api_key = os.environ["MISTRAL_API_KEY"]
331
  client = Mistral(api_key=api_key)
@@ -554,14 +567,8 @@ class EvaluationModel:
554
  for doc, summary in source_summary_pairs:
555
  if util.is_summary_valid(summary):
556
  try:
557
- summary = summary.replace('<bos>','').replace('<eos>','').strip()
558
  score = self.predict([(doc, summary)])[0]
559
- # print(score)
560
- # if score < 0.5:
561
- # print(doc)
562
- # print('-'*10)
563
- # print(summary)
564
- # print('='*20)
565
  hem_scores.append(score)
566
  sources.append(doc)
567
  summaries.append(summary)
 
160
  using_replicate_api = False
161
  replicate_api_models = ['snowflake', 'llama-3.1-405b']
162
  using_pipeline = False
163
+ pipeline_models = ['llama-3.1', 'phi-3-mini','falcon-7b', 'phi-3.5', 'mistral-nemo']
164
 
165
  for replicate_api_model in replicate_api_models:
166
  if replicate_api_model in self.model_id.lower():
 
325
  result = message.content[0].text
326
  print(result)
327
  return result
328
+
329
+ elif 'command-r' in self.model_id.lower():
330
+ co = cohere.Client(os.getenv('COHERE_API_TOKEN'))
331
+ response = co.chat(
332
+ chat_history=[
333
+ {"role": "SYSTEM", "message": system_prompt},
334
+ ],
335
+ message=user_prompt,
336
+ )
337
+ result = response.text
338
+ print(result)
339
+ return result
340
+
341
+
342
  elif 'mistral-large' in self.model_id.lower():
343
  api_key = os.environ["MISTRAL_API_KEY"]
344
  client = Mistral(api_key=api_key)
 
567
  for doc, summary in source_summary_pairs:
568
  if util.is_summary_valid(summary):
569
  try:
570
+ summary = util.normalize_summary(summary)
571
  score = self.predict([(doc, summary)])[0]
 
 
 
 
 
 
572
  hem_scores.append(score)
573
  sources.append(doc)
574
  summaries.append(summary)
src/backend/util.py CHANGED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
  def is_summary_valid(summary: str) -> bool:
2
  """
3
  Checks if the summary is valid.
@@ -76,3 +81,12 @@ def format_results(model_name: str, revision: str, precision: str,
76
  }
77
 
78
  return results
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from markdown_it import MarkdownIt
3
+ from mdit_plain.renderer import RendererPlain
4
+
5
+
6
  def is_summary_valid(summary: str) -> bool:
7
  """
8
  Checks if the summary is valid.
 
81
  }
82
 
83
  return results
84
+
85
+ parser = MarkdownIt(renderer_cls=RendererPlain)
86
+
87
+ def normalize_summary(summary: str) -> str:
88
+ summary = summary.replace('<bos>','').replace('<eos>','')
89
+ summary = parser.render(summary)
90
+ summary = summary.replace('*','')
91
+ summary = re.sub('\s{2,}', ' ', summary)
92
+ return summary