Emroi commited on
Commit
e14cc31
1 Parent(s): a19ee9a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -22
app.py CHANGED
@@ -1,23 +1,32 @@
1
- # pip install transformers==4.41.1
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
- from huggingface_hub import login
4
- login()
5
-
6
- model_id = "CohereForAI/aya-23-8B"
7
- tokenizer = AutoTokenizer.from_pretrained(model_id)
8
- model = AutoModelForCausalLM.from_pretrained(model_id)
9
-
10
- # Format message with the command-r-plus chat template
11
- messages = [{"role": "user", "content": "Расскажи о себе"}]
12
- input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")
13
- ## <BOS_TOKEN><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Anneme onu ne kadar sevdiğimi anlatan bir mektup yaz<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
14
-
15
- gen_tokens = model.generate(
16
- input_ids,
17
- max_new_tokens=100,
18
- do_sample=True,
19
- temperature=0.3,
20
- )
21
-
22
- gen_text = tokenizer.decode(gen_tokens[0])
23
- print(gen_text)
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
+
5
+
6
+
7
+
8
+
9
+
10
+ model = AutoModelForCausalLM.from_pretrained("Vikhrmodels/Vikhr-7B-instruct_0.4",
11
+ device_map="auto",
12
+ attn_implementation="flash_attention_2",
13
+ torch_dtype=torch.bfloat16)
14
+
15
+ tokenizer = AutoTokenizer.from_pretrained("Vikhrmodels/Vikhr-7B-instruct_0.4")
16
+ from transformers import AutoTokenizer, pipeline
17
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
18
+ prompts = [
19
+ "В чем разница между фруктом и овощем?",
20
+ "Годы жизни колмагорова?"]
21
+
22
+ def test_inference(prompt):
23
+ prompt = pipe.tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
24
+ print(prompt)
25
+ outputs = pipe(prompt, max_new_tokens=512, do_sample=True, num_beams=1, temperature=0.25, top_k=50, top_p=0.98, eos_token_id=79097)
26
+ return outputs[0]['generated_text'][len(prompt):].strip()
27
+
28
+
29
+ for prompt in prompts:
30
+ print(f" prompt:\n{prompt}")
31
+ print(f" response:\n{test_inference(prompt)}")
32
+ print("-"*50)