niknikita commited on
Commit
b58d194
1 Parent(s): 7895958

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -1
app.py CHANGED
@@ -37,7 +37,6 @@ def get_top95(y_predict, convert_target):
37
  from transformers import DistilBertModel, DistilBertTokenizer
38
 
39
 
40
- model = torch.load("pytorch_distilbert_news (3).bin", map_location=torch.device('cpu'))
41
  # model.load_state_dict(checkpoint['model'])
42
  # optimizer.load_state_dict(checkpoint['opt'])
43
  # model.to("cpu")
@@ -50,6 +49,37 @@ model = torch.load("pytorch_distilbert_news (3).bin", map_location=torch.device(
50
  # num_labels=8,
51
  # return_dict=False)
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def get_predict(title, abstract):
54
  tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
55
  # encoded_dict = tokenizer.encode_plus(
 
37
  from transformers import DistilBertModel, DistilBertTokenizer
38
 
39
 
 
40
  # model.load_state_dict(checkpoint['model'])
41
  # optimizer.load_state_dict(checkpoint['opt'])
42
  # model.to("cpu")
 
49
  # num_labels=8,
50
  # return_dict=False)
51
 
52
+
53
+
54
+ class DistillBERTClass(torch.nn.Module):
55
+ def __init__(self):
56
+ super(DistillBERTClass, self).__init__()
57
+ self.l1 = DistilBertModel.from_pretrained("distilbert-base-uncased")
58
+ self.pre_classifier = torch.nn.Linear(768, 768)
59
+ self.dropout = torch.nn.Dropout(0.3)
60
+ self.classifier = torch.nn.Linear(768, 8)
61
+
62
+ def forward(self, input_ids, attention_mask):
63
+ output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask)
64
+ hidden_state = output_1[0]
65
+ pooler = hidden_state[:, 0]
66
+ pooler = self.pre_classifier(pooler)
67
+ pooler = torch.nn.ReLU()(pooler)
68
+ pooler = self.dropout(pooler)
69
+ output = self.classifier(pooler)
70
+ return output
71
+
72
+
73
+ model = DistillBERTClass()
74
+ LEARNING_RATE = 1e-05
75
+
76
+ optimizer = torch.optim.Adam(params = model.parameters(), lr=LEARNING_RATE)
77
+
78
+
79
+
80
+ model = torch.load("pytorch_distilbert_news (3).bin", map_location=torch.device('cpu'))
81
+
82
+
83
  def get_predict(title, abstract):
84
  tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
85
  # encoded_dict = tokenizer.encode_plus(