Miaoran000 commited on
Commit
6919058
1 Parent(s): 066863b

minor fix to work for recent models

Browse files
Files changed (1) hide show
  1. src/backend/model_operations.py +122 -103
src/backend/model_operations.py CHANGED
@@ -11,15 +11,19 @@ import pandas as pd
11
  import spacy
12
  import litellm
13
  from tqdm import tqdm
14
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoModelForTokenClassification, AutoConfig
15
  from peft import PeftModel
16
  import torch
17
  import cohere
18
  from openai import OpenAI
 
19
  import anthropic
20
  import replicate
21
- import google.generativeai as genai
 
 
22
  from mistralai import Mistral
 
23
 
24
 
25
  import src.backend.util as util
@@ -156,7 +160,7 @@ class SummaryGenerator:
156
  def generate_summary(self, system_prompt: str, user_prompt: str):
157
  # Using Together AI API
158
  using_together_api = False
159
- together_ai_api_models = ['mixtral', 'dbrx', 'wizardlm', 'llama-3-', 'qwen', 'zero-one-ai'] #, 'mistralai'
160
  using_replicate_api = False
161
  replicate_api_models = ['snowflake', 'llama-3.1-405b']
162
  using_pipeline = False
@@ -181,99 +185,80 @@ class SummaryGenerator:
181
 
182
  # if 'mixtral' in self.model_id.lower() or 'dbrx' in self.model_id.lower() or 'wizardlm' in self.model_id.lower(): # For mixtral and dbrx models, use Together AI API
183
  if using_together_api:
184
- # print('using together api')
185
- # suffix = "completions" if ('mixtral' in self.model_id.lower() or 'base' in self.model_id.lower()) else "chat/completions"
186
- suffix = "chat/completions"
187
- url = f"https://api.together.xyz/v1/{suffix}"
188
-
189
- payload = {
190
- "model": self.model_id,
191
- 'max_new_tokens': 250,
192
- "temperature": 0.0,
193
-
194
- }
195
- payload['messages'] = [{"role": "system", "content": system_prompt},
196
- {"role": "user", "content": user_prompt}]
197
- headers = {
198
- "accept": "application/json",
199
- "content-type": "application/json",
200
- "Authorization": f"Bearer {os.environ['TOGETHER_API_KEY']}"
201
- }
202
-
203
- response = requests.post(url, json=payload, headers=headers)
204
- print(response)
205
- try:
206
- result = json.loads(response.text)
207
- # print(result)
208
- result = result["choices"][0]
209
- if 'message' in result:
210
- result = result["message"]["content"].strip()
211
- else:
212
- result = result["text"]
213
- result_candidates = [result_cancdidate for result_cancdidate in result.split('\n\n') if len(result_cancdidate) > 0]
214
- result = result_candidates[0]
215
- # print(result)
216
- except:
217
- # print(response)
218
- result = ''
219
  print(result)
220
  return result
221
 
222
  # Using OpenAI API
223
- elif 'gpt' in self.model_id.lower():
224
  client = OpenAI()
225
  response = client.chat.completions.create(
226
  model=self.model_id.replace('openai/',''),
227
  messages=[{"role": "system", "content": system_prompt},
228
- {"role": "user", "content": user_prompt}],
229
- temperature=0.0,
230
- max_tokens=250,
 
231
  )
232
  # print(response)
233
  result = response.choices[0].message.content
234
  print(result)
235
  return result
236
-
237
- # Using Google AI API for Gemini models
238
  elif 'gemini' in self.model_id.lower():
239
- genai.configure(api_key=os.getenv('GOOGLE_AI_API_KEY'))
 
 
 
 
 
240
  generation_config = {
241
  "temperature": 0,
242
- "top_p": 0.95, # cannot change
243
- "top_k": 0,
244
- "max_output_tokens": 250,
245
- # "response_mime_type": "application/json",
246
  }
247
  safety_settings = [
248
- {
249
- "category": "HARM_CATEGORY_HARASSMENT",
250
- "threshold": "BLOCK_NONE"
251
- },
252
- {
253
- "category": "HARM_CATEGORY_HATE_SPEECH",
254
- "threshold": "BLOCK_NONE"
255
- },
256
- {
257
- "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
258
- "threshold": "BLOCK_NONE"
259
- },
260
- {
261
- "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
262
- "threshold": "BLOCK_NONE"
263
- },
264
  ]
265
- model = genai.GenerativeModel(model_name=self.model_id.lower().split('google/')[-1],
266
- generation_config=generation_config,
267
- system_instruction=system_prompt,
268
- safety_settings=safety_settings)
269
- # print(model)
270
- convo = model.start_chat(history=[])
271
- convo.send_message(user_prompt)
272
- # print(convo.last)
273
- result = convo.last.text
274
  print(result)
