test
Browse files- .run_distillation.py.un~ +0 -0
- generation_config.json +1 -1
- run_distillation.py +29 -12
- tokenizer.json +10 -10
.run_distillation.py.un~
CHANGED
Binary files a/.run_distillation.py.un~ and b/.run_distillation.py.un~ differ
|
|
generation_config.json
CHANGED
@@ -165,7 +165,7 @@
|
|
165 |
"<|yue|>": 50358,
|
166 |
"<|zh|>": 50260
|
167 |
},
|
168 |
-
"language": "<|
|
169 |
"max_initial_timestamp_index": 1,
|
170 |
"max_length": 448,
|
171 |
"no_timestamps_token_id": 50364,
|
|
|
165 |
"<|yue|>": 50358,
|
166 |
"<|zh|>": 50260
|
167 |
},
|
168 |
+
"language": "<|no|>",
|
169 |
"max_initial_timestamp_index": 1,
|
170 |
"max_length": 448,
|
171 |
"no_timestamps_token_id": 50364,
|
run_distillation.py
CHANGED
@@ -1344,16 +1344,23 @@ def main():
|
|
1344 |
else raw_datasets[eval_split].select(range(data_args.max_eval_samples))
|
1345 |
)
|
1346 |
|
|
|
1347 |
def is_wer_in_range(ground_truth, whisper_transcript):
|
1348 |
norm_ground_truth = normalizer(ground_truth)
|
1349 |
-
if
|
|
|
|
|
|
|
|
|
|
|
1350 |
norm_whisper_transcript = normalizer(whisper_transcript)
|
1351 |
wer = 100 * metric.compute(predictions=[norm_whisper_transcript], references=[norm_ground_truth])
|
1352 |
return wer < wer_threshold
|
1353 |
else:
|
1354 |
-
# filter automatically since we
|
1355 |
return False
|
1356 |
|
|
|
1357 |
filter_by_wer_threshold = partial(
|
1358 |
raw_datasets["train"].filter,
|
1359 |
function=is_wer_in_range,
|
@@ -1517,20 +1524,30 @@ def main():
|
|
1517 |
]
|
1518 |
wer_ortho = 100 * metric.compute(predictions=spaced_pred_str, references=spaced_label_str)
|
1519 |
|
1520 |
-
#
|
1521 |
-
|
1522 |
-
|
1523 |
-
|
1524 |
-
|
1525 |
-
label_str = [label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
|
1526 |
-
# filtering step to only evaluate the samples that correspond to non-zero normalized references:
|
1527 |
-
norm_pred_str = [norm_pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
|
1528 |
-
norm_label_str = [norm_label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
|
1529 |
|
1530 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1531 |
|
|
|
|
|
|
|
|
|
|
|
|
|
1532 |
return {"wer": wer, "wer_ortho": wer_ortho}, pred_str, label_str, norm_pred_str, norm_label_str
|
1533 |
|
|
|
1534 |
# 9. Save feature extractor, tokenizer, config and generation config
|
1535 |
feature_extractor.save_pretrained(training_args.output_dir)
|
1536 |
tokenizer.save_pretrained(training_args.output_dir)
|
|
|
1344 |
else raw_datasets[eval_split].select(range(data_args.max_eval_samples))
|
1345 |
)
|
1346 |
|
1347 |
+
# 10.3: filter training data based on WER threshold -> this is KEY to good distillation performance
|
1348 |
def is_wer_in_range(ground_truth, whisper_transcript):
|
1349 |
norm_ground_truth = normalizer(ground_truth)
|
1350 |
+
if whisper_transcript is not None and whisper_transcript.upper() == whisper_transcript:
|
1351 |
+
# filter entirely upper-case transcriptions: these are erroneous generations from large-v3
|
1352 |
+
return False
|
1353 |
+
elif len(norm_ground_truth) == 0 and len(normalizer(whisper_transcript)) == 0:
|
1354 |
+
return True
|
1355 |
+
elif len(norm_ground_truth.strip()) > 0 and whisper_transcript is not None and len(normalizer(whisper_transcript).strip()) > 0:
|
1356 |
norm_whisper_transcript = normalizer(whisper_transcript)
|
1357 |
wer = 100 * metric.compute(predictions=[norm_whisper_transcript], references=[norm_ground_truth])
|
1358 |
return wer < wer_threshold
|
1359 |
else:
|
1360 |
+
# filter automatically since we cant know WER
|
1361 |
return False
|
1362 |
|
1363 |
+
|
1364 |
filter_by_wer_threshold = partial(
|
1365 |
raw_datasets["train"].filter,
|
1366 |
function=is_wer_in_range,
|
|
|
1524 |
]
|
1525 |
wer_ortho = 100 * metric.compute(predictions=spaced_pred_str, references=spaced_label_str)
|
1526 |
|
1527 |
+
# Iterate through all predictions and labels
|
1528 |
+
for pred, label in zip(pred_str, label_str):
|
1529 |
+
# Normalize the prediction and label
|
1530 |
+
normalized_pred = normalizer(pred)
|
1531 |
+
normalized_label = normalizer(label)
|
|
|
|
|
|
|
|
|
1532 |
|
1533 |
+
# If either normalized string is empty after normalization, replace with "<|nospeech|>"
|
1534 |
+
if not normalized_pred.strip():
|
1535 |
+
normalized_pred = "<|nospeech|>"
|
1536 |
+
if not normalized_label.strip():
|
1537 |
+
normalized_label = "<|nospeech|>"
|
1538 |
+
|
1539 |
+
norm_pred_str.append(normalized_pred)
|
1540 |
+
norm_label_str.append(normalized_label)
|
1541 |
|
1542 |
+
# Replace original strings with "<|nocaptions|>" where necessary for consistency
|
1543 |
+
pred_str = [pred if len(pred.strip()) > 0 else "<|nospeech|>" for pred in pred_str]
|
1544 |
+
label_str = [label if len(label.strip()) > 0 else "<|nospeech|>" for label in label_str]
|
1545 |
+
|
1546 |
+
# Compute WER using all entries, including those with "<|nocaptions|>"
|
1547 |
+
wer = 100 * metric.compute(predictions=norm_pred_str, references=norm_label_str)
|
1548 |
return {"wer": wer, "wer_ortho": wer_ortho}, pred_str, label_str, norm_pred_str, norm_label_str
|
1549 |
|
1550 |
+
|
1551 |
# 9. Save feature extractor, tokenizer, config and generation config
|
1552 |
feature_extractor.save_pretrained(training_args.output_dir)
|
1553 |
tokenizer.save_pretrained(training_args.output_dir)
|
tokenizer.json
CHANGED
@@ -14503,7 +14503,7 @@
|
|
14503 |
},
|
14504 |
{
|
14505 |
"SpecialToken": {
|
14506 |
-
"id": "<|
|
14507 |
"type_id": 0
|
14508 |
}
|
14509 |
},
|
@@ -14541,7 +14541,7 @@
|
|
14541 |
},
|
14542 |
{
|
14543 |
"SpecialToken": {
|
14544 |
-
"id": "<|
|
14545 |
"type_id": 0
|
14546 |
}
|
14547 |
},
|
@@ -14586,22 +14586,22 @@
|
|
14586 |
"<|endoftext|>"
|
14587 |
]
|
14588 |
},
|
14589 |
-
"<|
|
14590 |
-
"id": "<|
|
14591 |
"ids": [
|
14592 |
-
|
14593 |
],
|
14594 |
"tokens": [
|
14595 |
-
"<|
|
14596 |
]
|
14597 |
},
|
14598 |
-
"<|
|
14599 |
-
"id": "<|
|
14600 |
"ids": [
|
14601 |
-
|
14602 |
],
|
14603 |
"tokens": [
|
14604 |
-
"<|
|
14605 |
]
|
14606 |
},
|
14607 |
"<|startoftranscript|>": {
|
|
|
14503 |
},
|
14504 |
{
|
14505 |
"SpecialToken": {
|
14506 |
+
"id": "<|no|>",
|
14507 |
"type_id": 0
|
14508 |
}
|
14509 |
},
|
|
|
14541 |
},
|
14542 |
{
|
14543 |
"SpecialToken": {
|
14544 |
+
"id": "<|no|>",
|
14545 |
"type_id": 0
|
14546 |
}
|
14547 |
},
|
|
|
14586 |
"<|endoftext|>"
|
14587 |
]
|
14588 |
},
|
14589 |
+
"<|notimestamps|>": {
|
14590 |
+
"id": "<|notimestamps|>",
|
14591 |
"ids": [
|
14592 |
+
50364
|
14593 |
],
|
14594 |
"tokens": [
|
14595 |
+
"<|notimestamps|>"
|
14596 |
]
|
14597 |
},
|
14598 |
+
"<|no|>": {
|
14599 |
+
"id": "<|no|>",
|
14600 |
"ids": [
|
14601 |
+
50288
|
14602 |
],
|
14603 |
"tokens": [
|
14604 |
+
"<|no|>"
|
14605 |
]
|
14606 |
},
|
14607 |
"<|startoftranscript|>": {
|