alex6095 commited on
Commit
f355e52
1 Parent(s): b13a6c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -60,11 +60,12 @@ if text:
60
  return_tensors='pt',
61
  return_length=True
62
  )
63
- input_ids = encoded_dict['length'].unsqueeze(0)
64
- attn_mask = torch.arange(input_ids.size(1)).to(device)
 
65
  attn_mask = attn_mask[None, :] < input_ids_len[:, None]
66
- outputs = model(encoded_dict['input_ids'], attn_mask)
67
 
68
- _, preds = torch.max(outputs, 1)
69
 
70
  st.write(topics_raw[preds.squeeze(0)])
 
60
  return_tensors='pt',
61
  return_length=True
62
  )
63
+ input_ids = encoded_dict['input_ids']
64
+ input_ids_len = encoded_dict['length'].unsqueeze(0)
65
+ attn_mask = torch.arange(input_ids.size(1))
66
  attn_mask = attn_mask[None, :] < input_ids_len[:, None]
67
+ outputs = model(input_ids=input_ids, attention_mask=attn_mask)
68
 
69
+ _, preds = torch.max(outputs.logits, 1)
70
 
71
  st.write(topics_raw[preds.squeeze(0)])