Joshua Lochner commited on
Commit
320a2ba
1 Parent(s): 3af0cd0

Change to multiclass classifier

Browse files
Files changed (1) hide show
  1. src/train.py +32 -22
src/train.py CHANGED
@@ -1,7 +1,7 @@
1
  from preprocess import load_datasets, DatasetArguments
2
- from predict import ClassifierArguments, SPONSOR_MATCH_RE
3
  from shared import CustomTokens, device, GeneralArguments, OutputArguments
4
- from model import ModelArguments, get_model, get_tokenizer
5
  import transformers
6
  import logging
7
  import os
@@ -14,15 +14,17 @@ from transformers import (
14
  DataCollatorForSeq2Seq,
15
  HfArgumentParser,
16
  Seq2SeqTrainer,
17
- Seq2SeqTrainingArguments
 
 
18
  )
 
19
  from transformers.trainer_utils import get_last_checkpoint
20
  from transformers.utils import check_min_version
21
  from transformers.utils.versions import require_version
22
  from sklearn.linear_model import LogisticRegression
23
  from sklearn.feature_extraction.text import TfidfVectorizer
24
  from utils import re_findall
25
- import re
26
 
27
  # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
28
  check_min_version('4.13.0.dev0')
@@ -256,8 +258,9 @@ def main():
256
 
257
  ngram_range=(1, 2), # best so far
258
  # max_features=8000 # remove for higher accuracy?
259
- # max_features=50000
260
- max_features=10000
 
261
  )
262
 
263
  train_test_data = {
@@ -276,17 +279,17 @@ def main():
276
  dataset = raw_datasets[ds_type]
277
 
278
  for row in dataset:
279
- # Get matches:
280
- matches = re_findall(SPONSOR_MATCH_RE, row['extracted'])
281
-
282
- return # TODO fix
283
 
284
- if not matches:
285
- matches = [row['text']]
286
 
287
- for match in matches:
288
- train_test_data[ds_type]['X'].append(match)
289
- train_test_data[ds_type]['y'].append(row['sponsor'])
290
 
291
  print('Fitting')
292
  _X_train = vectorizer.fit_transform(train_test_data['train']['X'])
@@ -296,10 +299,10 @@ def main():
296
  y_test = train_test_data['test']['y']
297
 
298
  # 2. Create classifier
299
- classifier = LogisticRegression(max_iter=500)
300
 
301
  # 3. Fit data
302
- print('fit classifier')
303
  classifier.fit(_X_train, y_train)
304
 
305
  # 4. Measure accuracy
@@ -336,9 +339,15 @@ def main():
336
  )
337
 
338
  # Load pretrained model and tokenizer
339
- tokenizer = get_tokenizer(model_args)
340
- model = get_model(model_args)
341
  model.to(device())
 
 
 
 
 
 
342
  model.resize_token_embeddings(len(tokenizer))
343
 
344
  if model.config.decoder_start_token_id is None:
@@ -479,9 +488,10 @@ def main():
479
  train_result = trainer.train(resume_from_checkpoint=checkpoint)
480
  trainer.save_model() # Saves the tokenizer too for easy upload
481
  except KeyboardInterrupt:
482
- print('Saving model')
483
- trainer.save_model(os.path.join(
484
- training_args.output_dir, 'checkpoint-latest')) # TODO use dir
 
485
  raise
486
 
487
  metrics = train_result.metrics
 
1
  from preprocess import load_datasets, DatasetArguments
2
+ from predict import ClassifierArguments, SEGMENT_MATCH_RE, CATEGORIES
3
  from shared import CustomTokens, device, GeneralArguments, OutputArguments
4
+ from model import ModelArguments
5
  import transformers
6
  import logging
7
  import os
 
14
  DataCollatorForSeq2Seq,
15
  HfArgumentParser,
16
  Seq2SeqTrainer,
17
+ Seq2SeqTrainingArguments,
18
+ AutoTokenizer,
19
+ AutoModelForSeq2SeqLM
20
  )
21
+
22
  from transformers.trainer_utils import get_last_checkpoint
23
  from transformers.utils import check_min_version
24
  from transformers.utils.versions import require_version
25
  from sklearn.linear_model import LogisticRegression
26
  from sklearn.feature_extraction.text import TfidfVectorizer
27
  from utils import re_findall
 
28
 
29
  # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
30
  check_min_version('4.13.0.dev0')
 
258
 
259
  ngram_range=(1, 2), # best so far
260
  # max_features=8000 # remove for higher accuracy?
261
+ max_features=20000
262
+ # max_features=10000
263
+ # max_features=1000
264
  )
265
 
266
  train_test_data = {
 
279
  dataset = raw_datasets[ds_type]
280
 
281
  for row in dataset:
282
+ matches = re_findall(SEGMENT_MATCH_RE, row['extracted'])
283
+ if matches:
284
+ for match in matches:
285
+ train_test_data[ds_type]['X'].append(match['text'])
286
 
287
+ class_index = CATEGORIES.index(match['category'])
288
+ train_test_data[ds_type]['y'].append(class_index)
289
 
290
+ else:
291
+ train_test_data[ds_type]['X'].append(row['text'])
292
+ train_test_data[ds_type]['y'].append(0)
293
 
294
  print('Fitting')
295
  _X_train = vectorizer.fit_transform(train_test_data['train']['X'])
 
299
  y_test = train_test_data['test']['y']
300
 
301
  # 2. Create classifier
302
+ classifier = LogisticRegression(max_iter=2000, class_weight='balanced')
303
 
304
  # 3. Fit data
305
+ print('Fit classifier')
306
  classifier.fit(_X_train, y_train)
307
 
308
  # 4. Measure accuracy
 
339
  )
340
 
341
  # Load pretrained model and tokenizer
342
+ model = AutoModelForSeq2SeqLM.from_pretrained(
343
+ model_args.model_name_or_path)
344
  model.to(device())
345
+
346
+ tokenizer = AutoTokenizer.from_pretrained(
347
+ model_args.model_name_or_path)
348
+
349
+ # Ensure model and tokenizer contain the custom tokens
350
+ CustomTokens.add_custom_tokens(tokenizer)
351
  model.resize_token_embeddings(len(tokenizer))
352
 
353
  if model.config.decoder_start_token_id is None:
 
488
  train_result = trainer.train(resume_from_checkpoint=checkpoint)
489
  trainer.save_model() # Saves the tokenizer too for easy upload
490
  except KeyboardInterrupt:
491
+ # TODO add option to save model on interrupt?
492
+ # print('Saving model')
493
+ # trainer.save_model(os.path.join(
494
+ # training_args.output_dir, 'checkpoint-latest')) # TODO use dir
495
  raise
496
 
497
  metrics = train_result.metrics