Mattral commited on
Commit
8e79faf
1 Parent(s): b3cf22a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -0
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
+ import random
4
+ import textwrap
5
+
6
+ # Define the model to be used
7
+ model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
8
+ client = InferenceClient(model)
9
+
10
+ # Embedded system prompt
11
+ system_prompt_text = "You are a smart and helpful Health consultant and therapist named CareNetAI owned by YAiC. You help and support with any kind of request and provide a detailed answer or suggestion to the question. You are friendly and willing to help depressed people and also help people identify manipultors and how to protect themselves. But if you are asked about something unethical or dangerous, you must refuse and provide a safe and respectful way to handle that."
12
+
13
+ # Read the content of the info.md file
14
+ with open("info.md", "r") as file:
15
+ info_md_content = file.read()
16
+
17
+ # Chunk the info.md content into smaller sections
18
+ chunk_size = 2000 # Adjust this size as needed
19
+ info_md_chunks = textwrap.wrap(info_md_content, chunk_size)
20
+
21
+ def get_all_chunks(chunks):
22
+ return "\n\n".join(chunks)
23
+
24
+ def format_prompt_mixtral(message, history, info_md_chunks):
25
+ prompt = "<s>"
26
+ all_chunks = get_all_chunks(info_md_chunks)
27
+ prompt += f"{all_chunks}\n\n" # Add all chunks of info.md at the beginning
28
+ prompt += f"{system_prompt_text}\n\n" # Add the system prompt
29
+
30
+ if history:
31
+ for user_prompt, bot_response in history:
32
+ prompt += f"[INST] {user_prompt} [/INST]"
33
+ prompt += f" {bot_response}</s> "
34
+ prompt += f"[INST] {message} [/INST]"
35
+ return prompt
36
+
37
+ def chat_inf(prompt, history, seed, temp, tokens, top_p, rep_p):
38
+ generate_kwargs = dict(
39
+ temperature=temp,
40
+ max_new_tokens=tokens,
41
+ top_p=top_p,
42
+ repetition_penalty=rep_p,
43
+ do_sample=True,
44
+ seed=seed,
45
+ )
46
+
47
+ formatted_prompt = format_prompt_mixtral(prompt, history, info_md_chunks)
48
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
49
+ output = ""
50
+ for response in stream:
51
+ output += response.token.text
52
+ yield [(prompt, output)]
53
+ history.append((prompt, output))
54
+ yield history
55
+
56
+ def clear_fn():
57
+ return None, None
58
+
59
+ rand_val = random.randint(1, 1111111111111111)
60
+
61
+ def check_rand(inp, val):
62
+ if inp:
63
+ return gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, value=random.randint(1, 1111111111111111))
64
+ else:
65
+ return gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, value=int(val))
66
+
67
+ with gr.Blocks() as app: # Add auth here
68
+ gr.HTML("""<center><h1 style='font-size:xx-large;'>PTT Chatbot</h1><br><h3>running on Huggingface Inference </h3><br><h7>EXPERIMENTAL</center>""")
69
+ with gr.Row():
70
+ chat = gr.Chatbot(height=500)
71
+ with gr.Group():
72
+ with gr.Row():
73
+ with gr.Column(scale=3):
74
+ inp = gr.Textbox(label="Prompt", lines=5, interactive=True) # Increased lines and interactive
75
+ with gr.Row():
76
+ with gr.Column(scale=2):
77
+ btn = gr.Button("Chat")
78
+ with gr.Column(scale=1):
79
+ with gr.Group():
80
+ stop_btn = gr.Button("Stop")
81
+ clear_btn = gr.Button("Clear")
82
+ with gr.Column(scale=1):
83
+ with gr.Group():
84
+ rand = gr.Checkbox(label="Random Seed", value=True)
85
+ seed = gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, step=1, value=rand_val)
86
+ tokens = gr.Slider(label="Max new tokens", value=3840, minimum=0, maximum=8000, step=64, interactive=True, visible=True, info="The maximum number of tokens")
87
+ temp = gr.Slider(label="Temperature", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
88
+ top_p = gr.Slider(label="Top-P", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
89
+ rep_p = gr.Slider(label="Repetition Penalty", step=0.1, minimum=0.1, maximum=2.0, value=1.0)
90
+
91
+ hid1 = gr.Number(value=1, visible=False)
92
+
93
+ go = btn.click(check_rand, [rand, seed], seed).then(chat_inf, [inp, chat, seed, temp, tokens, top_p, rep_p], chat)
94
+
95
+ stop_btn.click(None, None, None, cancels=[go])
96
+ clear_btn.click(clear_fn, None, [inp, chat])
97
+
98
+ app.queue(default_concurrency_limit=10).launch(share=True, auth=("admin", "0112358"))