Spaces:
Build error
Build error
Create token_weighter.py
Browse files
rhyme_with_ai/token_weighter.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
class TokenWeighter:
|
5 |
+
def __init__(self, tokenizer):
|
6 |
+
self.tokenizer_ = tokenizer
|
7 |
+
self.proba = self.get_token_proba()
|
8 |
+
|
9 |
+
def get_token_proba(self):
|
10 |
+
valid_token_mask = self._filter_short_partial(self.tokenizer_.vocab)
|
11 |
+
return valid_token_mask
|
12 |
+
|
13 |
+
def _filter_short_partial(self, vocab):
|
14 |
+
valid_token_ids = [v for k, v in vocab.items() if len(k) > 1 and "#" not in k]
|
15 |
+
is_valid = np.zeros(len(vocab.keys()))
|
16 |
+
is_valid[valid_token_ids] = 1
|
17 |
+
return is_valid
|