MikkoLipsanen
commited on
Commit
•
375fd17
1
Parent(s):
0976156
Upload 5 files
Browse files- augment.py +89 -0
- requirements.txt +10 -0
- test.py +192 -0
- train.py +332 -0
- utils.py +107 -0
augment.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision import transforms
|
2 |
+
import random
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
class RandAug:
|
6 |
+
"""Randomly chosen image augmentations."""
|
7 |
+
|
8 |
+
def __init__(self, img_size, choice=None):
|
9 |
+
# Augmentation options
|
10 |
+
self.trans = ['identity', 'rotate', 'color', 'sharpness', 'blur', 'padding' ,'perspective']
|
11 |
+
self.img_size = img_size
|
12 |
+
self.choice = choice
|
13 |
+
|
14 |
+
def __call__(self, img):
|
15 |
+
if self.choice == None:
|
16 |
+
# Weights set 40% probability for the 'identity' augmentation choice
|
17 |
+
self.choice = random.choices(self.trans, weights=(40, 10, 10, 10, 10, 10, 10))[0]
|
18 |
+
|
19 |
+
if self.choice == 'identity':
|
20 |
+
trans = transforms.Compose([
|
21 |
+
transforms.Resize((self.img_size,self.img_size)),
|
22 |
+
transforms.ToTensor()
|
23 |
+
])
|
24 |
+
img = trans(img)
|
25 |
+
|
26 |
+
elif self.choice == 'rotate':
|
27 |
+
degrees = random.uniform(0, 180)
|
28 |
+
rand_fill = random.choice([0,1])
|
29 |
+
trans = transforms.Compose([
|
30 |
+
transforms.Resize((self.img_size,self.img_size)),
|
31 |
+
transforms.ToTensor(),
|
32 |
+
transforms.RandomRotation(degrees, expand=True, fill=rand_fill),
|
33 |
+
transforms.Resize((self.img_size,self.img_size))
|
34 |
+
])
|
35 |
+
img = trans(img)
|
36 |
+
|
37 |
+
elif self.choice == 'color':
|
38 |
+
rand_brightness = random.uniform(0, 0.3)
|
39 |
+
rand_hue = random.uniform(0, 0.5)
|
40 |
+
rand_contrast = random.uniform(0, 0.5)
|
41 |
+
rand_saturation = random.uniform(0, 0.5)
|
42 |
+
trans = transforms.Compose([
|
43 |
+
transforms.Resize((self.img_size,self.img_size)),
|
44 |
+
transforms.ToTensor(),
|
45 |
+
transforms.ColorJitter(brightness=rand_brightness, contrast=rand_contrast, saturation=rand_saturation, hue=rand_hue)
|
46 |
+
])
|
47 |
+
img = trans(img)
|
48 |
+
|
49 |
+
elif self.choice=='sharpness':
|
50 |
+
sharpness = 1+(np.random.exponential()/2)
|
51 |
+
trans = transforms.Compose([
|
52 |
+
transforms.Resize((self.img_size,self.img_size)),
|
53 |
+
transforms.ToTensor(),
|
54 |
+
transforms.RandomAdjustSharpness(sharpness, p=1)
|
55 |
+
])
|
56 |
+
img = trans(img)
|
57 |
+
|
58 |
+
elif self.choice=='blur':
|
59 |
+
kernel = random.choice([1,3,5])
|
60 |
+
trans = transforms.Compose([
|
61 |
+
transforms.Resize((self.img_size,self.img_size)),
|
62 |
+
transforms.ToTensor(),
|
63 |
+
transforms.GaussianBlur(kernel, sigma=(0.1, 2.0))
|
64 |
+
])
|
65 |
+
img = trans(img)
|
66 |
+
|
67 |
+
elif self.choice=='padding':
|
68 |
+
pad = random.choice([3,10,25])
|
69 |
+
rand_fill = random.choice([0,1])
|
70 |
+
trans = transforms.Compose([
|
71 |
+
transforms.Resize((self.img_size,self.img_size)),
|
72 |
+
transforms.ToTensor(),
|
73 |
+
transforms.Pad(pad, fill=rand_fill, padding_mode='constant'),
|
74 |
+
transforms.Resize((self.img_size,self.img_size))
|
75 |
+
])
|
76 |
+
img = trans(img)
|
77 |
+
|
78 |
+
elif self.choice=='perspective':
|
79 |
+
scale = random.uniform(0.1, 0.5)
|
80 |
+
rand_fill = random.choice([0,1])
|
81 |
+
trans = transforms.Compose([
|
82 |
+
transforms.Resize((self.img_size,self.img_size)),
|
83 |
+
transforms.ToTensor(),
|
84 |
+
transforms.RandomPerspective(distortion_scale=scale, p=1.0, fill=rand_fill),
|
85 |
+
transforms.Resize((self.img_size,self.img_size))
|
86 |
+
])
|
87 |
+
img = trans(img)
|
88 |
+
|
89 |
+
return img
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cu116
|
2 |
+
torch==1.12.1+cu116
|
3 |
+
torchvision==0.13.1+cu116
|
4 |
+
scikit-learn==1.0.2
|
5 |
+
numpy==1.21.6
|
6 |
+
pillow==9.3.0
|
7 |
+
matplotlib==3.5.3
|
8 |
+
onnx==1.13.0
|
9 |
+
onnxruntime==1.13.1
|
10 |
+
tqdm==4.64.1
|
test.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
from __future__ import division
|
3 |
+
import torch
|
4 |
+
import onnxruntime
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
import torchvision
|
8 |
+
from torchvision import transforms
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, confusion_matrix, ConfusionMatrixDisplay
|
11 |
+
import seaborn as sn
|
12 |
+
import random
|
13 |
+
import time
|
14 |
+
import json
|
15 |
+
from PIL import Image
|
16 |
+
from PIL import ImageFile
|
17 |
+
from pathlib import Path
|
18 |
+
import argparse
|
19 |
+
print("PyTorch Version: ",torch.__version__)
|
20 |
+
print("Torchvision Version: ",torchvision.__version__)
|
21 |
+
|
22 |
+
parser = argparse.ArgumentParser('arguments for testing the model')
|
23 |
+
|
24 |
+
parser.add_argument('--ts_empty_folder', type=str, default="/data/taulukot/solukuvat/empty/test/",
|
25 |
+
help='path to test data')
|
26 |
+
parser.add_argument('--ts_ok_folder', type=str, default="/data/taulukot/solukuvat/ok/test/",
|
27 |
+
help='path to test data')
|
28 |
+
parser.add_argument('--results_folder', type=str, default="./results/aug_28022024/",
|
29 |
+
help='Folder for saving results')
|
30 |
+
parser.add_argument('--model_path', type=str, default="/koodit/table_segmentation/empty_cell_detection/train/models/aug_b32_lr0001_28022024.onnx",
|
31 |
+
help='path to load model file from')
|
32 |
+
parser.add_argument('--batch_size', type=int, default=16,
|
33 |
+
help='batch_size')
|
34 |
+
parser.add_argument('--num_classes', type=int, default=2,
|
35 |
+
help='number of classes for classification')
|
36 |
+
parser.add_argument('--name', type=str, default='empty_cell_augment_28022024',
|
37 |
+
help='name given to result files')
|
38 |
+
|
39 |
+
start = time.time()
|
40 |
+
|
41 |
+
# nohup python test.py > logs/aug_test_28022024.txt 2>&1 &
|
42 |
+
# echo $! > output/save_pid.txt
|
43 |
+
|
44 |
+
torch.manual_seed(67)
|
45 |
+
random.seed(67)
|
46 |
+
|
47 |
+
args = parser.parse_args()
|
48 |
+
|
49 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
50 |
+
Image.MAX_IMAGE_PIXELS = None
|
51 |
+
|
52 |
+
# https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html
|
53 |
+
|
54 |
+
|
55 |
+
def get_data():
|
56 |
+
empty_path = Path(args.ts_empty_folder)
|
57 |
+
ok_path = Path(args.ts_ok_folder)
|
58 |
+
|
59 |
+
empty_files = list(empty_path.glob('*.jpg'))
|
60 |
+
ok_files = list(ok_path.glob('*.jpg'))
|
61 |
+
|
62 |
+
empty_labels = np.zeros(len(empty_files))
|
63 |
+
ok_labels = np.ones(len(ok_files))
|
64 |
+
|
65 |
+
#ts_data_files = ts_data_files[:20]
|
66 |
+
#ts_data_labels = ts_data_labels[:20]
|
67 |
+
#ts_ok_files = ts_ok_files[:20]
|
68 |
+
#ts_ok_labels = ts_ok_labels[:20]
|
69 |
+
|
70 |
+
ts_files = empty_files + ok_files
|
71 |
+
ts_labels = np.concatenate((empty_labels, ok_labels))
|
72 |
+
|
73 |
+
print('Test data with empty cells: ', len(empty_files))
|
74 |
+
print('Test data without empty cells: ', len(ok_files))
|
75 |
+
|
76 |
+
return ts_files, ts_labels
|
77 |
+
|
78 |
+
|
79 |
+
def initialize_model():
|
80 |
+
model = onnxruntime.InferenceSession(args.model_path)
|
81 |
+
input_size = 224
|
82 |
+
return model, input_size
|
83 |
+
|
84 |
+
# Function for getting precision, recall and F-score metrics
|
85 |
+
def get_precision_recall(y_true, y_pred):
|
86 |
+
precision_recall_fscore = precision_recall_fscore_support(y_true, y_pred, average=None)
|
87 |
+
|
88 |
+
prec_0 = precision_recall_fscore[0][0]
|
89 |
+
rec_0 = precision_recall_fscore[1][0]
|
90 |
+
F_0 = precision_recall_fscore[2][0]
|
91 |
+
|
92 |
+
prec_1 = precision_recall_fscore[0][1]
|
93 |
+
rec_1 = precision_recall_fscore[1][1]
|
94 |
+
F_1 = precision_recall_fscore[2][1]
|
95 |
+
|
96 |
+
print('\nPrecision for ok: %.2f'%prec_1)
|
97 |
+
print('Recall for ok: %.2f'%rec_1)
|
98 |
+
print('F-score for ok: %.2f'%F_1)
|
99 |
+
|
100 |
+
print('Precision for empty: %.2f'%prec_0 )
|
101 |
+
print('Recall for empty: %.2f'%rec_0)
|
102 |
+
print('F-score for empty: %.2f'%F_0)
|
103 |
+
|
104 |
+
|
105 |
+
def createConfusionMatrix(y_true, y_pred):
|
106 |
+
classes = np.array(['empty', 'ok'])
|
107 |
+
|
108 |
+
# Build confusion matrix
|
109 |
+
cf_matrix = confusion_matrix(y_true, y_pred)
|
110 |
+
print(cf_matrix)
|
111 |
+
df_cm = pd.DataFrame(cf_matrix, index=classes,
|
112 |
+
columns=classes)
|
113 |
+
plt.figure(figsize=(12, 7))
|
114 |
+
return sn.heatmap(df_cm, annot=True).get_figure()
|
115 |
+
|
116 |
+
def save_preds(y_true, y_pred, paths):
|
117 |
+
# Identifies images that were not classified correctly
|
118 |
+
incorrect_indices = np.where(y_true != y_pred)
|
119 |
+
incorrectly_predicted_images = paths[incorrect_indices]
|
120 |
+
correct_labels = y_true[incorrect_indices].astype(str)
|
121 |
+
incorrect_preds = dict(zip(incorrectly_predicted_images, correct_labels))
|
122 |
+
|
123 |
+
print(f'{len(incorrect_preds)} incorrect predictions')
|
124 |
+
|
125 |
+
# Save file names and labels of incorrectly classified images
|
126 |
+
with open(args.results_folder + args.name + '_incorrect_preds', "w") as fp:
|
127 |
+
json.dump(incorrect_preds, fp)
|
128 |
+
|
129 |
+
# Initialize the model for this run
|
130 |
+
model, input_size = initialize_model()
|
131 |
+
|
132 |
+
# Print the model we just instantiated
|
133 |
+
#print(model_ft)
|
134 |
+
|
135 |
+
data_transforms = transforms.Compose([
|
136 |
+
transforms.Resize((input_size, input_size)),
|
137 |
+
transforms.ToTensor()
|
138 |
+
])
|
139 |
+
|
140 |
+
print("Initializing Datasets and Dataloaders...")
|
141 |
+
|
142 |
+
ts_files, ts_labels = get_data()
|
143 |
+
|
144 |
+
# Function for getting model predictions on test data
|
145 |
+
def test_model(model, ts_files, ts_labels):
|
146 |
+
since = time.time()
|
147 |
+
label_preds = []
|
148 |
+
true_labels = []
|
149 |
+
paths = []
|
150 |
+
n = len(ts_files)
|
151 |
+
# Iterate over data
|
152 |
+
for i in range(n):
|
153 |
+
print(f'{i}/{n}')
|
154 |
+
image = Image.open(ts_files[i])
|
155 |
+
label = ts_labels[i]
|
156 |
+
image = data_transforms(image.convert("RGB")).unsqueeze(0)
|
157 |
+
# Transform tensor to numpy array
|
158 |
+
img = image.detach().cpu().numpy()
|
159 |
+
input = {model.get_inputs()[0].name: img}
|
160 |
+
# Run model prediction
|
161 |
+
output = model.run(None, input)
|
162 |
+
# Get predicted class
|
163 |
+
pred = np.argmax(output[0], 1)
|
164 |
+
pred_class = pred.item()
|
165 |
+
label_preds.append(pred_class)
|
166 |
+
true_labels.append(label)
|
167 |
+
paths.append(str(ts_files[i]))
|
168 |
+
|
169 |
+
time_elapsed = time.time() - since
|
170 |
+
print('Testing complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
|
171 |
+
|
172 |
+
return np.array(label_preds), np.array(true_labels), np.array(paths)
|
173 |
+
|
174 |
+
ts_labels = np.array(ts_labels)
|
175 |
+
|
176 |
+
# Test model
|
177 |
+
y_pred, y_true, paths = test_model(model, ts_files, ts_labels)
|
178 |
+
# Saves information of incorrect predictions
|
179 |
+
save_preds(y_true, y_pred, paths)
|
180 |
+
# Calculates and prints precision, recall and F-score metrics
|
181 |
+
get_precision_recall(y_true, y_pred)
|
182 |
+
|
183 |
+
# Save confusion matrix to Tensorboard
|
184 |
+
#cm = createConfusionMatrix(y_true, y_pred)
|
185 |
+
#writer.add_figure("Confusion matrix", cm)
|
186 |
+
# Create and save confusion matrix of the predictions and true labels
|
187 |
+
conf_matrix = ConfusionMatrixDisplay.from_predictions(y_true, y_pred, normalize='true', display_labels=np.array(['empty', 'ok']))
|
188 |
+
plt.savefig(args.results_folder + args.name + '_conf_matrix.jpg', bbox_inches='tight')
|
189 |
+
|
190 |
+
end = time.time()
|
191 |
+
time_in_mins = (end - start) / 60
|
192 |
+
print('Time: %.2f minutes' % time_in_mins)
|
train.py
ADDED
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
from __future__ import division
|
3 |
+
import torch
|
4 |
+
import torchvision
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch.utils.data import Dataset, DataLoader
|
7 |
+
from torchvision import models
|
8 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
9 |
+
from sklearn.utils import class_weight
|
10 |
+
from sklearn.metrics import precision_recall_fscore_support
|
11 |
+
import numpy as np
|
12 |
+
import time
|
13 |
+
import argparse
|
14 |
+
from tqdm import tqdm
|
15 |
+
from PIL import Image, ImageFile
|
16 |
+
from pathlib import Path
|
17 |
+
|
18 |
+
from augment import RandAug
|
19 |
+
import utils
|
20 |
+
|
21 |
+
print("PyTorch Version: ",torch.__version__)
|
22 |
+
print("Torchvision Version: ",torchvision.__version__)
|
23 |
+
|
24 |
+
# Much of the code is a modified version of the code available at
|
25 |
+
# https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html
|
26 |
+
|
27 |
+
|
28 |
+
# nohup python train.py > logs/empty_cell_aug_28032024.txt 2>&1 &
|
29 |
+
# echo $! > logs/save_pid.txt
|
30 |
+
|
31 |
+
parser = argparse.ArgumentParser('arguments for training')
|
32 |
+
|
33 |
+
parser.add_argument('--tr_empty_folder', type=str, default="/data/taulukot/solukuvat/empty/train/",
|
34 |
+
help='path to training data with empty images')
|
35 |
+
parser.add_argument('--val_empty_folder', type=str, default="/data/taulukot/solukuvat/empty/val/",
|
36 |
+
help='path to validation data with empty images')
|
37 |
+
parser.add_argument('--tr_ok_folder', type=str, default="/data/taulukot/solukuvat/ok/train/",
|
38 |
+
help='path to training data with ok images')
|
39 |
+
parser.add_argument('--val_ok_folder', type=str, default="/data/taulukot/solukuvat/ok/val/",
|
40 |
+
help='path to validation data with ok images')
|
41 |
+
parser.add_argument('--results_folder', type=str, default="results/28032024_aug/",
|
42 |
+
help='Folder for saving training results.')
|
43 |
+
parser.add_argument('--save_model_path', type=str, default="./models/",
|
44 |
+
help='Path for saving model file.')
|
45 |
+
parser.add_argument('--batch_size', type=int, default=32,
|
46 |
+
help='Batch size used for model training. ')
|
47 |
+
parser.add_argument('--lr', type=float, default=0.0001,
|
48 |
+
help='Base learning rate.')
|
49 |
+
parser.add_argument('--device', type=str, default='cpu',
|
50 |
+
help='Defines whether the model is trained using cpu or gpu.')
|
51 |
+
parser.add_argument('--num_classes', type=int, default=2,
|
52 |
+
help='Number of classes used in classification.')
|
53 |
+
parser.add_argument('--num_epochs', type=int, default=15,
|
54 |
+
help='Number of training epochs.')
|
55 |
+
parser.add_argument('--random_seed', type=int, default=8765,
|
56 |
+
help='Number used for initializing random number generation.')
|
57 |
+
parser.add_argument('--early_stop_threshold', type=int, default=3,
|
58 |
+
help='Threshold value of epochs after which training stops if validation accuracy does not improve.')
|
59 |
+
parser.add_argument('--save_model_format', type=str, default='torch',
|
60 |
+
help='Defines the format for saving the model.')
|
61 |
+
parser.add_argument('--augment_choice', type=str, default=None,
|
62 |
+
help='Defines which image augmentation(s) are used. Defaults to randomly selected augmentations.')
|
63 |
+
parser.add_argument('--model_name', type=str, default='aug_b32_lr0001',
|
64 |
+
help='Current date.')
|
65 |
+
parser.add_argument('--date', type=str, default=time.strftime("%d%m%Y"),
|
66 |
+
help='Current date.')
|
67 |
+
|
68 |
+
args = parser.parse_args()
|
69 |
+
|
70 |
+
# PIL settings to avoid errors caused by truncated and large images
|
71 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
72 |
+
Image.MAX_IMAGE_PIXELS = None
|
73 |
+
|
74 |
+
# List for saving the names of damaged images
|
75 |
+
damaged_images = []
|
76 |
+
|
77 |
+
def get_datapaths():
|
78 |
+
"""Function for loading train and validation data."""
|
79 |
+
tr_empty_files = list(Path(args.tr_empty_folder).glob('*'))
|
80 |
+
tr_ok_files = list(Path(args.tr_ok_folder).glob('*'))
|
81 |
+
val_empty_files = list(Path(args.val_empty_folder).glob('*'))
|
82 |
+
val_ok_files = list(Path(args.val_ok_folder).glob('*'))
|
83 |
+
# Create labels for train and validation data
|
84 |
+
tr_labels = np.concatenate((np.zeros(len(tr_empty_files)), np.ones(len(tr_ok_files))))
|
85 |
+
val_labels = np.concatenate((np.zeros(len(val_empty_files)), np.ones(len(val_ok_files))))
|
86 |
+
# Combine faulty and non-faulty images
|
87 |
+
tr_files = tr_empty_files + tr_ok_files
|
88 |
+
val_files = val_empty_files + val_ok_files
|
89 |
+
|
90 |
+
print('\nTraining data with empty cells: ', len(tr_empty_files))
|
91 |
+
print('Training data without empty cells: ', len(tr_ok_files))
|
92 |
+
|
93 |
+
print('Validation data with empty cells: ', len(val_empty_files))
|
94 |
+
print('Validation data without empty cells: ', len(val_ok_files))
|
95 |
+
|
96 |
+
data_dict = {'tr_data': tr_files, 'tr_labels': tr_labels,
|
97 |
+
'val_data': val_files, 'val_labels': val_labels}
|
98 |
+
|
99 |
+
return data_dict
|
100 |
+
|
101 |
+
class ImageDataset(Dataset):
|
102 |
+
"""PyTorch Dataset class is used for generating training and validation datasets."""
|
103 |
+
def __init__(self, img_paths, img_labels, transform=None, target_transform=None):
|
104 |
+
self.img_paths = img_paths
|
105 |
+
self.img_labels = img_labels
|
106 |
+
self.transform = transform
|
107 |
+
self.target_transform = target_transform
|
108 |
+
|
109 |
+
def __len__(self):
|
110 |
+
return len(self.img_labels)
|
111 |
+
|
112 |
+
def __getitem__(self, idx):
|
113 |
+
img_path = self.img_paths[idx]
|
114 |
+
try:
|
115 |
+
image = Image.open(img_path).convert('RGB')
|
116 |
+
label = self.img_labels[idx]
|
117 |
+
except:
|
118 |
+
# Image is considered damaged if reading the image fails
|
119 |
+
damaged_images.append(img_path)
|
120 |
+
return None
|
121 |
+
if self.transform:
|
122 |
+
image = self.transform(image.convert("RGB"))
|
123 |
+
if self.target_transform:
|
124 |
+
label = self.target_transform(label)
|
125 |
+
|
126 |
+
return image, label
|
127 |
+
|
128 |
+
def initialize_model():
|
129 |
+
"""Function for initializing pretrained neural network model (DenseNet121)."""
|
130 |
+
model_ft = models.densenet121(weights=torchvision.models.DenseNet121_Weights.IMAGENET1K_V1)
|
131 |
+
num_ftrs = model_ft.classifier.in_features
|
132 |
+
model_ft.classifier = nn.Linear(num_ftrs, args.num_classes)
|
133 |
+
input_size = 224
|
134 |
+
|
135 |
+
return model_ft, input_size
|
136 |
+
|
137 |
+
def collate_fn(batch):
|
138 |
+
"""Helper function for creating data batches."""
|
139 |
+
batch = list(filter(lambda x: x is not None, batch))
|
140 |
+
|
141 |
+
return torch.utils.data.dataloader.default_collate(batch)
|
142 |
+
|
143 |
+
def initialize_dataloaders(data_dict, input_size):
|
144 |
+
"""Function for initializing datasets and dataloaders."""
|
145 |
+
# Train and validation datasets
|
146 |
+
train_dataset = ImageDataset(img_paths=data_dict['tr_data'], img_labels=data_dict['tr_labels'], transform=RandAug(input_size, args.augment_choice))
|
147 |
+
validation_dataset = ImageDataset(img_paths=data_dict['val_data'], img_labels=data_dict['val_labels'], transform=RandAug(input_size, 'identity'))
|
148 |
+
# Train and validation dataloaders
|
149 |
+
train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=args.batch_size, shuffle=True, num_workers=4)
|
150 |
+
validation_dataloader = DataLoader(validation_dataset, collate_fn=collate_fn, batch_size=args.batch_size, shuffle=True, num_workers=4)
|
151 |
+
|
152 |
+
return {'train': train_dataloader, 'val': validation_dataloader}
|
153 |
+
|
154 |
+
def get_criterion(data_dict):
|
155 |
+
"""Function for generating class weights and initializing the loss function."""
|
156 |
+
y = np.asarray(data_dict['tr_labels'])
|
157 |
+
# Class weights are used for compensating the unbalance
|
158 |
+
# in the number of training data from the two classes
|
159 |
+
class_weights=class_weight.compute_class_weight(class_weight='balanced', classes=np.unique(y), y=y)
|
160 |
+
class_weights=torch.tensor(class_weights, dtype=torch.float).to(args.device)
|
161 |
+
print('\nClass weights: ', class_weights.tolist())
|
162 |
+
# Cross Entropy Loss function
|
163 |
+
criterion = nn.CrossEntropyLoss(weight=class_weights, reduction='mean')
|
164 |
+
|
165 |
+
return criterion
|
166 |
+
|
167 |
+
def get_optimizer(model):
|
168 |
+
"""Function for initializing the optimizer."""
|
169 |
+
# Model parameters are split into two groups: parameters of the classifier
|
170 |
+
# layer and other model parameters
|
171 |
+
params_1 = [param for name, param in model.named_parameters()
|
172 |
+
if name not in ["classifier.weight", "classifier.bias"]]
|
173 |
+
params_2 = model.classifier.parameters()
|
174 |
+
# 10 x larger learning rate is used when training the parameters
|
175 |
+
# of the classification layers
|
176 |
+
params_to_update = [
|
177 |
+
{'params': params_1, 'lr': args.lr},
|
178 |
+
{'params': params_2, 'lr': args.lr * 10}
|
179 |
+
]
|
180 |
+
# Adam optimizer
|
181 |
+
optimizer = torch.optim.Adam(params_to_update, args.lr)
|
182 |
+
# Scheduler reduces learning rate when validation accuracy does not improve for an epoch
|
183 |
+
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=0, verbose=True)
|
184 |
+
|
185 |
+
return optimizer, scheduler
|
186 |
+
|
187 |
+
def train_model(model, dataloaders, criterion, optimizer, scheduler=None):
|
188 |
+
"""Function for model training and validation."""
|
189 |
+
since = time.time()
|
190 |
+
# Lists for saving train and validation metrics for each epoch
|
191 |
+
tr_loss_history = []
|
192 |
+
tr_acc_history = []
|
193 |
+
tr_f1_history = []
|
194 |
+
val_loss_history = []
|
195 |
+
val_acc_history = []
|
196 |
+
val_f1_history = []
|
197 |
+
# Lists for saving learning rates for the 2 parameter groups
|
198 |
+
lr1_history = []
|
199 |
+
lr2_history = []
|
200 |
+
|
201 |
+
# Best F1 value and best epoch are saved in variables
|
202 |
+
best_f1 = 0
|
203 |
+
best_epoch = 0
|
204 |
+
early_stop = False
|
205 |
+
|
206 |
+
# Train / validation loop
|
207 |
+
for epoch in tqdm(range(args.num_epochs)):
|
208 |
+
# Save learning rates for the epoch
|
209 |
+
lr1_history.append(optimizer.param_groups[0]["lr"])
|
210 |
+
lr2_history.append(optimizer.param_groups[1]["lr"])
|
211 |
+
|
212 |
+
print('Epoch {}/{}'.format(epoch+1, args.num_epochs))
|
213 |
+
print('-' * 10)
|
214 |
+
|
215 |
+
# Each epoch has a training and validation phase
|
216 |
+
for phase in ['train', 'val']:
|
217 |
+
if phase == 'train':
|
218 |
+
model.train() # Set model to training mode
|
219 |
+
else:
|
220 |
+
model.eval() # Set model to evaluate mode
|
221 |
+
|
222 |
+
running_loss = 0.0
|
223 |
+
running_corrects = 0
|
224 |
+
running_f1 = 0.0
|
225 |
+
|
226 |
+
# Iterate over data in batch
|
227 |
+
for inputs, labels in dataloaders[phase]:
|
228 |
+
if dataloaders[phase] is None:
|
229 |
+
continue
|
230 |
+
else:
|
231 |
+
inputs = inputs.to(args.device)
|
232 |
+
labels = labels.long().to(args.device)
|
233 |
+
|
234 |
+
# Zero the parameter gradients
|
235 |
+
optimizer.zero_grad()
|
236 |
+
|
237 |
+
# Track history only in training phase
|
238 |
+
with torch.set_grad_enabled(phase == 'train'):
|
239 |
+
# Get model outputs and calculate loss
|
240 |
+
outputs = model(inputs)
|
241 |
+
loss = criterion(outputs, labels)
|
242 |
+
# Model predictions of the image labels for the batch
|
243 |
+
_, preds = torch.max(outputs, 1)
|
244 |
+
|
245 |
+
# Backward + optimize only if in training phase
|
246 |
+
if phase == 'train':
|
247 |
+
loss.backward()
|
248 |
+
optimizer.step()
|
249 |
+
|
250 |
+
# Get weighted F1 score for the results
|
251 |
+
precision_recall_fscore = precision_recall_fscore_support(labels.data.detach().cpu().numpy(), preds.detach().cpu().numpy(), average='weighted', zero_division=0)
|
252 |
+
f1_score = precision_recall_fscore[2]
|
253 |
+
|
254 |
+
# update statistics
|
255 |
+
running_loss += loss.item() * inputs.size(0)
|
256 |
+
running_corrects += torch.sum(preds == labels.data).cpu()
|
257 |
+
running_f1 += f1_score
|
258 |
+
|
259 |
+
# Calculate loss, accuracy and F1 score for the epoch
|
260 |
+
epoch_loss = running_loss / len(dataloaders[phase].dataset)
|
261 |
+
epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
|
262 |
+
epoch_f1 = running_f1 / len(dataloaders[phase])
|
263 |
+
|
264 |
+
print('\nEpoch {} - {} - Loss: {:.4f} Acc: {:.4f} F1: {:.4f}\n'.format(epoch+1, phase, epoch_loss, epoch_acc, epoch_f1))
|
265 |
+
|
266 |
+
# Validation step
|
267 |
+
if phase == 'val':
|
268 |
+
val_acc_history.append(epoch_acc)
|
269 |
+
val_loss_history.append(epoch_loss)
|
270 |
+
val_f1_history.append(epoch_f1)
|
271 |
+
if epoch_f1 > best_f1:
|
272 |
+
print('\nF1 score {:.4f} improved from {:.4f}. Saving the model.\n'.format(epoch_f1, best_f1))
|
273 |
+
# Model with best F1 score is saved
|
274 |
+
utils.save_model(model, 224, args.save_model_format, args.save_model_path, args.model_name, args.date)
|
275 |
+
model = model.to(args.device)
|
276 |
+
best_f1 = epoch_f1
|
277 |
+
best_epoch = epoch
|
278 |
+
elif epoch - best_epoch > args.early_stop_threshold:
|
279 |
+
# terminates the training loop if validation accuracy has not improved
|
280 |
+
print("Early stopped training at epoch %d" % epoch)
|
281 |
+
# Set early stopping condition
|
282 |
+
early_stop = True
|
283 |
+
break
|
284 |
+
elif phase == 'train':
|
285 |
+
tr_acc_history.append(epoch_acc)
|
286 |
+
tr_loss_history.append(epoch_loss)
|
287 |
+
tr_f1_history.append(epoch_f1)
|
288 |
+
|
289 |
+
# Break outer loop if early stopping condition is activated
|
290 |
+
if early_stop:
|
291 |
+
break
|
292 |
+
# Take scheduler step
|
293 |
+
if scheduler:
|
294 |
+
scheduler.step(val_f1_history[-1])
|
295 |
+
|
296 |
+
time_elapsed = time.time() - since
|
297 |
+
print('\nTraining complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
|
298 |
+
print('Best validation F1 score: {:.4f}'.format(best_f1))
|
299 |
+
# Returns model with the weights from the best epoch (based on validation accuracy)
|
300 |
+
hist_dict = {'tr_acc': tr_acc_history,
|
301 |
+
'val_acc': val_acc_history,
|
302 |
+
'val_loss': val_loss_history,
|
303 |
+
'val_f1': val_f1_history,
|
304 |
+
'tr_loss': tr_loss_history,
|
305 |
+
'tr_f1': tr_f1_history,
|
306 |
+
'lr1': lr1_history,
|
307 |
+
'lr2': lr2_history}
|
308 |
+
|
309 |
+
return hist_dict
|
310 |
+
|
311 |
+
def main():
|
312 |
+
# Set random seed(s)
|
313 |
+
utils.set_seed(args.random_seed)
|
314 |
+
# Load image paths and labels
|
315 |
+
data_dict = get_datapaths()
|
316 |
+
# Initialize the model
|
317 |
+
model, input_size = initialize_model()
|
318 |
+
# Print the model architecture
|
319 |
+
#print(model_ft)
|
320 |
+
# Send the model to GPU (if available)
|
321 |
+
model = model.to(args.device)
|
322 |
+
print("\nInitializing Datasets and Dataloaders...")
|
323 |
+
dataloaders_dict = initialize_dataloaders(data_dict, input_size)
|
324 |
+
criterion = get_criterion(data_dict)
|
325 |
+
optimizer, scheduler = get_optimizer(model)
|
326 |
+
# Train and evaluate model
|
327 |
+
hist_dict = train_model(model, dataloaders_dict, criterion, optimizer, scheduler)
|
328 |
+
print('Damaged images: ', damaged_images)
|
329 |
+
utils.plot_metrics(hist_dict, args.results_folder, args.date)
|
330 |
+
|
331 |
+
if __name__ == '__main__':
|
332 |
+
main()
|
utils.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import onnx
|
3 |
+
import onnxruntime
|
4 |
+
import os
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import numpy as np
|
7 |
+
import random
|
8 |
+
|
9 |
+
def set_seed(random_seed):
|
10 |
+
"""Function for setting random seed for the relevant libraries."""
|
11 |
+
np.random.seed(random_seed)
|
12 |
+
random.seed(random_seed)
|
13 |
+
torch.manual_seed(random_seed)
|
14 |
+
torch.cuda.manual_seed(random_seed)
|
15 |
+
# When running on the CuDNN backend, two further options must be set
|
16 |
+
torch.backends.cudnn.deterministic = True
|
17 |
+
torch.backends.cudnn.benchmark = False
|
18 |
+
# Set a fixed value for the hash seed
|
19 |
+
os.environ["PYTHONHASHSEED"] = str(random_seed)
|
20 |
+
print(f"Random seed set as {random_seed}")
|
21 |
+
|
22 |
+
def save_model(model, input_size, save_model_format, save_model_path, model_name, date):
|
23 |
+
"""Function for saving the model in .pth or .onnx format.
|
24 |
+
Code modified from
|
25 |
+
https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html"""
|
26 |
+
if save_model_format == 'onnx':
|
27 |
+
onnx_model_path = os.path.join(save_model_path, model_name + '_' + date + '.onnx')
|
28 |
+
# Random batch size
|
29 |
+
batch_size = 1
|
30 |
+
# Random input to the model (with correct dimensions)
|
31 |
+
x = torch.randn(batch_size, 3, input_size, input_size, requires_grad=True)
|
32 |
+
model = model.to('cpu')
|
33 |
+
torch_out = model(x)
|
34 |
+
|
35 |
+
# Export the model
|
36 |
+
torch.onnx.export(model, # model being run
|
37 |
+
x, # model input (or a tuple for multiple inputs)
|
38 |
+
onnx_model_path, # where to save the model (can be a file or file-like object)
|
39 |
+
export_params=True, # store the trained parameter weights inside the model file
|
40 |
+
opset_version=10, # the ONNX version to export the model to
|
41 |
+
do_constant_folding=True, # whether to execute constant folding for optimization
|
42 |
+
input_names = ['input'], # the model's input names
|
43 |
+
output_names = ['output'], # the model's output names
|
44 |
+
dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes
|
45 |
+
'output' : {0 : 'batch_size'}})
|
46 |
+
|
47 |
+
print('ONNX model saved to ', onnx_model_path)
|
48 |
+
# Test transformed model
|
49 |
+
onnx_model = onnx.load(onnx_model_path)
|
50 |
+
onnx.checker.check_model(onnx_model)
|
51 |
+
print('ONNX model checked.')
|
52 |
+
|
53 |
+
def to_numpy(tensor):
|
54 |
+
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
|
55 |
+
|
56 |
+
onnx_session = onnxruntime.InferenceSession(onnx_model_path)
|
57 |
+
# compute ONNX Runtime output prediction
|
58 |
+
onnx_inputs = {onnx_session.get_inputs()[0].name: to_numpy(x)}
|
59 |
+
onnx_out = onnx_session.run(None, onnx_inputs)
|
60 |
+
# compare ONNX Runtime and PyTorch results
|
61 |
+
np.testing.assert_allclose(to_numpy(torch_out), onnx_out[0], rtol=1e-03, atol=1e-05)
|
62 |
+
print("Exported model has been tested with ONNXRuntime, and the result looks good!\n")
|
63 |
+
|
64 |
+
else:
|
65 |
+
pytorch_model_path = os.path.join(save_model_path, 'densenet_' + date + '.pth')
|
66 |
+
torch.save(model, pytorch_model_path)
|
67 |
+
print('Pytorch model saved to ', pytorch_model_path)
|
68 |
+
|
69 |
+
|
70 |
+
def plot_metrics(hist_dict, results_folder, date):
|
71 |
+
"""Function for plotting the training and validation results."""
|
72 |
+
epochs = range(1, len(hist_dict['tr_loss'])+1)
|
73 |
+
plt.plot(epochs, hist_dict['tr_loss'], 'g', label='Training loss')
|
74 |
+
plt.plot(epochs, hist_dict['val_loss'], 'b', label='Validation loss')
|
75 |
+
plt.title('Training and Validation loss')
|
76 |
+
plt.xlabel('Epochs')
|
77 |
+
plt.ylabel('Loss')
|
78 |
+
plt.legend()
|
79 |
+
plt.savefig(results_folder + date + '_tr_val_loss.jpg', bbox_inches='tight')
|
80 |
+
plt.close()
|
81 |
+
|
82 |
+
plt.plot(epochs, hist_dict['tr_acc'], 'g', label='Training accuracy')
|
83 |
+
plt.plot(epochs, hist_dict['val_acc'], 'b', label='Validation accuracy')
|
84 |
+
plt.title('Training and Validation accuracy')
|
85 |
+
plt.xlabel('Epochs')
|
86 |
+
plt.ylabel('Accuracy')
|
87 |
+
plt.legend()
|
88 |
+
plt.savefig(results_folder + date + '_tr_val_acc.jpg', bbox_inches='tight')
|
89 |
+
plt.close()
|
90 |
+
|
91 |
+
plt.plot(epochs, hist_dict['tr_f1'], 'g', label='Training F1 score')
|
92 |
+
plt.plot(epochs, hist_dict['val_f1'], 'b', label='Validation F1 score')
|
93 |
+
plt.title('Training and Validation F1 score')
|
94 |
+
plt.xlabel('Epochs')
|
95 |
+
plt.ylabel('F1 score')
|
96 |
+
plt.legend()
|
97 |
+
plt.savefig(results_folder + date + '_tr_val_f1.jpg', bbox_inches='tight')
|
98 |
+
plt.close()
|
99 |
+
|
100 |
+
plt.plot(epochs, hist_dict['lr1'], 'g', label='Backbone learning rate')
|
101 |
+
plt.plot(epochs, hist_dict['lr2'], 'b', label='Classifier learning rate')
|
102 |
+
plt.title('Learning rate')
|
103 |
+
plt.xlabel('Epochs')
|
104 |
+
plt.ylabel('Learning rate')
|
105 |
+
plt.legend()
|
106 |
+
plt.savefig(results_folder + date + '_learning_rate.jpg', bbox_inches='tight')
|
107 |
+
plt.close()
|