import pandas as pd import torch from torch.utils.data import Dataset from PIL import Image import argparse from evaluate import load from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator, AdamW import torchvision.transforms as transforms from augments import RandAug, RandRotate parser = argparse.ArgumentParser('arguments for the code') parser.add_argument('--root_path', type=str, default="", help='Root path to data files.') parser.add_argument('--tr_data_path', type=str, default="/path/to/train_data.csv", help='Path to .csv file containing the training data.') parser.add_argument('--val_data_path', type=str, default="/path/to/val_data.csv", help='Path to .csv file containing the validation data.') parser.add_argument('--output_path', type=str, default="./output/path/", help='Path for saving training results.') parser.add_argument('--model_path', type=str, default="/model/path/", help='Path to trocr model') parser.add_argument('--processor_path', type=str, default="/processor/path/", help='Path to trocr processor') parser.add_argument('--epochs', type=int, default=15, help='Training epochs.') parser.add_argument('--batch_size', type=int, default=16, help='Training epochs.') parser.add_argument('--device', type=str, default="cuda:0", help='Device used for training.') parser.add_argument('--augment', type=int, default=0, help='Defines if image augmentations are used during training.') args = parser.parse_args() # Initialize processor and model processor = TrOCRProcessor.from_pretrained(args.processor_path) model = VisionEncoderDecoderModel.from_pretrained(args.model_path) model.to(args.device) # Initialize metrics cer_metric = load("cer") wer_metric = load("wer") # Load train and validation data to dataframes train_df = pd.read_csv(args.tr_data_path) val_df = pd.read_csv(args.val_data_path) # Reset the indices to start from zero train_df.reset_index(drop=True, inplace=True) val_df.reset_index(drop=True, inplace=True) # Torch dataset class TextlineDataset(Dataset): def __init__(self, root_dir, df, processor, max_target_length=128, augment=False): self.root_dir = root_dir self.df = df self.processor = processor self.max_target_length = max_target_length self.augment = augment self.augmentator = RandAug() self.rotator = RandRotate() def __len__(self): return len(self.df) def __getitem__(self, idx): # get file name + text file_name = self.df['file_name'][idx] text = self.df['text'][idx] # prepare image (i.e. resize + normalize) image = Image.open(self.root_dir + file_name).convert("RGB") if self.augment: image = self.augmentator(image) pixel_values = self.processor(image, return_tensors="pt").pixel_values # add labels (input_ids) by encoding the text labels = self.processor.tokenizer(text, padding="max_length", truncation=True, max_length=self.max_target_length).input_ids # important: make sure that PAD tokens are ignored by the loss function labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels] #encoding = {"pixel_values": pixel_values.squeeze(0),"labels":labels} encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} return encoding # Create train and validation datasets train_dataset = TextlineDataset(root_dir=args.root_path, df=train_df, processor=processor, augment=args.augment) eval_dataset = TextlineDataset(root_dir=args.root_path, df=val_df, processor=processor, augment=False) print("Number of training examples:", len(train_dataset)) print("Number of validation examples:", len(eval_dataset)) # Define model configuration # set special tokens used for creating the decoder_input_ids from the labels model.config.decoder_start_token_id = processor.tokenizer.cls_token_id model.config.pad_token_id = processor.tokenizer.pad_token_id # make sure vocab size is set correctly model.config.vocab_size = model.config.decoder.vocab_size # set beam search parameters model.config.eos_token_id = processor.tokenizer.sep_token_id model.config.max_length = 64 model.config.early_stopping = True model.config.no_repeat_ngram_size = 3 model.config.length_penalty = 2.0 model.config.num_beams = 1 # Set arguments for model training # For all argumenst see https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments training_args = Seq2SeqTrainingArguments( predict_with_generate=True, eval_strategy="epoch", save_strategy="epoch", logging_strategy="steps", logging_steps=50, per_device_train_batch_size=args.batch_size, per_device_eval_batch_size=args.batch_size, load_best_model_at_end=True, metric_for_best_model='cer', greater_is_better=False, fp16=True, num_train_epochs=args.epochs, save_total_limit=1, output_dir=args.output_path, optim='adamw_torch' ) # Function for computing CER and WER metrics for the prediction results def compute_metrics(pred): labels_ids = pred.label_ids pred_ids = pred.predictions pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True) labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id label_str = processor.batch_decode(labels_ids, skip_special_tokens=True) cer = cer_metric.compute(predictions=pred_str, references=label_str) wer = wer_metric.compute(predictions=pred_str, references=label_str) return {"cer": cer, "wer": wer} # Instantiate trainer # For all parameters see: https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainer trainer = Seq2SeqTrainer( model=model, tokenizer=processor.image_processor, args=training_args, compute_metrics=compute_metrics, train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=default_data_collator, ) # Train the model trainer.train() #trainer.train(resume_from_checkpoint = True) model.save_pretrained(args.output_path) processor.save_pretrained(args.output_path + "/processor")