Matt
Bugfix
5005844
raw
history blame
2.2 kB
import gradio as gr
from transformers import AutoTokenizer
import json
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
demo_conversation = """[
{"role": "system", "content": "You are a helpful chatbot."},
{"role": "user", "content": "Hi there!"},
{"role": "assistant", "content": "Hello, human!"},
{"role": "user", "content": "Can I ask a question?"}
]"""
default_template = """{% for message in messages %}
{{ "<|im_start|>" + message["role"] + "\\n" + message["content"] + "<|im_end|>\\n" }}
{% endfor %}
{% if add_generation_prompt %}
{{ "<|im_start|>assistant\\n" }}
{% endif %}"""
description_text = """### This space is a helper app for writing [Chat Templates](https://huggingface.co/docs/transformers/main/en/chat_templating).
### When you're happy with the outputs from your template, you can use the code block at the end to add it to a PR!"""
def apply_chat_template(template, test_conversation, add_generation_prompt, cleanup_whitespace):
if cleanup_whitespace:
template = "".join([line.strip() for line in template.split('\n')])
tokenizer.chat_template = template
outputs = []
conversation = json.loads(test_conversation)
pr_snippet = (
"tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)",
f"tokenizer.chat_template = \"{template}\"",
tokenizer.push_to_hub(CHECKPOINT, create_pr=True)
)
pr_snippet = "\n".join(pr_snippet)
formatted = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=add_generation_prompt)
return formatted, pr_snippet
iface = gr.Interface(
description=description_text,
fn=apply_chat_template,
inputs=[
gr.TextArea(value=default_template, lines=10, max_lines=30, label="Chat Template"),
gr.TextArea(value=demo_conversation, lines=6, label="Conversation"),
gr.Checkbox(value=False, label="Add generation prompt"),
gr.Checkbox(value=True, label="Cleanup template whitespace"),
],
outputs=[
gr.TextArea(label="Formatted conversation"),
gr.TextArea(label="Code snippet to create PR", lines=3, show_copy_button=True)
]
)
iface.launch()