Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -36,29 +36,7 @@ def get_top95(y_predict, convert_target):
|
|
36 |
# Creating the customized model, by adding a drop out and a dense layer on top of distil bert to get the final output for the model.
|
37 |
from transformers import DistilBertModel, DistilBertTokenizer
|
38 |
|
39 |
-
class DistillBERTClass(torch.nn.Module):
|
40 |
-
def __init__(self):
|
41 |
-
super(DistillBERTClass, self).__init__()
|
42 |
-
self.l1 = DistilBertModel.from_pretrained("distilbert-base-uncased")
|
43 |
-
self.pre_classifier = torch.nn.Linear(768, 768)
|
44 |
-
self.dropout = torch.nn.Dropout(0.3)
|
45 |
-
self.classifier = torch.nn.Linear(768, 8)
|
46 |
|
47 |
-
def forward(self, input_ids, attention_mask):
|
48 |
-
output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask)
|
49 |
-
hidden_state = output_1[0]
|
50 |
-
pooler = hidden_state[:, 0]
|
51 |
-
pooler = self.pre_classifier(pooler)
|
52 |
-
pooler = torch.nn.ReLU()(pooler)
|
53 |
-
pooler = self.dropout(pooler)
|
54 |
-
output = self.classifier(pooler)
|
55 |
-
return output
|
56 |
-
|
57 |
-
|
58 |
-
model = DistillBERTClass()
|
59 |
-
LEARNING_RATE = 1e-05
|
60 |
-
|
61 |
-
optimizer = torch.optim.Adam(params = model.parameters(), lr=LEARNING_RATE)
|
62 |
model = torch.load("pytorch_distilbert_news (3).bin", map_location=torch.device('cpu'))
|
63 |
# model.load_state_dict(checkpoint['model'])
|
64 |
# optimizer.load_state_dict(checkpoint['opt'])
|
@@ -90,6 +68,7 @@ def get_predict(title, abstract):
|
|
90 |
attention_mask=inputs['attention_mask'],
|
91 |
)
|
92 |
logits = outputs[0]
|
|
|
93 |
y_predict = torch.nn.functional.softmax(logits).cpu().detach().numpy()
|
94 |
file_path = "sample.json"
|
95 |
|
|
|
36 |
# Creating the customized model, by adding a drop out and a dense layer on top of distil bert to get the final output for the model.
|
37 |
from transformers import DistilBertModel, DistilBertTokenizer
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
model = torch.load("pytorch_distilbert_news (3).bin", map_location=torch.device('cpu'))
|
41 |
# model.load_state_dict(checkpoint['model'])
|
42 |
# optimizer.load_state_dict(checkpoint['opt'])
|
|
|
68 |
attention_mask=inputs['attention_mask'],
|
69 |
)
|
70 |
logits = outputs[0]
|
71 |
+
print(logits)
|
72 |
y_predict = torch.nn.functional.softmax(logits).cpu().detach().numpy()
|
73 |
file_path = "sample.json"
|
74 |
|