pere commited on
Commit
dea1939
1 Parent(s): 24292e6
.run_distillation.py.un~ CHANGED
Binary files a/.run_distillation.py.un~ and b/.run_distillation.py.un~ differ
 
generation_config.json CHANGED
@@ -165,7 +165,7 @@
165
  "<|yue|>": 50358,
166
  "<|zh|>": 50260
167
  },
168
- "language": "<|en|>",
169
  "max_initial_timestamp_index": 1,
170
  "max_length": 448,
171
  "no_timestamps_token_id": 50364,
 
165
  "<|yue|>": 50358,
166
  "<|zh|>": 50260
167
  },
168
+ "language": "<|no|>",
169
  "max_initial_timestamp_index": 1,
170
  "max_length": 448,
171
  "no_timestamps_token_id": 50364,
run_distillation.py CHANGED
@@ -1344,16 +1344,23 @@ def main():
1344
  else raw_datasets[eval_split].select(range(data_args.max_eval_samples))
1345
  )
1346
 
 
1347
  def is_wer_in_range(ground_truth, whisper_transcript):
1348
  norm_ground_truth = normalizer(ground_truth)
1349
- if len(norm_ground_truth) > 0 and whisper_transcript is not None:
 
 
 
 
 
1350
  norm_whisper_transcript = normalizer(whisper_transcript)
1351
  wer = 100 * metric.compute(predictions=[norm_whisper_transcript], references=[norm_ground_truth])
1352
  return wer < wer_threshold
1353
  else:
1354
- # filter automatically since we can't know the WER
1355
  return False
1356
 
 
1357
  filter_by_wer_threshold = partial(
1358
  raw_datasets["train"].filter,
1359
  function=is_wer_in_range,
@@ -1517,20 +1524,30 @@ def main():
1517
  ]
1518
  wer_ortho = 100 * metric.compute(predictions=spaced_pred_str, references=spaced_label_str)
1519
 
