Joshua Lochner commited on
Commit
74c1216
1 Parent(s): 3f7ce4e

Improve prediction output format

Browse files
Files changed (1) hide show
  1. src/predict.py +15 -11
src/predict.py CHANGED
@@ -25,9 +25,10 @@ import logging
25
 
26
 
27
  def seconds_to_time(seconds):
 
28
  h, remainder = divmod(abs(int(seconds)), 3600)
29
  m, s = divmod(remainder, 60)
30
- return f"{'-' if seconds < 0 else ''}{h:02}:{m:02}:{s:02}"
31
 
32
 
33
  @dataclass
@@ -266,20 +267,23 @@ def main():
266
  tokenizer = AutoTokenizer.from_pretrained(predict_args.model_path)
267
 
268
  predict_args.video_id = predict_args.video_id.strip()
269
- print(
270
- f'Predicting for https://www.youtube.com/watch?v={predict_args.video_id}')
271
  predictions = predict(predict_args.video_id, model, tokenizer,
272
  segmentation_args, classifier_args=classifier_args)
273
 
274
- for prediction in predictions:
275
- print(' '.join([w['text'] for w in prediction['words']]))
276
- print(seconds_to_time(prediction['start']),
277
- '-->', seconds_to_time(prediction['end']))
278
- print(prediction['start'], '-->', prediction['end'])
279
- print(prediction['probability'])
280
- print()
281
 
282
- print()
 
 
 
 
 
 
 
 
283
 
284
 
285
  if __name__ == '__main__':
 
25
 
26
 
27
  def seconds_to_time(seconds):
28
+ fractional = str(round(seconds % 1, 3))[1:]
29
  h, remainder = divmod(abs(int(seconds)), 3600)
30
  m, s = divmod(remainder, 60)
31
+ return f"{'-' if seconds < 0 else ''}{h:02}:{m:02}:{s:02}{fractional}"
32
 
33
 
34
  @dataclass
 
267
  tokenizer = AutoTokenizer.from_pretrained(predict_args.model_path)
268
 
269
  predict_args.video_id = predict_args.video_id.strip()
 
 
270
  predictions = predict(predict_args.video_id, model, tokenizer,
271
  segmentation_args, classifier_args=classifier_args)
272
 
273
+ video_url = f'https://www.youtube.com/watch?v={predict_args.video_id}'
274
+ if not predictions:
275
+ print('No predictions found for', video_url)
276
+ return
 
 
 
277
 
278
+ print(len(predictions), 'predictions found for', video_url)
279
+ for index, prediction in enumerate(predictions, start=1):
280
+ print(f'Prediction #{index}:')
281
+ print('Text: "',
282
+ ' '.join([w['text'] for w in prediction['words']]), '"', sep='')
283
+ print('Time:', seconds_to_time(
284
+ prediction['start']), '-->', seconds_to_time(prediction['end']))
285
+ print('Probability:', prediction['probability'])
286
+ print()
287
 
288
 
289
  if __name__ == '__main__':