File size: 2,992 Bytes
21b28f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os

import gdown as gdown
import nltk
import streamlit as st
from nltk.tokenize import sent_tokenize

from source.pipeline import MultiLabelPipeline, inputs_to_dataset


def download_models(ids):
    """
    Download all models.
    :param ids: name and links of models
    :return:
    """

    # Download sentence tokenizer
    nltk.download('punkt')

    # Download model from drive if not stored locally
    for key in ids:
        if not os.path.isfile(f"model/{key}.pt"):
            url = f"https://drive.google.com/uc?id={ids[key]}"
            gdown.download(url=url, output=f"model/{key}.pt")


@st.cache
def load_labels():
    """
    Load model labels.
    :return:
    """

    return [
        "admiration",
        "amusement",
        "anger",
        "annoyance",
        "approval",
        "caring",
        "confusion",
        "curiosity",
        "desire",
        "disappointment",
        "disapproval",
        "disgust",
        "embarrassment",
        "excitement",
        "fear",
        "gratitude",
        "grief",
        "joy",
        "love",
        "nervousness",
        "optimism",
        "pride",
        "realization",
        "relief",
        "remorse",
        "sadness",
        "surprise",
        "neutral"
    ]


@st.cache(allow_output_mutation=True)
def load_model(model_path):
    """
    Load model and cache it.
    :param model_path: path to model
    :return:
    """

    model = MultiLabelPipeline(model_path=model_path)

    return model


# Page config
st.set_page_config(layout="centered")
st.title("Multiclass Emotion Classification")
st.write("DeepMind Language Perceiver for Multiclass Emotion Classification (Eng). ")

maintenance = False
if maintenance:
    st.write("Unavailable for now (file downloads limit). ")
else:
    # Variables
    ids = {'perceiver-go-emotions': st.secrets['model']}
    labels = load_labels()

    # Download all models from drive
    download_models(ids)

    # Display labels
    st.markdown(f"__Labels:__ {', '.join(labels)}")

    # Model selection
    left, right = st.columns([4, 2])
    inputs = left.text_area('', max_chars=4096, value='This is a space about multiclass emotion classification. Write '
                                                      'something here to see what happens!')
    model_path = right.selectbox('', options=[k for k in ids], index=0, help='Model to use. ')
    split = right.checkbox('Split into sentences', value=True)
    model = load_model(model_path=f"model/{model_path}.pt")
    right.write(model.device)

    if split:
        if not inputs.isspace() and inputs != "":
            with st.spinner('Processing text... This may take a while.'):
                left.write(model(inputs_to_dataset(sent_tokenize(inputs)), batch_size=1))
    else:
        if not inputs.isspace() and inputs != "":
            with st.spinner('Processing text... This may take a while.'):
                left.write(model(inputs_to_dataset([inputs]), batch_size=1))