awacke1's picture
Update app.py
24be11a
raw
history blame contribute delete
No virus
2.38 kB
import streamlit as st
import altair as alt
import torch
from transformers import AlbertTokenizer, AlbertForSequenceClassification
import sentencepiece as spm
import pandas as pd
# Load pre-trained model and tokenizer
model_name = "albert-base-v2"
tokenizer = AlbertTokenizer.from_pretrained(model_name)
model = AlbertForSequenceClassification.from_pretrained(model_name)
# Define function to classify input text
def classify_text(text):
inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits.detach().numpy()[0]
probabilities = torch.softmax(torch.tensor(logits), dim=0).tolist()
return probabilities
# Set up Streamlit app
st.title("ALBERT Text Classification App")
# Create input box for user to enter text
default_text = "Streamlit-Altair: A component that allows the creation of Altair visualizations within Streamlit.\nStreamlit-Bokeh: A component that allows the creation of Bokeh visualizations within Streamlit.\nStreamlit-Plotly: A component that allows the creation of Plotly visualizations within Streamlit.\nStreamlit-Mapbox: A component that allows the creation of Mapbox maps within Streamlit.\nStreamlit-DeckGL: A component that allows the creation of Deck.GL visualizations within Streamlit.\nStreamlit-Wordcloud: A component that allows the creation of word clouds within Streamlit.\nStreamlit-Audio: A component that allows the playing of audio files within Streamlit.\nStreamlit-Video: A component that allows the playing of video files within Streamlit.\nStreamlit-EmbedCode: A component that allows the embedding of code snippets within Streamlit.\nStreamlit-Components: A component that provides a library of custom Streamlit components created by the Streamlit community."
text_input = st.text_area("Enter text to classify", default_text, height=200)
# Classify input text and display results
if st.button("Classify"):
if text_input:
probabilities = classify_text(text_input)
df = pd.DataFrame({
'Label': ['Negative', 'Positive'],
'Probability': probabilities
})
chart = alt.Chart(df).mark_bar().encode(
x='Probability',
y=alt.Y('Label', sort=['Negative', 'Positive'])
)
st.write(chart)
else:
st.write("Please enter some text to classify.")