275
  return result
276
-
277
  elif using_replicate_api:
278
  print("using replicate")
279
  if 'snowflake' in self.model_id.lower():
@@ -338,7 +323,6 @@ class SummaryGenerator:
338
  print(result)
339
  return result
340
 
341
-
342
  elif 'mistral-large' in self.model_id.lower():
343
  api_key = os.environ["MISTRAL_API_KEY"]
344
  client = Mistral(api_key=api_key)
@@ -363,35 +347,30 @@ class SummaryGenerator:
363
  print(result)
364
  return result
365
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  # Using HF API or download checkpoints
367
  elif self.local_model is None and self.local_pipeline is None:
368
- # try: # try use HuggingFace API
369
- # print('** using huggingface api')
370
- # response = litellm.completion(
371
- # model=self.model,
372
- # messages=[{"role": "system", "content": system_prompt},
373
- # {"role": "user", "content": user_prompt}],
374
- # temperature=0.0,
375
- # max_tokens=250,
376
- # api_base=self.api_base,
377
- # )
378
- # result = response['choices'][0]['message']['content']
379
- # result = result.split('<|im_end|>')[0]
380
- # print(result)
381
- # return result
382
- # except Exception as e:
383
- # if 'Rate limit reached' in str(e) :
384
- # wait_time = 300
385
- # current_time = datetime.now().strftime('%H:%M:%S')
386
- # print(f"Rate limit hit at {current_time}. Waiting for 5 minutes before retrying...")
387
- # time.sleep(wait_time)
388
- # else:
389
  if using_pipeline:
390
  self.local_pipeline = pipeline(
391
  "text-generation",
392
  model=self.model_id,
393
  tokenizer=AutoTokenizer.from_pretrained(self.model_id),
394
- model_kwargs={"torch_dtype": torch.bfloat16},
395
  device_map="auto",
396
  trust_remote_code=True
397
  )
@@ -404,12 +383,18 @@ class SummaryGenerator:
404
  attn_implementation="flash_attention_2",
405
  device_map="auto",
406
  use_mamba_kernels=False)
 
 
 
 
 
 
 
407
  else:
408
  self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id, trust_remote_code=True, device_map="auto", torch_dtype="auto")
409
  # print(self.local_model.device)
410
  print("Local model loaded")
411
-
412
-
413
  # Using local model/pipeline
414
  if self.local_pipeline:
415
  print('Using Transformers pipeline')
@@ -432,7 +417,7 @@ class SummaryGenerator:
432
  if 'gemma' in self.model_id.lower() or 'mistral-7b' in self.model_id.lower():
433
  messages=[
434
  # gemma-1.1, mistral-7b does not accept system role
435
- {"role": "user", "content": system_prompt + ' ' + user_prompt}
436
  ]
437
  prompt = self.tokenizer.apply_chat_template(messages,add_generation_prompt=True, tokenize=False)
438
 
@@ -442,6 +427,21 @@ class SummaryGenerator:
442
  elif 'intel' in self.model_id.lower():
443
  prompt = f"### System:\n{system_prompt}\n### User:\n{user_prompt}\n### Assistant:\n"
444
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
  else:
