File size: 3,150 Bytes
57b690d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import io
import json
import re

import gradio as gr
import matplotlib.pyplot as plt
import pandas as pd
from datasets import load_dataset
from PIL import Image
from transformers import AutoTokenizer

tokenizers = [
    "google/gemma-7b",
    "meta-llama/Llama-2-7b",
    "mistralai/Mistral-7B-v0.1",
    "facebook/opt-2.7b",
    "microsoft/phi-2",
    "THUDM/chatglm3-6b",
    "Qwen/Qwen1.5-7B-Chat",
    "bigscience/bloom-560m",
    "ise-uiuc/Magicoder-S-DS-6.7B",
    "google/flan-t5-base",
    "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
]


def plot_histogram(data):
    plt.hist(data)
    plt.title("Histogram of number of tokens per dataset item")
    buf = io.BytesIO()
    plt.savefig(buf, format="png")
    buf.seek(0)
    im = Image.open(buf)
    return im


def count(model_id, dataset_id, config, split, column, add_special_tokens=True):
    tokencounter = []
    wordcounter = []
    charcounter = []
    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    if config == "":
        config is None
    dataset = load_dataset(dataset_id, config, split=split, trust_remote_code=True)
    pattern = r"[a-zA-Z]+"
    for item in dataset:
        tokens = tokenizer(item[column], add_special_tokens=add_special_tokens)["input_ids"]
        tokencounter.append(len(tokens))
        charcounter.append(len(item[column]))
        # not 100% accurate but good enough
        words = re.findall(pattern, item[column])
        wordcounter.append(len(words))

    df = pd.DataFrame(tokencounter).describe().T
    df.insert(0, "type", "tokens")
    dfc = pd.DataFrame(charcounter).describe().T
    dfc.insert(0, "type", "chars")
    dfw = pd.DataFrame(wordcounter).describe().T
    dfw.insert(0, "type", "words")
    df.loc[-1] = dfw.values[0]
    df.index = df.index + 1  # shifting index
    df.loc[-1] = dfc.values[0]
    df = df.round(1)
    df.drop("count", axis=1, inplace=True)

    return plot_histogram(tokencounter), df


demo = gr.Interface(
    fn=count,
    title="Dataset token counts and distribution",
    inputs=[
        gr.Dropdown(label="Tokenizer", choices=tokenizers, allow_custom_value=True),
        gr.Textbox(label="Dataset"),
        gr.Textbox(label="Config"),
        gr.Textbox(label="Split"),
        gr.Textbox(label="Column"),
        gr.Checkbox(label="Add special tokens", value=True),
    ],
    outputs=[
        gr.Image(),
        gr.Dataframe(label="Token, word and character counts per dataset item"),
    ],
    examples=[
        ["mistralai/Mistral-7B-v0.1", "gsarti/flores_101", "eng", "dev", "sentence"],
        ["mistralai/Mistral-7B-v0.1", "Muennighoff/flores200", "eng_Latn", "dev", "sentence"],
        ["mistralai/Mistral-7B-v0.1", "wikitext", "wikitext-2-v1", "validation", "text"],
        ["mistralai/Mistral-7B-v0.1", "hails/mmlu_no_train", "elementary_mathematics", "test", "question"],
        ["mistralai/Mistral-7B-v0.1", "imdb", "", "test", "text"],
        ["mistralai/Mistral-7B-v0.1", "gsm8k", "main", "test", "question"],
        ["mistralai/Mistral-7B-v0.1", "locuslab/TOFU", "world_facts", "train", "question"],
    ],
    cache_examples=False
)

demo.launch()