Spaces:
Sleeping
Sleeping
srijaydeshpande
commited on
Commit
•
4473449
1
Parent(s):
9c3c6ce
Upload 4 files
Browse files- Assistance/SingleImageCropper.py +79 -0
- Assistance/join_images.py +61 -0
- segment2tissue_safron_media.py +683 -0
- tools/process.py +310 -0
Assistance/SingleImageCropper.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
from PIL import Image
|
4 |
+
import numpy as np
|
5 |
+
import time
|
6 |
+
import PIL
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import argparse
|
9 |
+
|
10 |
+
patchsize = 296
|
11 |
+
stride = 236
|
12 |
+
pad=0
|
13 |
+
|
14 |
+
PIL.Image.MAX_IMAGE_PIXELS = 933120000
|
15 |
+
|
16 |
+
def CropImage(image_path,output_dir,pad):
|
17 |
+
image_name = os.path.split(image_path)[1].split('.')[0]
|
18 |
+
im = Image.open(image_path)
|
19 |
+
|
20 |
+
size = im.size
|
21 |
+
# new_size = (size[0]+40,size[1]+40)
|
22 |
+
new_size = size
|
23 |
+
|
24 |
+
if(pad):
|
25 |
+
new_im = Image.new("RGB", new_size) ## luckily, this is already black!
|
26 |
+
new_im.paste(im, ((new_size[0] - size[0]) // 2,
|
27 |
+
(new_size[1] - size[1]) // 2))
|
28 |
+
#plt.imshow(new_im)
|
29 |
+
#plt.show()
|
30 |
+
else:
|
31 |
+
new_im = im
|
32 |
+
|
33 |
+
width, height = new_im.size
|
34 |
+
|
35 |
+
x = 0
|
36 |
+
y = 0
|
37 |
+
right = 0
|
38 |
+
bottom = 0
|
39 |
+
|
40 |
+
while (bottom < height):
|
41 |
+
while (right < width):
|
42 |
+
left = x
|
43 |
+
top = y
|
44 |
+
right = left + patchsize
|
45 |
+
bottom = top + patchsize
|
46 |
+
if (right > width):
|
47 |
+
offset = right - width
|
48 |
+
right -= offset
|
49 |
+
left -= offset
|
50 |
+
if (bottom > height):
|
51 |
+
offset = bottom - height
|
52 |
+
bottom -= offset
|
53 |
+
top -= offset
|
54 |
+
im_crop = new_im.crop((left, top, right, bottom))
|
55 |
+
im_crop_name = image_name + "_" + str(left) + "_" + str(top) + ".png"
|
56 |
+
output_path = os.path.join(output_dir, im_crop_name)
|
57 |
+
im_crop.save(output_path)
|
58 |
+
x += stride
|
59 |
+
x = 0
|
60 |
+
right = 0
|
61 |
+
y += stride
|
62 |
+
|
63 |
+
start_time = time.time()
|
64 |
+
|
65 |
+
parser = argparse.ArgumentParser()
|
66 |
+
parser.add_argument("--image_path", help="path to image to crop",
|
67 |
+
default=r"F:\Datasets\DigestPath\safron\Benign\test\3\ablation_study_neg_46\neg_46.png")
|
68 |
+
parser.add_argument("--output_dir", help="path to output folder",
|
69 |
+
default=r"F:\Datasets\DigestPath\safron\Benign\test\3\ablation_study_neg_46\cropped_safron_patchadv")
|
70 |
+
parser.add_argument("--pad", type=int, default=0, help="pad the image borders")
|
71 |
+
|
72 |
+
args = parser.parse_args()
|
73 |
+
|
74 |
+
if not os.path.exists(args.output_dir):
|
75 |
+
os.makedirs(args.output_dir)
|
76 |
+
|
77 |
+
CropImage(args.image_path,args.output_dir,args.pad)
|
78 |
+
|
79 |
+
print("--- %s seconds ---" % (time.time() - start_time))
|
Assistance/join_images.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import PIL
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import matplotlib
|
8 |
+
import time
|
9 |
+
import argparse
|
10 |
+
|
11 |
+
PIL.Image.MAX_IMAGE_PIXELS = 933120000
|
12 |
+
|
13 |
+
def join_images(input_dir,output_file,height,width):
|
14 |
+
|
15 |
+
start_time = time.time()
|
16 |
+
|
17 |
+
paths = glob.glob(os.path.join(input_dir,"*.png"))
|
18 |
+
|
19 |
+
patch = 256
|
20 |
+
|
21 |
+
image = np.zeros((height,width,3))
|
22 |
+
count_masks = np.zeros((height,width,3))
|
23 |
+
k=0
|
24 |
+
for path in paths:
|
25 |
+
if('outputs' in path):
|
26 |
+
imname = os.path.split(path)[1].replace("-outputs","").split(".")[0]
|
27 |
+
imname = imname.split("_")
|
28 |
+
y,x = int(imname[-2]),int(imname[-1])
|
29 |
+
img = Image.open(path)
|
30 |
+
img = np.asarray(img)
|
31 |
+
#print("X => ",x," Y => ",y)
|
32 |
+
image[x:x+patch,y:y+patch,:] += img
|
33 |
+
count_masks[x:x+patch,y:y+patch,:]+=1.0
|
34 |
+
k+=1
|
35 |
+
|
36 |
+
count_masks = count_masks.clip(min=1)
|
37 |
+
|
38 |
+
image = image/count_masks
|
39 |
+
|
40 |
+
image = image/255.0
|
41 |
+
|
42 |
+
# im = Image.fromarray(image)
|
43 |
+
# im.save(output_file)
|
44 |
+
|
45 |
+
matplotlib.image.imsave(output_file, image)
|
46 |
+
|
47 |
+
print("--- %s seconds ---" % (time.time() - start_time))
|
48 |
+
|
49 |
+
print("Done")
|
50 |
+
|
51 |
+
parser = argparse.ArgumentParser()
|
52 |
+
parser.add_argument("--patches_dir", help="path to generated patches to join",
|
53 |
+
default="./tmp/results/images")
|
54 |
+
parser.add_argument("--output_file", help="path to output file",
|
55 |
+
default="F:/Datasets/DigestPath/safron/test/single/outputs/sample.jpeg")
|
56 |
+
parser.add_argument("--im_height", type=int, help="image height")
|
57 |
+
parser.add_argument("--im_width", type=int, help="image width")
|
58 |
+
|
59 |
+
args = parser.parse_args()
|
60 |
+
|
61 |
+
join_images(args.patches_dir,args.output_file,args.im_height,args.im_width)
|
segment2tissue_safron_media.py
ADDED
@@ -0,0 +1,683 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import tensorflow.compat.v1 as tf
|
6 |
+
import numpy as np
|
7 |
+
import argparse
|
8 |
+
import os
|
9 |
+
import json
|
10 |
+
import glob
|
11 |
+
import random
|
12 |
+
import collections
|
13 |
+
import math
|
14 |
+
import time
|
15 |
+
# visualize image
|
16 |
+
import matplotlib.pyplot as plt
|
17 |
+
|
18 |
+
#disable v2 behavious
|
19 |
+
tf.disable_v2_behavior()
|
20 |
+
|
21 |
+
# enable eager execution
|
22 |
+
# tf.compat.v1.enable_eager_execution()
|
23 |
+
|
24 |
+
parser = argparse.ArgumentParser()
|
25 |
+
parser.add_argument("--input_dir", help="path to folder containing images")
|
26 |
+
parser.add_argument("--mode", required=True, choices=["train", "test", "export"])
|
27 |
+
parser.add_argument("--output_dir", required=True, help="where to put output files")
|
28 |
+
parser.add_argument("--seed", type=int)
|
29 |
+
parser.add_argument("--checkpoint", default=None,
|
30 |
+
help="directory with checkpoint to resume training from or use for testing")
|
31 |
+
|
32 |
+
parser.add_argument("--max_steps", type=int, help="number of training steps (0 to disable)")
|
33 |
+
parser.add_argument("--max_epochs", type=int, help="number of training epochs")
|
34 |
+
parser.add_argument("--summary_freq", type=int, default=10000, help="update summaries every summary_freq steps")
|
35 |
+
parser.add_argument("--progress_freq", type=int, default=50, help="display progress every progress_freq steps")
|
36 |
+
parser.add_argument("--trace_freq", type=int, default=0, help="trace execution every trace_freq steps")
|
37 |
+
parser.add_argument("--display_freq", type=int, default=1000, help="write current training images every display_freq steps")
|
38 |
+
parser.add_argument("--save_freq", type=int, default=1000, help="save model every save_freq steps, 0 to disable")
|
39 |
+
parser.add_argument("--separable_conv", action="store_true", help="use separable convolutions in the generator")
|
40 |
+
parser.add_argument("--aspect_ratio", type=float, default=1.0, help="aspect ratio of output images (width/height)")
|
41 |
+
parser.add_argument("--batch_size", type=int, default=1, help="number of images in batch")
|
42 |
+
parser.add_argument("--which_direction", type=str, default="BtoA", choices=["AtoB", "BtoA"])
|
43 |
+
parser.add_argument("--ngf", type=int, default=64, help="number of generator filters in first conv layer")
|
44 |
+
parser.add_argument("--ndf", type=int, default=64, help="number of discriminator filters in first conv layer")
|
45 |
+
parser.add_argument("--scale_size", type=int, default=728, help="scale images to this size before cropping to 256x256")
|
46 |
+
parser.add_argument("--flip", dest="flip", action="store_true", help="flip images horizontally")
|
47 |
+
parser.add_argument("--no_flip", dest="flip", action="store_false", help="don't flip images horizontally")
|
48 |
+
# parser.set_defaults(flip=True)
|
49 |
+
parser.add_argument("--lr", type=float, default=0.0001, help="initial learning rate for adam")
|
50 |
+
parser.add_argument("--beta1", type=float, default=0.5, help="momentum term of adam")
|
51 |
+
parser.add_argument("--l1_weight", type=float, default=100.0, help="weight on L1 term for generator gradient")
|
52 |
+
parser.add_argument("--gan_weight", type=float, default=1.0, help="weight on GAN term for generator gradient")
|
53 |
+
|
54 |
+
# export options
|
55 |
+
parser.add_argument("--output_filetype", default="png", choices=["png", "jpeg"])
|
56 |
+
a = parser.parse_args()
|
57 |
+
|
58 |
+
EPS = 1e-12
|
59 |
+
|
60 |
+
Examples = collections.namedtuple("Examples", "paths, inputs, targets, count, steps_per_epoch")
|
61 |
+
Model = collections.namedtuple("Model",
|
62 |
+
"outputs, predict_real, predict_fake, discrim_loss, discrim_grads_and_vars, gen_loss_GAN, gen_loss_L1, gen_grads_and_vars, train")
|
63 |
+
|
64 |
+
def preprocess(image):
|
65 |
+
with tf.name_scope("preprocess"):
|
66 |
+
# [0, 1] => [-1, 1]
|
67 |
+
return image * 2 - 1
|
68 |
+
|
69 |
+
|
70 |
+
def deprocess(image):
|
71 |
+
with tf.name_scope("deprocess"):
|
72 |
+
# [-1, 1] => [0, 1]
|
73 |
+
return (image + 1) / 2
|
74 |
+
|
75 |
+
|
76 |
+
def discrim_conv(batch_input, out_channels, stride):
|
77 |
+
padded_input = tf.pad(batch_input, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="CONSTANT")
|
78 |
+
return tf.layers.conv2d(padded_input, out_channels, kernel_size=4, strides=(stride, stride), padding="valid",
|
79 |
+
kernel_initializer=tf.random_normal_initializer(0, 0.02))
|
80 |
+
|
81 |
+
|
82 |
+
def gen_conv(batch_input, out_channels):
|
83 |
+
# [batch, in_height, in_width, in_channels] => [batch, out_height, out_width, out_channels]
|
84 |
+
initializer = tf.random_normal_initializer(0, 0.02)
|
85 |
+
if a.separable_conv:
|
86 |
+
return tf.layers.separable_conv2d(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same",
|
87 |
+
depthwise_initializer=initializer, pointwise_initializer=initializer)
|
88 |
+
else:
|
89 |
+
return tf.layers.conv2d(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same",
|
90 |
+
kernel_initializer=initializer)
|
91 |
+
|
92 |
+
|
93 |
+
def gen_deconv(batch_input, out_channels):
|
94 |
+
# [batch, in_height, in_width, in_channels] => [batch, out_height, out_width, out_channels]
|
95 |
+
initializer = tf.random_normal_initializer(0, 0.02)
|
96 |
+
if a.separable_conv:
|
97 |
+
_b, h, w, _c = batch_input.shape
|
98 |
+
resized_input = tf.image.resize_images(batch_input, [h * 2, w * 2],
|
99 |
+
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
|
100 |
+
return tf.layers.separable_conv2d(resized_input, out_channels, kernel_size=4, strides=(1, 1), padding="same",
|
101 |
+
depthwise_initializer=initializer, pointwise_initializer=initializer)
|
102 |
+
else:
|
103 |
+
return tf.layers.conv2d_transpose(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same",
|
104 |
+
kernel_initializer=initializer)
|
105 |
+
|
106 |
+
|
107 |
+
def lrelu(x, a):
|
108 |
+
with tf.name_scope("lrelu"):
|
109 |
+
# adding these together creates the leak part and linear part
|
110 |
+
# then cancels them out by subtracting/adding an absolute value term
|
111 |
+
# leak: a*x/2 - a*abs(x)/2
|
112 |
+
# linear: x/2 + abs(x)/2
|
113 |
+
|
114 |
+
# this block looks like it has 2 inputs on the graph unless we do this
|
115 |
+
x = tf.identity(x)
|
116 |
+
return (0.5 * (1 + a)) * x + (0.5 * (1 - a)) * tf.abs(x)
|
117 |
+
|
118 |
+
|
119 |
+
def batchnorm(inputs):
|
120 |
+
return tf.layers.batch_normalization(inputs, axis=3, epsilon=1e-5, momentum=0.1, training=True,
|
121 |
+
gamma_initializer=tf.random_normal_initializer(1.0, 0.02))
|
122 |
+
|
123 |
+
|
124 |
+
def check_image(image):
|
125 |
+
assertion = tf.assert_equal(tf.shape(image)[-1], 3, message="image must have 3 color channels")
|
126 |
+
with tf.control_dependencies([assertion]):
|
127 |
+
image = tf.identity(image)
|
128 |
+
|
129 |
+
if image.get_shape().ndims not in (3, 4):
|
130 |
+
raise ValueError("image must be either 3 or 4 dimensions")
|
131 |
+
|
132 |
+
# make the last dimension 3 so that you can unstack the colors
|
133 |
+
shape = list(image.get_shape())
|
134 |
+
shape[-1] = 3
|
135 |
+
image.set_shape(shape)
|
136 |
+
return image
|
137 |
+
|
138 |
+
|
139 |
+
def load_examples():
|
140 |
+
|
141 |
+
if a.input_dir is None or not os.path.exists(a.input_dir):
|
142 |
+
raise Exception("input_dir does not exist")
|
143 |
+
|
144 |
+
input_paths = glob.glob(os.path.join(a.input_dir, "*.jpg"))
|
145 |
+
decode = tf.image.decode_jpeg
|
146 |
+
if len(input_paths) == 0:
|
147 |
+
input_paths = glob.glob(os.path.join(a.input_dir, "*.png"))
|
148 |
+
decode = tf.image.decode_png
|
149 |
+
|
150 |
+
if len(input_paths) == 0:
|
151 |
+
raise Exception("input_dir contains no image files")
|
152 |
+
|
153 |
+
def get_name(path):
|
154 |
+
name, _ = os.path.splitext(os.path.basename(path))
|
155 |
+
return name
|
156 |
+
|
157 |
+
# if the image names are numbers, sort by the value rather than asciibetically
|
158 |
+
# having sorted inputs means that the outputs are sorted in test mode
|
159 |
+
if all(get_name(path).isdigit() for path in input_paths):
|
160 |
+
input_paths = sorted(input_paths, key=lambda path: int(get_name(path)))
|
161 |
+
else:
|
162 |
+
input_paths = sorted(input_paths)
|
163 |
+
|
164 |
+
with tf.name_scope("load_images"):
|
165 |
+
path_queue = tf.train.string_input_producer(input_paths, shuffle=a.mode == "train")
|
166 |
+
reader = tf.WholeFileReader()
|
167 |
+
paths, contents = reader.read(path_queue)
|
168 |
+
raw_input = decode(contents)
|
169 |
+
raw_input = tf.image.convert_image_dtype(raw_input, dtype=tf.float32)
|
170 |
+
|
171 |
+
assertion = tf.assert_equal(tf.shape(raw_input)[2], 3, message="image does not have 3 channels")
|
172 |
+
|
173 |
+
with tf.control_dependencies([assertion]):
|
174 |
+
raw_input = tf.identity(raw_input)
|
175 |
+
|
176 |
+
raw_input.set_shape([None, None, 3])
|
177 |
+
|
178 |
+
# break apart image pair and move to range [-1, 1]
|
179 |
+
width = tf.shape(raw_input)[1] # [height, width, channels]
|
180 |
+
|
181 |
+
a_images = preprocess(raw_input[:, :width // 2, :])
|
182 |
+
b_images = preprocess(raw_input[:, width // 2:, :])
|
183 |
+
|
184 |
+
if a.which_direction == "AtoB":
|
185 |
+
inputs, targets = [a_images, b_images]
|
186 |
+
elif a.which_direction == "BtoA":
|
187 |
+
inputs, targets = [b_images, a_images]
|
188 |
+
else:
|
189 |
+
raise Exception("invalid direction")
|
190 |
+
|
191 |
+
# synchronize seed for image operations so that we do the same operations to both
|
192 |
+
# input and output images
|
193 |
+
def transform(image):
|
194 |
+
r = image
|
195 |
+
r.set_shape([a.scale_size,a.scale_size,3])
|
196 |
+
#r = tf.image.resize_images(r, [a.scale_size, a.scale_size], method=tf.image.ResizeMethod.AREA)
|
197 |
+
return r
|
198 |
+
|
199 |
+
with tf.name_scope("input_images"):
|
200 |
+
input_images = transform(inputs)
|
201 |
+
|
202 |
+
with tf.name_scope("target_images"):
|
203 |
+
target_images = transform(targets)
|
204 |
+
|
205 |
+
paths_batch, inputs_batch, targets_batch = tf.train.batch([paths, input_images, target_images],
|
206 |
+
batch_size=a.batch_size)
|
207 |
+
|
208 |
+
steps_per_epoch = int(math.ceil(len(input_paths) / a.batch_size))
|
209 |
+
|
210 |
+
return Examples(
|
211 |
+
paths=paths_batch,
|
212 |
+
inputs=inputs_batch,
|
213 |
+
targets=targets_batch,
|
214 |
+
count=len(input_paths),
|
215 |
+
steps_per_epoch=steps_per_epoch,
|
216 |
+
)
|
217 |
+
|
218 |
+
|
219 |
+
def create_generator(generator_inputs, generator_outputs_channels):
|
220 |
+
layers = []
|
221 |
+
|
222 |
+
#Add Filter to detect edges
|
223 |
+
filter_shape = [41,41,3,a.ngf]
|
224 |
+
with tf.variable_scope("encoder_1"):
|
225 |
+
filter = tf.get_variable('edge_detector', filter_shape, initializer=tf.random_normal_initializer(stddev=0.02))
|
226 |
+
strides = [1, 1, 1, 1]
|
227 |
+
output = tf.nn.conv2d(generator_inputs, filter, strides=strides, padding='VALID')
|
228 |
+
output = lrelu(output, 0.2)
|
229 |
+
layers.append(output)
|
230 |
+
|
231 |
+
# encoder_1: [batch, 256, 256, in_channels] => [batch, 128, 128, ngf]
|
232 |
+
with tf.variable_scope("encoder_1"):
|
233 |
+
output = gen_conv(output, a.ngf)
|
234 |
+
layers.append(output)
|
235 |
+
|
236 |
+
layer_specs = [
|
237 |
+
a.ngf * 2, # encoder_2: [batch, 128, 128, ngf] => [batch, 64, 64, ngf * 2]
|
238 |
+
a.ngf * 4, # encoder_3: [batch, 64, 64, ngf * 2] => [batch, 32, 32, ngf * 4]
|
239 |
+
a.ngf * 8, # encoder_4: [batch, 32, 32, ngf * 4] => [batch, 16, 16, ngf * 8]
|
240 |
+
a.ngf * 8, # encoder_5: [batch, 16, 16, ngf * 8] => [batch, 8, 8, ngf * 8]
|
241 |
+
a.ngf * 8, # encoder_6: [batch, 8, 8, ngf * 8] => [batch, 4, 4, ngf * 8]
|
242 |
+
a.ngf * 8, # encoder_7: [batch, 4, 4, ngf * 8] => [batch, 2, 2, ngf * 8]
|
243 |
+
a.ngf * 8, # encoder_8: [batch, 2, 2, ngf * 8] => [batch, 1, 1, ngf * 8]
|
244 |
+
]
|
245 |
+
|
246 |
+
for out_channels in layer_specs:
|
247 |
+
with tf.variable_scope("encoder_%d" % (len(layers) + 1)):
|
248 |
+
rectified = lrelu(layers[-1], 0.2)
|
249 |
+
# [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels]
|
250 |
+
convolved = gen_conv(rectified, out_channels)
|
251 |
+
output = batchnorm(convolved)
|
252 |
+
layers.append(output)
|
253 |
+
|
254 |
+
layer_specs = [
|
255 |
+
(a.ngf * 8, 0.5), # decoder_8: [batch, 1, 1, ngf * 8] => [batch, 2, 2, ngf * 8 * 2]
|
256 |
+
(a.ngf * 8, 0.5), # decoder_7: [batch, 2, 2, ngf * 8 * 2] => [batch, 4, 4, ngf * 8 * 2]
|
257 |
+
(a.ngf * 8, 0.5), # decoder_6: [batch, 4, 4, ngf * 8 * 2] => [batch, 8, 8, ngf * 8 * 2]
|
258 |
+
(a.ngf * 8, 0.0), # decoder_5: [batch, 8, 8, ngf * 8 * 2] => [batch, 16, 16, ngf * 8 * 2]
|
259 |
+
(a.ngf * 4, 0.0), # decoder_4: [batch, 16, 16, ngf * 8 * 2] => [batch, 32, 32, ngf * 4 * 2]
|
260 |
+
(a.ngf * 2, 0.0), # decoder_3: [batch, 32, 32, ngf * 4 * 2] => [batch, 64, 64, ngf * 2 * 2]
|
261 |
+
(a.ngf, 0.0), # decoder_2: [batch, 64, 64, ngf * 2 * 2] => [batch, 128, 128, ngf * 2]
|
262 |
+
]
|
263 |
+
|
264 |
+
num_encoder_layers = len(layers)
|
265 |
+
for decoder_layer, (out_channels, dropout) in enumerate(layer_specs):
|
266 |
+
skip_layer = num_encoder_layers - decoder_layer - 1
|
267 |
+
with tf.variable_scope("decoder_%d" % (skip_layer + 1)):
|
268 |
+
if decoder_layer == 0:
|
269 |
+
# first decoder layer doesn't have skip connections
|
270 |
+
# since it is directly connected to the skip_layer
|
271 |
+
input = layers[-1]
|
272 |
+
else:
|
273 |
+
input = tf.concat([layers[-1], layers[skip_layer]], axis=3)
|
274 |
+
#input = layers[-1]
|
275 |
+
|
276 |
+
rectified = tf.nn.relu(input)
|
277 |
+
# [batch, in_height, in_width, in_channels] => [batch, in_height*2, in_width*2, out_channels]
|
278 |
+
output = gen_deconv(rectified, out_channels)
|
279 |
+
output = batchnorm(output)
|
280 |
+
|
281 |
+
if dropout > 0.0:
|
282 |
+
output = tf.nn.dropout(output, keep_prob=1 - dropout)
|
283 |
+
|
284 |
+
layers.append(output)
|
285 |
+
|
286 |
+
# decoder_1: [batch, 128, 128, ngf * 2] => [batch, 256, 256, generator_outputs_channels]
|
287 |
+
with tf.variable_scope("decoder_1"):
|
288 |
+
input = tf.concat([layers[-1], layers[1]], axis=3)
|
289 |
+
rectified = tf.nn.relu(input)
|
290 |
+
output = gen_deconv(rectified, generator_outputs_channels)
|
291 |
+
output = tf.tanh(output)
|
292 |
+
layers.append(output)
|
293 |
+
|
294 |
+
return layers[-1]
|
295 |
+
|
296 |
+
|
297 |
+
ksize_rows = 296
|
298 |
+
ksize_cols = 296
|
299 |
+
strides_rows = 236
|
300 |
+
strides_cols = 236
|
301 |
+
num_patches = int(a.scale_size/strides_rows)
|
302 |
+
ksizes = [1, ksize_rows, ksize_cols, 1]
|
303 |
+
ksizes_output = [1, 256, 256, 1]
|
304 |
+
strides = [1, strides_rows, strides_cols, 1]
|
305 |
+
rates = [1, 1, 1, 1]
|
306 |
+
padding='VALID'
|
307 |
+
|
308 |
+
|
309 |
+
def extract_patches(x, ksizes, strides, rates):
|
310 |
+
return tf.extract_image_patches(
|
311 |
+
x,
|
312 |
+
ksizes, strides, rates,
|
313 |
+
padding="VALID"
|
314 |
+
)
|
315 |
+
|
316 |
+
|
317 |
+
def extract_patches_inverse(x, y):
|
318 |
+
_x = tf.zeros_like(x)
|
319 |
+
_y = extract_patches(_x, ksizes_output, strides, rates)
|
320 |
+
grad = tf.gradients(_y, _x)[0]
|
321 |
+
# Divide by grad, to "average" together the overlapping patches
|
322 |
+
# otherwise they would simply sum up
|
323 |
+
return tf.gradients(_y, _x, grad_ys=y)[0] / grad
|
324 |
+
|
325 |
+
|
326 |
+
def create_model(inputs, targets):
|
327 |
+
|
328 |
+
out_channels = int(targets.get_shape()[-1])
|
329 |
+
|
330 |
+
if(a.mode == "train" or a.scale_size != 296): #scale_size = 296 while testing
|
331 |
+
|
332 |
+
def create_discriminator(discrim_inputs, discrim_targets):
|
333 |
+
n_layers = 5
|
334 |
+
layers = []
|
335 |
+
|
336 |
+
# 2x [batch, height, width, in_channels] => [batch, height, width, in_channels * 2]
|
337 |
+
input = tf.concat([discrim_inputs, discrim_targets], axis=3)
|
338 |
+
|
339 |
+
# layer_1: [batch, 256, 256, in_channels * 2] => [batch, 128, 128, ndf]
|
340 |
+
with tf.variable_scope("layer_1"):
|
341 |
+
convolved = discrim_conv(input, a.ndf, stride=2)
|
342 |
+
rectified = lrelu(convolved, 0.2)
|
343 |
+
layers.append(rectified)
|
344 |
+
|
345 |
+
# layer_2: [batch, 128, 128, ndf] => [batch, 64, 64, ndf * 2]
|
346 |
+
# layer_3: [batch, 64, 64, ndf * 2] => [batch, 32, 32, ndf * 4]
|
347 |
+
# layer_4: [batch, 32, 32, ndf * 4] => [batch, 31, 31, ndf * 8]
|
348 |
+
for i in range(n_layers):
|
349 |
+
with tf.variable_scope("layer_%d" % (len(layers) + 1)):
|
350 |
+
out_channels = a.ndf * min(2 ** (i + 1), 8)
|
351 |
+
stride = 1 if i == n_layers - 1 else 2 # last layer here has stride 1
|
352 |
+
convolved = discrim_conv(layers[-1], out_channels, stride=stride)
|
353 |
+
normalized = batchnorm(convolved)
|
354 |
+
rectified = lrelu(normalized, 0.2)
|
355 |
+
layers.append(rectified)
|
356 |
+
|
357 |
+
# layer_5: [batch, 31, 31, ndf * 8] => [batch, 30, 30, 1]
|
358 |
+
with tf.variable_scope("layer_%d" % (len(layers) + 1)):
|
359 |
+
convolved = discrim_conv(rectified, out_channels=1, stride=1)
|
360 |
+
output = tf.sigmoid(convolved)
|
361 |
+
layers.append(output)
|
362 |
+
print(layers)
|
363 |
+
return layers[-1]
|
364 |
+
|
365 |
+
# Pad inputs to make shape to handle edge patches, so as we can give context as input to generator (context around target patch)
|
366 |
+
inputs_bounded = tf.image.pad_to_bounding_box(inputs, 20, 20, a.scale_size + 40, a.scale_size + 40)
|
367 |
+
|
368 |
+
#Extract Patches
|
369 |
+
with tf.variable_scope("extract_patches"):
|
370 |
+
inputs_patches = extract_patches(inputs_bounded, ksizes, strides, rates)
|
371 |
+
|
372 |
+
with tf.name_scope("patches_generator"):
|
373 |
+
with tf.variable_scope("generator", reuse=tf.AUTO_REUSE):
|
374 |
+
output_patches = []
|
375 |
+
patch_size_len = 256*256*3
|
376 |
+
#Get output from each patch via same generator
|
377 |
+
for i in range(0, num_patches):
|
378 |
+
for j in range(0, num_patches):
|
379 |
+
patch = inputs_patches[0, i, j,]
|
380 |
+
#patch_size_len = int(patch.get_shape()[0]) #defined above
|
381 |
+
# reshape
|
382 |
+
patch = tf.reshape(patch, [ksize_rows, ksize_cols, 3])
|
383 |
+
patch = tf.expand_dims(patch,0)
|
384 |
+
patch_output = create_generator(patch, out_channels)
|
385 |
+
output_patches.append(tf.reshape(patch_output,[patch_size_len]))
|
386 |
+
output_patches = tf.stack(output_patches)
|
387 |
+
output_patches = tf.reshape(output_patches, [1, num_patches, num_patches, patch_size_len])
|
388 |
+
|
389 |
+
#Stitch all patches back
|
390 |
+
k = tf.constant(0.1, shape=[1, a.scale_size, a.scale_size, 3])
|
391 |
+
outputs = extract_patches_inverse(k, output_patches)
|
392 |
+
|
393 |
+
# create two copies of discriminator, one for real pairs and one for fake pairs
|
394 |
+
# they share the same underlying variables
|
395 |
+
with tf.name_scope("real_discriminator"):
|
396 |
+
with tf.variable_scope("discriminator"):
|
397 |
+
# 2x [batch, height, width, channels] => [batch, 30, 30, 1]
|
398 |
+
predict_real = create_discriminator(inputs, targets)
|
399 |
+
|
400 |
+
with tf.name_scope("fake_discriminator"):
|
401 |
+
with tf.variable_scope("discriminator", reuse=True):
|
402 |
+
# 2x [batch, height, width, channels] => [batch, 30, 30, 1]
|
403 |
+
predict_fake = create_discriminator(inputs, outputs)
|
404 |
+
|
405 |
+
with tf.name_scope("discriminator_loss"):
|
406 |
+
# minimizing -tf.log will try to get inputs to 1
|
407 |
+
# predict_real => 1
|
408 |
+
# predict_fake => 0
|
409 |
+
discrim_loss = tf.reduce_mean(-(tf.log(tf.clip_by_value((predict_real + EPS),1e-12,1.0)) + tf.log(tf.clip_by_value((1 - predict_fake + EPS),1e-12,1.0))))
|
410 |
+
|
411 |
+
with tf.name_scope("generator_loss"):
|
412 |
+
# predict_fake => 1
|
413 |
+
# abs(targets - outputs) => 0
|
414 |
+
gen_loss_GAN = tf.reduce_mean(-tf.log(tf.clip_by_value((predict_fake + EPS),1e-12,1.0)))
|
415 |
+
gen_loss_L1 = tf.reduce_mean(tf.abs(targets - outputs))
|
416 |
+
gen_loss = gen_loss_GAN * a.gan_weight + gen_loss_L1 * a.l1_weight
|
417 |
+
|
418 |
+
with tf.name_scope("discriminator_train"):
|
419 |
+
discrim_tvars = [var for var in tf.trainable_variables() if var.name.startswith("discriminator")]
|
420 |
+
discrim_optim = tf.train.AdamOptimizer(a.lr, a.beta1)
|
421 |
+
discrim_grads_and_vars = discrim_optim.compute_gradients(discrim_loss, var_list=discrim_tvars)
|
422 |
+
discrim_train = discrim_optim.apply_gradients(discrim_grads_and_vars)
|
423 |
+
|
424 |
+
with tf.name_scope("generator_train"):
|
425 |
+
with tf.control_dependencies([discrim_train]):
|
426 |
+
gen_tvars = [var for var in tf.trainable_variables() if var.name.startswith("generator")]
|
427 |
+
gen_optim = tf.train.AdamOptimizer(a.lr, a.beta1)
|
428 |
+
gen_grads_and_vars = gen_optim.compute_gradients(gen_loss, var_list=gen_tvars)
|
429 |
+
gen_train = gen_optim.apply_gradients(gen_grads_and_vars)
|
430 |
+
|
431 |
+
|
432 |
+
ema = tf.train.ExponentialMovingAverage(decay=0.99)
|
433 |
+
update_losses = ema.apply([discrim_loss, gen_loss_GAN, gen_loss_L1])
|
434 |
+
global_step = tf.train.get_or_create_global_step()
|
435 |
+
incr_global_step = tf.assign(global_step, global_step + 1)
|
436 |
+
discrim_loss = ema.average(discrim_loss)
|
437 |
+
gen_loss_GAN=ema.average(gen_loss_GAN)
|
438 |
+
gen_loss_L1 = ema.average(gen_loss_L1)
|
439 |
+
update_ops = [update_losses, incr_global_step, gen_train]
|
440 |
+
train = tf.group(update_ops)
|
441 |
+
|
442 |
+
elif (a.mode == "test"):
|
443 |
+
predict_real = None
|
444 |
+
predict_fake = None
|
445 |
+
discrim_loss = None
|
446 |
+
discrim_grads_and_vars = None
|
447 |
+
gen_loss_GAN = None
|
448 |
+
gen_loss_L1 = None
|
449 |
+
gen_grads_and_vars = None
|
450 |
+
train = None
|
451 |
+
with tf.name_scope("patches_generator"):
|
452 |
+
with tf.variable_scope("generator", reuse=tf.AUTO_REUSE):
|
453 |
+
outputs = create_generator(inputs, out_channels)
|
454 |
+
else:
|
455 |
+
print("Give correct mode")
|
456 |
+
exit(0)
|
457 |
+
|
458 |
+
return Model(
|
459 |
+
predict_real=predict_real,
|
460 |
+
predict_fake=predict_fake,
|
461 |
+
discrim_loss=discrim_loss,
|
462 |
+
discrim_grads_and_vars=discrim_grads_and_vars,
|
463 |
+
gen_loss_GAN=gen_loss_GAN,
|
464 |
+
gen_loss_L1=gen_loss_L1,
|
465 |
+
gen_grads_and_vars=gen_grads_and_vars,
|
466 |
+
outputs=outputs,
|
467 |
+
train=train
|
468 |
+
)
|
469 |
+
|
470 |
+
|
471 |
+
def save_images(fetches, step=None):
|
472 |
+
image_dir = os.path.join(a.output_dir, "images")
|
473 |
+
if not os.path.exists(image_dir):
|
474 |
+
os.makedirs(image_dir)
|
475 |
+
filesets = []
|
476 |
+
for i, in_path in enumerate(fetches["paths"]):
|
477 |
+
name, _ = os.path.splitext(os.path.basename(in_path.decode("utf8")))
|
478 |
+
fileset = {"name": name, "step": step}
|
479 |
+
for kind in ["inputs", "outputs", "targets"]:
|
480 |
+
filename = name + "-" + kind + ".png"
|
481 |
+
if step is not None:
|
482 |
+
filename = "%08d-%s" % (step, filename)
|
483 |
+
fileset[kind] = filename
|
484 |
+
out_path = os.path.join(image_dir, filename)
|
485 |
+
contents = fetches[kind][i]
|
486 |
+
with open(out_path, "wb") as f:
|
487 |
+
f.write(contents)
|
488 |
+
filesets.append(fileset)
|
489 |
+
return filesets
|
490 |
+
|
491 |
+
|
492 |
+
def append_index(filesets, step=False):
|
493 |
+
index_path = os.path.join(a.output_dir, "index.html")
|
494 |
+
if os.path.exists(index_path):
|
495 |
+
index = open(index_path, "a")
|
496 |
+
else:
|
497 |
+
index = open(index_path, "w")
|
498 |
+
index.write("<html><body><table><tr>")
|
499 |
+
if step:
|
500 |
+
index.write("<th>step</th>")
|
501 |
+
index.write("<th>name</th><th>input</th><th>output</th><th>target</th></tr>")
|
502 |
+
|
503 |
+
for fileset in filesets:
|
504 |
+
index.write("<tr>")
|
505 |
+
|
506 |
+
if step:
|
507 |
+
index.write("<td>%d</td>" % fileset["step"])
|
508 |
+
index.write("<td>%s</td>" % fileset["name"])
|
509 |
+
|
510 |
+
for kind in ["inputs", "outputs", "targets"]:
|
511 |
+
index.write("<td><img src='images/%s'></td>" % fileset[kind])
|
512 |
+
|
513 |
+
index.write("</tr>")
|
514 |
+
return index_path
|
515 |
+
|
516 |
+
|
517 |
+
def main():
|
518 |
+
if a.seed is None:
|
519 |
+
a.seed = random.randint(0, 2 ** 31 - 1)
|
520 |
+
|
521 |
+
tf.set_random_seed(a.seed)
|
522 |
+
np.random.seed(a.seed)
|
523 |
+
random.seed(a.seed)
|
524 |
+
|
525 |
+
if not os.path.exists(a.output_dir):
|
526 |
+
os.makedirs(a.output_dir)
|
527 |
+
|
528 |
+
if a.mode == "test" or a.mode == "export":
|
529 |
+
if a.checkpoint is None:
|
530 |
+
raise Exception("checkpoint required for test mode")
|
531 |
+
|
532 |
+
# load some options from the checkpoint
|
533 |
+
options = {"which_direction", "ngf", "ndf"}
|
534 |
+
with open(os.path.join(a.checkpoint, "options.json")) as f:
|
535 |
+
for key, val in json.loads(f.read()).items():
|
536 |
+
if key in options:
|
537 |
+
print("loaded", key, "=", val)
|
538 |
+
setattr(a, key, val)
|
539 |
+
# disable these features in test mode
|
540 |
+
#a.scale_size = CROP_SIZE
|
541 |
+
a.flip = False
|
542 |
+
|
543 |
+
for k, v in a._get_kwargs():
|
544 |
+
print(k, "=", v)
|
545 |
+
|
546 |
+
with open(os.path.join(a.output_dir, "options.json"), "w") as f:
|
547 |
+
f.write(json.dumps(vars(a), sort_keys=True, indent=4))
|
548 |
+
|
549 |
+
examples = load_examples()
|
550 |
+
|
551 |
+
# inputs and targets are [batch_size, height, width, channels]
|
552 |
+
model = create_model(examples.inputs, examples.targets)
|
553 |
+
|
554 |
+
inputs = deprocess(examples.inputs)
|
555 |
+
targets = deprocess(examples.targets)
|
556 |
+
|
557 |
+
outputs = deprocess(model.outputs)
|
558 |
+
|
559 |
+
def convert(image):
|
560 |
+
return tf.image.convert_image_dtype(image, dtype=tf.uint8, saturate=True)
|
561 |
+
|
562 |
+
# reverse any processing on images so they can be written to disk or displayed to user
|
563 |
+
with tf.name_scope("convert_inputs"):
|
564 |
+
converted_inputs = convert(inputs)
|
565 |
+
|
566 |
+
with tf.name_scope("convert_targets"):
|
567 |
+
converted_targets = convert(targets)
|
568 |
+
|
569 |
+
with tf.name_scope("convert_outputs"):
|
570 |
+
converted_outputs = convert(outputs)
|
571 |
+
|
572 |
+
with tf.name_scope("encode_images"):
|
573 |
+
display_fetches = {
|
574 |
+
"paths": examples.paths,
|
575 |
+
"inputs": tf.map_fn(tf.image.encode_png, converted_inputs, dtype=tf.string, name="input_pngs"),
|
576 |
+
"targets": tf.map_fn(tf.image.encode_png, converted_targets, dtype=tf.string, name="target_pngs"),
|
577 |
+
"outputs": tf.map_fn(tf.image.encode_png, converted_outputs, dtype=tf.string, name="output_pngs"),
|
578 |
+
}
|
579 |
+
|
580 |
+
|
581 |
+
with tf.name_scope("parameter_count"):
|
582 |
+
parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()])
|
583 |
+
|
584 |
+
saver = tf.train.Saver(max_to_keep=1)
|
585 |
+
|
586 |
+
logdir = a.output_dir if (a.trace_freq > 0 or a.summary_freq > 0) else None
|
587 |
+
sv = tf.train.Supervisor(logdir=logdir, save_summaries_secs=0, saver=None)
|
588 |
+
|
589 |
+
with sv.managed_session() as sess:
|
590 |
+
print("parameter_count =", sess.run(parameter_count))
|
591 |
+
|
592 |
+
if a.checkpoint is not None:
|
593 |
+
print("loading model from checkpoint")
|
594 |
+
print(a.checkpoint)
|
595 |
+
checkpoint = tf.train.latest_checkpoint(a.checkpoint)
|
596 |
+
saver.restore(sess, checkpoint)
|
597 |
+
|
598 |
+
max_steps = 2 ** 32
|
599 |
+
|
600 |
+
if a.max_epochs is not None:
|
601 |
+
max_steps = examples.steps_per_epoch * a.max_epochs
|
602 |
+
if a.max_steps is not None:
|
603 |
+
max_steps = a.max_steps
|
604 |
+
|
605 |
+
if a.mode == "test":
|
606 |
+
# testing
|
607 |
+
# at most, process the test data once
|
608 |
+
start = time.time()
|
609 |
+
max_steps = min(examples.steps_per_epoch, max_steps)
|
610 |
+
for step in range(max_steps):
|
611 |
+
results = sess.run(display_fetches)
|
612 |
+
filesets = save_images(results)
|
613 |
+
for i, f in enumerate(filesets):
|
614 |
+
print("evaluated image", f["name"])
|
615 |
+
index_path = append_index(filesets)
|
616 |
+
print("wrote index at", index_path)
|
617 |
+
print("rate", (time.time() - start) / max_steps)
|
618 |
+
else:
|
619 |
+
# training
|
620 |
+
start = time.time()
|
621 |
+
|
622 |
+
for step in range(max_steps):
|
623 |
+
def should(freq):
|
624 |
+
return freq > 0 and ((step + 1) % freq == 0 or step == max_steps - 1)
|
625 |
+
|
626 |
+
options = None
|
627 |
+
run_metadata = None
|
628 |
+
if should(a.trace_freq):
|
629 |
+
options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
|
630 |
+
run_metadata = tf.RunMetadata()
|
631 |
+
|
632 |
+
fetches = {
|
633 |
+
"train": model.train,
|
634 |
+
"global_step": sv.global_step,
|
635 |
+
}
|
636 |
+
|
637 |
+
if should(a.progress_freq):
|
638 |
+
fetches["discrim_loss"] = model.discrim_loss
|
639 |
+
fetches["gen_loss_GAN"] = model.gen_loss_GAN
|
640 |
+
fetches["gen_loss_L1"] = model.gen_loss_L1
|
641 |
+
|
642 |
+
if should(a.summary_freq):
|
643 |
+
fetches["summary"] = sv.summary_op
|
644 |
+
|
645 |
+
if should(a.display_freq):
|
646 |
+
fetches["display"] = display_fetches
|
647 |
+
|
648 |
+
results = sess.run(fetches, options=options, run_metadata=run_metadata)
|
649 |
+
|
650 |
+
if should(a.summary_freq):
|
651 |
+
print("recording summary")
|
652 |
+
sv.summary_writer.add_summary(results["summary"], results["global_step"])
|
653 |
+
|
654 |
+
if should(a.display_freq):
|
655 |
+
print("saving display images")
|
656 |
+
filesets = save_images(results["display"], step=results["global_step"])
|
657 |
+
append_index(filesets, step=True)
|
658 |
+
|
659 |
+
if should(a.trace_freq):
|
660 |
+
print("recording trace")
|
661 |
+
sv.summary_writer.add_run_metadata(run_metadata, "step_%d" % results["global_step"])
|
662 |
+
|
663 |
+
if should(a.progress_freq):
|
664 |
+
# global_step will have the correct step count if we resume from a checkpoint
|
665 |
+
train_epoch = math.ceil(results["global_step"] / examples.steps_per_epoch)
|
666 |
+
train_step = (results["global_step"] - 1) % examples.steps_per_epoch + 1
|
667 |
+
rate = (step + 1) * a.batch_size / (time.time() - start)
|
668 |
+
remaining = (max_steps - step) * a.batch_size / rate
|
669 |
+
print("progress epoch %d step %d image/sec %0.1f remaining %dm" % (
|
670 |
+
train_epoch, train_step, rate, remaining / 60))
|
671 |
+
print("discrim_loss", results["discrim_loss"])
|
672 |
+
print("gen_loss_GAN", results["gen_loss_GAN"])
|
673 |
+
print("gen_loss_L1", results["gen_loss_L1"])
|
674 |
+
|
675 |
+
if should(a.save_freq):
|
676 |
+
print("saving model")
|
677 |
+
saver.save(sess, os.path.join(a.output_dir, "model"), global_step=sv.global_step)
|
678 |
+
|
679 |
+
if sv.should_stop():
|
680 |
+
break
|
681 |
+
|
682 |
+
|
683 |
+
main()
|
tools/process.py
ADDED
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import tensorflow.compat.v1 as tf
|
6 |
+
tf.disable_v2_behavior()
|
7 |
+
|
8 |
+
import argparse
|
9 |
+
import os
|
10 |
+
import tempfile
|
11 |
+
import subprocess
|
12 |
+
#import tensorflow as tf
|
13 |
+
import numpy as np
|
14 |
+
import tfimage as im
|
15 |
+
import threading
|
16 |
+
import time
|
17 |
+
import multiprocessing
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
edge_pool = None
|
22 |
+
|
23 |
+
|
24 |
+
parser = argparse.ArgumentParser()
|
25 |
+
parser.add_argument("--input_dir", required=True, help="path to folder containing images")
|
26 |
+
parser.add_argument("--output_dir", required=True, help="output path")
|
27 |
+
parser.add_argument("--operation", required=True, choices=["grayscale", "resize", "blank", "combine", "edges"])
|
28 |
+
parser.add_argument("--workers", type=int, default=1, help="number of workers")
|
29 |
+
# resize
|
30 |
+
parser.add_argument("--pad", action="store_true", help="pad instead of crop for resize operation")
|
31 |
+
parser.add_argument("--size", type=int, default=256, help="size to use for resize operation")
|
32 |
+
# combine
|
33 |
+
parser.add_argument("--b_dir", type=str, help="path to folder containing B images for combine operation")
|
34 |
+
a = parser.parse_args()
|
35 |
+
|
36 |
+
|
37 |
+
def resize(src):
|
38 |
+
height, width, _ = src.shape
|
39 |
+
dst = src
|
40 |
+
if height != width:
|
41 |
+
if a.pad:
|
42 |
+
size = max(height, width)
|
43 |
+
# pad to correct ratio
|
44 |
+
oh = (size - height) // 2
|
45 |
+
ow = (size - width) // 2
|
46 |
+
dst = im.pad(image=dst, offset_height=oh, offset_width=ow, target_height=size, target_width=size)
|
47 |
+
else:
|
48 |
+
# crop to correct ratio
|
49 |
+
size = min(height, width)
|
50 |
+
oh = (height - size) // 2
|
51 |
+
ow = (width - size) // 2
|
52 |
+
dst = im.crop(image=dst, offset_height=oh, offset_width=ow, target_height=size, target_width=size)
|
53 |
+
|
54 |
+
assert(dst.shape[0] == dst.shape[1])
|
55 |
+
|
56 |
+
size, _, _ = dst.shape
|
57 |
+
if size > a.size:
|
58 |
+
dst = im.downscale(images=dst, size=[a.size, a.size])
|
59 |
+
elif size < a.size:
|
60 |
+
dst = im.upscale(images=dst, size=[a.size, a.size])
|
61 |
+
return dst
|
62 |
+
|
63 |
+
|
64 |
+
def blank(src):
|
65 |
+
height, width, _ = src.shape
|
66 |
+
if height != width:
|
67 |
+
raise Exception("non-square image")
|
68 |
+
|
69 |
+
image_size = width
|
70 |
+
size = int(image_size * 0.3)
|
71 |
+
offset = int(image_size / 2 - size / 2)
|
72 |
+
|
73 |
+
dst = src
|
74 |
+
dst[offset:offset + size,offset:offset + size,:] = np.ones([size, size, 3])
|
75 |
+
return dst
|
76 |
+
|
77 |
+
|
78 |
+
def combine(src, src_path):
|
79 |
+
if a.b_dir is None:
|
80 |
+
raise Exception("missing b_dir")
|
81 |
+
|
82 |
+
# find corresponding file in b_dir, could have a different extension
|
83 |
+
basename, _ = os.path.splitext(os.path.basename(src_path))
|
84 |
+
for ext in [".png", ".jpg"]:
|
85 |
+
sibling_path = os.path.join(a.b_dir, basename + ext)
|
86 |
+
if os.path.exists(sibling_path):
|
87 |
+
sibling = im.load(sibling_path)
|
88 |
+
break
|
89 |
+
else:
|
90 |
+
raise Exception("could not find sibling image for " + src_path)
|
91 |
+
|
92 |
+
# make sure that dimensions are correct
|
93 |
+
height, width, _ = src.shape
|
94 |
+
if height != sibling.shape[0] or width != sibling.shape[1]:
|
95 |
+
raise Exception("differing sizes")
|
96 |
+
|
97 |
+
# convert both images to RGB if necessary
|
98 |
+
if src.shape[2] == 1:
|
99 |
+
src = im.grayscale_to_rgb(images=src)
|
100 |
+
|
101 |
+
if sibling.shape[2] == 1:
|
102 |
+
sibling = im.grayscale_to_rgb(images=sibling)
|
103 |
+
|
104 |
+
# remove alpha channel
|
105 |
+
if src.shape[2] == 4:
|
106 |
+
src = src[:,:,:3]
|
107 |
+
|
108 |
+
if sibling.shape[2] == 4:
|
109 |
+
sibling = sibling[:,:,:3]
|
110 |
+
|
111 |
+
return np.concatenate([src, sibling], axis=1)
|
112 |
+
|
113 |
+
|
114 |
+
def grayscale(src):
|
115 |
+
return im.grayscale_to_rgb(images=im.rgb_to_grayscale(images=src))
|
116 |
+
|
117 |
+
|
118 |
+
net = None
|
119 |
+
def run_caffe(src):
|
120 |
+
# lazy load caffe and create net
|
121 |
+
global net
|
122 |
+
if net is None:
|
123 |
+
# don't require caffe unless we are doing edge detection
|
124 |
+
os.environ["GLOG_minloglevel"] = "2" # disable logging from caffe
|
125 |
+
import caffe
|
126 |
+
# using this requires using the docker image or assembling a bunch of dependencies
|
127 |
+
# and then changing these hardcoded paths
|
128 |
+
net = caffe.Net("/opt/caffe/examples/hed/deploy.prototxt", "/opt/caffe/hed_pretrained_bsds.caffemodel", caffe.TEST)
|
129 |
+
|
130 |
+
net.blobs["data"].reshape(1, *src.shape)
|
131 |
+
net.blobs["data"].data[...] = src
|
132 |
+
net.forward()
|
133 |
+
return net.blobs["sigmoid-fuse"].data[0][0,:,:]
|
134 |
+
|
135 |
+
|
136 |
+
def edges(src):
|
137 |
+
# based on https://github.com/phillipi/pix2pix/blob/master/scripts/edges/batch_hed.py
|
138 |
+
# and https://github.com/phillipi/pix2pix/blob/master/scripts/edges/PostprocessHED.m
|
139 |
+
import scipy.io
|
140 |
+
src = src * 255
|
141 |
+
border = 128 # put a padding around images since edge detection seems to detect edge of image
|
142 |
+
src = src[:,:,:3] # remove alpha channel if present
|
143 |
+
src = np.pad(src, ((border, border), (border, border), (0,0)), "reflect")
|
144 |
+
src = src[:,:,::-1]
|
145 |
+
src -= np.array((104.00698793,116.66876762,122.67891434))
|
146 |
+
src = src.transpose((2, 0, 1))
|
147 |
+
|
148 |
+
# [height, width, channels] => [batch, channel, height, width]
|
149 |
+
fuse = edge_pool.apply(run_caffe, [src])
|
150 |
+
fuse = fuse[border:-border, border:-border]
|
151 |
+
|
152 |
+
with tempfile.NamedTemporaryFile(suffix=".png") as png_file, tempfile.NamedTemporaryFile(suffix=".mat") as mat_file:
|
153 |
+
scipy.io.savemat(mat_file.name, {"input": fuse})
|
154 |
+
|
155 |
+
octave_code = r"""
|
156 |
+
E = 1-load(input_path).input;
|
157 |
+
E = imresize(E, [image_width,image_width]);
|
158 |
+
E = 1 - E;
|
159 |
+
E = single(E);
|
160 |
+
[Ox, Oy] = gradient(convTri(E, 4), 1);
|
161 |
+
[Oxx, ~] = gradient(Ox, 1);
|
162 |
+
[Oxy, Oyy] = gradient(Oy, 1);
|
163 |
+
O = mod(atan(Oyy .* sign(-Oxy) ./ (Oxx + 1e-5)), pi);
|
164 |
+
E = edgesNmsMex(E, O, 1, 5, 1.01, 1);
|
165 |
+
E = double(E >= max(eps, threshold));
|
166 |
+
E = bwmorph(E, 'thin', inf);
|
167 |
+
E = bwareaopen(E, small_edge);
|
168 |
+
E = 1 - E;
|
169 |
+
E = uint8(E * 255);
|
170 |
+
imwrite(E, output_path);
|
171 |
+
"""
|
172 |
+
|
173 |
+
config = dict(
|
174 |
+
input_path="'%s'" % mat_file.name,
|
175 |
+
output_path="'%s'" % png_file.name,
|
176 |
+
image_width=256,
|
177 |
+
threshold=25.0/255.0,
|
178 |
+
small_edge=5,
|
179 |
+
)
|
180 |
+
|
181 |
+
args = ["octave"]
|
182 |
+
for k, v in config.items():
|
183 |
+
args.extend(["--eval", "%s=%s;" % (k, v)])
|
184 |
+
|
185 |
+
args.extend(["--eval", octave_code])
|
186 |
+
try:
|
187 |
+
subprocess.check_output(args, stderr=subprocess.STDOUT)
|
188 |
+
except subprocess.CalledProcessError as e:
|
189 |
+
print("octave failed")
|
190 |
+
print("returncode:", e.returncode)
|
191 |
+
print("output:", e.output)
|
192 |
+
raise
|
193 |
+
return im.load(png_file.name)
|
194 |
+
|
195 |
+
|
196 |
+
def process(src_path, dst_path):
|
197 |
+
src = im.load(src_path)
|
198 |
+
|
199 |
+
if a.operation == "grayscale":
|
200 |
+
dst = grayscale(src)
|
201 |
+
elif a.operation == "resize":
|
202 |
+
dst = resize(src)
|
203 |
+
elif a.operation == "blank":
|
204 |
+
dst = blank(src)
|
205 |
+
elif a.operation == "combine":
|
206 |
+
dst = combine(src, src_path)
|
207 |
+
elif a.operation == "edges":
|
208 |
+
dst = edges(src)
|
209 |
+
else:
|
210 |
+
raise Exception("invalid operation")
|
211 |
+
|
212 |
+
im.save(dst, dst_path)
|
213 |
+
|
214 |
+
|
215 |
+
complete_lock = threading.Lock()
|
216 |
+
start = None
|
217 |
+
num_complete = 0
|
218 |
+
total = 0
|
219 |
+
|
220 |
+
def complete():
|
221 |
+
global num_complete, rate, last_complete
|
222 |
+
|
223 |
+
with complete_lock:
|
224 |
+
num_complete += 1
|
225 |
+
now = time.time()
|
226 |
+
elapsed = now - start
|
227 |
+
rate = num_complete / elapsed
|
228 |
+
if rate > 0:
|
229 |
+
remaining = (total - num_complete) / rate
|
230 |
+
else:
|
231 |
+
remaining = 0
|
232 |
+
|
233 |
+
print("%d/%d complete %0.2f images/sec %dm%ds elapsed %dm%ds remaining" % (num_complete, total, rate, elapsed // 60, elapsed % 60, remaining // 60, remaining % 60))
|
234 |
+
|
235 |
+
last_complete = now
|
236 |
+
|
237 |
+
|
238 |
+
def main():
|
239 |
+
if not os.path.exists(a.output_dir):
|
240 |
+
os.makedirs(a.output_dir)
|
241 |
+
|
242 |
+
src_paths = []
|
243 |
+
dst_paths = []
|
244 |
+
|
245 |
+
skipped = 0
|
246 |
+
for src_path in im.find(a.input_dir):
|
247 |
+
name, _ = os.path.splitext(os.path.basename(src_path))
|
248 |
+
dst_path = os.path.join(a.output_dir, name + ".png")
|
249 |
+
if os.path.exists(dst_path):
|
250 |
+
skipped += 1
|
251 |
+
else:
|
252 |
+
src_paths.append(src_path)
|
253 |
+
dst_paths.append(dst_path)
|
254 |
+
|
255 |
+
print("skipping %d files that already exist" % skipped)
|
256 |
+
|
257 |
+
global total
|
258 |
+
total = len(src_paths)
|
259 |
+
|
260 |
+
print("processing %d files" % total)
|
261 |
+
|
262 |
+
global start
|
263 |
+
start = time.time()
|
264 |
+
|
265 |
+
if a.operation == "edges":
|
266 |
+
# use a multiprocessing pool for this operation so it can use multiple CPUs
|
267 |
+
# create the pool before we launch processing threads
|
268 |
+
global edge_pool
|
269 |
+
edge_pool = multiprocessing.Pool(a.workers)
|
270 |
+
|
271 |
+
if a.workers == 1:
|
272 |
+
with tf.Session() as sess:
|
273 |
+
for src_path, dst_path in zip(src_paths, dst_paths):
|
274 |
+
process(src_path, dst_path)
|
275 |
+
complete()
|
276 |
+
else:
|
277 |
+
queue = tf.train.input_producer(zip(src_paths, dst_paths), shuffle=False, num_epochs=1)
|
278 |
+
dequeue_op = queue.dequeue()
|
279 |
+
|
280 |
+
def worker(coord):
|
281 |
+
with sess.as_default():
|
282 |
+
while not coord.should_stop():
|
283 |
+
try:
|
284 |
+
src_path, dst_path = sess.run(dequeue_op)
|
285 |
+
except tf.errors.OutOfRangeError:
|
286 |
+
coord.request_stop()
|
287 |
+
break
|
288 |
+
|
289 |
+
process(src_path, dst_path)
|
290 |
+
complete()
|
291 |
+
|
292 |
+
# init epoch counter for the queue
|
293 |
+
local_init_op = tf.local_variables_initializer()
|
294 |
+
with tf.Session() as sess:
|
295 |
+
sess.run(local_init_op)
|
296 |
+
|
297 |
+
coord = tf.train.Coordinator()
|
298 |
+
threads = tf.train.start_queue_runners(coord=coord)
|
299 |
+
for i in range(a.workers):
|
300 |
+
t = threading.Thread(target=worker, args=(coord,))
|
301 |
+
t.start()
|
302 |
+
threads.append(t)
|
303 |
+
|
304 |
+
try:
|
305 |
+
coord.join(threads)
|
306 |
+
except KeyboardInterrupt:
|
307 |
+
coord.request_stop()
|
308 |
+
coord.join(threads)
|
309 |
+
|
310 |
+
main()
|