freQuensy23 commited on
Commit
79b2407
1 Parent(s): f0c7657

Add mixtral

Browse files
Files changed (2) hide show
  1. app.py +6 -6
  2. generators.py +17 -26
app.py CHANGED
@@ -9,20 +9,20 @@ load_dotenv()
9
 
10
  async def handle(system_input: str, user_input: str):
11
  print(system_input, user_input)
12
- buffers = ["", "", "", "", ""]
13
  async for outputs in async_zip_stream(
14
  generate_gpt2(system_input, user_input),
15
  generate_mistral_7bvo1(system_input, user_input),
16
  generate_llama2(system_input, user_input),
17
  generate_llama3(system_input, user_input),
18
  generate_t5(system_input, user_input),
 
19
  ):
20
  # gpt_output, mistral_output, llama_output, llama2_output, llama3_output, llama4_output = outputs
21
  for i, b in enumerate(buffers):
22
  buffers[i] += str(outputs[i])
23
 
24
- yield list(buffers) + ["", ""]
25
- yield list(buffers) + [generate_bloom(system_input, user_input)]
26
 
27
 
28
  with gr.Blocks() as demo:
@@ -30,10 +30,10 @@ with gr.Blocks() as demo:
30
  with gr.Row():
31
  gpt = gr.Textbox(label='gpt-2', lines=4, interactive=False, info='OpenAI\n14 February 2019')
32
  t5 = gr.Textbox(label='t5', lines=4, interactive=False, info='Google\n12 Dec 2019')
33
- bloom = gr.Textbox(label='bloom [GPU]', lines=4, interactive=False, info='Big Science\n11 Jul 2022')
34
- with gr.Row():
35
  llama2 = gr.Textbox(label='llama-2', lines=4, interactive=False, info='MetaAI\n18 Jul 2023')
 
36
  mistral = gr.Textbox(label='mistral-v01', lines=4, interactive=False, info='MistralAI\n20 Sep 2023')
 
37
  llama3 = gr.Textbox(label='llama-3.1', lines=4, interactive=False, info='MetaAI\n18 Jul 2024')
38
 
39
  user_input = gr.Textbox(label='User Input', lines=2, value='Calculate expression: 7-3=')
@@ -42,7 +42,7 @@ with gr.Blocks() as demo:
42
  gen_button.click(
43
  fn=handle,
44
  inputs=[system_input, user_input],
45
- outputs=[gpt, mistral, llama2, llama3, t5, bloom],
46
  )
47
 
48
  demo.launch()
 
9
 
10
  async def handle(system_input: str, user_input: str):
11
  print(system_input, user_input)
12
+ buffers = ["", "", "", "", "", ""]
13
  async for outputs in async_zip_stream(
14
  generate_gpt2(system_input, user_input),
15
  generate_mistral_7bvo1(system_input, user_input),
16
  generate_llama2(system_input, user_input),
17
  generate_llama3(system_input, user_input),
18
  generate_t5(system_input, user_input),
19
+ generate_mixtral(system_input, user_input),
20
  ):
21
  # gpt_output, mistral_output, llama_output, llama2_output, llama3_output, llama4_output = outputs
22
  for i, b in enumerate(buffers):
23
  buffers[i] += str(outputs[i])
24
 
25
+ yield list(buffers)
 
26
 
27
 
28
  with gr.Blocks() as demo:
 
30
  with gr.Row():
31
  gpt = gr.Textbox(label='gpt-2', lines=4, interactive=False, info='OpenAI\n14 February 2019')
32
  t5 = gr.Textbox(label='t5', lines=4, interactive=False, info='Google\n12 Dec 2019')
 
 
33
  llama2 = gr.Textbox(label='llama-2', lines=4, interactive=False, info='MetaAI\n18 Jul 2023')
34
+ with gr.Row():
35
  mistral = gr.Textbox(label='mistral-v01', lines=4, interactive=False, info='MistralAI\n20 Sep 2023')
36
+ mixtral = gr.Textbox(label='mixtral', lines=4, interactive=False, info='Mistral AI\n11 Dec 2023')
37
  llama3 = gr.Textbox(label='llama-3.1', lines=4, interactive=False, info='MetaAI\n18 Jul 2024')
38
 
39
  user_input = gr.Textbox(label='User Input', lines=2, value='Calculate expression: 7-3=')
 
42
  gen_button.click(
43
  fn=handle,
44
  inputs=[system_input, user_input],
45
+ outputs=[gpt, mistral, llama2, llama3, t5, mixtral],
46
  )
47
 
48
  demo.launch()
generators.py CHANGED
@@ -67,36 +67,27 @@ async def generate_llama2(system_input, user_input):
67
  yield message.choices[0].delta.content
68
 
69
 
70
- @spaces.GPU(duration=120)
71
- def generate_openllama(system_input, user_input):
72
- model_path = 'openlm-research/open_llama_3b_v2'
73
- tokenizer = LlamaTokenizer.from_pretrained(model_path)
74
- model = LlamaForCausalLM.from_pretrained(
75
- model_path, torch_dtype=torch.float16, device_map='cuda',
76
- )
77
- print('model openllama loaded')
78
- input_text = f"{system_input}\n{user_input}"
79
- input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")
80
- output = model.generate(input_ids, max_length=128)
81
- return tokenizer.decode(output[0], skip_special_tokens=True)
82
-
83
-
84
- @spaces.GPU(duration=120)
85
- def generate_bloom(system_input, user_input):
86
- model_path = 'bigscience/bloom-7b1'
87
- tokenizer = AutoTokenizer.from_pretrained(model_path)
88
- model = AutoModelForCausalLM.from_pretrained(
89
- model_path, torch_dtype=torch.float16, device_map='cuda',
90
  )
91
- input_text = f"{system_input}\n{user_input}"
92
- input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")
93
- output = model.generate(input_ids, max_length=128)
94
- return tokenizer.decode(output[0], skip_special_tokens=True)
 
 
 
 
 
 
 
95
 
96
 
97
- async def generate_llama3(system_input, user_input):
98
  client = AsyncInferenceClient(
99
- "meta-llama/Meta-Llama-3.1-8B-Instruct",
100
  token=os.getenv('HF_TOKEN')
101
  )
102
  try:
 
67
  yield message.choices[0].delta.content
68
 
69
 
70
+ async def generate_llama3(system_input, user_input):
71
+ client = AsyncInferenceClient(
72
+ "meta-llama/Meta-Llama-3.1-8B-Instruct",
73
+ token=os.getenv('HF_TOKEN')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  )
75
+ try:
76
+ async for message in await client.chat_completion(
77
+ messages=[
78
+ {"role": "system", "content": system_input},
79
+ {"role": "user", "content": user_input}, ],
80
+ max_tokens=256,
81
+ stream=True,
82
+ ):
83
+ yield message.choices[0].delta.content
84
+ except json.JSONDecodeError:
85
+ pass
86
 
87
 
88
+ async def generate_mixtral(system_input, user_input):
89
  client = AsyncInferenceClient(
90
+ "mistralai/Mixtral-8x7B-Instruct-v0.1",
91
  token=os.getenv('HF_TOKEN')
92
  )
93
  try: