Spaces:
Sleeping
Sleeping
from html import unescape | |
from unicodedata import normalize | |
import gradio as gr | |
from transformers import pipeline, AutoModel | |
import re | |
re_multispace = re.compile(r"\s+") | |
model_task_mapping = { | |
"Server": "Server", | |
"Category": "Category", | |
"Gender": "Gender", | |
"Day Of Week": "Day_of_week" | |
} | |
def normalize_text(text): | |
if text is None: | |
return None | |
text = text.strip() | |
text = text.replace("\n", " ") | |
text = text.replace("\t", " ") | |
text = text.replace("\r", " ") | |
text = re_multispace.sub(" ", text) | |
text = unescape(text) | |
text = normalize("NFKC", text) | |
return text | |
pipelines = {task: pipeline(task="text-classification", | |
model=f"hynky/{model}", tokenizer="ufal/robeczech-base", | |
truncation=True, max_length=512, | |
top_k=5 | |
) for task, model in model_task_mapping.items()} | |
def predict(article): | |
article = normalize_text(article) | |
predictions = [pipelines[model](article)[0] for model in model_task_mapping.keys()] | |
predictions = [{pred["label"]: round(pred["score"], 3) for pred in task_preds} for task_preds in predictions] | |
return predictions | |
gr.Interface( | |
predict, | |
inputs=gr.Textbox(lines=4, placeholder="Paste a news article here..."), | |
# multioutput of gradio text | |
outputs=[gr.Label(num_top_classes=5, label=task) | |
for task in model_task_mapping.keys()], | |
title="News Article Classifier", | |
).launch() | |