File size: 1,397 Bytes
9746697
 
 
861d5e9
9746697
 
 
861d5e9
 
 
 
 
 
9746697
 
861d5e9
9746697
 
 
 
 
 
 
 
 
 
 
 
861d5e9
 
2de65f4
 
861d5e9
9746697
2de65f4
9746697
 
861d5e9
2de65f4
861d5e9
9746697
 
 
861d5e9
9746697
861d5e9
 
9746697
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from html import unescape
from unicodedata import normalize
import gradio as gr
from transformers import pipeline, AutoModel
import re

re_multispace = re.compile(r"\s+")
model_task_mapping = {
    "Server": "Server",
    "Category": "Category",
    "Gender": "Gender",
    "Day Of Week": "Day_of_week"
}

def normalize_text(text):
    if text is None:
        return None

    text = text.strip()
    text = text.replace("\n", " ")
    text = text.replace("\t", " ")
    text = text.replace("\r", " ")
    text = re_multispace.sub(" ", text)
    text = unescape(text)
    text = normalize("NFKC", text)
    return text


pipelines = {task: pipeline(task="text-classification",
 model=f"hynky/{model}", tokenizer="ufal/robeczech-base",
 truncation=True, max_length=512,
 top_k=5
) for task, model in model_task_mapping.items()}


def predict(article):
    article = normalize_text(article)
    predictions = [pipelines[model](article)[0] for model in model_task_mapping.keys()]
    predictions = [{pred["label"]: round(pred["score"], 3) for pred in task_preds} for task_preds in predictions]
    return predictions

gr.Interface(
    predict,
    inputs=gr.Textbox(lines=4, placeholder="Paste a news article here..."),
    # multioutput of gradio text
    outputs=[gr.Label(num_top_classes=5, label=task)
    for task in model_task_mapping.keys()],
    title="News Article Classifier",
).launch()