KoichiYasuoka
commited on
Commit
•
f3f71bf
1
Parent(s):
6da92f1
memory usage reduced
Browse files
ud.py
CHANGED
@@ -84,10 +84,13 @@ class UniversalDependenciesCausalPipeline(BellmanFordTokenClassificationPipeline
|
|
84 |
m.append(e[x,:].sum(axis=0))
|
85 |
m.append(e[self.tokenizer.sep_token_id,:])
|
86 |
m.append(e[self.tokenizer.pad_token_id,:])
|
87 |
-
m=torch.stack(m)
|
88 |
k=list(range(len(d)+1))
|
|
|
89 |
with torch.no_grad():
|
90 |
-
|
|
|
|
|
91 |
for i in range(len(d)):
|
92 |
for j in range(i):
|
93 |
e[-j-1,-i-1],e[-i-1,-j-1]=e[-i-1,i-j]+self.left_arc,e[-i-1,i-j]+self.right_arc
|
|
|
84 |
m.append(e[x,:].sum(axis=0))
|
85 |
m.append(e[self.tokenizer.sep_token_id,:])
|
86 |
m.append(e[self.tokenizer.pad_token_id,:])
|
87 |
+
m=torch.stack(m).to(self.device)
|
88 |
k=list(range(len(d)+1))
|
89 |
+
e=[]
|
90 |
with torch.no_grad():
|
91 |
+
for i in range(len(d)):
|
92 |
+
e.append(self.model(inputs_embeds=torch.unsqueeze(m[k+list(range(i,len(d)))+[-1]*i,:],0)).logits[0,-len(d):,:])
|
93 |
+
e=torch.stack(e).cpu().numpy()
|
94 |
for i in range(len(d)):
|
95 |
for j in range(i):
|
96 |
e[-j-1,-i-1],e[-i-1,-j-1]=e[-i-1,i-j]+self.left_arc,e[-i-1,i-j]+self.right_arc
|