1520
- # normalize everything and re-compute the WER
1521
- norm_pred_str = [normalizer(pred) for pred in pred_str]
1522
- norm_label_str = [normalizer(label) for label in label_str]
1523
- # for logging, we need the pred/labels to match the norm_pred/norm_labels, so discard any filtered samples here
1524
- pred_str = [pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
1525
- label_str = [label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
1526
- # filtering step to only evaluate the samples that correspond to non-zero normalized references:
1527
- norm_pred_str = [norm_pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
1528
- norm_label_str = [norm_label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
1529
 
1530
- wer = 100 * metric.compute(predictions=norm_pred_str, references=norm_label_str)
 
 
 
 
 
 
 
1531
 
 
 
 
 
 
 
1532
  return {"wer": wer, "wer_ortho": wer_ortho}, pred_str, label_str, norm_pred_str, norm_label_str
1533
 
 
1534
  # 9. Save feature extractor, tokenizer, config and generation config
1535
  feature_extractor.save_pretrained(training_args.output_dir)
1536
  tokenizer.save_pretrained(training_args.output_dir)
 
1344
  else raw_datasets[eval_split].select(range(data_args.max_eval_samples))
1345
  )
1346
 
1347
+ # 10.3: filter training data based on WER threshold -> this is KEY to good distillation performance
1348
  def is_wer_in_range(ground_truth, whisper_transcript):
1349
  norm_ground_truth = normalizer(ground_truth)
1350
+ if whisper_transcript is not None and whisper_transcript.upper() == whisper_transcript:
1351
+ # filter entirely upper-case transcriptions: these are erroneous generations from large-v3
1352
+ return False
1353
+ elif len(norm_ground_truth) == 0 and len(normalizer(whisper_transcript)) == 0:
1354
+ return True
1355
+ elif len(norm_ground_truth.strip()) > 0 and whisper_transcript is not None and len(normalizer(whisper_transcript).strip()) > 0:
1356
  norm_whisper_transcript = normalizer(whisper_transcript)
1357
  wer = 100 * metric.compute(predictions=[norm_whisper_transcript], references=[norm_ground_truth])
1358
  return wer < wer_threshold
1359
  else:
1360
+ # filter automatically since we cant know WER
1361
  return False
1362
 
1363
+
1364
  filter_by_wer_threshold = partial(
1365
  raw_datasets["train"].filter,
1366
  function=is_wer_in_range,
 
1524
  ]
1525
  wer_ortho = 100 * metric.compute(predictions=spaced_pred_str, references=spaced_label_str)
1526
 
1527
+ # Iterate through all predictions and labels
1528
+ for pred, label in zip(pred_str, label_str):
1529
+ # Normalize the prediction and label
1530
+ normalized_pred = normalizer(pred)
1531
+ normalized_label = normalizer(label)
 
 
 
 
1532
 
1533
+ # If either normalized string is empty after normalization, replace with "<|nospeech|>"
1534
+ if not normalized_pred.strip():
1535
+ normalized_pred = "<|nospeech|>"
1536
+ if not normalized_label.strip():
1537
+ normalized_label = "<|nospeech|>"
1538
+
1539
+ norm_pred_str.append(normalized_pred)
1540
+ norm_label_str.append(normalized_label)
1541
 
1542
+ # Replace original strings with "<|nocaptions|>" where necessary for consistency
1543
+ pred_str = [pred if len(pred.strip()) > 0 else "<|nospeech|>" for pred in pred_str]
1544
+ label_str = [label if len(label.strip()) > 0 else "<|nospeech|>" for label in label_str]
1545
+
1546
+ # Compute WER using all entries, including those with "<|nocaptions|>"
1547
+ wer = 100 * metric.compute(predictions=norm_pred_str, references=norm_label_str)
1548
  return {"wer": wer, "wer_ortho": wer_ortho}, pred_str, label_str, norm_pred_str, norm_label_str
1549
 
1550
+
1551
  # 9. Save feature extractor, tokenizer, config and generation config
1552
  feature_extractor.save_pretrained(training_args.output_dir)
1553
  tokenizer.save_pretrained(training_args.output_dir)
tokenizer.json CHANGED
@@ -14503,7 +14503,7 @@
14503
  },
14504
  {
14505
  "SpecialToken": {
14506
- "id": "<|en|>",
14507
  "type_id": 0
14508
  }
14509
  },
@@ -14541,7 +14541,7 @@
14541
  },
14542
  {
14543
  "SpecialToken": {
14544
- "id": "<|en|>",
14545
  "type_id": 0
14546
  }
14547
  },
@@ -14586,22 +14586,22 @@
14586
  "<|endoftext|>"
14587
  ]
14588
  },
14589
- "<|en|>": {
14590
- "id": "<|en|>",
14591
  "ids": [
14592
- 50259
14593
  ],
14594
  "tokens": [
14595
- "<|en|>"
14596
  ]
14597
  },
14598
- "<|notimestamps|>": {
14599
- "id": "<|notimestamps|>",
14600
  "ids": [
14601
- 50364
14602
  ],
14603
  "tokens": [
14604
- "<|notimestamps|>"
14605
  ]
14606
  },
14607
  "<|startoftranscript|>": {
 
14503
  },
14504
  {
14505
  "SpecialToken": {
14506
+ "id": "<|no|>",
14507
  "type_id": 0
14508
  }
14509
  },
 
14541
  },
14542
  {
14543
  "SpecialToken": {
14544
+ "id": "<|no|>",
14545
  "type_id": 0
14546
  }
14547
  },
 
14586
  "<|endoftext|>"
14587
  ]
14588
  },
14589
+ "<|notimestamps|>": {
14590
+ "id": "<|notimestamps|>",
14591
  "ids": [
14592
+ 50364
14593
  ],
14594
  "tokens": [
14595
+ "<|notimestamps|>"
14596
  ]
14597
  },
14598
+ "<|no|>": {
14599
+ "id": "<|no|>",
14600
  "ids": [
14601
+ 50288
14602
  ],
14603
  "tokens": [
14604
+ "<|no|>"
14605
  ]
14606
  },
14607
  "<|startoftranscript|>": {