Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files- .gitattributes +4 -0
- IndicTransTokenizer/.gitignore +3 -0
- IndicTransTokenizer/IndicTransTokenizer/__init__.py +2 -0
- IndicTransTokenizer/IndicTransTokenizer/en-indic/dict.SRC.json +0 -0
- IndicTransTokenizer/IndicTransTokenizer/en-indic/dict.TGT.json +0 -0
- IndicTransTokenizer/IndicTransTokenizer/en-indic/model.SRC +0 -0
- IndicTransTokenizer/IndicTransTokenizer/en-indic/model.TGT +3 -0
- IndicTransTokenizer/IndicTransTokenizer/indic-en/dict.SRC.json +0 -0
- IndicTransTokenizer/IndicTransTokenizer/indic-en/dict.TGT.json +0 -0
- IndicTransTokenizer/IndicTransTokenizer/indic-en/model.SRC +3 -0
- IndicTransTokenizer/IndicTransTokenizer/indic-en/model.TGT +0 -0
- IndicTransTokenizer/IndicTransTokenizer/indic-indic/dict.SRC.json +0 -0
- IndicTransTokenizer/IndicTransTokenizer/indic-indic/dict.TGT.json +0 -0
- IndicTransTokenizer/IndicTransTokenizer/indic-indic/model.SRC +3 -0
- IndicTransTokenizer/IndicTransTokenizer/indic-indic/model.TGT +3 -0
- IndicTransTokenizer/IndicTransTokenizer/tokenizer.py +262 -0
- IndicTransTokenizer/IndicTransTokenizer/utils.py +530 -0
- IndicTransTokenizer/IndicTransTokenizer/version.py +1 -0
- IndicTransTokenizer/IndicTransTokenizer/version.txt +1 -0
- IndicTransTokenizer/LICENSE +21 -0
- IndicTransTokenizer/README.md +77 -0
- IndicTransTokenizer/requirements.txt +6 -0
- IndicTransTokenizer/setup.py +47 -0
- README.md +5 -4
- app.py +87 -0
- config.py +5 -0
- examples.py +11 -0
- indictrans2.py +98 -0
- requirements.txt +8 -0
.gitattributes
CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
IndicTransTokenizer/IndicTransTokenizer/en-indic/model.TGT filter=lfs diff=lfs merge=lfs -text
|
37 |
+
IndicTransTokenizer/IndicTransTokenizer/indic-en/model.SRC filter=lfs diff=lfs merge=lfs -text
|
38 |
+
IndicTransTokenizer/IndicTransTokenizer/indic-indic/model.SRC filter=lfs diff=lfs merge=lfs -text
|
39 |
+
IndicTransTokenizer/IndicTransTokenizer/indic-indic/model.TGT filter=lfs diff=lfs merge=lfs -text
|
IndicTransTokenizer/.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
dist/
|
2 |
+
IndicTransTokenizer.egg-info
|
3 |
+
IndicTransTokenizer/__pycache__/
|
IndicTransTokenizer/IndicTransTokenizer/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .tokenizer import IndicTransTokenizer
|
2 |
+
from .utils import IndicProcessor
|
IndicTransTokenizer/IndicTransTokenizer/en-indic/dict.SRC.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
IndicTransTokenizer/IndicTransTokenizer/en-indic/dict.TGT.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
IndicTransTokenizer/IndicTransTokenizer/en-indic/model.SRC
ADDED
Binary file (759 kB). View file
|
|
IndicTransTokenizer/IndicTransTokenizer/en-indic/model.TGT
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ac9257c8e76b8b607705b959cc3d075656ea33032f7a974e467b8941df6e98d4
|
3 |
+
size 3256903
|
IndicTransTokenizer/IndicTransTokenizer/indic-en/dict.SRC.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
IndicTransTokenizer/IndicTransTokenizer/indic-en/dict.TGT.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
IndicTransTokenizer/IndicTransTokenizer/indic-en/model.SRC
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ac9257c8e76b8b607705b959cc3d075656ea33032f7a974e467b8941df6e98d4
|
3 |
+
size 3256903
|
IndicTransTokenizer/IndicTransTokenizer/indic-en/model.TGT
ADDED
Binary file (759 kB). View file
|
|
IndicTransTokenizer/IndicTransTokenizer/indic-indic/dict.SRC.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
IndicTransTokenizer/IndicTransTokenizer/indic-indic/dict.TGT.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
IndicTransTokenizer/IndicTransTokenizer/indic-indic/model.SRC
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ac9257c8e76b8b607705b959cc3d075656ea33032f7a974e467b8941df6e98d4
|
3 |
+
size 3256903
|
IndicTransTokenizer/IndicTransTokenizer/indic-indic/model.TGT
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ac9257c8e76b8b607705b959cc3d075656ea33032f7a974e467b8941df6e98d4
|
3 |
+
size 3256903
|
IndicTransTokenizer/IndicTransTokenizer/tokenizer.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import torch
|
4 |
+
from transformers import BatchEncoding
|
5 |
+
from typing import Dict, List, Tuple, Union
|
6 |
+
from sentencepiece import SentencePieceProcessor
|
7 |
+
|
8 |
+
_PATH = os.path.dirname(os.path.realpath(__file__))
|
9 |
+
|
10 |
+
|
11 |
+
class IndicTransTokenizer:
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
direction=None,
|
15 |
+
model_name=None,
|
16 |
+
unk_token="<unk>",
|
17 |
+
bos_token="<s>",
|
18 |
+
eos_token="</s>",
|
19 |
+
pad_token="<pad>",
|
20 |
+
model_max_length=256,
|
21 |
+
):
|
22 |
+
self.model_max_length = model_max_length
|
23 |
+
|
24 |
+
self.supported_langs = [
|
25 |
+
"asm_Beng",
|
26 |
+
"awa_Deva",
|
27 |
+
"ben_Beng",
|
28 |
+
"bho_Deva",
|
29 |
+
"brx_Deva",
|
30 |
+
"doi_Deva",
|
31 |
+
"eng_Latn",
|
32 |
+
"gom_Deva",
|
33 |
+
"gon_Deva",
|
34 |
+
"guj_Gujr",
|
35 |
+
"hin_Deva",
|
36 |
+
"hne_Deva",
|
37 |
+
"kan_Knda",
|
38 |
+
"kas_Arab",
|
39 |
+
"kas_Deva",
|
40 |
+
"kha_Latn",
|
41 |
+
"lus_Latn",
|
42 |
+
"mag_Deva",
|
43 |
+
"mai_Deva",
|
44 |
+
"mal_Mlym",
|
45 |
+
"mar_Deva",
|
46 |
+
"mni_Beng",
|
47 |
+
"mni_Mtei",
|
48 |
+
"npi_Deva",
|
49 |
+
"ory_Orya",
|
50 |
+
"pan_Guru",
|
51 |
+
"san_Deva",
|
52 |
+
"sat_Olck",
|
53 |
+
"snd_Arab",
|
54 |
+
"snd_Deva",
|
55 |
+
"tam_Taml",
|
56 |
+
"tel_Telu",
|
57 |
+
"urd_Arab",
|
58 |
+
"unr_Deva",
|
59 |
+
]
|
60 |
+
|
61 |
+
if model_name is None and direction is None:
|
62 |
+
raise ValueError("Either model_name or direction must be provided!")
|
63 |
+
|
64 |
+
if model_name is not None:
|
65 |
+
direction = self.get_direction(model_name) # model_name overrides direction
|
66 |
+
|
67 |
+
self.src_vocab_fp = os.path.join(_PATH, direction, "dict.SRC.json")
|
68 |
+
self.tgt_vocab_fp = os.path.join(_PATH, direction, "dict.TGT.json")
|
69 |
+
self.src_spm_fp = os.path.join(_PATH, direction, "model.SRC")
|
70 |
+
self.tgt_spm_fp = os.path.join(_PATH, direction, "model.TGT")
|
71 |
+
|
72 |
+
self.unk_token = unk_token
|
73 |
+
self.pad_token = pad_token
|
74 |
+
self.eos_token = eos_token
|
75 |
+
self.bos_token = bos_token
|
76 |
+
|
77 |
+
self.encoder = self._load_json(self.src_vocab_fp)
|
78 |
+
if self.unk_token not in self.encoder:
|
79 |
+
raise KeyError("<unk> token must be in vocab")
|
80 |
+
assert self.pad_token in self.encoder
|
81 |
+
self.encoder_rev = {v: k for k, v in self.encoder.items()}
|
82 |
+
|
83 |
+
self.decoder = self._load_json(self.tgt_vocab_fp)
|
84 |
+
if self.unk_token not in self.encoder:
|
85 |
+
raise KeyError("<unk> token must be in vocab")
|
86 |
+
assert self.pad_token in self.encoder
|
87 |
+
self.decoder_rev = {v: k for k, v in self.decoder.items()}
|
88 |
+
|
89 |
+
# load SentencePiece model for pre-processing
|
90 |
+
self.src_spm = self._load_spm(self.src_spm_fp)
|
91 |
+
self.tgt_spm = self._load_spm(self.tgt_spm_fp)
|
92 |
+
|
93 |
+
self.unk_token_id = self.encoder[self.unk_token]
|
94 |
+
self.pad_token_id = self.encoder[self.pad_token]
|
95 |
+
self.eos_token_id = self.encoder[self.eos_token]
|
96 |
+
self.bos_token_id = self.encoder[self.bos_token]
|
97 |
+
|
98 |
+
def get_direction(self, model_name: str) -> str:
|
99 |
+
pieces = model_name.split("/")[-1].split("-")
|
100 |
+
return f"{pieces[1]}-{pieces[2]}"
|
101 |
+
|
102 |
+
def is_special_token(self, x: str):
|
103 |
+
return (x == self.pad_token) or (x == self.bos_token) or (x == self.eos_token)
|
104 |
+
|
105 |
+
def get_vocab_size(self, src: bool) -> int:
|
106 |
+
"""Returns the size of the vocabulary"""
|
107 |
+
return len(self.encoder) if src else len(self.decoder)
|
108 |
+
|
109 |
+
def _load_spm(self, path: str) -> SentencePieceProcessor:
|
110 |
+
return SentencePieceProcessor(model_file=path)
|
111 |
+
|
112 |
+
def _save_json(self, data, path: str) -> None:
|
113 |
+
with open(path, "w", encoding="utf-8") as f:
|
114 |
+
json.dump(data, f, indent=2)
|
115 |
+
|
116 |
+
def _load_json(self, path: str) -> Union[Dict, List]:
|
117 |
+
with open(path, "r", encoding="utf-8") as f:
|
118 |
+
return json.load(f)
|
119 |
+
|
120 |
+
def _convert_token_to_id(self, token: str, src: bool) -> int:
|
121 |
+
"""Converts an token (str) into an index (integer) using the source/target vocabulary map."""
|
122 |
+
return (
|
123 |
+
self.encoder.get(token, self.encoder[self.unk_token])
|
124 |
+
if src
|
125 |
+
else self.decoder.get(token, self.encoder[self.unk_token])
|
126 |
+
)
|
127 |
+
|
128 |
+
def _convert_id_to_token(self, index: int, src: bool) -> str:
|
129 |
+
"""Converts an index (integer) into a token (str) using the source/target vocabulary map."""
|
130 |
+
return (
|
131 |
+
self.encoder_rev.get(index, self.unk_token)
|
132 |
+
if src
|
133 |
+
else self.decoder_rev.get(index, self.unk_token)
|
134 |
+
)
|
135 |
+
|
136 |
+
def _convert_tokens_to_string(self, tokens: List[str], src: bool) -> str:
|
137 |
+
"""Uses sentencepiece model for detokenization"""
|
138 |
+
if src:
|
139 |
+
if tokens[0] in self.supported_langs and tokens[1] in self.supported_langs:
|
140 |
+
tokens = tokens[2:]
|
141 |
+
return " ".join(tokens)
|
142 |
+
else:
|
143 |
+
return " ".join(tokens)
|
144 |
+
|
145 |
+
def _remove_translation_tags(self, text: str) -> Tuple[List, str]:
|
146 |
+
"""Removes the translation tags before text normalization and tokenization."""
|
147 |
+
tokens = text.split(" ")
|
148 |
+
return tokens[:2], " ".join(tokens[2:])
|
149 |
+
|
150 |
+
def _tokenize_src_line(self, line: str) -> List[str]:
|
151 |
+
"""Tokenizes a source line."""
|
152 |
+
tags, text = self._remove_translation_tags(line)
|
153 |
+
tokens = self.src_spm.encode(text, out_type=str)
|
154 |
+
return tags + tokens
|
155 |
+
|
156 |
+
def _tokenize_tgt_line(self, line: str) -> List[str]:
|
157 |
+
"""Tokenizes a target line."""
|
158 |
+
return self.tgt_spm.encode(line, out_type=str)
|
159 |
+
|
160 |
+
def tokenize(self, text: str, src: bool) -> List[str]:
|
161 |
+
"""Tokenizes a string into tokens using the source/target vocabulary."""
|
162 |
+
return self._tokenize_src_line(text) if src else self._tokenize_tgt_line(text)
|
163 |
+
|
164 |
+
def batch_tokenize(self, batch: List[str], src: bool) -> List[List[str]]:
|
165 |
+
"""Tokenizes a list of strings into tokens using the source/target vocabulary."""
|
166 |
+
return [self.tokenize(line, src) for line in batch]
|
167 |
+
|
168 |
+
def _create_attention_mask(self, ids: List[int], max_seq_len: int, src: bool) -> List[int]:
|
169 |
+
"""Creates a attention mask for the input sequence."""
|
170 |
+
if src:
|
171 |
+
return [0] * (max_seq_len - len(ids)) + [1] * (len(ids) + 1)
|
172 |
+
else:
|
173 |
+
return [1] * (len(ids) + 1) + [0] * (max_seq_len - len(ids))
|
174 |
+
|
175 |
+
def _pad_batch(self, tokens: List[str], max_seq_len: int, src: bool) -> List[str]:
|
176 |
+
"""Pads a batch of tokens and adds BOS/EOS tokens."""
|
177 |
+
if src:
|
178 |
+
return [self.pad_token] * (max_seq_len - len(tokens)) + tokens + [self.eos_token]
|
179 |
+
else:
|
180 |
+
return tokens + [self.eos_token] + [self.pad_token] * (max_seq_len - len(tokens))
|
181 |
+
|
182 |
+
def _decode_line(self, ids: List[int], src: bool) -> List[str]:
|
183 |
+
return [self._convert_id_to_token(_id, src) for _id in ids]
|
184 |
+
|
185 |
+
def _encode_line(self, tokens: List[str], src: bool) -> List[int]:
|
186 |
+
return [self._convert_token_to_id(token, src) for token in tokens]
|
187 |
+
|
188 |
+
def _strip_special_tokens(self, tokens: List[str]) -> List[str]:
|
189 |
+
return [token for token in tokens if not self.is_special_token(token)]
|
190 |
+
|
191 |
+
def _single_input_preprocessing(
|
192 |
+
self, tokens: List[str], src: bool, max_seq_len: int
|
193 |
+
) -> Tuple[List[int], List[int], int]:
|
194 |
+
"""Tokenizes a string into tokens and also converts them into integers using source/target vocabulary map."""
|
195 |
+
attention_mask = self._create_attention_mask(tokens, max_seq_len, src)
|
196 |
+
padded_tokens = self._pad_batch(tokens, max_seq_len, src)
|
197 |
+
input_ids = self._encode_line(padded_tokens, src)
|
198 |
+
return input_ids, attention_mask
|
199 |
+
|
200 |
+
def _single_output_postprocessing(self, ids: List[int], src: bool) -> str:
|
201 |
+
"""Detokenizes a list of integer ids into a string using the source/target vocabulary."""
|
202 |
+
tokens = self._decode_line(ids, src)
|
203 |
+
tokens = self._strip_special_tokens(tokens)
|
204 |
+
return (
|
205 |
+
self._convert_tokens_to_string(tokens, src).replace(" ", "").replace("▁", " ").strip()
|
206 |
+
)
|
207 |
+
|
208 |
+
def __call__(
|
209 |
+
self,
|
210 |
+
batch: Union[list, str],
|
211 |
+
src: bool,
|
212 |
+
truncation: bool = False,
|
213 |
+
padding: str = "longest",
|
214 |
+
max_length: int = None,
|
215 |
+
return_tensors: str = "pt",
|
216 |
+
return_attention_mask: bool = True,
|
217 |
+
return_length: bool = False,
|
218 |
+
) -> BatchEncoding:
|
219 |
+
"""Tokenizes a string into tokens and also converts them into integers using source/target vocabulary map."""
|
220 |
+
assert padding in [
|
221 |
+
"longest",
|
222 |
+
"max_length",
|
223 |
+
], "Padding should be either 'longest' or 'max_length'"
|
224 |
+
|
225 |
+
if not isinstance(batch, list):
|
226 |
+
raise TypeError(f"Batch must be a list, but current batch is of type {type(batch)}")
|
227 |
+
|
228 |
+
# tokenize the source sentences
|
229 |
+
batch = self.batch_tokenize(batch, src)
|
230 |
+
|
231 |
+
# truncate the sentences if needed
|
232 |
+
if truncation and max_length is not None:
|
233 |
+
batch = [ids[:max_length] for ids in batch]
|
234 |
+
|
235 |
+
lengths = [len(ids) for ids in batch]
|
236 |
+
|
237 |
+
max_seq_len = max(lengths) if padding == "longest" else max_length
|
238 |
+
|
239 |
+
input_ids, attention_mask = zip(
|
240 |
+
*[
|
241 |
+
self._single_input_preprocessing(tokens=tokens, src=src, max_seq_len=max_seq_len)
|
242 |
+
for tokens in batch
|
243 |
+
]
|
244 |
+
)
|
245 |
+
|
246 |
+
_data = {"input_ids": input_ids}
|
247 |
+
|
248 |
+
if return_attention_mask:
|
249 |
+
_data["attention_mask"] = attention_mask
|
250 |
+
|
251 |
+
if return_length:
|
252 |
+
_data["lengths"] = lengths
|
253 |
+
|
254 |
+
return BatchEncoding(_data, tensor_type=return_tensors)
|
255 |
+
|
256 |
+
def batch_decode(self, batch: Union[list, torch.Tensor], src: bool) -> List[List[str]]:
|
257 |
+
"""Detokenizes a list of integer ids or a tensor into a list of strings using the source/target vocabulary."""
|
258 |
+
|
259 |
+
if isinstance(batch, torch.Tensor):
|
260 |
+
batch = batch.detach().cpu().tolist()
|
261 |
+
|
262 |
+
return [self._single_output_postprocessing(ids=ids, src=src) for ids in batch]
|
IndicTransTokenizer/IndicTransTokenizer/utils.py
ADDED
@@ -0,0 +1,530 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from typing import List, Tuple, Union
|
3 |
+
|
4 |
+
from indicnlp.tokenize import indic_tokenize, indic_detokenize
|
5 |
+
from indicnlp.normalize.indic_normalize import IndicNormalizerFactory
|
6 |
+
from sacremoses import MosesPunctNormalizer, MosesTokenizer, MosesDetokenizer
|
7 |
+
from indicnlp.transliterate.unicode_transliterate import UnicodeIndicTransliterator
|
8 |
+
|
9 |
+
|
10 |
+
class IndicProcessor:
|
11 |
+
def __init__(self, inference=True):
|
12 |
+
self.inference = inference
|
13 |
+
|
14 |
+
self._flores_codes = {
|
15 |
+
"asm_Beng": "as",
|
16 |
+
"awa_Deva": "hi",
|
17 |
+
"ben_Beng": "bn",
|
18 |
+
"bho_Deva": "hi",
|
19 |
+
"brx_Deva": "hi",
|
20 |
+
"doi_Deva": "hi",
|
21 |
+
"eng_Latn": "en",
|
22 |
+
"gom_Deva": "kK",
|
23 |
+
"gon_Deva": "hi",
|
24 |
+
"guj_Gujr": "gu",
|
25 |
+
"hin_Deva": "hi",
|
26 |
+
"hne_Deva": "hi",
|
27 |
+
"kan_Knda": "kn",
|
28 |
+
"kas_Arab": "ur",
|
29 |
+
"kas_Deva": "hi",
|
30 |
+
"kha_Latn": "en",
|
31 |
+
"lus_Latn": "en",
|
32 |
+
"mag_Deva": "hi",
|
33 |
+
"mai_Deva": "hi",
|
34 |
+
"mal_Mlym": "ml",
|
35 |
+
"mar_Deva": "mr",
|
36 |
+
"mni_Beng": "bn",
|
37 |
+
"mni_Mtei": "hi",
|
38 |
+
"npi_Deva": "ne",
|
39 |
+
"ory_Orya": "or",
|
40 |
+
"pan_Guru": "pa",
|
41 |
+
"san_Deva": "hi",
|
42 |
+
"sat_Olck": "or",
|
43 |
+
"snd_Arab": "ur",
|
44 |
+
"snd_Deva": "hi",
|
45 |
+
"tam_Taml": "ta",
|
46 |
+
"tel_Telu": "te",
|
47 |
+
"urd_Arab": "ur",
|
48 |
+
"unr_Deva": "hi",
|
49 |
+
}
|
50 |
+
|
51 |
+
self._indic_num_map = {
|
52 |
+
"\u09e6": "0",
|
53 |
+
"0": "0",
|
54 |
+
"\u0ae6": "0",
|
55 |
+
"\u0ce6": "0",
|
56 |
+
"\u0966": "0",
|
57 |
+
"\u0660": "0",
|
58 |
+
"\uabf0": "0",
|
59 |
+
"\u0b66": "0",
|
60 |
+
"\u0a66": "0",
|
61 |
+
"\u1c50": "0",
|
62 |
+
"\u06f0": "0",
|
63 |
+
"\u09e7": "1",
|
64 |
+
"1": "1",
|
65 |
+
"\u0ae7": "1",
|
66 |
+
"\u0967": "1",
|
67 |
+
"\u0ce7": "1",
|
68 |
+
"\u06f1": "1",
|
69 |
+
"\uabf1": "1",
|
70 |
+
"\u0b67": "1",
|
71 |
+
"\u0a67": "1",
|
72 |
+
"\u1c51": "1",
|
73 |
+
"\u0c67": "1",
|
74 |
+
"\u09e8": "2",
|
75 |
+
"2": "2",
|
76 |
+
"\u0ae8": "2",
|
77 |
+
"\u0968": "2",
|
78 |
+
"\u0ce8": "2",
|
79 |
+
"\u06f2": "2",
|
80 |
+
"\uabf2": "2",
|
81 |
+
"\u0b68": "2",
|
82 |
+
"\u0a68": "2",
|
83 |
+
"\u1c52": "2",
|
84 |
+
"\u0c68": "2",
|
85 |
+
"\u09e9": "3",
|
86 |
+
"3": "3",
|
87 |
+
"\u0ae9": "3",
|
88 |
+
"\u0969": "3",
|
89 |
+
"\u0ce9": "3",
|
90 |
+
"\u06f3": "3",
|
91 |
+
"\uabf3": "3",
|
92 |
+
"\u0b69": "3",
|
93 |
+
"\u0a69": "3",
|
94 |
+
"\u1c53": "3",
|
95 |
+
"\u0c69": "3",
|
96 |
+
"\u09ea": "4",
|
97 |
+
"4": "4",
|
98 |
+
"\u0aea": "4",
|
99 |
+
"\u096a": "4",
|
100 |
+
"\u0cea": "4",
|
101 |
+
"\u06f4": "4",
|
102 |
+
"\uabf4": "4",
|
103 |
+
"\u0b6a": "4",
|
104 |
+
"\u0a6a": "4",
|
105 |
+
"\u1c54": "4",
|
106 |
+
"\u0c6a": "4",
|
107 |
+
"\u09eb": "5",
|
108 |
+
"5": "5",
|
109 |
+
"\u0aeb": "5",
|
110 |
+
"\u096b": "5",
|
111 |
+
"\u0ceb": "5",
|
112 |
+
"\u06f5": "5",
|
113 |
+
"\uabf5": "5",
|
114 |
+
"\u0b6b": "5",
|
115 |
+
"\u0a6b": "5",
|
116 |
+
"\u1c55": "5",
|
117 |
+
"\u0c6b": "5",
|
118 |
+
"\u09ec": "6",
|
119 |
+
"6": "6",
|
120 |
+
"\u0aec": "6",
|
121 |
+
"\u096c": "6",
|
122 |
+
"\u0cec": "6",
|
123 |
+
"\u06f6": "6",
|
124 |
+
"\uabf6": "6",
|
125 |
+
"\u0b6c": "6",
|
126 |
+
"\u0a6c": "6",
|
127 |
+
"\u1c56": "6",
|
128 |
+
"\u0c6c": "6",
|
129 |
+
"\u09ed": "7",
|
130 |
+
"7": "7",
|
131 |
+
"\u0aed": "7",
|
132 |
+
"\u096d": "7",
|
133 |
+
"\u0ced": "7",
|
134 |
+
"\u06f7": "7",
|
135 |
+
"\uabf7": "7",
|
136 |
+
"\u0b6d": "7",
|
137 |
+
"\u0a6d": "7",
|
138 |
+
"\u1c57": "7",
|
139 |
+
"\u0c6d": "7",
|
140 |
+
"\u09ee": "8",
|
141 |
+
"8": "8",
|
142 |
+
"\u0aee": "8",
|
143 |
+
"\u096e": "8",
|
144 |
+
"\u0cee": "8",
|
145 |
+
"\u06f8": "8",
|
146 |
+
"\uabf8": "8",
|
147 |
+
"\u0b6e": "8",
|
148 |
+
"\u0a6e": "8",
|
149 |
+
"\u1c58": "8",
|
150 |
+
"\u0c6e": "8",
|
151 |
+
"\u09ef": "9",
|
152 |
+
"9": "9",
|
153 |
+
"\u0aef": "9",
|
154 |
+
"\u096f": "9",
|
155 |
+
"\u0cef": "9",
|
156 |
+
"\u06f9": "9",
|
157 |
+
"\uabf9": "9",
|
158 |
+
"\u0b6f": "9",
|
159 |
+
"\u0a6f": "9",
|
160 |
+
"\u1c59": "9",
|
161 |
+
"\u0c6f": "9",
|
162 |
+
}
|
163 |
+
|
164 |
+
self._placeholder_entity_maps = []
|
165 |
+
|
166 |
+
self._en_tok = MosesTokenizer(lang="en")
|
167 |
+
self._en_normalizer = MosesPunctNormalizer()
|
168 |
+
self._en_detok = MosesDetokenizer(lang="en")
|
169 |
+
self._xliterator = UnicodeIndicTransliterator()
|
170 |
+
|
171 |
+
self._multispace_regex = re.compile("[ ]{2,}")
|
172 |
+
self._digit_space_percent = re.compile(r"(\d) %")
|
173 |
+
self._double_quot_punc = re.compile(r"\"([,\.]+)")
|
174 |
+
self._digit_nbsp_digit = re.compile(r"(\d) (\d)")
|
175 |
+
self._end_bracket_space_punc_regex = re.compile(r"\) ([\.!:?;,])")
|
176 |
+
|
177 |
+
self._URL_PATTERN = r"\b(?<![\w/.])(?:(?:https?|ftp)://)?(?:(?:[\w-]+\.)+(?!\.))(?:[\w/\-?#&=%.]+)+(?!\.\w+)\b"
|
178 |
+
self._NUMERAL_PATTERN = r"(~?\d+\.?\d*\s?%?\s?-?\s?~?\d+\.?\d*\s?%|~?\d+%|\d+[-\/.,:']\d+[-\/.,:'+]\d+(?:\.\d+)?|\d+[-\/.:'+]\d+(?:\.\d+)?)"
|
179 |
+
self._EMAIL_PATTERN = r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}"
|
180 |
+
self._OTHER_PATTERN = r"[A-Za-z0-9]*[#|@]\w+"
|
181 |
+
|
182 |
+
def _add_placeholder_entity_map(self, placeholder_entity_map):
|
183 |
+
self._placeholder_entity_maps.append(placeholder_entity_map)
|
184 |
+
|
185 |
+
def get_placeholder_entity_maps(self):
|
186 |
+
return self._placeholder_entity_maps
|
187 |
+
|
188 |
+
def _punc_norm(self, text) -> str:
|
189 |
+
text = (
|
190 |
+
text.replace("\r", "")
|
191 |
+
.replace("(", " (")
|
192 |
+
.replace(")", ") ")
|
193 |
+
.replace("( ", "(")
|
194 |
+
.replace(" )", ")")
|
195 |
+
.replace(" :", ":")
|
196 |
+
.replace(" ;", ";")
|
197 |
+
.replace("`", "'")
|
198 |
+
.replace("„", '"')
|
199 |
+
.replace("“", '"')
|
200 |
+
.replace("”", '"')
|
201 |
+
.replace("–", "-")
|
202 |
+
.replace("—", " - ")
|
203 |
+
.replace("´", "'")
|
204 |
+
.replace("‘", "'")
|
205 |
+
.replace("‚", "'")
|
206 |
+
.replace("’", "'")
|
207 |
+
.replace("''", '"')
|
208 |
+
.replace("´´", '"')
|
209 |
+
.replace("…", "...")
|
210 |
+
.replace(" « ", ' "')
|
211 |
+
.replace("« ", '"')
|
212 |
+
.replace("«", '"')
|
213 |
+
.replace(" » ", '" ')
|
214 |
+
.replace(" »", '"')
|
215 |
+
.replace("»", '"')
|
216 |
+
.replace(" %", "%")
|
217 |
+
.replace("nº ", "nº ")
|
218 |
+
.replace(" :", ":")
|
219 |
+
.replace(" ºC", " ºC")
|
220 |
+
.replace(" cm", " cm")
|
221 |
+
.replace(" ?", "?")
|
222 |
+
.replace(" !", "!")
|
223 |
+
.replace(" ;", ";")
|
224 |
+
.replace(", ", ", ")
|
225 |
+
)
|
226 |
+
|
227 |
+
text = self._multispace_regex.sub(" ", text)
|
228 |
+
text = self._end_bracket_space_punc_regex.sub(r")\1", text)
|
229 |
+
text = self._digit_space_percent.sub(r"\1%", text)
|
230 |
+
text = self._double_quot_punc.sub(r'\1"', text)
|
231 |
+
text = self._digit_nbsp_digit.sub(r"\1.\2", text)
|
232 |
+
return text.strip()
|
233 |
+
|
234 |
+
def _normalize_indic_numerals(self, line: str) -> str:
|
235 |
+
"""
|
236 |
+
Normalize the numerals in Indic languages from native script to Roman script (if present).
|
237 |
+
|
238 |
+
Args:
|
239 |
+
line (str): an input string with Indic numerals to be normalized.
|
240 |
+
|
241 |
+
Returns:
|
242 |
+
str: an input string with the all Indic numerals normalized to Roman script.
|
243 |
+
"""
|
244 |
+
return "".join([self._indic_num_map.get(c, c) for c in line])
|
245 |
+
|
246 |
+
def _wrap_with_placeholders(self, text: str, patterns: list) -> str:
|
247 |
+
"""
|
248 |
+
Wraps substrings with matched patterns in the given text with placeholders and returns
|
249 |
+
the modified text along with a mapping of the placeholders to their original value.
|
250 |
+
|
251 |
+
Args:
|
252 |
+
text (str): an input string which needs to be wrapped with the placeholders.
|
253 |
+
pattern (list): list of patterns to search for in the input string.
|
254 |
+
|
255 |
+
Returns:
|
256 |
+
text (str): a modified text.
|
257 |
+
"""
|
258 |
+
|
259 |
+
serial_no = 1
|
260 |
+
|
261 |
+
placeholder_entity_map = dict()
|
262 |
+
|
263 |
+
indic_failure_cases = [
|
264 |
+
"آی ڈی ",
|
265 |
+
"ꯑꯥꯏꯗꯤ",
|
266 |
+
"आईडी",
|
267 |
+
"आई . डी . ",
|
268 |
+
"आई . डी .",
|
269 |
+
"आई. डी. ",
|
270 |
+
"आई. डी.",
|
271 |
+
"ऐटि",
|
272 |
+
"آئی ڈی ",
|
273 |
+
"ᱟᱭᱰᱤ ᱾",
|
274 |
+
"आयडी",
|
275 |
+
"ऐडि",
|
276 |
+
"आइडि",
|
277 |
+
"ᱟᱭᱰᱤ",
|
278 |
+
]
|
279 |
+
|
280 |
+
for pattern in patterns:
|
281 |
+
matches = set(re.findall(pattern, text))
|
282 |
+
|
283 |
+
# wrap common match with placeholder tags
|
284 |
+
for match in matches:
|
285 |
+
if pattern == self._URL_PATTERN:
|
286 |
+
# Avoids false positive URL matches for names with initials.
|
287 |
+
if len(match.replace(".", "")) < 4:
|
288 |
+
continue
|
289 |
+
if pattern == self._NUMERAL_PATTERN:
|
290 |
+
# Short numeral patterns do not need placeholder based handling.
|
291 |
+
if (
|
292 |
+
len(match.replace(" ", "").replace(".", "").replace(":", ""))
|
293 |
+
< 4
|
294 |
+
):
|
295 |
+
continue
|
296 |
+
|
297 |
+
# Set of Translations of "ID" in all the suppported languages have been collated.
|
298 |
+
# This has been added to deal with edge cases where placeholders might get translated.
|
299 |
+
base_placeholder = f"<ID{serial_no}>"
|
300 |
+
|
301 |
+
placeholder_entity_map[f"<ID{serial_no}]"] = match
|
302 |
+
placeholder_entity_map[f"< ID{serial_no} ]"] = match
|
303 |
+
placeholder_entity_map[f"<ID{serial_no}>"] = match
|
304 |
+
placeholder_entity_map[f"< ID{serial_no} >"] = match
|
305 |
+
|
306 |
+
for i in indic_failure_cases:
|
307 |
+
placeholder_entity_map[f"<{i}{serial_no}>"] = match
|
308 |
+
placeholder_entity_map[f"< {i}{serial_no} >"] = match
|
309 |
+
placeholder_entity_map[f"< {i} {serial_no} >"] = match
|
310 |
+
placeholder_entity_map[f"<{i} {serial_no}]"] = match
|
311 |
+
placeholder_entity_map[f"< {i} {serial_no} ]"] = match
|
312 |
+
placeholder_entity_map[f"[{i} {serial_no}]"] = match
|
313 |
+
placeholder_entity_map[f"[ {i} {serial_no} ]"] = match
|
314 |
+
|
315 |
+
text = text.replace(match, base_placeholder)
|
316 |
+
serial_no += 1
|
317 |
+
|
318 |
+
text = re.sub("\s+", " ", text).replace(">/", ">").replace("]/", "]")
|
319 |
+
self._add_placeholder_entity_map(placeholder_entity_map)
|
320 |
+
return text
|
321 |
+
|
322 |
+
def _normalize(
|
323 |
+
self,
|
324 |
+
text: str,
|
325 |
+
) -> Tuple[str, dict]:
|
326 |
+
"""
|
327 |
+
Normalizes and wraps the spans of input string with placeholder tags. It first normalizes
|
328 |
+
the Indic numerals in the input string to Roman script. Later, it uses the input string with normalized
|
329 |
+
Indic numerals to wrap the spans of text matching the pattern with placeholder tags.
|
330 |
+
|
331 |
+
Args:
|
332 |
+
text (str): input string.
|
333 |
+
pattern (list): list of patterns to search for in the input string.
|
334 |
+
|
335 |
+
Returns:
|
336 |
+
text (str): the modified text
|
337 |
+
"""
|
338 |
+
patterns = [
|
339 |
+
self._EMAIL_PATTERN,
|
340 |
+
self._URL_PATTERN,
|
341 |
+
self._NUMERAL_PATTERN,
|
342 |
+
self._OTHER_PATTERN,
|
343 |
+
]
|
344 |
+
|
345 |
+
text = self._normalize_indic_numerals(text.strip())
|
346 |
+
|
347 |
+
if self.inference:
|
348 |
+
text = self._wrap_with_placeholders(text, patterns)
|
349 |
+
|
350 |
+
return text
|
351 |
+
|
352 |
+
def _apply_lang_tags(
|
353 |
+
self, sents: List[str], src_lang: str, tgt_lang: str, delimiter=" "
|
354 |
+
) -> List[str]:
|
355 |
+
"""
|
356 |
+
Add special tokens indicating source and target language to the start of the each input sentence.
|
357 |
+
Each resulting input sentence will have the format: "`{src_lang} {tgt_lang} {input_sentence}`".
|
358 |
+
|
359 |
+
Args:
|
360 |
+
sent (str): input sentence to be translated.
|
361 |
+
src_lang (str): flores lang code of the input sentence.
|
362 |
+
tgt_lang (str): flores lang code in which the input sentence will be translated.
|
363 |
+
|
364 |
+
Returns:
|
365 |
+
List[str]: list of input sentences with the special tokens added to the start.
|
366 |
+
"""
|
367 |
+
return [f"{src_lang}{delimiter}{tgt_lang}{delimiter}{x.strip()}" for x in sents]
|
368 |
+
|
369 |
+
def _preprocess(
|
370 |
+
self,
|
371 |
+
sent: str,
|
372 |
+
lang: str,
|
373 |
+
normalizer: Union[MosesPunctNormalizer, IndicNormalizerFactory],
|
374 |
+
) -> str:
|
375 |
+
"""
|
376 |
+
Preprocess an input text sentence by normalizing, tokenization, and possibly transliterating it.
|
377 |
+
|
378 |
+
Args:
|
379 |
+
sent (str): input text sentence to preprocess.
|
380 |
+
normalizer (Union[MosesPunctNormalizer, IndicNormalizerFactory]): an object that performs normalization on the text.
|
381 |
+
lang (str): flores language code of the input text sentence.
|
382 |
+
|
383 |
+
Returns:
|
384 |
+
sent (str): a preprocessed input text sentence
|
385 |
+
"""
|
386 |
+
iso_lang = self._flores_codes[lang]
|
387 |
+
sent = self._punc_norm(sent)
|
388 |
+
sent = self._normalize(sent)
|
389 |
+
|
390 |
+
transliterate = True
|
391 |
+
if lang.split("_")[1] in ["Arab", "Aran", "Olck", "Mtei", "Latn"]:
|
392 |
+
transliterate = False
|
393 |
+
|
394 |
+
if iso_lang == "en":
|
395 |
+
processed_sent = " ".join(
|
396 |
+
self._en_tok.tokenize(
|
397 |
+
self._en_normalizer.normalize(sent.strip()), escape=False
|
398 |
+
)
|
399 |
+
)
|
400 |
+
elif transliterate:
|
401 |
+
# transliterates from the any specific language to devanagari
|
402 |
+
# which is why we specify lang2_code as "hi".
|
403 |
+
processed_sent = self._xliterator.transliterate(
|
404 |
+
" ".join(
|
405 |
+
indic_tokenize.trivial_tokenize(
|
406 |
+
normalizer.normalize(sent.strip()), iso_lang
|
407 |
+
)
|
408 |
+
),
|
409 |
+
iso_lang,
|
410 |
+
"hi",
|
411 |
+
).replace(" ् ", "्")
|
412 |
+
else:
|
413 |
+
# we only need to transliterate for joint training
|
414 |
+
processed_sent = " ".join(
|
415 |
+
indic_tokenize.trivial_tokenize(
|
416 |
+
normalizer.normalize(sent.strip()), iso_lang
|
417 |
+
)
|
418 |
+
)
|
419 |
+
|
420 |
+
return processed_sent
|
421 |
+
|
422 |
+
def preprocess_batch(
|
423 |
+
self, batch: List[str], src_lang: str, tgt_lang: str, is_target: bool = False
|
424 |
+
) -> List[str]:
|
425 |
+
"""
|
426 |
+
Preprocess an array of sentences by normalizing, tokenization, and possibly transliterating it. It also tokenizes the
|
427 |
+
normalized text sequences using sentence piece tokenizer and also adds language tags.
|
428 |
+
|
429 |
+
Args:
|
430 |
+
batch (List[str]): input list of sentences to preprocess.
|
431 |
+
src_lang (str): flores language code of the input text sentences.
|
432 |
+
tgt_lang (str): flores language code of the output text sentences.
|
433 |
+
is_target (bool): add language tags if false otherwise skip it.
|
434 |
+
|
435 |
+
Returns:
|
436 |
+
List[str]: a list of preprocessed input text sentences.
|
437 |
+
"""
|
438 |
+
# reset the placeholder entity map for each batch
|
439 |
+
|
440 |
+
normalizer = (
|
441 |
+
IndicNormalizerFactory().get_normalizer(self._flores_codes[src_lang])
|
442 |
+
if src_lang != "eng_Latn"
|
443 |
+
else None
|
444 |
+
)
|
445 |
+
|
446 |
+
preprocessed_sents = [
|
447 |
+
self._preprocess(sent, src_lang, normalizer) for sent in batch
|
448 |
+
]
|
449 |
+
|
450 |
+
tagged_sents = (
|
451 |
+
self._apply_lang_tags(preprocessed_sents, src_lang, tgt_lang)
|
452 |
+
if not is_target
|
453 |
+
else preprocessed_sents
|
454 |
+
)
|
455 |
+
|
456 |
+
return tagged_sents
|
457 |
+
|
458 |
+
def _postprocess(
|
459 |
+
self,
|
460 |
+
sent: str,
|
461 |
+
placeholder_entity_map: dict,
|
462 |
+
lang: str = "hin_Deva",
|
463 |
+
):
|
464 |
+
"""
|
465 |
+
Postprocesses a single input sentence after the translation generation.
|
466 |
+
|
467 |
+
Args:
|
468 |
+
sent (str): input sentence to postprocess.
|
469 |
+
placeholder_entity_map (dict): dictionary mapping placeholders to the original entity values.
|
470 |
+
lang (str): flores language code of the input sentence.
|
471 |
+
|
472 |
+
Returns:
|
473 |
+
text (str): postprocessed input sentence.
|
474 |
+
"""
|
475 |
+
|
476 |
+
lang_code, script_code = lang.split("_")
|
477 |
+
iso_lang = self._flores_codes[lang]
|
478 |
+
|
479 |
+
# Fixes for Perso-Arabic scripts
|
480 |
+
if script_code in ["Arab", "Aran"]:
|
481 |
+
sent = (
|
482 |
+
sent.replace(" ؟", "؟")
|
483 |
+
.replace(" ۔", "۔")
|
484 |
+
.replace(" ،", "،")
|
485 |
+
.replace("ٮ۪", "ؠ")
|
486 |
+
)
|
487 |
+
|
488 |
+
if lang_code == "ory":
|
489 |
+
sent = sent.replace("ଯ଼", "ୟ")
|
490 |
+
|
491 |
+
for k, v in placeholder_entity_map.items():
|
492 |
+
sent = sent.replace(k, v)
|
493 |
+
|
494 |
+
return (
|
495 |
+
self._en_detok.detokenize(sent.split(" "))
|
496 |
+
if lang == "eng_Latn"
|
497 |
+
else indic_detokenize.trivial_detokenize(
|
498 |
+
self._xliterator.transliterate(sent, "hi", iso_lang),
|
499 |
+
iso_lang,
|
500 |
+
)
|
501 |
+
)
|
502 |
+
|
503 |
+
def postprocess_batch(
|
504 |
+
self,
|
505 |
+
sents: List[str],
|
506 |
+
lang: str = "hin_Deva",
|
507 |
+
) -> List[str]:
|
508 |
+
"""
|
509 |
+
Postprocesses a batch of input sentences after the translation generations.
|
510 |
+
|
511 |
+
Args:
|
512 |
+
sents (List[str]): batch of translated sentences to postprocess.
|
513 |
+
placeholder_entity_map (List[dict]): dictionary mapping placeholders to the original entity values.
|
514 |
+
lang (str): flores language code of the input sentences.
|
515 |
+
|
516 |
+
Returns:
|
517 |
+
List[str]: postprocessed batch of input sentences.
|
518 |
+
"""
|
519 |
+
|
520 |
+
placeholder_entity_maps = self.get_placeholder_entity_maps()
|
521 |
+
|
522 |
+
postprocessed_sents = [
|
523 |
+
self._postprocess(sent, placeholder_entity_map, lang)
|
524 |
+
for sent, placeholder_entity_map in zip(sents, placeholder_entity_maps)
|
525 |
+
]
|
526 |
+
|
527 |
+
# reset the placeholder entity map after each batch
|
528 |
+
self._placeholder_entity_maps.clear()
|
529 |
+
|
530 |
+
return postprocessed_sents
|
IndicTransTokenizer/IndicTransTokenizer/version.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__version__ = "0.1.1"
|
IndicTransTokenizer/IndicTransTokenizer/version.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
0.1.1
|
IndicTransTokenizer/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) Varun Gumma.
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE
|
IndicTransTokenizer/README.md
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# IndicTransTokenizer
|
2 |
+
|
3 |
+
The goal of this repository is to provide a simple, modular, and extendable tokenizer for [IndicTrans2](https://github.com/AI4Bharat/IndicTrans2) and be compatible with the HuggingFace models released.
|
4 |
+
|
5 |
+
## Pre-requisites
|
6 |
+
- `Python 3.8+`
|
7 |
+
- [Indic NLP Library](https://github.com/VarunGumma/indic_nlp_library)
|
8 |
+
- Other requirements as listed in `requirements.txt`
|
9 |
+
|
10 |
+
## Configuration
|
11 |
+
- Editable installation (Note, this may take a while):
|
12 |
+
```bash
|
13 |
+
git clone https://github.com/VarunGumma/IndicTransTokenizer
|
14 |
+
cd IndicTransTokenizer
|
15 |
+
|
16 |
+
pip install --editable ./
|
17 |
+
```
|
18 |
+
|
19 |
+
## Usage
|
20 |
+
```python
|
21 |
+
import torch
|
22 |
+
from transformers import AutoModelForSeq2SeqLM
|
23 |
+
from IndicTransTokenizer import IndicProcessor, IndicTransTokenizer
|
24 |
+
|
25 |
+
tokenizer = IndicTransTokenizer(direction="en-indic")
|
26 |
+
ip = IndicProcessor(inference=True)
|
27 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("ai4bharat/indictrans2-en-indic-dist-200M", trust_remote_code=True)
|
28 |
+
|
29 |
+
sentences = [
|
30 |
+
"This is a test sentence.",
|
31 |
+
"This is another longer different test sentence.",
|
32 |
+
"Please send an SMS to 9876543210 and an email on [email protected] by 15th October, 2023.",
|
33 |
+
]
|
34 |
+
|
35 |
+
batch = ip.preprocess_batch(sentences, src_lang="eng_Latn", tgt_lang="hin_Deva")
|
36 |
+
batch = tokenizer(batch, src=True, return_tensors="pt")
|
37 |
+
|
38 |
+
with torch.inference_mode():
|
39 |
+
outputs = model.generate(**batch, num_beams=5, num_return_sequences=1, max_length=256)
|
40 |
+
|
41 |
+
outputs = tokenizer.batch_decode(outputs, src=False)
|
42 |
+
outputs = ip.postprocess_batch(outputs, lang="hin_Deva")
|
43 |
+
print(outputs)
|
44 |
+
|
45 |
+
>>> ['यह एक परीक्षण वाक्य है।', 'यह एक और लंबा अलग परीक्षण वाक्य है।', 'कृपया 9876543210 पर एक एस. एम. एस. भेजें और 15 अक्टूबर, 2023 तक [email protected] पर एक ईमेल भेजें।']
|
46 |
+
```
|
47 |
+
|
48 |
+
For using the tokenizer to train/fine-tune the model, just set the `inference` argument of IndicProcessor to `False`.
|
49 |
+
|
50 |
+
## Authors
|
51 |
+
- Varun Gumma ([email protected])
|
52 |
+
- Jay Gala ([email protected])
|
53 |
+
- Pranjal Agadh Chitale ([email protected])
|
54 |
+
- Raj Dabre ([email protected])
|
55 |
+
|
56 |
+
|
57 |
+
## Bugs and Contribution
|
58 |
+
Since this a bleeding-edge module, you may encounter broken stuff and import issues once in a while. In case you encounter any bugs or want additional functionalities, please feel free to raise `Issues`/`Pull Requests` or contact the authors.
|
59 |
+
|
60 |
+
|
61 |
+
## Citation
|
62 |
+
If you use our codebase, models or tokenizer, please do cite the following paper:
|
63 |
+
```bibtex
|
64 |
+
@article{
|
65 |
+
gala2023indictrans,
|
66 |
+
title={IndicTrans2: Towards High-Quality and Accessible Machine Translation Models for all 22 Scheduled Indian Languages},
|
67 |
+
author={Jay Gala and Pranjal A Chitale and A K Raghavan and Varun Gumma and Sumanth Doddapaneni and Aswanth Kumar M and Janki Atul Nawale and Anupama Sujatha and Ratish Puduppully and Vivek Raghavan and Pratyush Kumar and Mitesh M Khapra and Raj Dabre and Anoop Kunchukuttan},
|
68 |
+
journal={Transactions on Machine Learning Research},
|
69 |
+
issn={2835-8856},
|
70 |
+
year={2023},
|
71 |
+
url={https://openreview.net/forum?id=vfT4YuzAYA},
|
72 |
+
note={}
|
73 |
+
}
|
74 |
+
```
|
75 |
+
|
76 |
+
## Note
|
77 |
+
This tokenizer module is currently **not** compatible with the [PreTrainedTokenizer](https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/tokenizer#transformers.PreTrainedTokenizer) module from HuggingFace. Hence, we are actively looking for `Pull Requests` to port this tokenizer to HF. Any leads on that front are welcome!
|
IndicTransTokenizer/requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
setuptools==68.2.2
|
2 |
+
torch
|
3 |
+
sacremoses
|
4 |
+
sentencepiece
|
5 |
+
transformers
|
6 |
+
indic-nlp-library-IT2 @ git+https://github.com/VarunGumma/indic_nlp_library
|
IndicTransTokenizer/setup.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pathlib
|
3 |
+
from sys import version_info, exit
|
4 |
+
from setuptools import setup, find_packages
|
5 |
+
from pkg_resources import parse_requirements
|
6 |
+
|
7 |
+
|
8 |
+
def write_version_py():
|
9 |
+
with open(os.path.join("IndicTransTokenizer", "version.txt"), "r") as f:
|
10 |
+
version = f.read().strip()
|
11 |
+
|
12 |
+
with open(os.path.join("IndicTransTokenizer", "version.py"), "w") as f:
|
13 |
+
f.write(f'__version__ = "{version}"\n')
|
14 |
+
return version
|
15 |
+
|
16 |
+
|
17 |
+
if version_info < (3, 8):
|
18 |
+
exit("Sorry, Python >= 3.8 is required for IndicTransTokenizer.")
|
19 |
+
|
20 |
+
|
21 |
+
with open("README.md", "r", errors="ignore", encoding="utf-8") as fh:
|
22 |
+
long_description = fh.read().strip()
|
23 |
+
|
24 |
+
version = write_version_py()
|
25 |
+
|
26 |
+
setup(
|
27 |
+
name="IndicTransTokenizer",
|
28 |
+
version=version,
|
29 |
+
author="Varun Gumma",
|
30 |
+
author_email="[email protected]",
|
31 |
+
description="A simple, consistent, and extendable module for IndicTrans2 tokenizer compatible with the HuggingFace models",
|
32 |
+
long_description=long_description,
|
33 |
+
long_description_content_type="text/markdown",
|
34 |
+
url="https://github.com/VarunGumma/IndicTransTokenizer",
|
35 |
+
packages=find_packages(),
|
36 |
+
license="MIT",
|
37 |
+
classifiers=[
|
38 |
+
"Programming Language :: Python :: 3",
|
39 |
+
"License :: OSI Approved :: MIT License",
|
40 |
+
"Operating System :: OS Independent",
|
41 |
+
],
|
42 |
+
python_requires=">=3.8",
|
43 |
+
install_requires=[
|
44 |
+
str(requirement)
|
45 |
+
for requirement in parse_requirements(pathlib.Path(f"requirements.txt").open())
|
46 |
+
],
|
47 |
+
)
|
README.md
CHANGED
@@ -1,12 +1,13 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.21.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
---
|
11 |
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: IndicTrans2 for Conversation
|
3 |
+
emoji: 📚
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: red
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.21.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
license: apache-2.0
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import time
|
3 |
+
from config import model_repo_id, src_lang, tgt_lang
|
4 |
+
from indictrans2 import initialize_model_and_tokenizer, batch_translate
|
5 |
+
from examples import example_sentences
|
6 |
+
|
7 |
+
|
8 |
+
def load_models():
|
9 |
+
model_dict = {}
|
10 |
+
|
11 |
+
print("\tLoading model: %s" % model_repo_id)
|
12 |
+
|
13 |
+
# build model and tokenizer
|
14 |
+
en_indic_tokenizer, en_indic_model, en_indic_lora_model = (
|
15 |
+
initialize_model_and_tokenizer()
|
16 |
+
)
|
17 |
+
|
18 |
+
model_dict["_tokenizer"] = en_indic_tokenizer
|
19 |
+
model_dict["_model"] = en_indic_model
|
20 |
+
model_dict["_lora_model"] = en_indic_lora_model
|
21 |
+
|
22 |
+
return model_dict
|
23 |
+
|
24 |
+
|
25 |
+
def translation(text):
|
26 |
+
|
27 |
+
start_time = time.time()
|
28 |
+
|
29 |
+
tokenizer = model_dict["_tokenizer"]
|
30 |
+
model = model_dict["_model"]
|
31 |
+
lora_model = model_dict["_lora_model"]
|
32 |
+
|
33 |
+
# org translation
|
34 |
+
org_translation = batch_translate(
|
35 |
+
[text],
|
36 |
+
model=model,
|
37 |
+
tokenizer=tokenizer,
|
38 |
+
)
|
39 |
+
org_output = org_translation[0]
|
40 |
+
end_time = time.time()
|
41 |
+
|
42 |
+
# lora translation
|
43 |
+
lora_translation = batch_translate(
|
44 |
+
[text],
|
45 |
+
model=lora_model,
|
46 |
+
tokenizer=tokenizer,
|
47 |
+
)
|
48 |
+
lora_output = lora_translation[0]
|
49 |
+
end_time2 = time.time()
|
50 |
+
|
51 |
+
result = {
|
52 |
+
"source": src_lang,
|
53 |
+
"target": tgt_lang,
|
54 |
+
"input": text,
|
55 |
+
"it2_result": org_output,
|
56 |
+
"it2_conv_result": lora_output,
|
57 |
+
"it2_inference_time": end_time - start_time,
|
58 |
+
"it2_conv_inference_time": end_time2 - end_time,
|
59 |
+
}
|
60 |
+
|
61 |
+
return result
|
62 |
+
|
63 |
+
|
64 |
+
print("\tinit models")
|
65 |
+
|
66 |
+
global model_dict
|
67 |
+
|
68 |
+
model_dict = load_models()
|
69 |
+
|
70 |
+
inputs = gr.Textbox(lines=5, label="Input text")
|
71 |
+
outputs = gr.JSON(container=True)
|
72 |
+
submit_btn = gr.Button("Translate", variant="primary")
|
73 |
+
|
74 |
+
title = "IndicTrans2 fine-tuned on conversation"
|
75 |
+
description = f"Note: LoRA is trained only on En-Hi pair.\nDetails: https://github.com/AI4Bharat/IndicTrans2.\nLoRA Model: https://huggingface.co/sam749/IndicTrans2-Conv"
|
76 |
+
|
77 |
+
gr.Interface(
|
78 |
+
fn=translation,
|
79 |
+
inputs=inputs,
|
80 |
+
outputs=outputs,
|
81 |
+
title=title,
|
82 |
+
description=description,
|
83 |
+
submit_btn=submit_btn,
|
84 |
+
examples=example_sentences,
|
85 |
+
examples_per_page=10,
|
86 |
+
cache_examples=False,
|
87 |
+
).launch(share=True)
|
config.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_repo_id = "ai4bharat/indictrans2-en-indic-dist-200M"
|
2 |
+
lora_repo_id = "sam749/IndicTrans2-Conv"
|
3 |
+
src_lang = "eng_Latn"
|
4 |
+
tgt_lang = "hin_Deva"
|
5 |
+
batch_size = 8
|
examples.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
example_sentences = [
|
2 |
+
['Avantika to Prakash: Did you mean "I play cricket"? What position do you play?'],
|
3 |
+
["'do you eat pizza?', Manoj said to Jaya"],
|
4 |
+
["Ankita to Avantika: can you come with me to tour?"],
|
5 |
+
[
|
6 |
+
'Sudha to Sakshi: Did you mean "I\'ll grab some coffee before the meeting starts."? Can I join you too?'
|
7 |
+
],
|
8 |
+
[
|
9 |
+
'Anil to Sakshi: Did you mean "I\'ll grab some coffee before the meeting starts."? Can I join you too?'
|
10 |
+
],
|
11 |
+
]
|
indictrans2.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoModelForSeq2SeqLM, BitsAndBytesConfig
|
3 |
+
from IndicTransTokenizer.IndicTransTokenizer.utils import IndicProcessor
|
4 |
+
from IndicTransTokenizer.IndicTransTokenizer.tokenizer import IndicTransTokenizer
|
5 |
+
from peft import PeftModel
|
6 |
+
from config import lora_repo_id, model_repo_id, batch_size, src_lang, tgt_lang
|
7 |
+
|
8 |
+
|
9 |
+
DIRECTION = "en-indic"
|
10 |
+
QUANTIZATION = None
|
11 |
+
IP = IndicProcessor(inference=True)
|
12 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
13 |
+
HALF = True if torch.cuda.is_available() else False
|
14 |
+
|
15 |
+
|
16 |
+
def initialize_model_and_tokenizer():
|
17 |
+
|
18 |
+
if QUANTIZATION == "4-bit":
|
19 |
+
qconfig = BitsAndBytesConfig(
|
20 |
+
load_in_4bit=True,
|
21 |
+
bnb_4bit_use_double_quant=True,
|
22 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
23 |
+
)
|
24 |
+
elif QUANTIZATION == "8-bit":
|
25 |
+
qconfig = BitsAndBytesConfig(
|
26 |
+
load_in_8bit=True,
|
27 |
+
bnb_8bit_use_double_quant=True,
|
28 |
+
bnb_8bit_compute_dtype=torch.bfloat16,
|
29 |
+
)
|
30 |
+
else:
|
31 |
+
qconfig = None
|
32 |
+
|
33 |
+
tokenizer = IndicTransTokenizer(direction=DIRECTION)
|
34 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
35 |
+
model_repo_id,
|
36 |
+
trust_remote_code=True,
|
37 |
+
low_cpu_mem_usage=True,
|
38 |
+
quantization_config=qconfig,
|
39 |
+
)
|
40 |
+
model2 = AutoModelForSeq2SeqLM.from_pretrained(
|
41 |
+
model_repo_id,
|
42 |
+
trust_remote_code=True,
|
43 |
+
low_cpu_mem_usage=True,
|
44 |
+
quantization_config=qconfig,
|
45 |
+
)
|
46 |
+
|
47 |
+
if qconfig == None:
|
48 |
+
model = model.to(DEVICE)
|
49 |
+
model2 = model2.to(DEVICE)
|
50 |
+
|
51 |
+
model.eval()
|
52 |
+
model2.eval()
|
53 |
+
|
54 |
+
lora_model = PeftModel.from_pretrained(model2, lora_repo_id)
|
55 |
+
|
56 |
+
return tokenizer, model, lora_model
|
57 |
+
|
58 |
+
|
59 |
+
def batch_translate(input_sentences, model, tokenizer):
|
60 |
+
translations = []
|
61 |
+
for i in range(0, len(input_sentences), batch_size):
|
62 |
+
batch = input_sentences[i : i + batch_size]
|
63 |
+
|
64 |
+
# Preprocess the batch and extract entity mappings
|
65 |
+
batch = IP.preprocess_batch(batch, src_lang=src_lang, tgt_lang=tgt_lang)
|
66 |
+
|
67 |
+
# Tokenize the batch and generate input encodings
|
68 |
+
inputs = tokenizer(
|
69 |
+
batch,
|
70 |
+
src=True,
|
71 |
+
truncation=True,
|
72 |
+
padding="longest",
|
73 |
+
return_tensors="pt",
|
74 |
+
return_attention_mask=True,
|
75 |
+
).to(DEVICE)
|
76 |
+
|
77 |
+
# Generate translations using the model
|
78 |
+
with torch.inference_mode():
|
79 |
+
generated_tokens = model.generate(
|
80 |
+
**inputs,
|
81 |
+
use_cache=True,
|
82 |
+
min_length=0,
|
83 |
+
max_length=256,
|
84 |
+
num_beams=5,
|
85 |
+
num_return_sequences=1,
|
86 |
+
)
|
87 |
+
|
88 |
+
# Decode the generated tokens into text
|
89 |
+
generated_tokens = tokenizer.batch_decode(
|
90 |
+
generated_tokens.detach().cpu().tolist(), src=False
|
91 |
+
)
|
92 |
+
|
93 |
+
# Postprocess the translations, including entity replacement
|
94 |
+
translations += IP.postprocess_batch(generated_tokens, lang=tgt_lang)
|
95 |
+
|
96 |
+
del inputs
|
97 |
+
|
98 |
+
return translations
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
indic-nlp-library-IT2 @ git+https://github.com/VarunGumma/indic_nlp_library
|
2 |
+
setuptools==68.2.2
|
3 |
+
transformers
|
4 |
+
gradio
|
5 |
+
torch
|
6 |
+
peft
|
7 |
+
sacremoses
|
8 |
+
sentencepiece
|