Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -60,11 +60,12 @@ if text:
|
|
60 |
return_tensors='pt',
|
61 |
return_length=True
|
62 |
)
|
63 |
-
input_ids = encoded_dict['
|
64 |
-
|
|
|
65 |
attn_mask = attn_mask[None, :] < input_ids_len[:, None]
|
66 |
-
outputs = model(
|
67 |
|
68 |
-
_, preds = torch.max(outputs, 1)
|
69 |
|
70 |
st.write(topics_raw[preds.squeeze(0)])
|
|
|
60 |
return_tensors='pt',
|
61 |
return_length=True
|
62 |
)
|
63 |
+
input_ids = encoded_dict['input_ids']
|
64 |
+
input_ids_len = encoded_dict['length'].unsqueeze(0)
|
65 |
+
attn_mask = torch.arange(input_ids.size(1))
|
66 |
attn_mask = attn_mask[None, :] < input_ids_len[:, None]
|
67 |
+
outputs = model(input_ids=input_ids, attention_mask=attn_mask)
|
68 |
|
69 |
+
_, preds = torch.max(outputs.logits, 1)
|
70 |
|
71 |
st.write(topics_raw[preds.squeeze(0)])
|