446
  messages=[
447
  {"role": "system", "content": system_prompt},
@@ -455,14 +455,27 @@ class SummaryGenerator:
455
  outputs = self.local_model.generate(**input_ids, max_new_tokens=250, do_sample=True, temperature=0.01, pad_token_id=self.tokenizer.eos_token_id)
456
  if 'glm' in self.model_id.lower():
457
  outputs = outputs[:, input_ids['input_ids'].shape[1]:]
458
-
459
- result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
460
  if 'gemma-2' in self.model_id.lower():
461
  result = result.split(user_prompt + '\nmodel')[-1].strip()
462
  elif 'intel' in self.model_id.lower():
463
  result = result.split("### Assistant:\n")[-1]
464
  elif 'jamba' in self.model_id.lower():
465
  result = result.split(messages[-1]['content'])[1].strip()
 
 
466
  else:
467
  # print(prompt)
468
  # print('-'*50)
@@ -572,6 +585,12 @@ class EvaluationModel:
572
  hem_scores.append(score)
573
  sources.append(doc)
574
  summaries.append(summary)
 
 
 
 
 
 
575
  except Exception as e:
576
  logging.error(f"Error while running HEM: {e}")
577
  raise
 
11
  import spacy
12
  import litellm
13
  from tqdm import tqdm
14
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoModelForTokenClassification, AutoConfig, Qwen2VLForConditionalGeneration, AutoProcessor
15
  from peft import PeftModel
16
  import torch
17
  import cohere
18
  from openai import OpenAI
19
+ from together import Together
20
  import anthropic
21
  import replicate
22
+ # import google.generativeai as genai
23
+ import vertexai
24
+ from vertexai.generative_models import GenerativeModel, Part, SafetySetting, FinishReason
25
  from mistralai import Mistral
26
+ from qwen_vl_utils import process_vision_info
27
 
28
 
29
  import src.backend.util as util
 
160
  def generate_summary(self, system_prompt: str, user_prompt: str):
161
  # Using Together AI API
162
  using_together_api = False
163
+ together_ai_api_models = ['mixtral', 'dbrx', 'wizardlm', 'llama-3-', 'qwen2-72b-instruct', 'zero-one-ai', 'llama-3.2-'] #, 'mistralai'
164
  using_replicate_api = False
165
  replicate_api_models = ['snowflake', 'llama-3.1-405b']
166
  using_pipeline = False
 
185
 
186
  # if 'mixtral' in self.model_id.lower() or 'dbrx' in self.model_id.lower() or 'wizardlm' in self.model_id.lower(): # For mixtral and dbrx models, use Together AI API
187
  if using_together_api:
188
+ print('using together api')
189
+ client = Together(api_key=os.environ.get('TOGETHER_API_KEY'))
190
+ if 'llama-3.2-90b-vision' in self.model_id.lower() or 'llama-3.2-11b-vision' in self.model_id.lower():
191
+ messages = [
192
+ {"role": "system","content": system_prompt},
193
+ {"role": "user","content": [{"type": "text","text": user_prompt}]}
194
+ ]
195
+ else:
196
+ messages = [{"role": "system", "content": system_prompt},
197
+ {"role": "user", "content": user_prompt}]
198
+ response = client.chat.completions.create(
199
+ model=self.model_id,
200
+ messages = messages,
201
+ max_tokens=250,
202
+ temperature=0,
203
+ )
204
+ # print(response)
205
+ result = response.choices[0].message.content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  print(result)
207
  return result
208
 
209
  # Using OpenAI API
210
+ elif 'openai' in self.model_id.lower():
211
  client = OpenAI()
212
  response = client.chat.completions.create(
213
  model=self.model_id.replace('openai/',''),
214
  messages=[{"role": "system", "content": system_prompt},
215
+ {"role": "user", "content": user_prompt}] if 'gpt' in self.model_id
216
+ else [{"role": "user", "content": system_prompt + '\n' + user_prompt}],
217
+ temperature=0.0 if 'gpt' in self.model_id.lower() else 1.0, # fixed at 1 for o1 models
218
+ max_completion_tokens=250 if 'gpt' in self.model_id.lower() else None, # not compatible with o1 series models
219
  )
220
  # print(response)
221
  result = response.choices[0].message.content
222
  print(result)
223
  return result
224
+
 
225
  elif 'gemini' in self.model_id.lower():
226
+ vertexai.init(project=os.getenv("GOOGLE_PROJECT_ID"), location="us-central1")
227
+ gemini_model_id_map = {'gemini-1.5-pro-exp-0827':'gemini-pro-experimental', 'gemini-1.5-flash-exp-0827': 'gemini-flash-experimental'}
228
+ model = GenerativeModel(
229
+ self.model_id.lower().split('google/')[-1],
230
+ system_instruction = [system_prompt]
231
+ )
232
  generation_config = {
233
  "temperature": 0,
234
+ "max_output_tokens": 250
 
 
 
235
  }
236
  safety_settings = [
237
+ SafetySetting(
238
+ category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
239
+ threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE
240
+ ),
241
+ SafetySetting(
242
+ category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
243
+ threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE
244
+ ),
245
+ SafetySetting(
246
+ category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
247
+ threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE
248
+ ),
249
+ SafetySetting(
250
+ category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT,
251
+ threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE
252
+ )
253
  ]
254
+ response = model.generate_content(
255
+ user_prompt,
256
+ safety_settings=safety_settings,
257
+ generation_config=generation_config
258
+ )
259
+ result = response.text
 
 
 
260
  print(result)
261
  return result
 
262
  elif using_replicate_api:
263
  print("using replicate")
264
  if 'snowflake' in self.model_id.lower():
 
323
  print(result)
324
  return result
325
 
 
326
  elif 'mistral-large' in self.model_id.lower():
327
  api_key = os.environ["MISTRAL_API_KEY"]
328
  client = Mistral(api_key=api_key)
 
347
  print(result)
348
  return result
349
 
350
+ elif 'deepseek' in self.model_id.lower():
351
+ client = OpenAI(api_key=os.getenv("DeepSeek_API_KEY"), base_url="https://api.deepseek.com")
352
+ response = client.chat.completions.create(
353
+ model=self.model_id.split('/')[-1],
354
+ messages=[
355
+ {"role": "system", "content": system_prompt},
356
+ {"role": "user", "content": user_prompt},
357
+ ],
358
+ max_tokens=250,
359
+ temperature=0,
360
+ stream=False
361
+ )
362
+ result = response.choices[0].message.content
363
+ print(result)
364
+ return result
365
+
366
  # Using HF API or download checkpoints
367
  elif self.local_model is None and self.local_pipeline is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
  if using_pipeline:
369
  self.local_pipeline = pipeline(
370
  "text-generation",
371
  model=self.model_id,
372
  tokenizer=AutoTokenizer.from_pretrained(self.model_id),
373
+ torch_dtype=torch.bfloat16 if 'llama-3.2' in self.model_id.lower() else "auto",
374
  device_map="auto",
375
  trust_remote_code=True
376
  )
 
383
  attn_implementation="flash_attention_2",
384
  device_map="auto",
385
  use_mamba_kernels=False)
