Spaces:
Sleeping
Sleeping
File size: 3,150 Bytes
57b690d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import io
import json
import re
import gradio as gr
import matplotlib.pyplot as plt
import pandas as pd
from datasets import load_dataset
from PIL import Image
from transformers import AutoTokenizer
tokenizers = [
"google/gemma-7b",
"meta-llama/Llama-2-7b",
"mistralai/Mistral-7B-v0.1",
"facebook/opt-2.7b",
"microsoft/phi-2",
"THUDM/chatglm3-6b",
"Qwen/Qwen1.5-7B-Chat",
"bigscience/bloom-560m",
"ise-uiuc/Magicoder-S-DS-6.7B",
"google/flan-t5-base",
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
]
def plot_histogram(data):
plt.hist(data)
plt.title("Histogram of number of tokens per dataset item")
buf = io.BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
im = Image.open(buf)
return im
def count(model_id, dataset_id, config, split, column, add_special_tokens=True):
tokencounter = []
wordcounter = []
charcounter = []
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
if config == "":
config is None
dataset = load_dataset(dataset_id, config, split=split, trust_remote_code=True)
pattern = r"[a-zA-Z]+"
for item in dataset:
tokens = tokenizer(item[column], add_special_tokens=add_special_tokens)["input_ids"]
tokencounter.append(len(tokens))
charcounter.append(len(item[column]))
# not 100% accurate but good enough
words = re.findall(pattern, item[column])
wordcounter.append(len(words))
df = pd.DataFrame(tokencounter).describe().T
df.insert(0, "type", "tokens")
dfc = pd.DataFrame(charcounter).describe().T
dfc.insert(0, "type", "chars")
dfw = pd.DataFrame(wordcounter).describe().T
dfw.insert(0, "type", "words")
df.loc[-1] = dfw.values[0]
df.index = df.index + 1 # shifting index
df.loc[-1] = dfc.values[0]
df = df.round(1)
df.drop("count", axis=1, inplace=True)
return plot_histogram(tokencounter), df
demo = gr.Interface(
fn=count,
title="Dataset token counts and distribution",
inputs=[
gr.Dropdown(label="Tokenizer", choices=tokenizers, allow_custom_value=True),
gr.Textbox(label="Dataset"),
gr.Textbox(label="Config"),
gr.Textbox(label="Split"),
gr.Textbox(label="Column"),
gr.Checkbox(label="Add special tokens", value=True),
],
outputs=[
gr.Image(),
gr.Dataframe(label="Token, word and character counts per dataset item"),
],
examples=[
["mistralai/Mistral-7B-v0.1", "gsarti/flores_101", "eng", "dev", "sentence"],
["mistralai/Mistral-7B-v0.1", "Muennighoff/flores200", "eng_Latn", "dev", "sentence"],
["mistralai/Mistral-7B-v0.1", "wikitext", "wikitext-2-v1", "validation", "text"],
["mistralai/Mistral-7B-v0.1", "hails/mmlu_no_train", "elementary_mathematics", "test", "question"],
["mistralai/Mistral-7B-v0.1", "imdb", "", "test", "text"],
["mistralai/Mistral-7B-v0.1", "gsm8k", "main", "test", "question"],
["mistralai/Mistral-7B-v0.1", "locuslab/TOFU", "world_facts", "train", "question"],
],
cache_examples=False
)
demo.launch()
|