Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -10,6 +10,8 @@ from transformers import pipeline
|
|
10 |
from torch.utils.data import TensorDataset, random_split, DataLoader, RandomSampler, SequentialSampler
|
11 |
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
|
12 |
import streamlit as st
|
|
|
|
|
13 |
|
14 |
import pandas as pd
|
15 |
import json
|
@@ -32,22 +34,6 @@ def get_top95(y_predict, convert_target):
|
|
32 |
if cumsum > 0.95:
|
33 |
break
|
34 |
return lst_labels
|
35 |
-
#
|
36 |
-
# Creating the customized model, by adding a drop out and a dense layer on top of distil bert to get the final output for the model.
|
37 |
-
from transformers import DistilBertModel, DistilBertTokenizer
|
38 |
-
|
39 |
-
|
40 |
-
# model.load_state_dict(checkpoint['model'])
|
41 |
-
# optimizer.load_state_dict(checkpoint['opt'])
|
42 |
-
# model.to("cpu")
|
43 |
-
|
44 |
-
# print(model)
|
45 |
-
# model = DistilBertForSequenceClassification.from_pretrained("model/distilbert-model1.pt", local_files_only=True)
|
46 |
-
# tokenizer = BigBirdTokenizer.from_pretrained('google/bigbird-pegasus-large-arxiv')
|
47 |
-
|
48 |
-
# model = BigBirdPegasusForSequenceClassification.from_pretrained('google/bigbird-pegasus-large-arxiv',
|
49 |
-
# num_labels=8,
|
50 |
-
# return_dict=False)
|
51 |
|
52 |
|
53 |
|
@@ -82,16 +68,6 @@ model = torch.load("pytorch_distilbert_news (4).bin", map_location=torch.device(
|
|
82 |
|
83 |
def get_predict(title, abstract):
|
84 |
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
|
85 |
-
# encoded_dict = tokenizer.encode_plus(
|
86 |
-
# text, # document to encode.
|
87 |
-
# add_special_tokens=True, # add '[CLS]' and '[SEP]'
|
88 |
-
# max_length=512, # set max length
|
89 |
-
# truncation=True, # truncate longer messages
|
90 |
-
# pad_to_max_length=True, # add padding
|
91 |
-
# return_attention_mask=True, # create attn. masks
|
92 |
-
# return_tensors='pt' # return pytorch tensors
|
93 |
-
# )
|
94 |
-
|
95 |
inputs = tokenizer(title, abstract, return_tensors="pt")
|
96 |
outputs = model(
|
97 |
input_ids=inputs['input_ids'],
|
@@ -105,21 +81,13 @@ def get_predict(title, abstract):
|
|
105 |
with open(file_path, 'r') as json_file:
|
106 |
decode_target = json.load(json_file)
|
107 |
return get_top95(y_predict, decode_target)
|
108 |
-
#
|
109 |
-
#
|
110 |
-
#
|
111 |
-
#
|
112 |
-
#
|
113 |
-
# get_predict('''physics physics physics physics physics
|
114 |
-
# physics physics physics physics''')
|
115 |
-
#
|
116 |
|
117 |
-
|
118 |
-
st.markdown("
|
119 |
# ^-- можно показывать пользователю текст, картинки, ограниченное подмножество html - всё как в jupyter
|
120 |
|
121 |
-
title = st.text_area("
|
122 |
-
abstract = st.text_area("
|
123 |
|
124 |
# ^-- показать текстовое поле. В поле text лежит строка, которая находится там в данный момент
|
125 |
|
|
|
10 |
from torch.utils.data import TensorDataset, random_split, DataLoader, RandomSampler, SequentialSampler
|
11 |
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
|
12 |
import streamlit as st
|
13 |
+
from transformers import DistilBertModel, DistilBertTokenizer
|
14 |
+
|
15 |
|
16 |
import pandas as pd
|
17 |
import json
|
|
|
34 |
if cumsum > 0.95:
|
35 |
break
|
36 |
return lst_labels
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
|
39 |
|
|
|
68 |
|
69 |
def get_predict(title, abstract):
|
70 |
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
inputs = tokenizer(title, abstract, return_tensors="pt")
|
72 |
outputs = model(
|
73 |
input_ids=inputs['input_ids'],
|
|
|
81 |
with open(file_path, 'r') as json_file:
|
82 |
decode_target = json.load(json_file)
|
83 |
return get_top95(y_predict, decode_target)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
+
|
86 |
+
st.markdown("Классификатор статей")
|
87 |
# ^-- можно показывать пользователю текст, картинки, ограниченное подмножество html - всё как в jupyter
|
88 |
|
89 |
+
title = st.text_area("Title", key=1)
|
90 |
+
abstract = st.text_area("Abstract", key=2)
|
91 |
|
92 |
# ^-- показать текстовое поле. В поле text лежит строка, которая находится там в данный момент
|
93 |
|