KoichiYasuoka commited on
Commit
856b7ac
1 Parent(s): a2d7c30

GPU support

Browse files
Files changed (1) hide show
  1. ud.py +1 -1
ud.py CHANGED
@@ -5,7 +5,7 @@ class UniversalDependenciesPipeline(TokenClassificationPipeline):
5
  import torch
6
  v=[self.tokenizer.cls_token_id]+[t for t,(s,e) in zip(model_inputs["input_ids"][0].tolist(),model_inputs["offset_mapping"][0].tolist()) if s<e]+[self.tokenizer.sep_token_id]
7
  with torch.no_grad():
8
- e=self.model(input_ids=torch.tensor([v[0:i]+[self.tokenizer.mask_token_id]+v[i+1:]+[j] for i,j in enumerate(v[1:-1],1)]))
9
  return {"logits":e.logits[:,1:-2,:],**model_inputs}
10
  def postprocess(self,model_outputs,**kwargs):
11
  import numpy
 
5
  import torch
6
  v=[self.tokenizer.cls_token_id]+[t for t,(s,e) in zip(model_inputs["input_ids"][0].tolist(),model_inputs["offset_mapping"][0].tolist()) if s<e]+[self.tokenizer.sep_token_id]
7
  with torch.no_grad():
8
+ e=self.model(input_ids=torch.tensor([v[0:i]+[self.tokenizer.mask_token_id]+v[i+1:]+[j] for i,j in enumerate(v[1:-1],1)],device=self.device))
9
  return {"logits":e.logits[:,1:-2,:],**model_inputs}
10
  def postprocess(self,model_outputs,**kwargs):
11
  import numpy