MikkoLipsanen commited on
Commit
8713ab2
1 Parent(s): 8113bdf

Upload 2 files

Browse files
Files changed (2) hide show
  1. augments.py +57 -0
  2. train_trocr.py +182 -0
augments.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import torchvision.transforms as T
4
+ import numpy as np
5
+
6
+ class RandAug:
7
+ """Randomly chosen image augmentations."""
8
+
9
+ def __init__(self):
10
+ # Augmentation options
11
+ self.trans = ['identity', 'color', 'sharpness', 'blur']
12
+
13
+ def __call__(self, img):
14
+ self.choice = random.choices(self.trans, weights=(25, 25, 25, 25))[0]
15
+
16
+ if self.choice == 'identity':
17
+ return img
18
+
19
+ elif self.choice == 'color':
20
+ rand_brightness = random.uniform(0, 0.3)
21
+ rand_hue = random.uniform(0, 0.5)
22
+ rand_contrast = random.uniform(0, 0.5)
23
+ rand_saturation = random.uniform(0, 0.5)
24
+ trans = T.ColorJitter(brightness=rand_brightness, contrast=rand_contrast, saturation=rand_saturation, hue=rand_hue)
25
+ img = trans(img)
26
+
27
+ elif self.choice=='sharpness':
28
+ sharpness = 1+(np.random.exponential()/2)
29
+ trans = T.RandomAdjustSharpness(sharpness, p=1)
30
+ img = trans(img)
31
+
32
+ elif self.choice=='blur':
33
+ kernel = random.choice([1,3,5])
34
+ trans = T.GaussianBlur(kernel, sigma=(0.1, 2.0))
35
+ img = trans(img)
36
+
37
+ return img
38
+
39
+
40
+ class RandRotate:
41
+ """Randomly chosen image augmentations."""
42
+
43
+ def __init__(self, low = 0, high = 180):
44
+ # Augmentation options
45
+ self.rotation = torch.randint(low=low, high=high, size=(1,)).item()
46
+ self.trans = ['identity', 'rotation']
47
+
48
+ def __call__(self, img, mask):
49
+ self.choice = random.choices(self.trans, weights=(50, 50))[0]
50
+
51
+ if self.choice == 'identity':
52
+ return img, mask
53
+
54
+ elif self.choice == 'rotation':
55
+ rotated_img = T.functional.rotate(img=img, angle=self.rotation, expand=False)
56
+ rotated_mask = T.functional.rotate(img=mask, angle=self.rotation, expand=False)
57
+ return rotated_img, rotated_mask
train_trocr.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import torch
3
+ from torch.utils.data import Dataset
4
+ from PIL import Image
5
+ import argparse
6
+ from evaluate import load
7
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator, AdamW
8
+ import torchvision.transforms as transforms
9
+ from augments import RandAug, RandRotate
10
+ #import torch_optimizer as optim
11
+
12
+ parser = argparse.ArgumentParser('arguments for the code')
13
+
14
+ parser.add_argument('--root_path', type=str, default="",
15
+ help='Root path to data files.')
16
+ parser.add_argument('--tr_data_path', type=str, default="/data/htr/taulukkosolut/train/trocr/train_new_anno.csv",
17
+ help='Path to .csv file containing the training data.')
18
+ parser.add_argument('--val_data_path', type=str, default="/data/htr/taulukkosolut/val/trocr/val_new_anno.csv",
19
+ help='Path to .csv file containing the validation data.')
20
+ parser.add_argument('--output_path', type=str, default="./output/no_aug_1beam_07092024/",
21
+ help='Path for saving training results.')
22
+ parser.add_argument('--model_path', type=str, default="/4tb_01/models/htr/supermalli/202405_fp16/",
23
+ help='Path to trocr model')
24
+ parser.add_argument('--processor_path', type=str, default="/4tb_01/models/htr/supermalli/202405_fp16/processor",
25
+ help='Path to trocr processor')
26
+ parser.add_argument('--epochs', type=int, default=20,
27
+ help='Training epochs.')
28
+ parser.add_argument('--batch_size', type=int, default=16,
29
+ help='Training epochs.')
30
+ parser.add_argument('--device', type=str, default="cuda:0",
31
+ help='Device used for training.')
32
+ parser.add_argument('--augment', type=int, default=0,
33
+ help='Defines if image augmentations are used during training.')
34
+
35
+ args = parser.parse_args()
36
+
37
+ # nohup python train_trocr.py > logs/taulukkosolut_no_aug_1beam_07092024.txt 2>&1 &
38
+ # echo $! > logs/save_pid.txt
39
+
40
+ #image_size = (224,224)
41
+ #resized_images = []
42
+ # run using 2 GPUs: torchrun --nproc_per_node=2 train_trocr.py
43
+
44
+ # Initialize processor and model
45
+ processor = TrOCRProcessor.from_pretrained(args.processor_path)
46
+ model = VisionEncoderDecoderModel.from_pretrained(args.model_path)
47
+ model.to(args.device)
48
+
49
+ # Initialize metrics
50
+ cer_metric = load("cer")
51
+ wer_metric = load("wer")
52
+
53
+ # Load train and validation data to dataframes
54
+ train_df = pd.read_csv(args.tr_data_path)
55
+ val_df = pd.read_csv(args.val_data_path)
56
+ #train_df = train_df.iloc[:10]
57
+ #val_df = val_df.iloc[:5]
58
+
59
+ # Reset the indices to start from zero
60
+ train_df.reset_index(drop=True, inplace=True)
61
+ val_df.reset_index(drop=True, inplace=True)
62
+
63
+ # Torch dataset
64
+ class TextlineDataset(Dataset):
65
+ def __init__(self, root_dir, df, processor, max_target_length=128, augment=False):
66
+ self.root_dir = root_dir
67
+ self.df = df
68
+ self.processor = processor
69
+ self.max_target_length = max_target_length
70
+ self.augment = augment
71
+ self.augmentator = RandAug()
72
+ self.rotator = RandRotate()
73
+
74
+ def __len__(self):
75
+ return len(self.df)
76
+
77
+ def __getitem__(self, idx):
78
+ # get file name + text
79
+ file_name = self.df['file_name'][idx]
80
+ text = self.df['text'][idx]
81
+
82
+ # prepare image (i.e. resize + normalize)
83
+ image = Image.open(self.root_dir + file_name).convert("RGB")
84
+
85
+ if self.augment:
86
+ image = self.augmentator(image)
87
+
88
+ pixel_values = self.processor(image, return_tensors="pt").pixel_values
89
+
90
+ # add labels (input_ids) by encoding the text
91
+ labels = self.processor.tokenizer(text,
92
+ padding="max_length", truncation=True,
93
+ max_length=self.max_target_length).input_ids
94
+ # important: make sure that PAD tokens are ignored by the loss function
95
+ labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
96
+ #encoding = {"pixel_values": pixel_values.squeeze(0),"labels":labels}
97
+ encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
98
+ return encoding
99
+
100
+ # Create train and validation datasets
101
+ train_dataset = TextlineDataset(root_dir=args.root_path,
102
+ df=train_df,
103
+ processor=processor,
104
+ augment=args.augment)
105
+
106
+ eval_dataset = TextlineDataset(root_dir=args.root_path,
107
+ df=val_df,
108
+ processor=processor,
109
+ augment=False)
110
+
111
+ print("Number of training examples:", len(train_dataset))
112
+ print("Number of validation examples:", len(eval_dataset))
113
+
114
+ # Define model configuration
115
+
116
+ # set special tokens used for creating the decoder_input_ids from the labels
117
+ model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
118
+ model.config.pad_token_id = processor.tokenizer.pad_token_id
119
+ # make sure vocab size is set correctly
120
+ model.config.vocab_size = model.config.decoder.vocab_size
121
+ # set beam search parameters
122
+ model.config.eos_token_id = processor.tokenizer.sep_token_id
123
+ model.config.max_length = 64
124
+ model.config.early_stopping = True
125
+ model.config.no_repeat_ngram_size = 3
126
+ model.config.length_penalty = 2.0
127
+ model.config.num_beams = 1
128
+
129
+ # Set arguments for model training
130
+ # For all argumenst see https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments
131
+ training_args = Seq2SeqTrainingArguments(
132
+ learning_rate=2.779367469510554e-05,
133
+ predict_with_generate=True,
134
+ eval_strategy="epoch",
135
+ save_strategy="epoch",
136
+ logging_strategy="steps",
137
+ logging_steps=50,
138
+ per_device_train_batch_size=args.batch_size,
139
+ per_device_eval_batch_size=args.batch_size,
140
+ load_best_model_at_end=True,
141
+ metric_for_best_model='cer',
142
+ greater_is_better=False,
143
+ fp16=True,
144
+ num_train_epochs=args.epochs,
145
+ save_total_limit=1,
146
+ output_dir=args.output_path,
147
+ optim='adamw_torch'
148
+ )
149
+
150
+ # Function for computing CER and WER metrics for the prediction results
151
+ def compute_metrics(pred):
152
+ labels_ids = pred.label_ids
153
+ pred_ids = pred.predictions
154
+
155
+ pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
156
+ labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
157
+ label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)
158
+
159
+ cer = cer_metric.compute(predictions=pred_str, references=label_str)
160
+ wer = wer_metric.compute(predictions=pred_str, references=label_str)
161
+
162
+ return {"cer": cer, "wer": wer}
163
+
164
+
165
+ # instantiate trainer
166
+ # For all parameters see: https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainer
167
+ trainer = Seq2SeqTrainer(
168
+ model=model,
169
+ tokenizer=processor.image_processor,
170
+ args=training_args,
171
+ compute_metrics=compute_metrics,
172
+ train_dataset=train_dataset,
173
+ eval_dataset=eval_dataset,
174
+ data_collator=default_data_collator,
175
+ )
176
+
177
+ # Train the model
178
+ trainer.train()
179
+ #trainer.train(resume_from_checkpoint = True)
180
+ model.save_pretrained(args.output_path)
181
+ processor.save_pretrained(args.output_path + "/processor")
182
+