Spaces:
No application file
No application file
import torch | |
import streamlit as st | |
import io | |
import imageio | |
from PIL import Image | |
import torch.nn as nn | |
import time | |
from task2 import predict_class | |
from task3 import generate_text | |
import tempfile | |
import os | |
from transformers import pipeline | |
# Замените "ваш-пользователь/ваш-новый-репозиторий" на ваш новый путь на Hugging Face | |
model_path = "HaggiVaggi/nlp_project" | |
generator = pipeline('text-generation', model=model_path) | |
st.title('Обработка естественного языка • Natural Language Processing') | |
with st.sidebar: | |
st.header('Выберите страницу') | |
page = st.selectbox("Выберите страницу", ["Главная", "Отзывы на рестораны",\ | |
"Тематика новостей", "GPT by GPT-team", "Итоги"]) | |
if page == "Главная": | |
st.header('Выполнила команда "GPT":') | |
st.subheader('🦁Рома') | |
st.subheader('🐯Руслан') | |
st.subheader('🐱Тата') | |
st.header(" 🌟 " * 10) | |
st.header('Наши задачи:') | |
st.subheader('*Задача №1*: Классификация отзыва на рестораны') | |
st.subheader('*Задача №2*: Классификация тематики новостей из телеграм каналов') | |
st.subheader('*Задача №2*: Генерация текста GPT-моделью по пользовательскому prompt') | |
elif page == "Отзывы на рестораны": | |
st.header("Отзывы на рестораны:") | |
elif page == "Тематика новостей": | |
st.header("Тематика новостей:") | |
st.markdown(f"<span style='font-size:{30}px; color:purple'>{'Модель: DeepPavlov/rubert-base-cased'}</span>", unsafe_allow_html=True) | |
st.info('Модель основана на архитектуре BERT (Bidirectional Encoder Representations from Transformers), представленной в [статье](https://arxiv.org/abs/1810.04805)') | |
st.info('Rubert-base-cased: "cased" означает, что в этой модели сохранен регистр слов. Это важно для русского языка, где регистр может влиять на смысл слов.') | |
st.info('В библиотеке [Transformers от Hugging Face](https://huggingface.co/DeepPavlov/rubert-base-cased), слой классификации представляется в виде BertForSequenceClassification. Этот классификатор добавляется к основной модели BERT и обучается на конкретной задаче классификации текста.') | |
user_input = st.text_area('Введите текст поста и мы узнаем, к какой тематике его отнести:') | |
if st.button("Предсказать"): | |
pred = predict_class(user_input) | |
st.subheader("Это текст по теме:" ) | |
st.markdown(f'<span style="font-size:{25}px; color:pink">{pred}</span>', unsafe_allow_html=True) | |
st.subheader("Accuracy и Loss на 5 эпохах" ) | |
image_1 = imageio.imread('pictures/im1.png')[:, :, :] | |
st.image(image_1) | |
elif page == "GPT by GPT-team": | |
st.header("GPT by GPT-team:") | |
st.markdown(f"<span style='font-size:{30}px; color:green'>{'Модель: GPT2LMHeadModel'}</span>", unsafe_allow_html=True) | |
st.info('[GPT2LMHeadModel](https://huggingface.co/docs/transformers/model_doc/gpt2) - это модель, способная генерировать текст, учитывая предшествующий контекст.') | |
st.info('[Sberbank-ai/rugpt3small_based_on_gpt2](https://huggingface.co/ai-forever/rugpt3small_based_on_gpt2): Это конкретная предобученная модель GPT-2, которая была дообучена на русском языке командой Sber AI.\ | |
Она обладает способностью генерировать текст, принимая на вход текстовый контекст.') | |
user_input2 = st.text_area("Введите текст:", "") | |
if st.button("Сгенерировать"): | |
generated = generate_text(user_input2) | |
st.subheader("Сгенерированный текст:") | |
st.markdown(f'<span style="font-size:{25}px; color:green">{generated}</span>', unsafe_allow_html=True) | |
# st.subheader("- Модель: *ConvAutoencoder()*") | |
# st.subheader("- Количество эпох обучения: *100*") | |
# st.info('Расширение картинки должно быть в формате .jpg /.jpeg /.png') | |
# image_url2 = st.text_input("Введите URL изображения") | |
# start_time2 = time.time() | |
# if image_url2: | |
# # Загрузка изображения по ссылке | |
# response2 = requests.get(image_url2) | |
# image2 = Image.open(io.BytesIO(response2.content)) | |
# st.subheader('Ваше фото до обработки:') | |
# st.image(image2) | |
# prediction_result = predict_1(image2) | |
# show_result_button3 = st.button("Показать результат", key="result_button_3") | |
# if show_result_button3: | |
# st.success("Ваш результат готов!") | |
# st.subheader("Ваше фото после обработки:") | |
# st.image(prediction_result, channels='GRAY') | |
# st.subheader(f'Время предсказания: {round((time.time() - start_time2), 2)} сек.') | |
# st.header('🎈' * 10) | |
elif page == "Итоги": | |
st.header('Результаты и выводы') | |
# st.subheader('*Задача №1*: Детектирование ветряных мельниц') | |
# st.subheader("Метрики из Clear ML") | |
# image_1 = Image.open("pictures/P_curve.png") | |
# image_2 = Image.open("pictures/PR_curve.png") | |
# image_3 = Image.open("pictures/R_curve.png") | |
# image_4 = Image.open("pictures/F1_curve.png") | |
# # Отображаем изображения в одной строке | |
# st.image([image_1, image_2, image_3, image_4], caption=['Image 1 - P_curve', 'Image 2 - PR_curve', 'Image 3 - R_curve', 'Image 4 - F1_curve'], width=300) | |
# st.subheader("Результативные графики из Clear ML") | |
# image_5 = imageio.imread('pictures/plots.jpg')[:, :, :] | |
# st.image(image_5) | |