srijaydeshpande commited on
Commit
4473449
1 Parent(s): 9c3c6ce

Upload 4 files

Browse files
Assistance/SingleImageCropper.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ from PIL import Image
4
+ import numpy as np
5
+ import time
6
+ import PIL
7
+ import matplotlib.pyplot as plt
8
+ import argparse
9
+
10
+ patchsize = 296
11
+ stride = 236
12
+ pad=0
13
+
14
+ PIL.Image.MAX_IMAGE_PIXELS = 933120000
15
+
16
+ def CropImage(image_path,output_dir,pad):
17
+ image_name = os.path.split(image_path)[1].split('.')[0]
18
+ im = Image.open(image_path)
19
+
20
+ size = im.size
21
+ # new_size = (size[0]+40,size[1]+40)
22
+ new_size = size
23
+
24
+ if(pad):
25
+ new_im = Image.new("RGB", new_size) ## luckily, this is already black!
26
+ new_im.paste(im, ((new_size[0] - size[0]) // 2,
27
+ (new_size[1] - size[1]) // 2))
28
+ #plt.imshow(new_im)
29
+ #plt.show()
30
+ else:
31
+ new_im = im
32
+
33
+ width, height = new_im.size
34
+
35
+ x = 0
36
+ y = 0
37
+ right = 0
38
+ bottom = 0
39
+
40
+ while (bottom < height):
41
+ while (right < width):
42
+ left = x
43
+ top = y
44
+ right = left + patchsize
45
+ bottom = top + patchsize
46
+ if (right > width):
47
+ offset = right - width
48
+ right -= offset
49
+ left -= offset
50
+ if (bottom > height):
51
+ offset = bottom - height
52
+ bottom -= offset
53
+ top -= offset
54
+ im_crop = new_im.crop((left, top, right, bottom))
55
+ im_crop_name = image_name + "_" + str(left) + "_" + str(top) + ".png"
56
+ output_path = os.path.join(output_dir, im_crop_name)
57
+ im_crop.save(output_path)
58
+ x += stride
59
+ x = 0
60
+ right = 0
61
+ y += stride
62
+
63
+ start_time = time.time()
64
+
65
+ parser = argparse.ArgumentParser()
66
+ parser.add_argument("--image_path", help="path to image to crop",
67
+ default=r"F:\Datasets\DigestPath\safron\Benign\test\3\ablation_study_neg_46\neg_46.png")
68
+ parser.add_argument("--output_dir", help="path to output folder",
69
+ default=r"F:\Datasets\DigestPath\safron\Benign\test\3\ablation_study_neg_46\cropped_safron_patchadv")
70
+ parser.add_argument("--pad", type=int, default=0, help="pad the image borders")
71
+
72
+ args = parser.parse_args()
73
+
74
+ if not os.path.exists(args.output_dir):
75
+ os.makedirs(args.output_dir)
76
+
77
+ CropImage(args.image_path,args.output_dir,args.pad)
78
+
79
+ print("--- %s seconds ---" % (time.time() - start_time))
Assistance/join_images.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import numpy as np
4
+ from PIL import Image
5
+ import PIL
6
+ import matplotlib.pyplot as plt
7
+ import matplotlib
8
+ import time
9
+ import argparse
10
+
11
+ PIL.Image.MAX_IMAGE_PIXELS = 933120000
12
+
13
+ def join_images(input_dir,output_file,height,width):
14
+
15
+ start_time = time.time()
16
+
17
+ paths = glob.glob(os.path.join(input_dir,"*.png"))
18
+
19
+ patch = 256
20
+
21
+ image = np.zeros((height,width,3))
22
+ count_masks = np.zeros((height,width,3))
23
+ k=0
24
+ for path in paths:
25
+ if('outputs' in path):
26
+ imname = os.path.split(path)[1].replace("-outputs","").split(".")[0]
27
+ imname = imname.split("_")
28
+ y,x = int(imname[-2]),int(imname[-1])
29
+ img = Image.open(path)
30
+ img = np.asarray(img)
31
+ #print("X => ",x," Y => ",y)
32
+ image[x:x+patch,y:y+patch,:] += img
33
+ count_masks[x:x+patch,y:y+patch,:]+=1.0
34
+ k+=1
35
+
36
+ count_masks = count_masks.clip(min=1)
37
+
38
+ image = image/count_masks
39
+
40
+ image = image/255.0
41
+
42
+ # im = Image.fromarray(image)
43
+ # im.save(output_file)
44
+
45
+ matplotlib.image.imsave(output_file, image)
46
+
47
+ print("--- %s seconds ---" % (time.time() - start_time))
48
+
49
+ print("Done")
50
+
51
+ parser = argparse.ArgumentParser()
52
+ parser.add_argument("--patches_dir", help="path to generated patches to join",
53
+ default="./tmp/results/images")
54
+ parser.add_argument("--output_file", help="path to output file",
55
+ default="F:/Datasets/DigestPath/safron/test/single/outputs/sample.jpeg")
56
+ parser.add_argument("--im_height", type=int, help="image height")
57
+ parser.add_argument("--im_width", type=int, help="image width")
58
+
59
+ args = parser.parse_args()
60
+
61
+ join_images(args.patches_dir,args.output_file,args.im_height,args.im_width)
segment2tissue_safron_media.py ADDED
@@ -0,0 +1,683 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import tensorflow.compat.v1 as tf
6
+ import numpy as np
7
+ import argparse
8
+ import os
9
+ import json
10
+ import glob
11
+ import random
12
+ import collections
13
+ import math
14
+ import time
15
+ # visualize image
16
+ import matplotlib.pyplot as plt
17
+
18
+ #disable v2 behavious
19
+ tf.disable_v2_behavior()
20
+
21
+ # enable eager execution
22
+ # tf.compat.v1.enable_eager_execution()
23
+
24
+ parser = argparse.ArgumentParser()
25
+ parser.add_argument("--input_dir", help="path to folder containing images")
26
+ parser.add_argument("--mode", required=True, choices=["train", "test", "export"])
27
+ parser.add_argument("--output_dir", required=True, help="where to put output files")
28
+ parser.add_argument("--seed", type=int)
29
+ parser.add_argument("--checkpoint", default=None,
30
+ help="directory with checkpoint to resume training from or use for testing")
31
+
32
+ parser.add_argument("--max_steps", type=int, help="number of training steps (0 to disable)")
33
+ parser.add_argument("--max_epochs", type=int, help="number of training epochs")
34
+ parser.add_argument("--summary_freq", type=int, default=10000, help="update summaries every summary_freq steps")
35
+ parser.add_argument("--progress_freq", type=int, default=50, help="display progress every progress_freq steps")
36
+ parser.add_argument("--trace_freq", type=int, default=0, help="trace execution every trace_freq steps")
37
+ parser.add_argument("--display_freq", type=int, default=1000, help="write current training images every display_freq steps")
38
+ parser.add_argument("--save_freq", type=int, default=1000, help="save model every save_freq steps, 0 to disable")
39
+ parser.add_argument("--separable_conv", action="store_true", help="use separable convolutions in the generator")
40
+ parser.add_argument("--aspect_ratio", type=float, default=1.0, help="aspect ratio of output images (width/height)")
41
+ parser.add_argument("--batch_size", type=int, default=1, help="number of images in batch")
42
+ parser.add_argument("--which_direction", type=str, default="BtoA", choices=["AtoB", "BtoA"])
43
+ parser.add_argument("--ngf", type=int, default=64, help="number of generator filters in first conv layer")
44
+ parser.add_argument("--ndf", type=int, default=64, help="number of discriminator filters in first conv layer")
45
+ parser.add_argument("--scale_size", type=int, default=728, help="scale images to this size before cropping to 256x256")
46
+ parser.add_argument("--flip", dest="flip", action="store_true", help="flip images horizontally")
47
+ parser.add_argument("--no_flip", dest="flip", action="store_false", help="don't flip images horizontally")
48
+ # parser.set_defaults(flip=True)
49
+ parser.add_argument("--lr", type=float, default=0.0001, help="initial learning rate for adam")
50
+ parser.add_argument("--beta1", type=float, default=0.5, help="momentum term of adam")
51
+ parser.add_argument("--l1_weight", type=float, default=100.0, help="weight on L1 term for generator gradient")
52
+ parser.add_argument("--gan_weight", type=float, default=1.0, help="weight on GAN term for generator gradient")
53
+
54
+ # export options
55
+ parser.add_argument("--output_filetype", default="png", choices=["png", "jpeg"])
56
+ a = parser.parse_args()
57
+
58
+ EPS = 1e-12
59
+
60
+ Examples = collections.namedtuple("Examples", "paths, inputs, targets, count, steps_per_epoch")
61
+ Model = collections.namedtuple("Model",
62
+ "outputs, predict_real, predict_fake, discrim_loss, discrim_grads_and_vars, gen_loss_GAN, gen_loss_L1, gen_grads_and_vars, train")
63
+
64
+ def preprocess(image):
65
+ with tf.name_scope("preprocess"):
66
+ # [0, 1] => [-1, 1]
67
+ return image * 2 - 1
68
+
69
+
70
+ def deprocess(image):
71
+ with tf.name_scope("deprocess"):
72
+ # [-1, 1] => [0, 1]
73
+ return (image + 1) / 2
74
+
75
+
76
+ def discrim_conv(batch_input, out_channels, stride):
77
+ padded_input = tf.pad(batch_input, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="CONSTANT")
78
+ return tf.layers.conv2d(padded_input, out_channels, kernel_size=4, strides=(stride, stride), padding="valid",
79
+ kernel_initializer=tf.random_normal_initializer(0, 0.02))
80
+
81
+
82
+ def gen_conv(batch_input, out_channels):
83
+ # [batch, in_height, in_width, in_channels] => [batch, out_height, out_width, out_channels]
84
+ initializer = tf.random_normal_initializer(0, 0.02)
85
+ if a.separable_conv:
86
+ return tf.layers.separable_conv2d(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same",
87
+ depthwise_initializer=initializer, pointwise_initializer=initializer)
88
+ else:
89
+ return tf.layers.conv2d(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same",
90
+ kernel_initializer=initializer)
91
+
92
+
93
+ def gen_deconv(batch_input, out_channels):
94
+ # [batch, in_height, in_width, in_channels] => [batch, out_height, out_width, out_channels]
95
+ initializer = tf.random_normal_initializer(0, 0.02)
96
+ if a.separable_conv:
97
+ _b, h, w, _c = batch_input.shape
98
+ resized_input = tf.image.resize_images(batch_input, [h * 2, w * 2],
99
+ method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
100
+ return tf.layers.separable_conv2d(resized_input, out_channels, kernel_size=4, strides=(1, 1), padding="same",
101
+ depthwise_initializer=initializer, pointwise_initializer=initializer)
102
+ else:
103
+ return tf.layers.conv2d_transpose(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same",
104
+ kernel_initializer=initializer)
105
+
106
+
107
+ def lrelu(x, a):
108
+ with tf.name_scope("lrelu"):
109
+ # adding these together creates the leak part and linear part
110
+ # then cancels them out by subtracting/adding an absolute value term
111
+ # leak: a*x/2 - a*abs(x)/2
112
+ # linear: x/2 + abs(x)/2
113
+
114
+ # this block looks like it has 2 inputs on the graph unless we do this
115
+ x = tf.identity(x)
116
+ return (0.5 * (1 + a)) * x + (0.5 * (1 - a)) * tf.abs(x)
117
+
118
+
119
+ def batchnorm(inputs):
120
+ return tf.layers.batch_normalization(inputs, axis=3, epsilon=1e-5, momentum=0.1, training=True,
121
+ gamma_initializer=tf.random_normal_initializer(1.0, 0.02))
122
+
123
+
124
+ def check_image(image):
125
+ assertion = tf.assert_equal(tf.shape(image)[-1], 3, message="image must have 3 color channels")
126
+ with tf.control_dependencies([assertion]):
127
+ image = tf.identity(image)
128
+
129
+ if image.get_shape().ndims not in (3, 4):
130
+ raise ValueError("image must be either 3 or 4 dimensions")
131
+
132
+ # make the last dimension 3 so that you can unstack the colors
133
+ shape = list(image.get_shape())
134
+ shape[-1] = 3
135
+ image.set_shape(shape)
136
+ return image
137
+
138
+
139
+ def load_examples():
140
+
141
+ if a.input_dir is None or not os.path.exists(a.input_dir):
142
+ raise Exception("input_dir does not exist")
143
+
144
+ input_paths = glob.glob(os.path.join(a.input_dir, "*.jpg"))
145
+ decode = tf.image.decode_jpeg
146
+ if len(input_paths) == 0:
147
+ input_paths = glob.glob(os.path.join(a.input_dir, "*.png"))
148
+ decode = tf.image.decode_png
149
+
150
+ if len(input_paths) == 0:
151
+ raise Exception("input_dir contains no image files")
152
+
153
+ def get_name(path):
154
+ name, _ = os.path.splitext(os.path.basename(path))
155
+ return name
156
+
157
+ # if the image names are numbers, sort by the value rather than asciibetically
158
+ # having sorted inputs means that the outputs are sorted in test mode
159
+ if all(get_name(path).isdigit() for path in input_paths):
160
+ input_paths = sorted(input_paths, key=lambda path: int(get_name(path)))
161
+ else:
162
+ input_paths = sorted(input_paths)
163
+
164
+ with tf.name_scope("load_images"):
165
+ path_queue = tf.train.string_input_producer(input_paths, shuffle=a.mode == "train")
166
+ reader = tf.WholeFileReader()
167
+ paths, contents = reader.read(path_queue)
168
+ raw_input = decode(contents)
169
+ raw_input = tf.image.convert_image_dtype(raw_input, dtype=tf.float32)
170
+
171
+ assertion = tf.assert_equal(tf.shape(raw_input)[2], 3, message="image does not have 3 channels")
172
+
173
+ with tf.control_dependencies([assertion]):
174
+ raw_input = tf.identity(raw_input)
175
+
176
+ raw_input.set_shape([None, None, 3])
177
+
178
+ # break apart image pair and move to range [-1, 1]
179
+ width = tf.shape(raw_input)[1] # [height, width, channels]
180
+
181
+ a_images = preprocess(raw_input[:, :width // 2, :])
182
+ b_images = preprocess(raw_input[:, width // 2:, :])
183
+
184
+ if a.which_direction == "AtoB":
185
+ inputs, targets = [a_images, b_images]
186
+ elif a.which_direction == "BtoA":
187
+ inputs, targets = [b_images, a_images]
188
+ else:
189
+ raise Exception("invalid direction")
190
+
191
+ # synchronize seed for image operations so that we do the same operations to both
192
+ # input and output images
193
+ def transform(image):
194
+ r = image
195
+ r.set_shape([a.scale_size,a.scale_size,3])
196
+ #r = tf.image.resize_images(r, [a.scale_size, a.scale_size], method=tf.image.ResizeMethod.AREA)
197
+ return r
198
+
199
+ with tf.name_scope("input_images"):
200
+ input_images = transform(inputs)
201
+
202
+ with tf.name_scope("target_images"):
203
+ target_images = transform(targets)
204
+
205
+ paths_batch, inputs_batch, targets_batch = tf.train.batch([paths, input_images, target_images],
206
+ batch_size=a.batch_size)
207
+
208
+ steps_per_epoch = int(math.ceil(len(input_paths) / a.batch_size))
209
+
210
+ return Examples(
211
+ paths=paths_batch,
212
+ inputs=inputs_batch,
213
+ targets=targets_batch,
214
+ count=len(input_paths),
215
+ steps_per_epoch=steps_per_epoch,
216
+ )
217
+
218
+
219
+ def create_generator(generator_inputs, generator_outputs_channels):
220
+ layers = []
221
+
222
+ #Add Filter to detect edges
223
+ filter_shape = [41,41,3,a.ngf]
224
+ with tf.variable_scope("encoder_1"):
225
+ filter = tf.get_variable('edge_detector', filter_shape, initializer=tf.random_normal_initializer(stddev=0.02))
226
+ strides = [1, 1, 1, 1]
227
+ output = tf.nn.conv2d(generator_inputs, filter, strides=strides, padding='VALID')
228
+ output = lrelu(output, 0.2)
229
+ layers.append(output)
230
+
231
+ # encoder_1: [batch, 256, 256, in_channels] => [batch, 128, 128, ngf]
232
+ with tf.variable_scope("encoder_1"):
233
+ output = gen_conv(output, a.ngf)
234
+ layers.append(output)
235
+
236
+ layer_specs = [
237
+ a.ngf * 2, # encoder_2: [batch, 128, 128, ngf] => [batch, 64, 64, ngf * 2]
238
+ a.ngf * 4, # encoder_3: [batch, 64, 64, ngf * 2] => [batch, 32, 32, ngf * 4]
239
+ a.ngf * 8, # encoder_4: [batch, 32, 32, ngf * 4] => [batch, 16, 16, ngf * 8]
240
+ a.ngf * 8, # encoder_5: [batch, 16, 16, ngf * 8] => [batch, 8, 8, ngf * 8]
241
+ a.ngf * 8, # encoder_6: [batch, 8, 8, ngf * 8] => [batch, 4, 4, ngf * 8]
242
+ a.ngf * 8, # encoder_7: [batch, 4, 4, ngf * 8] => [batch, 2, 2, ngf * 8]
243
+ a.ngf * 8, # encoder_8: [batch, 2, 2, ngf * 8] => [batch, 1, 1, ngf * 8]
244
+ ]
245
+
246
+ for out_channels in layer_specs:
247
+ with tf.variable_scope("encoder_%d" % (len(layers) + 1)):
248
+ rectified = lrelu(layers[-1], 0.2)
249
+ # [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels]
250
+ convolved = gen_conv(rectified, out_channels)
251
+ output = batchnorm(convolved)
252
+ layers.append(output)
253
+
254
+ layer_specs = [
255
+ (a.ngf * 8, 0.5), # decoder_8: [batch, 1, 1, ngf * 8] => [batch, 2, 2, ngf * 8 * 2]
256
+ (a.ngf * 8, 0.5), # decoder_7: [batch, 2, 2, ngf * 8 * 2] => [batch, 4, 4, ngf * 8 * 2]
257
+ (a.ngf * 8, 0.5), # decoder_6: [batch, 4, 4, ngf * 8 * 2] => [batch, 8, 8, ngf * 8 * 2]
258
+ (a.ngf * 8, 0.0), # decoder_5: [batch, 8, 8, ngf * 8 * 2] => [batch, 16, 16, ngf * 8 * 2]
259
+ (a.ngf * 4, 0.0), # decoder_4: [batch, 16, 16, ngf * 8 * 2] => [batch, 32, 32, ngf * 4 * 2]
260
+ (a.ngf * 2, 0.0), # decoder_3: [batch, 32, 32, ngf * 4 * 2] => [batch, 64, 64, ngf * 2 * 2]
261
+ (a.ngf, 0.0), # decoder_2: [batch, 64, 64, ngf * 2 * 2] => [batch, 128, 128, ngf * 2]
262
+ ]
263
+
264
+ num_encoder_layers = len(layers)
265
+ for decoder_layer, (out_channels, dropout) in enumerate(layer_specs):
266
+ skip_layer = num_encoder_layers - decoder_layer - 1
267
+ with tf.variable_scope("decoder_%d" % (skip_layer + 1)):
268
+ if decoder_layer == 0:
269
+ # first decoder layer doesn't have skip connections
270
+ # since it is directly connected to the skip_layer
271
+ input = layers[-1]
272
+ else:
273
+ input = tf.concat([layers[-1], layers[skip_layer]], axis=3)
274
+ #input = layers[-1]
275
+
276
+ rectified = tf.nn.relu(input)
277
+ # [batch, in_height, in_width, in_channels] => [batch, in_height*2, in_width*2, out_channels]
278
+ output = gen_deconv(rectified, out_channels)
279
+ output = batchnorm(output)
280
+
281
+ if dropout > 0.0:
282
+ output = tf.nn.dropout(output, keep_prob=1 - dropout)
283
+
284
+ layers.append(output)
285
+
286
+ # decoder_1: [batch, 128, 128, ngf * 2] => [batch, 256, 256, generator_outputs_channels]
287
+ with tf.variable_scope("decoder_1"):
288
+ input = tf.concat([layers[-1], layers[1]], axis=3)
289
+ rectified = tf.nn.relu(input)
290
+ output = gen_deconv(rectified, generator_outputs_channels)
291
+ output = tf.tanh(output)
292
+ layers.append(output)
293
+
294
+ return layers[-1]
295
+
296
+
297
+ ksize_rows = 296
298
+ ksize_cols = 296
299
+ strides_rows = 236
300
+ strides_cols = 236
301
+ num_patches = int(a.scale_size/strides_rows)
302
+ ksizes = [1, ksize_rows, ksize_cols, 1]
303
+ ksizes_output = [1, 256, 256, 1]
304
+ strides = [1, strides_rows, strides_cols, 1]
305
+ rates = [1, 1, 1, 1]
306
+ padding='VALID'
307
+
308
+
309
+ def extract_patches(x, ksizes, strides, rates):
310
+ return tf.extract_image_patches(
311
+ x,
312
+ ksizes, strides, rates,
313
+ padding="VALID"
314
+ )
315
+
316
+
317
+ def extract_patches_inverse(x, y):
318
+ _x = tf.zeros_like(x)
319
+ _y = extract_patches(_x, ksizes_output, strides, rates)
320
+ grad = tf.gradients(_y, _x)[0]
321
+ # Divide by grad, to "average" together the overlapping patches
322
+ # otherwise they would simply sum up
323
+ return tf.gradients(_y, _x, grad_ys=y)[0] / grad
324
+
325
+
326
+ def create_model(inputs, targets):
327
+
328
+ out_channels = int(targets.get_shape()[-1])
329
+
330
+ if(a.mode == "train" or a.scale_size != 296): #scale_size = 296 while testing
331
+
332
+ def create_discriminator(discrim_inputs, discrim_targets):
333
+ n_layers = 5
334
+ layers = []
335
+
336
+ # 2x [batch, height, width, in_channels] => [batch, height, width, in_channels * 2]
337
+ input = tf.concat([discrim_inputs, discrim_targets], axis=3)
338
+
339
+ # layer_1: [batch, 256, 256, in_channels * 2] => [batch, 128, 128, ndf]
340
+ with tf.variable_scope("layer_1"):
341
+ convolved = discrim_conv(input, a.ndf, stride=2)
342
+ rectified = lrelu(convolved, 0.2)
343
+ layers.append(rectified)
344
+
345
+ # layer_2: [batch, 128, 128, ndf] => [batch, 64, 64, ndf * 2]
346
+ # layer_3: [batch, 64, 64, ndf * 2] => [batch, 32, 32, ndf * 4]
347
+ # layer_4: [batch, 32, 32, ndf * 4] => [batch, 31, 31, ndf * 8]
348
+ for i in range(n_layers):
349
+ with tf.variable_scope("layer_%d" % (len(layers) + 1)):
350
+ out_channels = a.ndf * min(2 ** (i + 1), 8)
351
+ stride = 1 if i == n_layers - 1 else 2 # last layer here has stride 1
352
+ convolved = discrim_conv(layers[-1], out_channels, stride=stride)
353
+ normalized = batchnorm(convolved)
354
+ rectified = lrelu(normalized, 0.2)
355
+ layers.append(rectified)
356
+
357
+ # layer_5: [batch, 31, 31, ndf * 8] => [batch, 30, 30, 1]
358
+ with tf.variable_scope("layer_%d" % (len(layers) + 1)):
359
+ convolved = discrim_conv(rectified, out_channels=1, stride=1)
360
+ output = tf.sigmoid(convolved)
361
+ layers.append(output)
362
+ print(layers)
363
+ return layers[-1]
364
+
365
+ # Pad inputs to make shape to handle edge patches, so as we can give context as input to generator (context around target patch)
366
+ inputs_bounded = tf.image.pad_to_bounding_box(inputs, 20, 20, a.scale_size + 40, a.scale_size + 40)
367
+
368
+ #Extract Patches
369
+ with tf.variable_scope("extract_patches"):
370
+ inputs_patches = extract_patches(inputs_bounded, ksizes, strides, rates)
371
+
372
+ with tf.name_scope("patches_generator"):
373
+ with tf.variable_scope("generator", reuse=tf.AUTO_REUSE):
374
+ output_patches = []
375
+ patch_size_len = 256*256*3
376
+ #Get output from each patch via same generator
377
+ for i in range(0, num_patches):
378
+ for j in range(0, num_patches):
379
+ patch = inputs_patches[0, i, j,]
380
+ #patch_size_len = int(patch.get_shape()[0]) #defined above
381
+ # reshape
382
+ patch = tf.reshape(patch, [ksize_rows, ksize_cols, 3])
383
+ patch = tf.expand_dims(patch,0)
384
+ patch_output = create_generator(patch, out_channels)
385
+ output_patches.append(tf.reshape(patch_output,[patch_size_len]))
386
+ output_patches = tf.stack(output_patches)
387
+ output_patches = tf.reshape(output_patches, [1, num_patches, num_patches, patch_size_len])
388
+
389
+ #Stitch all patches back
390
+ k = tf.constant(0.1, shape=[1, a.scale_size, a.scale_size, 3])
391
+ outputs = extract_patches_inverse(k, output_patches)
392
+
393
+ # create two copies of discriminator, one for real pairs and one for fake pairs
394
+ # they share the same underlying variables
395
+ with tf.name_scope("real_discriminator"):
396
+ with tf.variable_scope("discriminator"):
397
+ # 2x [batch, height, width, channels] => [batch, 30, 30, 1]
398
+ predict_real = create_discriminator(inputs, targets)
399
+
400
+ with tf.name_scope("fake_discriminator"):
401
+ with tf.variable_scope("discriminator", reuse=True):
402
+ # 2x [batch, height, width, channels] => [batch, 30, 30, 1]
403
+ predict_fake = create_discriminator(inputs, outputs)
404
+
405
+ with tf.name_scope("discriminator_loss"):
406
+ # minimizing -tf.log will try to get inputs to 1
407
+ # predict_real => 1
408
+ # predict_fake => 0
409
+ discrim_loss = tf.reduce_mean(-(tf.log(tf.clip_by_value((predict_real + EPS),1e-12,1.0)) + tf.log(tf.clip_by_value((1 - predict_fake + EPS),1e-12,1.0))))
410
+
411
+ with tf.name_scope("generator_loss"):
412
+ # predict_fake => 1
413
+ # abs(targets - outputs) => 0
414
+ gen_loss_GAN = tf.reduce_mean(-tf.log(tf.clip_by_value((predict_fake + EPS),1e-12,1.0)))
415
+ gen_loss_L1 = tf.reduce_mean(tf.abs(targets - outputs))
416
+ gen_loss = gen_loss_GAN * a.gan_weight + gen_loss_L1 * a.l1_weight
417
+
418
+ with tf.name_scope("discriminator_train"):
419
+ discrim_tvars = [var for var in tf.trainable_variables() if var.name.startswith("discriminator")]
420
+ discrim_optim = tf.train.AdamOptimizer(a.lr, a.beta1)
421
+ discrim_grads_and_vars = discrim_optim.compute_gradients(discrim_loss, var_list=discrim_tvars)
422
+ discrim_train = discrim_optim.apply_gradients(discrim_grads_and_vars)
423
+
424
+ with tf.name_scope("generator_train"):
425
+ with tf.control_dependencies([discrim_train]):
426
+ gen_tvars = [var for var in tf.trainable_variables() if var.name.startswith("generator")]
427
+ gen_optim = tf.train.AdamOptimizer(a.lr, a.beta1)
428
+ gen_grads_and_vars = gen_optim.compute_gradients(gen_loss, var_list=gen_tvars)
429
+ gen_train = gen_optim.apply_gradients(gen_grads_and_vars)
430
+
431
+
432
+ ema = tf.train.ExponentialMovingAverage(decay=0.99)
433
+ update_losses = ema.apply([discrim_loss, gen_loss_GAN, gen_loss_L1])
434
+ global_step = tf.train.get_or_create_global_step()
435
+ incr_global_step = tf.assign(global_step, global_step + 1)
436
+ discrim_loss = ema.average(discrim_loss)
437
+ gen_loss_GAN=ema.average(gen_loss_GAN)
438
+ gen_loss_L1 = ema.average(gen_loss_L1)
439
+ update_ops = [update_losses, incr_global_step, gen_train]
440
+ train = tf.group(update_ops)
441
+
442
+ elif (a.mode == "test"):
443
+ predict_real = None
444
+ predict_fake = None
445
+ discrim_loss = None
446
+ discrim_grads_and_vars = None
447
+ gen_loss_GAN = None
448
+ gen_loss_L1 = None
449
+ gen_grads_and_vars = None
450
+ train = None
451
+ with tf.name_scope("patches_generator"):
452
+ with tf.variable_scope("generator", reuse=tf.AUTO_REUSE):
453
+ outputs = create_generator(inputs, out_channels)
454
+ else:
455
+ print("Give correct mode")
456
+ exit(0)
457
+
458
+ return Model(
459
+ predict_real=predict_real,
460
+ predict_fake=predict_fake,
461
+ discrim_loss=discrim_loss,
462
+ discrim_grads_and_vars=discrim_grads_and_vars,
463
+ gen_loss_GAN=gen_loss_GAN,
464
+ gen_loss_L1=gen_loss_L1,
465
+ gen_grads_and_vars=gen_grads_and_vars,
466
+ outputs=outputs,
467
+ train=train
468
+ )
469
+
470
+
471
+ def save_images(fetches, step=None):
472
+ image_dir = os.path.join(a.output_dir, "images")
473
+ if not os.path.exists(image_dir):
474
+ os.makedirs(image_dir)
475
+ filesets = []
476
+ for i, in_path in enumerate(fetches["paths"]):
477
+ name, _ = os.path.splitext(os.path.basename(in_path.decode("utf8")))
478
+ fileset = {"name": name, "step": step}
479
+ for kind in ["inputs", "outputs", "targets"]:
480
+ filename = name + "-" + kind + ".png"
481
+ if step is not None:
482
+ filename = "%08d-%s" % (step, filename)
483
+ fileset[kind] = filename
484
+ out_path = os.path.join(image_dir, filename)
485
+ contents = fetches[kind][i]
486
+ with open(out_path, "wb") as f:
487
+ f.write(contents)
488
+ filesets.append(fileset)
489
+ return filesets
490
+
491
+
492
+ def append_index(filesets, step=False):
493
+ index_path = os.path.join(a.output_dir, "index.html")
494
+ if os.path.exists(index_path):
495
+ index = open(index_path, "a")
496
+ else:
497
+ index = open(index_path, "w")
498
+ index.write("<html><body><table><tr>")
499
+ if step:
500
+ index.write("<th>step</th>")
501
+ index.write("<th>name</th><th>input</th><th>output</th><th>target</th></tr>")
502
+
503
+ for fileset in filesets:
504
+ index.write("<tr>")
505
+
506
+ if step:
507
+ index.write("<td>%d</td>" % fileset["step"])
508
+ index.write("<td>%s</td>" % fileset["name"])
509
+
510
+ for kind in ["inputs", "outputs", "targets"]:
511
+ index.write("<td><img src='images/%s'></td>" % fileset[kind])
512
+
513
+ index.write("</tr>")
514
+ return index_path
515
+
516
+
517
+ def main():
518
+ if a.seed is None:
519
+ a.seed = random.randint(0, 2 ** 31 - 1)
520
+
521
+ tf.set_random_seed(a.seed)
522
+ np.random.seed(a.seed)
523
+ random.seed(a.seed)
524
+
525
+ if not os.path.exists(a.output_dir):
526
+ os.makedirs(a.output_dir)
527
+
528
+ if a.mode == "test" or a.mode == "export":
529
+ if a.checkpoint is None:
530
+ raise Exception("checkpoint required for test mode")
531
+
532
+ # load some options from the checkpoint
533
+ options = {"which_direction", "ngf", "ndf"}
534
+ with open(os.path.join(a.checkpoint, "options.json")) as f:
535
+ for key, val in json.loads(f.read()).items():
536
+ if key in options:
537
+ print("loaded", key, "=", val)
538
+ setattr(a, key, val)
539
+ # disable these features in test mode
540
+ #a.scale_size = CROP_SIZE
541
+ a.flip = False
542
+
543
+ for k, v in a._get_kwargs():
544
+ print(k, "=", v)
545
+
546
+ with open(os.path.join(a.output_dir, "options.json"), "w") as f:
547
+ f.write(json.dumps(vars(a), sort_keys=True, indent=4))
548
+
549
+ examples = load_examples()
550
+
551
+ # inputs and targets are [batch_size, height, width, channels]
552
+ model = create_model(examples.inputs, examples.targets)
553
+
554
+ inputs = deprocess(examples.inputs)
555
+ targets = deprocess(examples.targets)
556
+
557
+ outputs = deprocess(model.outputs)
558
+
559
+ def convert(image):
560
+ return tf.image.convert_image_dtype(image, dtype=tf.uint8, saturate=True)
561
+
562
+ # reverse any processing on images so they can be written to disk or displayed to user
563
+ with tf.name_scope("convert_inputs"):
564
+ converted_inputs = convert(inputs)
565
+
566
+ with tf.name_scope("convert_targets"):
567
+ converted_targets = convert(targets)
568
+
569
+ with tf.name_scope("convert_outputs"):
570
+ converted_outputs = convert(outputs)
571
+
572
+ with tf.name_scope("encode_images"):
573
+ display_fetches = {
574
+ "paths": examples.paths,
575
+ "inputs": tf.map_fn(tf.image.encode_png, converted_inputs, dtype=tf.string, name="input_pngs"),
576
+ "targets": tf.map_fn(tf.image.encode_png, converted_targets, dtype=tf.string, name="target_pngs"),
577
+ "outputs": tf.map_fn(tf.image.encode_png, converted_outputs, dtype=tf.string, name="output_pngs"),
578
+ }
579
+
580
+
581
+ with tf.name_scope("parameter_count"):
582
+ parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()])
583
+
584
+ saver = tf.train.Saver(max_to_keep=1)
585
+
586
+ logdir = a.output_dir if (a.trace_freq > 0 or a.summary_freq > 0) else None
587
+ sv = tf.train.Supervisor(logdir=logdir, save_summaries_secs=0, saver=None)
588
+
589
+ with sv.managed_session() as sess:
590
+ print("parameter_count =", sess.run(parameter_count))
591
+
592
+ if a.checkpoint is not None:
593
+ print("loading model from checkpoint")
594
+ print(a.checkpoint)
595
+ checkpoint = tf.train.latest_checkpoint(a.checkpoint)
596
+ saver.restore(sess, checkpoint)
597
+
598
+ max_steps = 2 ** 32
599
+
600
+ if a.max_epochs is not None:
601
+ max_steps = examples.steps_per_epoch * a.max_epochs
602
+ if a.max_steps is not None:
603
+ max_steps = a.max_steps
604
+
605
+ if a.mode == "test":
606
+ # testing
607
+ # at most, process the test data once
608
+ start = time.time()
609
+ max_steps = min(examples.steps_per_epoch, max_steps)
610
+ for step in range(max_steps):
611
+ results = sess.run(display_fetches)
612
+ filesets = save_images(results)
613
+ for i, f in enumerate(filesets):
614
+ print("evaluated image", f["name"])
615
+ index_path = append_index(filesets)
616
+ print("wrote index at", index_path)
617
+ print("rate", (time.time() - start) / max_steps)
618
+ else:
619
+ # training
620
+ start = time.time()
621
+
622
+ for step in range(max_steps):
623
+ def should(freq):
624
+ return freq > 0 and ((step + 1) % freq == 0 or step == max_steps - 1)
625
+
626
+ options = None
627
+ run_metadata = None
628
+ if should(a.trace_freq):
629
+ options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
630
+ run_metadata = tf.RunMetadata()
631
+
632
+ fetches = {
633
+ "train": model.train,
634
+ "global_step": sv.global_step,
635
+ }
636
+
637
+ if should(a.progress_freq):
638
+ fetches["discrim_loss"] = model.discrim_loss
639
+ fetches["gen_loss_GAN"] = model.gen_loss_GAN
640
+ fetches["gen_loss_L1"] = model.gen_loss_L1
641
+
642
+ if should(a.summary_freq):
643
+ fetches["summary"] = sv.summary_op
644
+
645
+ if should(a.display_freq):
646
+ fetches["display"] = display_fetches
647
+
648
+ results = sess.run(fetches, options=options, run_metadata=run_metadata)
649
+
650
+ if should(a.summary_freq):
651
+ print("recording summary")
652
+ sv.summary_writer.add_summary(results["summary"], results["global_step"])
653
+
654
+ if should(a.display_freq):
655
+ print("saving display images")
656
+ filesets = save_images(results["display"], step=results["global_step"])
657
+ append_index(filesets, step=True)
658
+
659
+ if should(a.trace_freq):
660
+ print("recording trace")
661
+ sv.summary_writer.add_run_metadata(run_metadata, "step_%d" % results["global_step"])
662
+
663
+ if should(a.progress_freq):
664
+ # global_step will have the correct step count if we resume from a checkpoint
665
+ train_epoch = math.ceil(results["global_step"] / examples.steps_per_epoch)
666
+ train_step = (results["global_step"] - 1) % examples.steps_per_epoch + 1
667
+ rate = (step + 1) * a.batch_size / (time.time() - start)
668
+ remaining = (max_steps - step) * a.batch_size / rate
669
+ print("progress epoch %d step %d image/sec %0.1f remaining %dm" % (
670
+ train_epoch, train_step, rate, remaining / 60))
671
+ print("discrim_loss", results["discrim_loss"])
672
+ print("gen_loss_GAN", results["gen_loss_GAN"])
673
+ print("gen_loss_L1", results["gen_loss_L1"])
674
+
675
+ if should(a.save_freq):
676
+ print("saving model")
677
+ saver.save(sess, os.path.join(a.output_dir, "model"), global_step=sv.global_step)
678
+
679
+ if sv.should_stop():
680
+ break
681
+
682
+
683
+ main()
tools/process.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import tensorflow.compat.v1 as tf
6
+ tf.disable_v2_behavior()
7
+
8
+ import argparse
9
+ import os
10
+ import tempfile
11
+ import subprocess
12
+ #import tensorflow as tf
13
+ import numpy as np
14
+ import tfimage as im
15
+ import threading
16
+ import time
17
+ import multiprocessing
18
+
19
+
20
+
21
+ edge_pool = None
22
+
23
+
24
+ parser = argparse.ArgumentParser()
25
+ parser.add_argument("--input_dir", required=True, help="path to folder containing images")
26
+ parser.add_argument("--output_dir", required=True, help="output path")
27
+ parser.add_argument("--operation", required=True, choices=["grayscale", "resize", "blank", "combine", "edges"])
28
+ parser.add_argument("--workers", type=int, default=1, help="number of workers")
29
+ # resize
30
+ parser.add_argument("--pad", action="store_true", help="pad instead of crop for resize operation")
31
+ parser.add_argument("--size", type=int, default=256, help="size to use for resize operation")
32
+ # combine
33
+ parser.add_argument("--b_dir", type=str, help="path to folder containing B images for combine operation")
34
+ a = parser.parse_args()
35
+
36
+
37
+ def resize(src):
38
+ height, width, _ = src.shape
39
+ dst = src
40
+ if height != width:
41
+ if a.pad:
42
+ size = max(height, width)
43
+ # pad to correct ratio
44
+ oh = (size - height) // 2
45
+ ow = (size - width) // 2
46
+ dst = im.pad(image=dst, offset_height=oh, offset_width=ow, target_height=size, target_width=size)
47
+ else:
48
+ # crop to correct ratio
49
+ size = min(height, width)
50
+ oh = (height - size) // 2
51
+ ow = (width - size) // 2
52
+ dst = im.crop(image=dst, offset_height=oh, offset_width=ow, target_height=size, target_width=size)
53
+
54
+ assert(dst.shape[0] == dst.shape[1])
55
+
56
+ size, _, _ = dst.shape
57
+ if size > a.size:
58
+ dst = im.downscale(images=dst, size=[a.size, a.size])
59
+ elif size < a.size:
60
+ dst = im.upscale(images=dst, size=[a.size, a.size])
61
+ return dst
62
+
63
+
64
+ def blank(src):
65
+ height, width, _ = src.shape
66
+ if height != width:
67
+ raise Exception("non-square image")
68
+
69
+ image_size = width
70
+ size = int(image_size * 0.3)
71
+ offset = int(image_size / 2 - size / 2)
72
+
73
+ dst = src
74
+ dst[offset:offset + size,offset:offset + size,:] = np.ones([size, size, 3])
75
+ return dst
76
+
77
+
78
+ def combine(src, src_path):
79
+ if a.b_dir is None:
80
+ raise Exception("missing b_dir")
81
+
82
+ # find corresponding file in b_dir, could have a different extension
83
+ basename, _ = os.path.splitext(os.path.basename(src_path))
84
+ for ext in [".png", ".jpg"]:
85
+ sibling_path = os.path.join(a.b_dir, basename + ext)
86
+ if os.path.exists(sibling_path):
87
+ sibling = im.load(sibling_path)
88
+ break
89
+ else:
90
+ raise Exception("could not find sibling image for " + src_path)
91
+
92
+ # make sure that dimensions are correct
93
+ height, width, _ = src.shape
94
+ if height != sibling.shape[0] or width != sibling.shape[1]:
95
+ raise Exception("differing sizes")
96
+
97
+ # convert both images to RGB if necessary
98
+ if src.shape[2] == 1:
99
+ src = im.grayscale_to_rgb(images=src)
100
+
101
+ if sibling.shape[2] == 1:
102
+ sibling = im.grayscale_to_rgb(images=sibling)
103
+
104
+ # remove alpha channel
105
+ if src.shape[2] == 4:
106
+ src = src[:,:,:3]
107
+
108
+ if sibling.shape[2] == 4:
109
+ sibling = sibling[:,:,:3]
110
+
111
+ return np.concatenate([src, sibling], axis=1)
112
+
113
+
114
+ def grayscale(src):
115
+ return im.grayscale_to_rgb(images=im.rgb_to_grayscale(images=src))
116
+
117
+
118
+ net = None
119
+ def run_caffe(src):
120
+ # lazy load caffe and create net
121
+ global net
122
+ if net is None:
123
+ # don't require caffe unless we are doing edge detection
124
+ os.environ["GLOG_minloglevel"] = "2" # disable logging from caffe
125
+ import caffe
126
+ # using this requires using the docker image or assembling a bunch of dependencies
127
+ # and then changing these hardcoded paths
128
+ net = caffe.Net("/opt/caffe/examples/hed/deploy.prototxt", "/opt/caffe/hed_pretrained_bsds.caffemodel", caffe.TEST)
129
+
130
+ net.blobs["data"].reshape(1, *src.shape)
131
+ net.blobs["data"].data[...] = src
132
+ net.forward()
133
+ return net.blobs["sigmoid-fuse"].data[0][0,:,:]
134
+
135
+
136
+ def edges(src):
137
+ # based on https://github.com/phillipi/pix2pix/blob/master/scripts/edges/batch_hed.py
138
+ # and https://github.com/phillipi/pix2pix/blob/master/scripts/edges/PostprocessHED.m
139
+ import scipy.io
140
+ src = src * 255
141
+ border = 128 # put a padding around images since edge detection seems to detect edge of image
142
+ src = src[:,:,:3] # remove alpha channel if present
143
+ src = np.pad(src, ((border, border), (border, border), (0,0)), "reflect")
144
+ src = src[:,:,::-1]
145
+ src -= np.array((104.00698793,116.66876762,122.67891434))
146
+ src = src.transpose((2, 0, 1))
147
+
148
+ # [height, width, channels] => [batch, channel, height, width]
149
+ fuse = edge_pool.apply(run_caffe, [src])
150
+ fuse = fuse[border:-border, border:-border]
151
+
152
+ with tempfile.NamedTemporaryFile(suffix=".png") as png_file, tempfile.NamedTemporaryFile(suffix=".mat") as mat_file:
153
+ scipy.io.savemat(mat_file.name, {"input": fuse})
154
+
155
+ octave_code = r"""
156
+ E = 1-load(input_path).input;
157
+ E = imresize(E, [image_width,image_width]);
158
+ E = 1 - E;
159
+ E = single(E);
160
+ [Ox, Oy] = gradient(convTri(E, 4), 1);
161
+ [Oxx, ~] = gradient(Ox, 1);
162
+ [Oxy, Oyy] = gradient(Oy, 1);
163
+ O = mod(atan(Oyy .* sign(-Oxy) ./ (Oxx + 1e-5)), pi);
164
+ E = edgesNmsMex(E, O, 1, 5, 1.01, 1);
165
+ E = double(E >= max(eps, threshold));
166
+ E = bwmorph(E, 'thin', inf);
167
+ E = bwareaopen(E, small_edge);
168
+ E = 1 - E;
169
+ E = uint8(E * 255);
170
+ imwrite(E, output_path);
171
+ """
172
+
173
+ config = dict(
174
+ input_path="'%s'" % mat_file.name,
175
+ output_path="'%s'" % png_file.name,
176
+ image_width=256,
177
+ threshold=25.0/255.0,
178
+ small_edge=5,
179
+ )
180
+
181
+ args = ["octave"]
182
+ for k, v in config.items():
183
+ args.extend(["--eval", "%s=%s;" % (k, v)])
184
+
185
+ args.extend(["--eval", octave_code])
186
+ try:
187
+ subprocess.check_output(args, stderr=subprocess.STDOUT)
188
+ except subprocess.CalledProcessError as e:
189
+ print("octave failed")
190
+ print("returncode:", e.returncode)
191
+ print("output:", e.output)
192
+ raise
193
+ return im.load(png_file.name)
194
+
195
+
196
+ def process(src_path, dst_path):
197
+ src = im.load(src_path)
198
+
199
+ if a.operation == "grayscale":
200
+ dst = grayscale(src)
201
+ elif a.operation == "resize":
202
+ dst = resize(src)
203
+ elif a.operation == "blank":
204
+ dst = blank(src)
205
+ elif a.operation == "combine":
206
+ dst = combine(src, src_path)
207
+ elif a.operation == "edges":
208
+ dst = edges(src)
209
+ else:
210
+ raise Exception("invalid operation")
211
+
212
+ im.save(dst, dst_path)
213
+
214
+
215
+ complete_lock = threading.Lock()
216
+ start = None
217
+ num_complete = 0
218
+ total = 0
219
+
220
+ def complete():
221
+ global num_complete, rate, last_complete
222
+
223
+ with complete_lock:
224
+ num_complete += 1
225
+ now = time.time()
226
+ elapsed = now - start
227
+ rate = num_complete / elapsed
228
+ if rate > 0:
229
+ remaining = (total - num_complete) / rate
230
+ else:
231
+ remaining = 0
232
+
233
+ print("%d/%d complete %0.2f images/sec %dm%ds elapsed %dm%ds remaining" % (num_complete, total, rate, elapsed // 60, elapsed % 60, remaining // 60, remaining % 60))
234
+
235
+ last_complete = now
236
+
237
+
238
+ def main():
239
+ if not os.path.exists(a.output_dir):
240
+ os.makedirs(a.output_dir)
241
+
242
+ src_paths = []
243
+ dst_paths = []
244
+
245
+ skipped = 0
246
+ for src_path in im.find(a.input_dir):
247
+ name, _ = os.path.splitext(os.path.basename(src_path))
248
+ dst_path = os.path.join(a.output_dir, name + ".png")
249
+ if os.path.exists(dst_path):
250
+ skipped += 1
251
+ else:
252
+ src_paths.append(src_path)
253
+ dst_paths.append(dst_path)
254
+
255
+ print("skipping %d files that already exist" % skipped)
256
+
257
+ global total
258
+ total = len(src_paths)
259
+
260
+ print("processing %d files" % total)
261
+
262
+ global start
263
+ start = time.time()
264
+
265
+ if a.operation == "edges":
266
+ # use a multiprocessing pool for this operation so it can use multiple CPUs
267
+ # create the pool before we launch processing threads
268
+ global edge_pool
269
+ edge_pool = multiprocessing.Pool(a.workers)
270
+
271
+ if a.workers == 1:
272
+ with tf.Session() as sess:
273
+ for src_path, dst_path in zip(src_paths, dst_paths):
274
+ process(src_path, dst_path)
275
+ complete()
276
+ else:
277
+ queue = tf.train.input_producer(zip(src_paths, dst_paths), shuffle=False, num_epochs=1)
278
+ dequeue_op = queue.dequeue()
279
+
280
+ def worker(coord):
281
+ with sess.as_default():
282
+ while not coord.should_stop():
283
+ try:
284
+ src_path, dst_path = sess.run(dequeue_op)
285
+ except tf.errors.OutOfRangeError:
286
+ coord.request_stop()
287
+ break
288
+
289
+ process(src_path, dst_path)
290
+ complete()
291
+
292
+ # init epoch counter for the queue
293
+ local_init_op = tf.local_variables_initializer()
294
+ with tf.Session() as sess:
295
+ sess.run(local_init_op)
296
+
297
+ coord = tf.train.Coordinator()
298
+ threads = tf.train.start_queue_runners(coord=coord)
299
+ for i in range(a.workers):
300
+ t = threading.Thread(target=worker, args=(coord,))
301
+ t.start()
302
+ threads.append(t)
303
+
304
+ try:
305
+ coord.join(threads)
306
+ except KeyboardInterrupt:
307
+ coord.request_stop()
308
+ coord.join(threads)
309
+
310
+ main()