386
+
387
+ elif 'qwen2-vl' in self.model_id.lower():
388
+ self.local_model = Qwen2VLForConditionalGeneration.from_pretrained(
389
+ self.model_id, torch_dtype="auto", device_map="auto"
390
+ )
391
+ self.processor = AutoProcessor.from_pretrained(self.model_id)
392
+
393
  else:
394
  self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id, trust_remote_code=True, device_map="auto", torch_dtype="auto")
395
  # print(self.local_model.device)
396
  print("Local model loaded")
397
+
 
398
  # Using local model/pipeline
399
  if self.local_pipeline:
400
  print('Using Transformers pipeline')
 
417
  if 'gemma' in self.model_id.lower() or 'mistral-7b' in self.model_id.lower():
418
  messages=[
419
  # gemma-1.1, mistral-7b does not accept system role
420
+ {"role": "user", "content": system_prompt + '\n' + user_prompt}
421
  ]
422
  prompt = self.tokenizer.apply_chat_template(messages,add_generation_prompt=True, tokenize=False)
423
 
 
427
  elif 'intel' in self.model_id.lower():
428
  prompt = f"### System:\n{system_prompt}\n### User:\n{user_prompt}\n### Assistant:\n"
429
 
430
+ elif 'qwen2-vl' in self.model_id.lower():
431
+ messages = [
432
+ {
433
+ "role": "system",
434
+ "content": [
435
+ {"type": "text", "text": system_prompt}
436
+ ]
437
+ },
438
+ {
439
+ "role": "user",
440
+ "content": [
441
+ {"type": "text", "text": user_prompt},
442
+ ],
443
+ }
444
+ ]
445
  else:
446
  messages=[
447
  {"role": "system", "content": system_prompt},
 
455
  outputs = self.local_model.generate(**input_ids, max_new_tokens=250, do_sample=True, temperature=0.01, pad_token_id=self.tokenizer.eos_token_id)
456
  if 'glm' in self.model_id.lower():
457
  outputs = outputs[:, input_ids['input_ids'].shape[1]:]
458
+ elif 'qwen2-vl' in self.model_id.lower() or 'qwen2.5' in self.model_id.lower():
459
+ outputs = [
460
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(input_ids.input_ids, outputs)
461
+ ]
462
+
463
+
464
+ if 'qwen2-vl' in self.model_id.lower():
465
+ result = self.processor.batch_decode(
466
+ outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
467
+ )[0]
468
+ else:
469
+ result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
470
+
471
  if 'gemma-2' in self.model_id.lower():
472
  result = result.split(user_prompt + '\nmodel')[-1].strip()
473
  elif 'intel' in self.model_id.lower():
474
  result = result.split("### Assistant:\n")[-1]
475
  elif 'jamba' in self.model_id.lower():
476
  result = result.split(messages[-1]['content'])[1].strip()
477
+ elif 'qwen2-vl' in self.model_id.lower() or 'qwen2.5' in self.model_id.lower():
478
+ pass
479
  else:
480
  # print(prompt)
481
  # print('-'*50)
 
585
  hem_scores.append(score)
586
  sources.append(doc)
587
  summaries.append(summary)
588
+ if score < 0.5:
589
+ print(score)
590
+ print(doc)
591
+ print('-'*20)
592
+ print(summary)
593
+ print('='*50)
594
  except Exception as e:
595
  logging.error(f"Error while running HEM: {e}")
596
  raise