Spaces:
Runtime error
Runtime error
File size: 2,999 Bytes
bbc9b75 dfcdd52 bbc9b75 6c6aac8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import streamlit as st
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import time
def generate_text(model, tokenizer, prompt, max_length, num_generations, temperature):
generated_texts = []
for _ in range(num_generations):
input_ids = tokenizer.encode(prompt, return_tensors="pt")
output = model.generate(
input_ids,
max_length=max_length,
temperature=temperature,
num_return_sequences=1
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
generated_texts.append(generated_text)
return generated_texts
button_style = """
<style>
.center-align {
display: flex;
justify-content: center;
</style>
"""
DEVICE = 'cpu'
tokenizer_path = "sberbank-ai/rugpt3small_based_on_gpt2"
model = torch.load('srcs/gpt_weights.pth').to(DEVICE)
tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path)
st.markdown("""
<style>
section[data-testid="stSidebar"][aria-expanded="true"]{
display: none;
}
</style>
""", unsafe_allow_html=True)
st.write("## Text generator")
st.page_link("app.py", label="Home", icon='🏠')
st.markdown(
"""
This streamlit-app can generate text using your prompt
"""
)
# Ввод пользовательского prompt
prompt = st.text_area("Enter your prompt:")
# Параметры генерации
max_length = st.slider("Max length of generated text:", min_value=10, max_value=500, value=100, step=10)
num_generations = st.slider("Number of generations:", min_value=1, max_value=10, value=3, step=1)
temperature = st.slider("Temperature:", min_value=0.1, max_value=2.0, value=1.0, step=0.1)
try:
if st.button("Generate text"):
start_time = time.time()
generated_texts = generate_text(model, tokenizer, prompt, max_length, num_generations, temperature)
end_time = time.time()
st.subheader("Сгенерированный текст:")
for i, text in enumerate(generated_texts, start=1):
st.write(f"Генерация {i}:\n{text}")
generation_time = end_time - start_time
st.write(f"\nВремя генерации: {generation_time:.2f} секунд")
st.markdown(button_style, unsafe_allow_html=True) # Применяем стиль к кнопке
st.markdown(
"""
<style>
div[data-baseweb="textarea"] {
border: 2px solid #3498db; /* Цвет границы */
border-radius: 5px; /* Закругленные углы */
background-color: #ecf0f1; /* Цвет фона */
padding: 10px; /* Поля вокруг текстового поля */
}
</style>
""",
unsafe_allow_html=True,
)
except:
st.write('Модель в разработке ( ノ ゚ー゚)ノ( ノ ゚ー゚)ノ( ノ ゚ー゚)ノ( ノ ゚ー゚)ノ') |