Spaces:
Runtime error
Runtime error
Joshua Lochner
commited on
Commit
•
90d1f68
1
Parent(s):
4678c9b
Add functionality to predict self-promo and interaction reminders
Browse files- src/evaluate.py +2 -4
- src/predict.py +31 -24
- src/preprocess.py +111 -55
- src/segment.py +3 -3
- src/shared.py +5 -6
- src/train.py +15 -15
- src/utils.py +6 -0
src/evaluate.py
CHANGED
@@ -105,13 +105,13 @@ def calculate_metrics(labelled_words, predictions):
|
|
105 |
|
106 |
if predicted_sponsor:
|
107 |
# total_positive_time += duration
|
108 |
-
if word['
|
109 |
metrics['true_positive'] += duration
|
110 |
else:
|
111 |
metrics['false_positive'] += duration
|
112 |
else:
|
113 |
# total_negative_time += duration
|
114 |
-
if word['
|
115 |
metrics['false_negative'] += duration
|
116 |
else:
|
117 |
metrics['true_negative'] += duration
|
@@ -176,8 +176,6 @@ def main():
|
|
176 |
with open(final_path) as fp:
|
177 |
final_data = json.load(fp)
|
178 |
|
179 |
-
classifier, vectorizer = get_classifier_vectorizer(classifier_args)
|
180 |
-
|
181 |
total_accuracy = 0
|
182 |
total_precision = 0
|
183 |
total_recall = 0
|
|
|
105 |
|
106 |
if predicted_sponsor:
|
107 |
# total_positive_time += duration
|
108 |
+
if word['category'] is not None: # Is actual sponsor
|
109 |
metrics['true_positive'] += duration
|
110 |
else:
|
111 |
metrics['false_positive'] += duration
|
112 |
else:
|
113 |
# total_negative_time += duration
|
114 |
+
if word['category'] is not None: # Is actual sponsor
|
115 |
metrics['false_negative'] += duration
|
116 |
else:
|
117 |
metrics['true_negative'] += duration
|
|
|
176 |
with open(final_path) as fp:
|
177 |
final_data = json.load(fp)
|
178 |
|
|
|
|
|
179 |
total_accuracy = 0
|
180 |
total_precision = 0
|
181 |
total_recall = 0
|
src/predict.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from
|
2 |
from shared import OutputArguments
|
3 |
from typing import Optional
|
4 |
from segment import (
|
@@ -11,21 +11,22 @@ from segment import (
|
|
11 |
SegmentationArguments
|
12 |
)
|
13 |
import preprocess
|
14 |
-
import re
|
15 |
from errors import TranscriptError
|
16 |
from model import get_classifier_vectorizer
|
17 |
from transformers import (
|
18 |
AutoModelForSeq2SeqLM,
|
19 |
-
AutoTokenizer
|
|
|
20 |
)
|
|
|
21 |
from dataclasses import dataclass, field
|
22 |
-
from transformers import HfArgumentParser
|
23 |
from shared import device
|
24 |
import logging
|
25 |
|
26 |
|
27 |
def seconds_to_time(seconds):
|
28 |
-
fractional =
|
|
|
29 |
h, remainder = divmod(abs(int(seconds)), 3600)
|
30 |
m, s = divmod(remainder, 60)
|
31 |
return f"{'-' if seconds < 0 else ''}{h:02}:{m:02}:{s:02}{fractional}"
|
@@ -64,7 +65,7 @@ class PredictArguments(TrainingOutputArguments):
|
|
64 |
)
|
65 |
|
66 |
|
67 |
-
SPONSOR_MATCH_RE = fr'(?<={CustomTokens.
|
68 |
|
69 |
MATCH_WINDOW = 25 # Increase for accuracy, but takes longer: O(n^3)
|
70 |
MERGE_TIME_WITHIN = 8 # Merge predictions if they are within x seconds
|
@@ -97,11 +98,13 @@ class ClassifierArguments:
|
|
97 |
default=0.5, metadata={'help': 'Remove all predictions whose classification probability is below this threshold.'})
|
98 |
|
99 |
|
100 |
-
def filter_predictions(predictions, classifier, vectorizer,
|
101 |
"""Use classifier to filter predictions"""
|
102 |
if not predictions:
|
103 |
return predictions
|
104 |
|
|
|
|
|
105 |
transformed_segments = vectorizer.transform([
|
106 |
preprocess.clean_text(' '.join([x['text'] for x in pred['words']]))
|
107 |
for pred in predictions
|
@@ -142,9 +145,7 @@ def predict(video_id, model, tokenizer, segmentation_args, words=None, classifie
|
|
142 |
words, prediction['start'], prediction['end'])
|
143 |
|
144 |
if classifier_args is not None:
|
145 |
-
|
146 |
-
predictions = filter_predictions(
|
147 |
-
predictions, classifier, vectorizer, classifier_args)
|
148 |
|
149 |
return predictions
|
150 |
|
@@ -166,13 +167,10 @@ def greedy_match(list, sublist):
|
|
166 |
return best_i, best_j, best_k
|
167 |
|
168 |
|
169 |
-
DEFAULT_TOKEN_PREFIX = 'summarize: '
|
170 |
-
|
171 |
-
|
172 |
def predict_sponsor_text(text, model, tokenizer):
|
173 |
"""Given a body of text, predict the words which are part of the sponsor"""
|
174 |
input_ids = tokenizer(
|
175 |
-
f'{
|
176 |
|
177 |
# Can't be longer than input length + SAFETY_TOKENS or model input dim
|
178 |
max_out_len = min(len(input_ids[0]) + SAFETY_TOKENS, model.model_dim)
|
@@ -183,10 +181,11 @@ def predict_sponsor_text(text, model, tokenizer):
|
|
183 |
|
184 |
def predict_sponsor_matches(text, model, tokenizer):
|
185 |
sponsorship_text = predict_sponsor_text(text, model, tokenizer)
|
186 |
-
|
|
|
187 |
return []
|
188 |
|
189 |
-
return
|
190 |
|
191 |
|
192 |
def segments_to_prediction_times(segments, model, tokenizer):
|
@@ -202,7 +201,7 @@ def segments_to_prediction_times(segments, model, tokenizer):
|
|
202 |
matches = predict_sponsor_matches(batch_text, model, tokenizer)
|
203 |
|
204 |
for match in matches:
|
205 |
-
matched_text = match.split()
|
206 |
# TODO skip if too short
|
207 |
|
208 |
i1, j1, k1 = greedy_match(
|
@@ -217,7 +216,8 @@ def segments_to_prediction_times(segments, model, tokenizer):
|
|
217 |
|
218 |
predicted_time_ranges.append({
|
219 |
'start': word_start(extracted_words[0]),
|
220 |
-
'end': word_end(extracted_words[-1])
|
|
|
221 |
})
|
222 |
|
223 |
# Necessary to sort matches by start time
|
@@ -225,23 +225,29 @@ def segments_to_prediction_times(segments, model, tokenizer):
|
|
225 |
|
226 |
# Merge overlapping predictions and sponsorships that are close together
|
227 |
# Caused by model having max input size
|
228 |
-
|
|
|
|
|
229 |
final_predicted_time_ranges = []
|
230 |
for range in predicted_time_ranges:
|
231 |
start_time = range['start']
|
232 |
end_time = range['end']
|
233 |
|
234 |
-
if
|
235 |
-
|
|
|
|
|
|
|
236 |
final_predicted_time_ranges[-1]['end'] = end_time
|
237 |
|
238 |
else: # No overlap, is a new prediction
|
239 |
final_predicted_time_ranges.append({
|
240 |
'start': start_time,
|
241 |
'end': end_time,
|
|
|
242 |
})
|
243 |
|
244 |
-
|
245 |
|
246 |
return final_predicted_time_ranges
|
247 |
|
@@ -268,7 +274,7 @@ def main():
|
|
268 |
|
269 |
predict_args.video_id = predict_args.video_id.strip()
|
270 |
predictions = predict(predict_args.video_id, model, tokenizer,
|
271 |
-
segmentation_args, classifier_args=classifier_args
|
272 |
|
273 |
video_url = f'https://www.youtube.com/watch?v={predict_args.video_id}'
|
274 |
if not predictions:
|
@@ -282,7 +288,8 @@ def main():
|
|
282 |
' '.join([w['text'] for w in prediction['words']]), '"', sep='')
|
283 |
print('Time:', seconds_to_time(
|
284 |
prediction['start']), '-->', seconds_to_time(prediction['end']))
|
285 |
-
print('Probability:', prediction
|
|
|
286 |
print()
|
287 |
|
288 |
|
|
|
1 |
+
from utils import re_findall
|
2 |
from shared import OutputArguments
|
3 |
from typing import Optional
|
4 |
from segment import (
|
|
|
11 |
SegmentationArguments
|
12 |
)
|
13 |
import preprocess
|
|
|
14 |
from errors import TranscriptError
|
15 |
from model import get_classifier_vectorizer
|
16 |
from transformers import (
|
17 |
AutoModelForSeq2SeqLM,
|
18 |
+
AutoTokenizer,
|
19 |
+
HfArgumentParser
|
20 |
)
|
21 |
+
from transformers.trainer_utils import get_last_checkpoint
|
22 |
from dataclasses import dataclass, field
|
|
|
23 |
from shared import device
|
24 |
import logging
|
25 |
|
26 |
|
27 |
def seconds_to_time(seconds):
|
28 |
+
fractional = round(seconds % 1, 3)
|
29 |
+
fractional = '' if fractional == 0 else str(fractional)[1:]
|
30 |
h, remainder = divmod(abs(int(seconds)), 3600)
|
31 |
m, s = divmod(remainder, 60)
|
32 |
return f"{'-' if seconds < 0 else ''}{h:02}:{m:02}:{s:02}{fractional}"
|
|
|
65 |
)
|
66 |
|
67 |
|
68 |
+
SPONSOR_MATCH_RE = fr'(?<={CustomTokens.START_SEGMENT.value})\s*_(?P<category>\S+)\s*(?P<text>.*?)\s*(?={CustomTokens.END_SEGMENT.value}|$)'
|
69 |
|
70 |
MATCH_WINDOW = 25 # Increase for accuracy, but takes longer: O(n^3)
|
71 |
MERGE_TIME_WITHIN = 8 # Merge predictions if they are within x seconds
|
|
|
98 |
default=0.5, metadata={'help': 'Remove all predictions whose classification probability is below this threshold.'})
|
99 |
|
100 |
|
101 |
+
def filter_predictions(predictions, classifier_args): # classifier, vectorizer,
|
102 |
"""Use classifier to filter predictions"""
|
103 |
if not predictions:
|
104 |
return predictions
|
105 |
|
106 |
+
classifier, vectorizer = get_classifier_vectorizer(classifier_args)
|
107 |
+
|
108 |
transformed_segments = vectorizer.transform([
|
109 |
preprocess.clean_text(' '.join([x['text'] for x in pred['words']]))
|
110 |
for pred in predictions
|
|
|
145 |
words, prediction['start'], prediction['end'])
|
146 |
|
147 |
if classifier_args is not None:
|
148 |
+
predictions = filter_predictions(predictions, classifier_args)
|
|
|
|
|
149 |
|
150 |
return predictions
|
151 |
|
|
|
167 |
return best_i, best_j, best_k
|
168 |
|
169 |
|
|
|
|
|
|
|
170 |
def predict_sponsor_text(text, model, tokenizer):
|
171 |
"""Given a body of text, predict the words which are part of the sponsor"""
|
172 |
input_ids = tokenizer(
|
173 |
+
f'{CustomTokens.EXTRACT_SEGMENTS_PREFIX.value} {text}', return_tensors='pt', truncation=True).input_ids.to(device())
|
174 |
|
175 |
# Can't be longer than input length + SAFETY_TOKENS or model input dim
|
176 |
max_out_len = min(len(input_ids[0]) + SAFETY_TOKENS, model.model_dim)
|
|
|
181 |
|
182 |
def predict_sponsor_matches(text, model, tokenizer):
|
183 |
sponsorship_text = predict_sponsor_text(text, model, tokenizer)
|
184 |
+
|
185 |
+
if CustomTokens.NO_SEGMENT.value in sponsorship_text:
|
186 |
return []
|
187 |
|
188 |
+
return re_findall(SPONSOR_MATCH_RE, sponsorship_text)
|
189 |
|
190 |
|
191 |
def segments_to_prediction_times(segments, model, tokenizer):
|
|
|
201 |
matches = predict_sponsor_matches(batch_text, model, tokenizer)
|
202 |
|
203 |
for match in matches:
|
204 |
+
matched_text = match['text'].split()
|
205 |
# TODO skip if too short
|
206 |
|
207 |
i1, j1, k1 = greedy_match(
|
|
|
216 |
|
217 |
predicted_time_ranges.append({
|
218 |
'start': word_start(extracted_words[0]),
|
219 |
+
'end': word_end(extracted_words[-1]),
|
220 |
+
'category': match['category']
|
221 |
})
|
222 |
|
223 |
# Necessary to sort matches by start time
|
|
|
225 |
|
226 |
# Merge overlapping predictions and sponsorships that are close together
|
227 |
# Caused by model having max input size
|
228 |
+
|
229 |
+
prev_prediction = None
|
230 |
+
|
231 |
final_predicted_time_ranges = []
|
232 |
for range in predicted_time_ranges:
|
233 |
start_time = range['start']
|
234 |
end_time = range['end']
|
235 |
|
236 |
+
if prev_prediction is not None and range['category'] == prev_prediction['category'] and (
|
237 |
+
start_time <= prev_prediction['end'] <= end_time or start_time -
|
238 |
+
prev_prediction['end'] <= MERGE_TIME_WITHIN
|
239 |
+
):
|
240 |
+
# Ending time of last segment is in this segment or c, so we extend last prediction range
|
241 |
final_predicted_time_ranges[-1]['end'] = end_time
|
242 |
|
243 |
else: # No overlap, is a new prediction
|
244 |
final_predicted_time_ranges.append({
|
245 |
'start': start_time,
|
246 |
'end': end_time,
|
247 |
+
'category': range['category']
|
248 |
})
|
249 |
|
250 |
+
prev_prediction = range
|
251 |
|
252 |
return final_predicted_time_ranges
|
253 |
|
|
|
274 |
|
275 |
predict_args.video_id = predict_args.video_id.strip()
|
276 |
predictions = predict(predict_args.video_id, model, tokenizer,
|
277 |
+
segmentation_args) # TODO add back , classifier_args=classifier_args
|
278 |
|
279 |
video_url = f'https://www.youtube.com/watch?v={predict_args.video_id}'
|
280 |
if not predictions:
|
|
|
288 |
' '.join([w['text'] for w in prediction['words']]), '"', sep='')
|
289 |
print('Time:', seconds_to_time(
|
290 |
prediction['start']), '-->', seconds_to_time(prediction['end']))
|
291 |
+
print('Probability:', prediction.get('probability'))
|
292 |
+
print('Category:', prediction.get('category'))
|
293 |
print()
|
294 |
|
295 |
|
src/preprocess.py
CHANGED
@@ -1,5 +1,6 @@
|
|
|
|
1 |
import itertools
|
2 |
-
from typing import Optional
|
3 |
from datasets import load_dataset
|
4 |
from model import ModelArguments
|
5 |
import segment
|
@@ -24,8 +25,10 @@ def find(s, ch):
|
|
24 |
return [i for i, ltr in enumerate(s) if ltr == ch]
|
25 |
|
26 |
|
27 |
-
def wordify(transcript):
|
28 |
"""Try to replicate format for automatically generated transcripts"""
|
|
|
|
|
29 |
words = []
|
30 |
|
31 |
for line_index, line in enumerate(transcript):
|
@@ -34,9 +37,14 @@ def wordify(transcript):
|
|
34 |
continue
|
35 |
|
36 |
start = line['start']
|
37 |
-
next_start = transcript[line_index +
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
40 |
duration = end - start
|
41 |
|
42 |
indices = find(text, ' ') + [len(text)]
|
@@ -52,9 +60,9 @@ def wordify(transcript):
|
|
52 |
w_start = start + percentage * duration
|
53 |
|
54 |
words.append({
|
55 |
-
'start': round(w_start,
|
56 |
-
'duration': round(w_duration,
|
57 |
-
'end': round(w_start + w_duration,
|
58 |
'text': word,
|
59 |
})
|
60 |
|
@@ -69,6 +77,10 @@ def get_manual_words(transcript_list):
|
|
69 |
return wordify(transcript)
|
70 |
|
71 |
|
|
|
|
|
|
|
|
|
72 |
def get_auto_words(transcript_list):
|
73 |
words = []
|
74 |
transcript = transcript_list.find_generated_transcript(['en'])
|
@@ -82,7 +94,7 @@ def get_auto_words(transcript_list):
|
|
82 |
offset_ms = word.get('tOffsetMs', 0)
|
83 |
|
84 |
texts = word['utf8'].replace(
|
85 |
-
|
86 |
).strip().split()
|
87 |
|
88 |
for text in texts:
|
@@ -94,7 +106,7 @@ def get_auto_words(transcript_list):
|
|
94 |
return words
|
95 |
|
96 |
|
97 |
-
def get_words(video_id, process=True, fallback=
|
98 |
"""Get parsed video transcript with caching system
|
99 |
returns None if not processed yet and process is False
|
100 |
"""
|
@@ -148,21 +160,31 @@ def extract_sponsors(words, min_sponsor_segment_length=5):
|
|
148 |
|
149 |
paragraphs = []
|
150 |
current = []
|
|
|
151 |
for word in words:
|
152 |
-
if
|
153 |
-
continue
|
154 |
|
155 |
-
if word['
|
156 |
current.append(word['text'])
|
157 |
else:
|
158 |
-
paragraphs.append(
|
|
|
|
|
|
|
159 |
current = []
|
160 |
-
|
161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
|
163 |
# Remove all too short:
|
164 |
paragraphs = list(filter(lambda x: len(
|
165 |
-
x) >= min_sponsor_segment_length, paragraphs))
|
166 |
|
167 |
return paragraphs
|
168 |
|
@@ -203,10 +225,8 @@ def clean_text(text):
|
|
203 |
text = re.sub(NUM_REGEX, CustomTokens.NUMBER.value, text)
|
204 |
|
205 |
# Replace profanity with special token
|
206 |
-
text = text.replace(CustomTokens.
|
207 |
-
|
208 |
-
text = text.replace(CustomTokens.PROFANITY_CONVERTED.value,
|
209 |
-
CustomTokens.PROFANITY.value)
|
210 |
|
211 |
return text.strip()
|
212 |
|
@@ -254,11 +274,25 @@ class PreprocessArguments:
|
|
254 |
do_create: bool = field(
|
255 |
default=False, metadata={'help': 'Merge sponsor segments into single file'}
|
256 |
)
|
|
|
257 |
min_votes: int = field(
|
258 |
default=0, metadata={'help': 'Minimum number of votes'})
|
259 |
# Downvotes will make this negative.
|
260 |
# 1 = At least one positive vote
|
261 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
do_transcribe: bool = field(
|
263 |
default=False, metadata={'help': 'Get transcripts for videos'}
|
264 |
)
|
@@ -266,7 +300,7 @@ class PreprocessArguments:
|
|
266 |
default=4, metadata={'help': 'Number of transcripts to download in parallel'})
|
267 |
|
268 |
overwrite: bool = field(
|
269 |
-
default=
|
270 |
)
|
271 |
|
272 |
do_generate: bool = field(
|
@@ -447,14 +481,26 @@ def main():
|
|
447 |
preprocess_args.raw_data_dir, preprocess_args.raw_data_file)
|
448 |
|
449 |
def get_rows():
|
|
|
|
|
|
|
450 |
with open(raw_dataset_path, newline='') as csvfile:
|
451 |
reader = csv.DictReader(csvfile)
|
|
|
452 |
for line in reader:
|
|
|
|
|
|
|
|
|
|
|
|
|
453 |
if line['service'] != 'YouTube':
|
454 |
continue
|
|
|
|
|
455 |
|
456 |
# TODO add support for other categories and action types?
|
457 |
-
if line['category']
|
458 |
continue
|
459 |
if line['actionType'] != 'skip':
|
460 |
continue
|
@@ -463,9 +509,6 @@ def main():
|
|
463 |
if line['hidden'] == '1' or line['shadowHidden'] == '1':
|
464 |
continue
|
465 |
|
466 |
-
if len(line['videoID']) != 11:
|
467 |
-
continue # Invalid youtube video ID
|
468 |
-
|
469 |
# Skip those that aren't highly voted
|
470 |
line['votes'] = int(line['votes'])
|
471 |
# incorrect_votes = int(line['incorrectVotes'])
|
@@ -494,6 +537,8 @@ def main():
|
|
494 |
for row in data_rows:
|
495 |
video_ids.add(row['videoID'])
|
496 |
|
|
|
|
|
497 |
print('Start transcribing')
|
498 |
with tqdm(total=len(video_ids)) as progress:
|
499 |
def on_job_complete(job):
|
@@ -517,21 +562,18 @@ def main():
|
|
517 |
final_path = os.path.join(
|
518 |
processed_args.processed_dir, processed_args.processed_file)
|
519 |
|
520 |
-
if
|
521 |
-
logging.info(f'{final_path} exists, opening file')
|
522 |
-
with open(final_path) as fp:
|
523 |
-
final_data = json.load(fp)
|
524 |
-
else:
|
525 |
print('Create final data')
|
526 |
|
527 |
final_data = {}
|
528 |
|
529 |
if data_rows is None:
|
530 |
data_rows = get_rows()
|
|
|
531 |
|
532 |
# TODO add progress bar
|
533 |
# TODO parallelise?
|
534 |
-
for line in data_rows:
|
535 |
video_id = line['videoID']
|
536 |
|
537 |
if video_id not in final_data:
|
@@ -540,7 +582,10 @@ def main():
|
|
540 |
segment_start = float(line['startTime'])
|
541 |
segment_end = float(line['endTime'])
|
542 |
|
543 |
-
video_words = get_words(video_id, process=
|
|
|
|
|
|
|
544 |
segment_words = segment.extract_segment(
|
545 |
video_words, segment_start, segment_end)
|
546 |
|
@@ -552,7 +597,8 @@ def main():
|
|
552 |
wps = len(segment_words)/duration if duration > 0 else 0
|
553 |
|
554 |
if wps < preprocess_args.min_wps:
|
555 |
-
print('bad segment in',
|
|
|
556 |
continue
|
557 |
|
558 |
final_data[video_id].append({
|
@@ -580,10 +626,16 @@ def main():
|
|
580 |
# raw_dataset_path, final_path, preprocess_args.min_votes)
|
581 |
# # TODO save metadata in final.json?
|
582 |
|
583 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
584 |
|
585 |
# TODO shuffle final_data
|
586 |
-
|
587 |
# if not os.path.exists(excess_path) or preprocess_args.overwrite
|
588 |
# TODO use overwrite param
|
589 |
|
@@ -610,10 +662,8 @@ def main():
|
|
610 |
write_mode = 'w' if preprocess_args.overwrite else 'a'
|
611 |
|
612 |
get_all = preprocess_args.max_videos is None
|
613 |
-
|
614 |
-
|
615 |
-
else:
|
616 |
-
total = preprocess_args.max_videos
|
617 |
|
618 |
index = 0
|
619 |
data = final_data.items()
|
@@ -641,7 +691,7 @@ def main():
|
|
641 |
elif count_videos >= preprocess_args.max_videos:
|
642 |
break
|
643 |
|
644 |
-
words = get_words(video_id, False)
|
645 |
if not words:
|
646 |
continue
|
647 |
|
@@ -662,34 +712,40 @@ def main():
|
|
662 |
progress.update()
|
663 |
|
664 |
for seg in segments:
|
665 |
-
|
666 |
-
segment_text = ' '.join((x['text'] for x in seg))
|
667 |
-
|
668 |
-
extracted_text = ''
|
669 |
-
for p in extract_sponsors(seg):
|
670 |
-
p_text = ' '.join(p)
|
671 |
-
extracted_text += f'{CustomTokens.START_SPONSOR.value} {p_text} {CustomTokens.END_SPONSOR.value}. '
|
672 |
-
|
673 |
duration = segment.word_end(
|
674 |
seg[-1]) - segment.word_start(seg[0])
|
675 |
wps = len(seg)/duration if duration > 0 else 0
|
|
|
676 |
# Ignore segments with "not enough words" in the transcript
|
677 |
if wps < preprocess_args.min_wps:
|
678 |
continue
|
679 |
|
|
|
|
|
680 |
d = {
|
681 |
'video_index': index,
|
682 |
'video_id': video_id,
|
683 |
'text': clean_text(segment_text),
|
684 |
-
'words_per_second': wps,
|
685 |
}
|
686 |
|
687 |
-
|
688 |
-
|
689 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
690 |
|
691 |
-
|
692 |
-
|
|
|
693 |
|
694 |
if preprocess_args.do_split:
|
695 |
print('Splitting')
|
|
|
1 |
+
from datetime import datetime
|
2 |
import itertools
|
3 |
+
from typing import Optional, List
|
4 |
from datasets import load_dataset
|
5 |
from model import ModelArguments
|
6 |
import segment
|
|
|
25 |
return [i for i, ltr in enumerate(s) if ltr == ch]
|
26 |
|
27 |
|
28 |
+
def wordify(transcript, maximum_wps=1):
|
29 |
"""Try to replicate format for automatically generated transcripts"""
|
30 |
+
|
31 |
+
# Do not allow segments to be on screen for too long using maximum_wps
|
32 |
words = []
|
33 |
|
34 |
for line_index, line in enumerate(transcript):
|
|
|
37 |
continue
|
38 |
|
39 |
start = line['start']
|
40 |
+
next_start = transcript[line_index + 1]['start'] \
|
41 |
+
if line_index < len(transcript) - 1 else float('inf')
|
42 |
+
|
43 |
+
# Use maximum wps to calculate latest end (to avoid segments which stay on screen too long)
|
44 |
+
longest_duration = maximum_wps * text.count(' ')
|
45 |
+
latest_end = start + longest_duration
|
46 |
+
end = min(start + line['duration'], next_start, latest_end)
|
47 |
+
|
48 |
duration = end - start
|
49 |
|
50 |
indices = find(text, ' ') + [len(text)]
|
|
|
60 |
w_start = start + percentage * duration
|
61 |
|
62 |
words.append({
|
63 |
+
'start': round(w_start, 3),
|
64 |
+
'duration': round(w_duration, 3),
|
65 |
+
'end': round(w_start + w_duration, 3),
|
66 |
'text': word,
|
67 |
})
|
68 |
|
|
|
77 |
return wordify(transcript)
|
78 |
|
79 |
|
80 |
+
PROFANITY_RAW = '[ __ ]' # How YouTube transcribes profanity
|
81 |
+
PROFANITY_CONVERTED = '*****' # Safer version for tokenizing
|
82 |
+
|
83 |
+
|
84 |
def get_auto_words(transcript_list):
|
85 |
words = []
|
86 |
transcript = transcript_list.find_generated_transcript(['en'])
|
|
|
94 |
offset_ms = word.get('tOffsetMs', 0)
|
95 |
|
96 |
texts = word['utf8'].replace(
|
97 |
+
PROFANITY_RAW, PROFANITY_CONVERTED
|
98 |
).strip().split()
|
99 |
|
100 |
for text in texts:
|
|
|
106 |
return words
|
107 |
|
108 |
|
109 |
+
def get_words(video_id, process=True, fallback=True, transcript_type='auto'):
|
110 |
"""Get parsed video transcript with caching system
|
111 |
returns None if not processed yet and process is False
|
112 |
"""
|
|
|
160 |
|
161 |
paragraphs = []
|
162 |
current = []
|
163 |
+
prev_category = None
|
164 |
for word in words:
|
165 |
+
if word['category'] is None: # and not current:
|
166 |
+
continue # Skip unimportant
|
167 |
|
168 |
+
if word['category'] == prev_category:
|
169 |
current.append(word['text'])
|
170 |
else:
|
171 |
+
paragraphs.append({
|
172 |
+
'words': current,
|
173 |
+
'category': prev_category,
|
174 |
+
})
|
175 |
current = []
|
176 |
+
|
177 |
+
prev_category = word['category']
|
178 |
+
|
179 |
+
if current and prev_category is not None:
|
180 |
+
paragraphs.append({
|
181 |
+
'words': current,
|
182 |
+
'category': prev_category,
|
183 |
+
})
|
184 |
|
185 |
# Remove all too short:
|
186 |
paragraphs = list(filter(lambda x: len(
|
187 |
+
x['words']) >= min_sponsor_segment_length, paragraphs))
|
188 |
|
189 |
return paragraphs
|
190 |
|
|
|
225 |
text = re.sub(NUM_REGEX, CustomTokens.NUMBER.value, text)
|
226 |
|
227 |
# Replace profanity with special token
|
228 |
+
text = text.replace(PROFANITY_RAW, CustomTokens.PROFANITY.value)
|
229 |
+
text = text.replace(PROFANITY_CONVERTED, CustomTokens.PROFANITY.value)
|
|
|
|
|
230 |
|
231 |
return text.strip()
|
232 |
|
|
|
274 |
do_create: bool = field(
|
275 |
default=False, metadata={'help': 'Merge sponsor segments into single file'}
|
276 |
)
|
277 |
+
|
278 |
min_votes: int = field(
|
279 |
default=0, metadata={'help': 'Minimum number of votes'})
|
280 |
# Downvotes will make this negative.
|
281 |
# 1 = At least one positive vote
|
282 |
|
283 |
+
min_date: str = field(
|
284 |
+
default='20/08/2021', metadata={'help': 'Only use submissions from after this date, defaults to the release of v3.0 (https://github.com/ajayyy/SponsorBlock/releases/tag/3.0)'})
|
285 |
+
|
286 |
+
categories: str = field(
|
287 |
+
default_factory=lambda: ['sponsor', 'selfpromo', 'interaction'],
|
288 |
+
metadata={
|
289 |
+
'nargs': '+',
|
290 |
+
'choices': ['intro', 'sponsor', 'interaction',
|
291 |
+
'outro', 'selfpromo', 'preview',
|
292 |
+
'poi_highlight', 'filler', 'music_offtopic'] # moreCategories
|
293 |
+
}
|
294 |
+
)
|
295 |
+
|
296 |
do_transcribe: bool = field(
|
297 |
default=False, metadata={'help': 'Get transcripts for videos'}
|
298 |
)
|
|
|
300 |
default=4, metadata={'help': 'Number of transcripts to download in parallel'})
|
301 |
|
302 |
overwrite: bool = field(
|
303 |
+
default=True, metadata={'help': 'Overwrite training, testing and validation data, if present.'}
|
304 |
)
|
305 |
|
306 |
do_generate: bool = field(
|
|
|
481 |
preprocess_args.raw_data_dir, preprocess_args.raw_data_file)
|
482 |
|
483 |
def get_rows():
|
484 |
+
|
485 |
+
latest_time = datetime.strptime(preprocess_args.min_date, '%d/%m/%Y')
|
486 |
+
|
487 |
with open(raw_dataset_path, newline='') as csvfile:
|
488 |
reader = csv.DictReader(csvfile)
|
489 |
+
|
490 |
for line in reader:
|
491 |
+
submitted_time = datetime.fromtimestamp(
|
492 |
+
float(line['timeSubmitted'])/1e3)
|
493 |
+
|
494 |
+
if submitted_time < latest_time:
|
495 |
+
continue
|
496 |
+
|
497 |
if line['service'] != 'YouTube':
|
498 |
continue
|
499 |
+
if len(line['videoID']) != 11:
|
500 |
+
continue # Invalid youtube video ID
|
501 |
|
502 |
# TODO add support for other categories and action types?
|
503 |
+
if line['category'] not in preprocess_args.categories:
|
504 |
continue
|
505 |
if line['actionType'] != 'skip':
|
506 |
continue
|
|
|
509 |
if line['hidden'] == '1' or line['shadowHidden'] == '1':
|
510 |
continue
|
511 |
|
|
|
|
|
|
|
512 |
# Skip those that aren't highly voted
|
513 |
line['votes'] = int(line['votes'])
|
514 |
# incorrect_votes = int(line['incorrectVotes'])
|
|
|
537 |
for row in data_rows:
|
538 |
video_ids.add(row['videoID'])
|
539 |
|
540 |
+
# TODO first set - os.listdir and do rest
|
541 |
+
|
542 |
print('Start transcribing')
|
543 |
with tqdm(total=len(video_ids)) as progress:
|
544 |
def on_job_complete(job):
|
|
|
562 |
final_path = os.path.join(
|
563 |
processed_args.processed_dir, processed_args.processed_file)
|
564 |
|
565 |
+
if preprocess_args.do_create:
|
|
|
|
|
|
|
|
|
566 |
print('Create final data')
|
567 |
|
568 |
final_data = {}
|
569 |
|
570 |
if data_rows is None:
|
571 |
data_rows = get_rows()
|
572 |
+
# data_rows = itertools.islice(data_rows, 1000) # TODO temp
|
573 |
|
574 |
# TODO add progress bar
|
575 |
# TODO parallelise?
|
576 |
+
for index, line in enumerate(data_rows):
|
577 |
video_id = line['videoID']
|
578 |
|
579 |
if video_id not in final_data:
|
|
|
582 |
segment_start = float(line['startTime'])
|
583 |
segment_end = float(line['endTime'])
|
584 |
|
585 |
+
video_words = get_words(video_id, process=False)
|
586 |
+
if not video_words:
|
587 |
+
continue
|
588 |
+
|
589 |
segment_words = segment.extract_segment(
|
590 |
video_words, segment_start, segment_end)
|
591 |
|
|
|
597 |
wps = len(segment_words)/duration if duration > 0 else 0
|
598 |
|
599 |
if wps < preprocess_args.min_wps:
|
600 |
+
print(index, 'Skipping bad segment in',
|
601 |
+
video_id, '| wps =', wps)
|
602 |
continue
|
603 |
|
604 |
final_data[video_id].append({
|
|
|
626 |
# raw_dataset_path, final_path, preprocess_args.min_votes)
|
627 |
# # TODO save metadata in final.json?
|
628 |
|
629 |
+
elif os.path.exists(final_path):
|
630 |
+
# Already exists
|
631 |
+
logging.info(f'{final_path} exists, opening file')
|
632 |
+
with open(final_path) as fp:
|
633 |
+
final_data = json.load(fp)
|
634 |
+
logging.info(f'Found {len(final_data)} videos')
|
635 |
+
else:
|
636 |
+
return # Do not continue
|
637 |
|
638 |
# TODO shuffle final_data
|
|
|
639 |
# if not os.path.exists(excess_path) or preprocess_args.overwrite
|
640 |
# TODO use overwrite param
|
641 |
|
|
|
662 |
write_mode = 'w' if preprocess_args.overwrite else 'a'
|
663 |
|
664 |
get_all = preprocess_args.max_videos is None
|
665 |
+
|
666 |
+
total = len(final_data) if get_all else preprocess_args.max_videos
|
|
|
|
|
667 |
|
668 |
index = 0
|
669 |
data = final_data.items()
|
|
|
691 |
elif count_videos >= preprocess_args.max_videos:
|
692 |
break
|
693 |
|
694 |
+
words = get_words(video_id, process=False)
|
695 |
if not words:
|
696 |
continue
|
697 |
|
|
|
712 |
progress.update()
|
713 |
|
714 |
for seg in segments:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
715 |
duration = segment.word_end(
|
716 |
seg[-1]) - segment.word_start(seg[0])
|
717 |
wps = len(seg)/duration if duration > 0 else 0
|
718 |
+
|
719 |
# Ignore segments with "not enough words" in the transcript
|
720 |
if wps < preprocess_args.min_wps:
|
721 |
continue
|
722 |
|
723 |
+
segment_text = ' '.join((x['text'] for x in seg))
|
724 |
+
extracted_segments = extract_sponsors(seg)
|
725 |
d = {
|
726 |
'video_index': index,
|
727 |
'video_id': video_id,
|
728 |
'text': clean_text(segment_text),
|
729 |
+
'words_per_second': round(wps, 3),
|
730 |
}
|
731 |
|
732 |
+
if extracted_segments:
|
733 |
+
extracted_texts = []
|
734 |
+
for s in extracted_segments:
|
735 |
+
w = ' '.join(s['words'])
|
736 |
+
category = s['category'].upper()
|
737 |
+
|
738 |
+
t = f"{CustomTokens.START_SEGMENT.value}_{category} {w} {CustomTokens.END_SEGMENT.value}_{category}"
|
739 |
+
extracted_texts.append(t)
|
740 |
+
|
741 |
+
extracted_text = '\n'.join(extracted_texts)
|
742 |
+
|
743 |
+
d['extracted'] = clean_text(extracted_text)
|
744 |
+
print(json.dumps(d), file=positive)
|
745 |
|
746 |
+
else:
|
747 |
+
d['extracted'] = CustomTokens.NO_SEGMENT.value
|
748 |
+
print(json.dumps(d), file=negative)
|
749 |
|
750 |
if preprocess_args.do_split:
|
751 |
print('Splitting')
|
src/segment.py
CHANGED
@@ -25,7 +25,7 @@ def get_overlapping_chunks_of_tokens(tokens, size, overlap):
|
|
25 |
|
26 |
|
27 |
# Generate up to max_tokens - SAFETY_TOKENS
|
28 |
-
SAFETY_TOKENS =
|
29 |
|
30 |
|
31 |
# TODO play around with this?
|
@@ -36,10 +36,10 @@ def add_labels_to_words(words, sponsor_segments):
|
|
36 |
|
37 |
# TODO binary search
|
38 |
for word in words:
|
39 |
-
word['
|
40 |
for sponsor_segment in sponsor_segments:
|
41 |
if sponsor_segment['start'] <= word['start'] <= sponsor_segment['end']:
|
42 |
-
word['
|
43 |
|
44 |
# TODO use extract_segment with mapping function?
|
45 |
# TODO remove sponsor segments that contain mostly empty space?
|
|
|
25 |
|
26 |
|
27 |
# Generate up to max_tokens - SAFETY_TOKENS
|
28 |
+
SAFETY_TOKENS = 12
|
29 |
|
30 |
|
31 |
# TODO play around with this?
|
|
|
36 |
|
37 |
# TODO binary search
|
38 |
for word in words:
|
39 |
+
word['category'] = None
|
40 |
for sponsor_segment in sponsor_segments:
|
41 |
if sponsor_segment['start'] <= word['start'] <= sponsor_segment['end']:
|
42 |
+
word['category'] = sponsor_segment['category']
|
43 |
|
44 |
# TODO use extract_segment with mapping function?
|
45 |
# TODO remove sponsor segments that contain mostly empty space?
|
src/shared.py
CHANGED
@@ -7,16 +7,17 @@ from typing import Optional
|
|
7 |
from dataclasses import dataclass, field
|
8 |
from enum import Enum
|
9 |
|
10 |
-
|
11 |
class CustomTokens(Enum):
|
|
|
|
|
12 |
URL = 'URL_TOKEN'
|
13 |
HYPHENATED_URL = 'HYPHENATED_URL_TOKEN'
|
14 |
NUMBER_PERCENTAGE = 'NUMBER_PERCENTAGE_TOKEN'
|
15 |
NUMBER = 'NUMBER_TOKEN'
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
|
21 |
SHORT_HYPHENATED = 'SHORT_HYPHENATED_TOKEN'
|
22 |
LONG_WORD = 'LONG_WORD_TOKEN'
|
@@ -26,8 +27,6 @@ class CustomTokens(Enum):
|
|
26 |
APPLAUSE = '[Applause]'
|
27 |
LAUGHTER = '[Laughter]'
|
28 |
|
29 |
-
PROFANITY_RAW = '[ __ ]' # How YouTube transcribes profanity
|
30 |
-
PROFANITY_CONVERTED = '*****' # Safer version for tokenizing
|
31 |
PROFANITY = 'PROFANITY_TOKEN'
|
32 |
|
33 |
@classmethod
|
|
|
7 |
from dataclasses import dataclass, field
|
8 |
from enum import Enum
|
9 |
|
|
|
10 |
class CustomTokens(Enum):
|
11 |
+
EXTRACT_SEGMENTS_PREFIX = 'EXTRACT_SEGMENTS: '
|
12 |
+
|
13 |
URL = 'URL_TOKEN'
|
14 |
HYPHENATED_URL = 'HYPHENATED_URL_TOKEN'
|
15 |
NUMBER_PERCENTAGE = 'NUMBER_PERCENTAGE_TOKEN'
|
16 |
NUMBER = 'NUMBER_TOKEN'
|
17 |
|
18 |
+
START_SEGMENT = 'START_SEGMENT_TOKEN'
|
19 |
+
END_SEGMENT = 'END_SEGMENT_TOKEN'
|
20 |
+
NO_SEGMENT = 'NO_SEGMENT_FOUND'
|
21 |
|
22 |
SHORT_HYPHENATED = 'SHORT_HYPHENATED_TOKEN'
|
23 |
LONG_WORD = 'LONG_WORD_TOKEN'
|
|
|
27 |
APPLAUSE = '[Applause]'
|
28 |
LAUGHTER = '[Laughter]'
|
29 |
|
|
|
|
|
30 |
PROFANITY = 'PROFANITY_TOKEN'
|
31 |
|
32 |
@classmethod
|
src/train.py
CHANGED
@@ -1,9 +1,8 @@
|
|
1 |
from preprocess import load_datasets, DatasetArguments
|
2 |
-
from predict import ClassifierArguments, SPONSOR_MATCH_RE
|
3 |
-
from shared import device, GeneralArguments, OutputArguments
|
4 |
-
from model import ModelArguments
|
5 |
import transformers
|
6 |
-
from model import get_model, get_tokenizer
|
7 |
import logging
|
8 |
import os
|
9 |
import sys
|
@@ -22,7 +21,7 @@ from transformers.utils import check_min_version
|
|
22 |
from transformers.utils.versions import require_version
|
23 |
from sklearn.linear_model import LogisticRegression
|
24 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
25 |
-
|
26 |
import re
|
27 |
|
28 |
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
@@ -117,7 +116,7 @@ class DataTrainingArguments:
|
|
117 |
},
|
118 |
)
|
119 |
source_prefix: Optional[str] = field(
|
120 |
-
default=
|
121 |
'help': 'A prefix to add before every source text (useful for T5 models).'}
|
122 |
)
|
123 |
|
@@ -135,11 +134,11 @@ class SequenceTrainingArguments(OutputArguments, Seq2SeqTrainingArguments):
|
|
135 |
num_train_epochs: float = field(
|
136 |
default=1, metadata={'help': 'Total number of training epochs to perform.'})
|
137 |
|
138 |
-
save_steps: int = field(default=
|
139 |
'help': 'Save checkpoint every X updates steps.'})
|
140 |
-
eval_steps: int = field(default=
|
141 |
'help': 'Run an evaluation every X steps.'})
|
142 |
-
logging_steps: int = field(default=
|
143 |
'help': 'Log every X updates steps.'})
|
144 |
|
145 |
skip_train_transformer: bool = field(default=False, metadata={
|
@@ -257,8 +256,8 @@ def main():
|
|
257 |
|
258 |
ngram_range=(1, 2), # best so far
|
259 |
# max_features=8000 # remove for higher accuracy?
|
260 |
-
max_features=50000
|
261 |
-
|
262 |
)
|
263 |
|
264 |
train_test_data = {
|
@@ -277,11 +276,12 @@ def main():
|
|
277 |
dataset = raw_datasets[ds_type]
|
278 |
|
279 |
for row in dataset:
|
280 |
-
|
281 |
# Get matches:
|
282 |
-
|
283 |
-
|
284 |
-
|
|
|
|
|
285 |
matches = [row['text']]
|
286 |
|
287 |
for match in matches:
|
|
|
1 |
from preprocess import load_datasets, DatasetArguments
|
2 |
+
from predict import ClassifierArguments, SPONSOR_MATCH_RE
|
3 |
+
from shared import CustomTokens, device, GeneralArguments, OutputArguments
|
4 |
+
from model import ModelArguments, get_model, get_tokenizer
|
5 |
import transformers
|
|
|
6 |
import logging
|
7 |
import os
|
8 |
import sys
|
|
|
21 |
from transformers.utils.versions import require_version
|
22 |
from sklearn.linear_model import LogisticRegression
|
23 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
24 |
+
from utils import re_findall
|
25 |
import re
|
26 |
|
27 |
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
|
|
116 |
},
|
117 |
)
|
118 |
source_prefix: Optional[str] = field(
|
119 |
+
default=CustomTokens.EXTRACT_SEGMENTS_PREFIX.value, metadata={
|
120 |
'help': 'A prefix to add before every source text (useful for T5 models).'}
|
121 |
)
|
122 |
|
|
|
134 |
num_train_epochs: float = field(
|
135 |
default=1, metadata={'help': 'Total number of training epochs to perform.'})
|
136 |
|
137 |
+
save_steps: int = field(default=5000, metadata={
|
138 |
'help': 'Save checkpoint every X updates steps.'})
|
139 |
+
eval_steps: int = field(default=5000, metadata={
|
140 |
'help': 'Run an evaluation every X steps.'})
|
141 |
+
logging_steps: int = field(default=5000, metadata={
|
142 |
'help': 'Log every X updates steps.'})
|
143 |
|
144 |
skip_train_transformer: bool = field(default=False, metadata={
|
|
|
256 |
|
257 |
ngram_range=(1, 2), # best so far
|
258 |
# max_features=8000 # remove for higher accuracy?
|
259 |
+
# max_features=50000
|
260 |
+
max_features=10000
|
261 |
)
|
262 |
|
263 |
train_test_data = {
|
|
|
276 |
dataset = raw_datasets[ds_type]
|
277 |
|
278 |
for row in dataset:
|
|
|
279 |
# Get matches:
|
280 |
+
matches = re_findall(SPONSOR_MATCH_RE, row['extracted'])
|
281 |
+
|
282 |
+
return # TODO fix
|
283 |
+
|
284 |
+
if not matches:
|
285 |
matches = [row['text']]
|
286 |
|
287 |
for match in matches:
|
src/utils.py
CHANGED
@@ -1,6 +1,8 @@
|
|
|
|
1 |
import asyncio
|
2 |
import os
|
3 |
|
|
|
4 |
class Job:
|
5 |
def __init__(self, function, *args, **kwargs) -> None:
|
6 |
self.function = function
|
@@ -84,3 +86,7 @@ class InterruptibleThreadPool:
|
|
84 |
self.loop.close()
|
85 |
|
86 |
return self.jobs
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
import asyncio
|
3 |
import os
|
4 |
|
5 |
+
|
6 |
class Job:
|
7 |
def __init__(self, function, *args, **kwargs) -> None:
|
8 |
self.function = function
|
|
|
86 |
self.loop.close()
|
87 |
|
88 |
return self.jobs
|
89 |
+
|
90 |
+
|
91 |
+
def re_findall(pattern, string):
|
92 |
+
return [m.groupdict() for m in re.finditer(pattern, string)]
|