KoichiYasuoka commited on
Commit
f3f71bf
1 Parent(s): 6da92f1

memory usage reduced

Browse files
Files changed (1) hide show
  1. ud.py +5 -2
ud.py CHANGED
@@ -84,10 +84,13 @@ class UniversalDependenciesCausalPipeline(BellmanFordTokenClassificationPipeline
84
  m.append(e[x,:].sum(axis=0))
85
  m.append(e[self.tokenizer.sep_token_id,:])
86
  m.append(e[self.tokenizer.pad_token_id,:])
87
- m=torch.stack(m)
88
  k=list(range(len(d)+1))
 
89
  with torch.no_grad():
90
- e=self.model(inputs_embeds=torch.stack([m[k+list(range(i,len(d)))+[-1]*i,:] for i in range(len(d))]).to(self.device)).logits[:,-len(d):,:].cpu().numpy()
 
 
91
  for i in range(len(d)):
92
  for j in range(i):
93
  e[-j-1,-i-1],e[-i-1,-j-1]=e[-i-1,i-j]+self.left_arc,e[-i-1,i-j]+self.right_arc
 
84
  m.append(e[x,:].sum(axis=0))
85
  m.append(e[self.tokenizer.sep_token_id,:])
86
  m.append(e[self.tokenizer.pad_token_id,:])
87
+ m=torch.stack(m).to(self.device)
88
  k=list(range(len(d)+1))
89
+ e=[]
90
  with torch.no_grad():
91
+ for i in range(len(d)):
92
+ e.append(self.model(inputs_embeds=torch.unsqueeze(m[k+list(range(i,len(d)))+[-1]*i,:],0)).logits[0,-len(d):,:])
93
+ e=torch.stack(e).cpu().numpy()
94
  for i in range(len(d)):
95
  for j in range(i):
96
  e[-j-1,-i-1],e[-i-1,-j-1]=e[-i-1,i-j]+self.left_arc,e[-i-1,i-j]+self.right_arc