Spaces:
Running
Running
freQuensy23
commited on
Commit
•
79b2407
1
Parent(s):
f0c7657
Add mixtral
Browse files- app.py +6 -6
- generators.py +17 -26
app.py
CHANGED
@@ -9,20 +9,20 @@ load_dotenv()
|
|
9 |
|
10 |
async def handle(system_input: str, user_input: str):
|
11 |
print(system_input, user_input)
|
12 |
-
buffers = ["", "", "", "", ""]
|
13 |
async for outputs in async_zip_stream(
|
14 |
generate_gpt2(system_input, user_input),
|
15 |
generate_mistral_7bvo1(system_input, user_input),
|
16 |
generate_llama2(system_input, user_input),
|
17 |
generate_llama3(system_input, user_input),
|
18 |
generate_t5(system_input, user_input),
|
|
|
19 |
):
|
20 |
# gpt_output, mistral_output, llama_output, llama2_output, llama3_output, llama4_output = outputs
|
21 |
for i, b in enumerate(buffers):
|
22 |
buffers[i] += str(outputs[i])
|
23 |
|
24 |
-
yield list(buffers)
|
25 |
-
yield list(buffers) + [generate_bloom(system_input, user_input)]
|
26 |
|
27 |
|
28 |
with gr.Blocks() as demo:
|
@@ -30,10 +30,10 @@ with gr.Blocks() as demo:
|
|
30 |
with gr.Row():
|
31 |
gpt = gr.Textbox(label='gpt-2', lines=4, interactive=False, info='OpenAI\n14 February 2019')
|
32 |
t5 = gr.Textbox(label='t5', lines=4, interactive=False, info='Google\n12 Dec 2019')
|
33 |
-
bloom = gr.Textbox(label='bloom [GPU]', lines=4, interactive=False, info='Big Science\n11 Jul 2022')
|
34 |
-
with gr.Row():
|
35 |
llama2 = gr.Textbox(label='llama-2', lines=4, interactive=False, info='MetaAI\n18 Jul 2023')
|
|
|
36 |
mistral = gr.Textbox(label='mistral-v01', lines=4, interactive=False, info='MistralAI\n20 Sep 2023')
|
|
|
37 |
llama3 = gr.Textbox(label='llama-3.1', lines=4, interactive=False, info='MetaAI\n18 Jul 2024')
|
38 |
|
39 |
user_input = gr.Textbox(label='User Input', lines=2, value='Calculate expression: 7-3=')
|
@@ -42,7 +42,7 @@ with gr.Blocks() as demo:
|
|
42 |
gen_button.click(
|
43 |
fn=handle,
|
44 |
inputs=[system_input, user_input],
|
45 |
-
outputs=[gpt, mistral, llama2, llama3, t5,
|
46 |
)
|
47 |
|
48 |
demo.launch()
|
|
|
9 |
|
10 |
async def handle(system_input: str, user_input: str):
|
11 |
print(system_input, user_input)
|
12 |
+
buffers = ["", "", "", "", "", ""]
|
13 |
async for outputs in async_zip_stream(
|
14 |
generate_gpt2(system_input, user_input),
|
15 |
generate_mistral_7bvo1(system_input, user_input),
|
16 |
generate_llama2(system_input, user_input),
|
17 |
generate_llama3(system_input, user_input),
|
18 |
generate_t5(system_input, user_input),
|
19 |
+
generate_mixtral(system_input, user_input),
|
20 |
):
|
21 |
# gpt_output, mistral_output, llama_output, llama2_output, llama3_output, llama4_output = outputs
|
22 |
for i, b in enumerate(buffers):
|
23 |
buffers[i] += str(outputs[i])
|
24 |
|
25 |
+
yield list(buffers)
|
|
|
26 |
|
27 |
|
28 |
with gr.Blocks() as demo:
|
|
|
30 |
with gr.Row():
|
31 |
gpt = gr.Textbox(label='gpt-2', lines=4, interactive=False, info='OpenAI\n14 February 2019')
|
32 |
t5 = gr.Textbox(label='t5', lines=4, interactive=False, info='Google\n12 Dec 2019')
|
|
|
|
|
33 |
llama2 = gr.Textbox(label='llama-2', lines=4, interactive=False, info='MetaAI\n18 Jul 2023')
|
34 |
+
with gr.Row():
|
35 |
mistral = gr.Textbox(label='mistral-v01', lines=4, interactive=False, info='MistralAI\n20 Sep 2023')
|
36 |
+
mixtral = gr.Textbox(label='mixtral', lines=4, interactive=False, info='Mistral AI\n11 Dec 2023')
|
37 |
llama3 = gr.Textbox(label='llama-3.1', lines=4, interactive=False, info='MetaAI\n18 Jul 2024')
|
38 |
|
39 |
user_input = gr.Textbox(label='User Input', lines=2, value='Calculate expression: 7-3=')
|
|
|
42 |
gen_button.click(
|
43 |
fn=handle,
|
44 |
inputs=[system_input, user_input],
|
45 |
+
outputs=[gpt, mistral, llama2, llama3, t5, mixtral],
|
46 |
)
|
47 |
|
48 |
demo.launch()
|
generators.py
CHANGED
@@ -67,36 +67,27 @@ async def generate_llama2(system_input, user_input):
|
|
67 |
yield message.choices[0].delta.content
|
68 |
|
69 |
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
model = LlamaForCausalLM.from_pretrained(
|
75 |
-
model_path, torch_dtype=torch.float16, device_map='cuda',
|
76 |
-
)
|
77 |
-
print('model openllama loaded')
|
78 |
-
input_text = f"{system_input}\n{user_input}"
|
79 |
-
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")
|
80 |
-
output = model.generate(input_ids, max_length=128)
|
81 |
-
return tokenizer.decode(output[0], skip_special_tokens=True)
|
82 |
-
|
83 |
-
|
84 |
-
@spaces.GPU(duration=120)
|
85 |
-
def generate_bloom(system_input, user_input):
|
86 |
-
model_path = 'bigscience/bloom-7b1'
|
87 |
-
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
88 |
-
model = AutoModelForCausalLM.from_pretrained(
|
89 |
-
model_path, torch_dtype=torch.float16, device_map='cuda',
|
90 |
)
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
|
97 |
-
async def
|
98 |
client = AsyncInferenceClient(
|
99 |
-
"
|
100 |
token=os.getenv('HF_TOKEN')
|
101 |
)
|
102 |
try:
|
|
|
67 |
yield message.choices[0].delta.content
|
68 |
|
69 |
|
70 |
+
async def generate_llama3(system_input, user_input):
|
71 |
+
client = AsyncInferenceClient(
|
72 |
+
"meta-llama/Meta-Llama-3.1-8B-Instruct",
|
73 |
+
token=os.getenv('HF_TOKEN')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
)
|
75 |
+
try:
|
76 |
+
async for message in await client.chat_completion(
|
77 |
+
messages=[
|
78 |
+
{"role": "system", "content": system_input},
|
79 |
+
{"role": "user", "content": user_input}, ],
|
80 |
+
max_tokens=256,
|
81 |
+
stream=True,
|
82 |
+
):
|
83 |
+
yield message.choices[0].delta.content
|
84 |
+
except json.JSONDecodeError:
|
85 |
+
pass
|
86 |
|
87 |
|
88 |
+
async def generate_mixtral(system_input, user_input):
|
89 |
client = AsyncInferenceClient(
|
90 |
+
"mistralai/Mixtral-8x7B-Instruct-v0.1",
|
91 |
token=os.getenv('HF_TOKEN')
|
92 |
)
|
93 |
try:
|