niknikita commited on
Commit
238ca26
1 Parent(s): d02d7c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -38
app.py CHANGED
@@ -10,6 +10,8 @@ from transformers import pipeline
10
  from torch.utils.data import TensorDataset, random_split, DataLoader, RandomSampler, SequentialSampler
11
  from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
12
  import streamlit as st
 
 
13
 
14
  import pandas as pd
15
  import json
@@ -32,22 +34,6 @@ def get_top95(y_predict, convert_target):
32
  if cumsum > 0.95:
33
  break
34
  return lst_labels
35
- #
36
- # Creating the customized model, by adding a drop out and a dense layer on top of distil bert to get the final output for the model.
37
- from transformers import DistilBertModel, DistilBertTokenizer
38
-
39
-
40
- # model.load_state_dict(checkpoint['model'])
41
- # optimizer.load_state_dict(checkpoint['opt'])
42
- # model.to("cpu")
43
-
44
- # print(model)
45
- # model = DistilBertForSequenceClassification.from_pretrained("model/distilbert-model1.pt", local_files_only=True)
46
- # tokenizer = BigBirdTokenizer.from_pretrained('google/bigbird-pegasus-large-arxiv')
47
-
48
- # model = BigBirdPegasusForSequenceClassification.from_pretrained('google/bigbird-pegasus-large-arxiv',
49
- # num_labels=8,
50
- # return_dict=False)
51
 
52
 
53
 
@@ -82,16 +68,6 @@ model = torch.load("pytorch_distilbert_news (4).bin", map_location=torch.device(
82
 
83
  def get_predict(title, abstract):
84
  tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
85
- # encoded_dict = tokenizer.encode_plus(
86
- # text, # document to encode.
87
- # add_special_tokens=True, # add '[CLS]' and '[SEP]'
88
- # max_length=512, # set max length
89
- # truncation=True, # truncate longer messages
90
- # pad_to_max_length=True, # add padding
91
- # return_attention_mask=True, # create attn. masks
92
- # return_tensors='pt' # return pytorch tensors
93
- # )
94
-
95
  inputs = tokenizer(title, abstract, return_tensors="pt")
96
  outputs = model(
97
  input_ids=inputs['input_ids'],
@@ -105,21 +81,13 @@ def get_predict(title, abstract):
105
  with open(file_path, 'r') as json_file:
106
  decode_target = json.load(json_file)
107
  return get_top95(y_predict, decode_target)
108
- #
109
- #
110
- #
111
- #
112
- #
113
- # get_predict('''physics physics physics physics physics
114
- # physics physics physics physics''')
115
- #
116
 
117
- st.markdown("### Hello, world!")
118
- st.markdown("<img width=200px src='https://rozetked.me/images/uploads/dwoilp3BVjlE.jpg'>", unsafe_allow_html=True)
119
  # ^-- можно показывать пользователю текст, картинки, ограниченное подмножество html - всё как в jupyter
120
 
121
- title = st.text_area("TEXT HERE", key=1)
122
- abstract = st.text_area("TEXT HERE", key=2)
123
 
124
  # ^-- показать текстовое поле. В поле text лежит строка, которая находится там в данный момент
125
 
 
10
  from torch.utils.data import TensorDataset, random_split, DataLoader, RandomSampler, SequentialSampler
11
  from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
12
  import streamlit as st
13
+ from transformers import DistilBertModel, DistilBertTokenizer
14
+
15
 
16
  import pandas as pd
17
  import json
 
34
  if cumsum > 0.95:
35
  break
36
  return lst_labels
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
 
39
 
 
68
 
69
  def get_predict(title, abstract):
70
  tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
 
 
 
 
 
 
 
 
 
 
71
  inputs = tokenizer(title, abstract, return_tensors="pt")
72
  outputs = model(
73
  input_ids=inputs['input_ids'],
 
81
  with open(file_path, 'r') as json_file:
82
  decode_target = json.load(json_file)
83
  return get_top95(y_predict, decode_target)
 
 
 
 
 
 
 
 
84
 
85
+
86
+ st.markdown("Классификатор статей")
87
  # ^-- можно показывать пользователю текст, картинки, ограниченное подмножество html - всё как в jupyter
88
 
89
+ title = st.text_area("Title", key=1)
90
+ abstract = st.text_area("Abstract", key=2)
91
 
92
  # ^-- показать текстовое поле. В поле text лежит строка, которая находится там в данный момент
93