from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow.compat.v1 as tf import numpy as np import argparse import os import json import glob import random import collections import math import time # visualize image import matplotlib.pyplot as plt #disable v2 behavious tf.disable_v2_behavior() # enable eager execution # tf.compat.v1.enable_eager_execution() parser = argparse.ArgumentParser() parser.add_argument("--input_dir", help="path to folder containing images") parser.add_argument("--mode", required=True, choices=["train", "test", "export"]) parser.add_argument("--output_dir", required=True, help="where to put output files") parser.add_argument("--seed", type=int) parser.add_argument("--checkpoint", default=None, help="directory with checkpoint to resume training from or use for testing") parser.add_argument("--max_steps", type=int, help="number of training steps (0 to disable)") parser.add_argument("--max_epochs", type=int, help="number of training epochs") parser.add_argument("--summary_freq", type=int, default=10000, help="update summaries every summary_freq steps") parser.add_argument("--progress_freq", type=int, default=50, help="display progress every progress_freq steps") parser.add_argument("--trace_freq", type=int, default=0, help="trace execution every trace_freq steps") parser.add_argument("--display_freq", type=int, default=1000, help="write current training images every display_freq steps") parser.add_argument("--save_freq", type=int, default=1000, help="save model every save_freq steps, 0 to disable") parser.add_argument("--separable_conv", action="store_true", help="use separable convolutions in the generator") parser.add_argument("--aspect_ratio", type=float, default=1.0, help="aspect ratio of output images (width/height)") parser.add_argument("--batch_size", type=int, default=1, help="number of images in batch") parser.add_argument("--which_direction", type=str, default="BtoA", choices=["AtoB", "BtoA"]) parser.add_argument("--ngf", type=int, default=64, help="number of generator filters in first conv layer") parser.add_argument("--ndf", type=int, default=64, help="number of discriminator filters in first conv layer") parser.add_argument("--scale_size", type=int, default=728, help="scale images to this size before cropping to 256x256") parser.add_argument("--flip", dest="flip", action="store_true", help="flip images horizontally") parser.add_argument("--no_flip", dest="flip", action="store_false", help="don't flip images horizontally") # parser.set_defaults(flip=True) parser.add_argument("--lr", type=float, default=0.0001, help="initial learning rate for adam") parser.add_argument("--beta1", type=float, default=0.5, help="momentum term of adam") parser.add_argument("--l1_weight", type=float, default=100.0, help="weight on L1 term for generator gradient") parser.add_argument("--gan_weight", type=float, default=1.0, help="weight on GAN term for generator gradient") # export options parser.add_argument("--output_filetype", default="png", choices=["png", "jpeg"]) a = parser.parse_args() EPS = 1e-12 Examples = collections.namedtuple("Examples", "paths, inputs, targets, count, steps_per_epoch") Model = collections.namedtuple("Model", "outputs, predict_real, predict_fake, discrim_loss, discrim_grads_and_vars, gen_loss_GAN, gen_loss_L1, gen_grads_and_vars, train") def preprocess(image): with tf.name_scope("preprocess"): # [0, 1] => [-1, 1] return image * 2 - 1 def deprocess(image): with tf.name_scope("deprocess"): # [-1, 1] => [0, 1] return (image + 1) / 2 def discrim_conv(batch_input, out_channels, stride): padded_input = tf.pad(batch_input, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="CONSTANT") return tf.layers.conv2d(padded_input, out_channels, kernel_size=4, strides=(stride, stride), padding="valid", kernel_initializer=tf.random_normal_initializer(0, 0.02)) def gen_conv(batch_input, out_channels): # [batch, in_height, in_width, in_channels] => [batch, out_height, out_width, out_channels] initializer = tf.random_normal_initializer(0, 0.02) if a.separable_conv: return tf.layers.separable_conv2d(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same", depthwise_initializer=initializer, pointwise_initializer=initializer) else: return tf.layers.conv2d(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same", kernel_initializer=initializer) def gen_deconv(batch_input, out_channels): # [batch, in_height, in_width, in_channels] => [batch, out_height, out_width, out_channels] initializer = tf.random_normal_initializer(0, 0.02) if a.separable_conv: _b, h, w, _c = batch_input.shape resized_input = tf.image.resize_images(batch_input, [h * 2, w * 2], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) return tf.layers.separable_conv2d(resized_input, out_channels, kernel_size=4, strides=(1, 1), padding="same", depthwise_initializer=initializer, pointwise_initializer=initializer) else: return tf.layers.conv2d_transpose(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same", kernel_initializer=initializer) def lrelu(x, a): with tf.name_scope("lrelu"): # adding these together creates the leak part and linear part # then cancels them out by subtracting/adding an absolute value term # leak: a*x/2 - a*abs(x)/2 # linear: x/2 + abs(x)/2 # this block looks like it has 2 inputs on the graph unless we do this x = tf.identity(x) return (0.5 * (1 + a)) * x + (0.5 * (1 - a)) * tf.abs(x) def batchnorm(inputs): return tf.layers.batch_normalization(inputs, axis=3, epsilon=1e-5, momentum=0.1, training=True, gamma_initializer=tf.random_normal_initializer(1.0, 0.02)) def check_image(image): assertion = tf.assert_equal(tf.shape(image)[-1], 3, message="image must have 3 color channels") with tf.control_dependencies([assertion]): image = tf.identity(image) if image.get_shape().ndims not in (3, 4): raise ValueError("image must be either 3 or 4 dimensions") # make the last dimension 3 so that you can unstack the colors shape = list(image.get_shape()) shape[-1] = 3 image.set_shape(shape) return image def load_examples(): if a.input_dir is None or not os.path.exists(a.input_dir): raise Exception("input_dir does not exist") input_paths = glob.glob(os.path.join(a.input_dir, "*.jpg")) decode = tf.image.decode_jpeg if len(input_paths) == 0: input_paths = glob.glob(os.path.join(a.input_dir, "*.png")) decode = tf.image.decode_png if len(input_paths) == 0: raise Exception("input_dir contains no image files") def get_name(path): name, _ = os.path.splitext(os.path.basename(path)) return name # if the image names are numbers, sort by the value rather than asciibetically # having sorted inputs means that the outputs are sorted in test mode if all(get_name(path).isdigit() for path in input_paths): input_paths = sorted(input_paths, key=lambda path: int(get_name(path))) else: input_paths = sorted(input_paths) with tf.name_scope("load_images"): path_queue = tf.train.string_input_producer(input_paths, shuffle=a.mode == "train") reader = tf.WholeFileReader() paths, contents = reader.read(path_queue) raw_input = decode(contents) raw_input = tf.image.convert_image_dtype(raw_input, dtype=tf.float32) assertion = tf.assert_equal(tf.shape(raw_input)[2], 3, message="image does not have 3 channels") with tf.control_dependencies([assertion]): raw_input = tf.identity(raw_input) raw_input.set_shape([None, None, 3]) # break apart image pair and move to range [-1, 1] width = tf.shape(raw_input)[1] # [height, width, channels] a_images = preprocess(raw_input[:, :width // 2, :]) b_images = preprocess(raw_input[:, width // 2:, :]) if a.which_direction == "AtoB": inputs, targets = [a_images, b_images] elif a.which_direction == "BtoA": inputs, targets = [b_images, a_images] else: raise Exception("invalid direction") # synchronize seed for image operations so that we do the same operations to both # input and output images def transform(image): r = image r.set_shape([a.scale_size,a.scale_size,3]) #r = tf.image.resize_images(r, [a.scale_size, a.scale_size], method=tf.image.ResizeMethod.AREA) return r with tf.name_scope("input_images"): input_images = transform(inputs) with tf.name_scope("target_images"): target_images = transform(targets) paths_batch, inputs_batch, targets_batch = tf.train.batch([paths, input_images, target_images], batch_size=a.batch_size) steps_per_epoch = int(math.ceil(len(input_paths) / a.batch_size)) return Examples( paths=paths_batch, inputs=inputs_batch, targets=targets_batch, count=len(input_paths), steps_per_epoch=steps_per_epoch, ) def create_generator(generator_inputs, generator_outputs_channels): layers = [] #Add Filter to detect edges filter_shape = [41,41,3,a.ngf] with tf.variable_scope("encoder_1"): filter = tf.get_variable('edge_detector', filter_shape, initializer=tf.random_normal_initializer(stddev=0.02)) strides = [1, 1, 1, 1] output = tf.nn.conv2d(generator_inputs, filter, strides=strides, padding='VALID') output = lrelu(output, 0.2) layers.append(output) # encoder_1: [batch, 256, 256, in_channels] => [batch, 128, 128, ngf] with tf.variable_scope("encoder_1"): output = gen_conv(output, a.ngf) layers.append(output) layer_specs = [ a.ngf * 2, # encoder_2: [batch, 128, 128, ngf] => [batch, 64, 64, ngf * 2] a.ngf * 4, # encoder_3: [batch, 64, 64, ngf * 2] => [batch, 32, 32, ngf * 4] a.ngf * 8, # encoder_4: [batch, 32, 32, ngf * 4] => [batch, 16, 16, ngf * 8] a.ngf * 8, # encoder_5: [batch, 16, 16, ngf * 8] => [batch, 8, 8, ngf * 8] a.ngf * 8, # encoder_6: [batch, 8, 8, ngf * 8] => [batch, 4, 4, ngf * 8] a.ngf * 8, # encoder_7: [batch, 4, 4, ngf * 8] => [batch, 2, 2, ngf * 8] a.ngf * 8, # encoder_8: [batch, 2, 2, ngf * 8] => [batch, 1, 1, ngf * 8] ] for out_channels in layer_specs: with tf.variable_scope("encoder_%d" % (len(layers) + 1)): rectified = lrelu(layers[-1], 0.2) # [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels] convolved = gen_conv(rectified, out_channels) output = batchnorm(convolved) layers.append(output) layer_specs = [ (a.ngf * 8, 0.5), # decoder_8: [batch, 1, 1, ngf * 8] => [batch, 2, 2, ngf * 8 * 2] (a.ngf * 8, 0.5), # decoder_7: [batch, 2, 2, ngf * 8 * 2] => [batch, 4, 4, ngf * 8 * 2] (a.ngf * 8, 0.5), # decoder_6: [batch, 4, 4, ngf * 8 * 2] => [batch, 8, 8, ngf * 8 * 2] (a.ngf * 8, 0.0), # decoder_5: [batch, 8, 8, ngf * 8 * 2] => [batch, 16, 16, ngf * 8 * 2] (a.ngf * 4, 0.0), # decoder_4: [batch, 16, 16, ngf * 8 * 2] => [batch, 32, 32, ngf * 4 * 2] (a.ngf * 2, 0.0), # decoder_3: [batch, 32, 32, ngf * 4 * 2] => [batch, 64, 64, ngf * 2 * 2] (a.ngf, 0.0), # decoder_2: [batch, 64, 64, ngf * 2 * 2] => [batch, 128, 128, ngf * 2] ] num_encoder_layers = len(layers) for decoder_layer, (out_channels, dropout) in enumerate(layer_specs): skip_layer = num_encoder_layers - decoder_layer - 1 with tf.variable_scope("decoder_%d" % (skip_layer + 1)): if decoder_layer == 0: # first decoder layer doesn't have skip connections # since it is directly connected to the skip_layer input = layers[-1] else: input = tf.concat([layers[-1], layers[skip_layer]], axis=3) #input = layers[-1] rectified = tf.nn.relu(input) # [batch, in_height, in_width, in_channels] => [batch, in_height*2, in_width*2, out_channels] output = gen_deconv(rectified, out_channels) output = batchnorm(output) if dropout > 0.0: output = tf.nn.dropout(output, keep_prob=1 - dropout) layers.append(output) # decoder_1: [batch, 128, 128, ngf * 2] => [batch, 256, 256, generator_outputs_channels] with tf.variable_scope("decoder_1"): input = tf.concat([layers[-1], layers[1]], axis=3) rectified = tf.nn.relu(input) output = gen_deconv(rectified, generator_outputs_channels) output = tf.tanh(output) layers.append(output) return layers[-1] ksize_rows = 296 ksize_cols = 296 strides_rows = 236 strides_cols = 236 num_patches = int(a.scale_size/strides_rows) ksizes = [1, ksize_rows, ksize_cols, 1] ksizes_output = [1, 256, 256, 1] strides = [1, strides_rows, strides_cols, 1] rates = [1, 1, 1, 1] padding='VALID' def extract_patches(x, ksizes, strides, rates): return tf.extract_image_patches( x, ksizes, strides, rates, padding="VALID" ) def extract_patches_inverse(x, y): _x = tf.zeros_like(x) _y = extract_patches(_x, ksizes_output, strides, rates) grad = tf.gradients(_y, _x)[0] # Divide by grad, to "average" together the overlapping patches # otherwise they would simply sum up return tf.gradients(_y, _x, grad_ys=y)[0] / grad def create_model(inputs, targets): out_channels = int(targets.get_shape()[-1]) if(a.mode == "train" or a.scale_size != 296): #scale_size = 296 while testing def create_discriminator(discrim_inputs, discrim_targets): n_layers = 5 layers = [] # 2x [batch, height, width, in_channels] => [batch, height, width, in_channels * 2] input = tf.concat([discrim_inputs, discrim_targets], axis=3) # layer_1: [batch, 256, 256, in_channels * 2] => [batch, 128, 128, ndf] with tf.variable_scope("layer_1"): convolved = discrim_conv(input, a.ndf, stride=2) rectified = lrelu(convolved, 0.2) layers.append(rectified) # layer_2: [batch, 128, 128, ndf] => [batch, 64, 64, ndf * 2] # layer_3: [batch, 64, 64, ndf * 2] => [batch, 32, 32, ndf * 4] # layer_4: [batch, 32, 32, ndf * 4] => [batch, 31, 31, ndf * 8] for i in range(n_layers): with tf.variable_scope("layer_%d" % (len(layers) + 1)): out_channels = a.ndf * min(2 ** (i + 1), 8) stride = 1 if i == n_layers - 1 else 2 # last layer here has stride 1 convolved = discrim_conv(layers[-1], out_channels, stride=stride) normalized = batchnorm(convolved) rectified = lrelu(normalized, 0.2) layers.append(rectified) # layer_5: [batch, 31, 31, ndf * 8] => [batch, 30, 30, 1] with tf.variable_scope("layer_%d" % (len(layers) + 1)): convolved = discrim_conv(rectified, out_channels=1, stride=1) output = tf.sigmoid(convolved) layers.append(output) print(layers) return layers[-1] # Pad inputs to make shape to handle edge patches, so as we can give context as input to generator (context around target patch) inputs_bounded = tf.image.pad_to_bounding_box(inputs, 20, 20, a.scale_size + 40, a.scale_size + 40) #Extract Patches with tf.variable_scope("extract_patches"): inputs_patches = extract_patches(inputs_bounded, ksizes, strides, rates) with tf.name_scope("patches_generator"): with tf.variable_scope("generator", reuse=tf.AUTO_REUSE): output_patches = [] patch_size_len = 256*256*3 #Get output from each patch via same generator for i in range(0, num_patches): for j in range(0, num_patches): patch = inputs_patches[0, i, j,] #patch_size_len = int(patch.get_shape()[0]) #defined above # reshape patch = tf.reshape(patch, [ksize_rows, ksize_cols, 3]) patch = tf.expand_dims(patch,0) patch_output = create_generator(patch, out_channels) output_patches.append(tf.reshape(patch_output,[patch_size_len])) output_patches = tf.stack(output_patches) output_patches = tf.reshape(output_patches, [1, num_patches, num_patches, patch_size_len]) #Stitch all patches back k = tf.constant(0.1, shape=[1, a.scale_size, a.scale_size, 3]) outputs = extract_patches_inverse(k, output_patches) # create two copies of discriminator, one for real pairs and one for fake pairs # they share the same underlying variables with tf.name_scope("real_discriminator"): with tf.variable_scope("discriminator"): # 2x [batch, height, width, channels] => [batch, 30, 30, 1] predict_real = create_discriminator(inputs, targets) with tf.name_scope("fake_discriminator"): with tf.variable_scope("discriminator", reuse=True): # 2x [batch, height, width, channels] => [batch, 30, 30, 1] predict_fake = create_discriminator(inputs, outputs) with tf.name_scope("discriminator_loss"): # minimizing -tf.log will try to get inputs to 1 # predict_real => 1 # predict_fake => 0 discrim_loss = tf.reduce_mean(-(tf.log(tf.clip_by_value((predict_real + EPS),1e-12,1.0)) + tf.log(tf.clip_by_value((1 - predict_fake + EPS),1e-12,1.0)))) with tf.name_scope("generator_loss"): # predict_fake => 1 # abs(targets - outputs) => 0 gen_loss_GAN = tf.reduce_mean(-tf.log(tf.clip_by_value((predict_fake + EPS),1e-12,1.0))) gen_loss_L1 = tf.reduce_mean(tf.abs(targets - outputs)) gen_loss = gen_loss_GAN * a.gan_weight + gen_loss_L1 * a.l1_weight with tf.name_scope("discriminator_train"): discrim_tvars = [var for var in tf.trainable_variables() if var.name.startswith("discriminator")] discrim_optim = tf.train.AdamOptimizer(a.lr, a.beta1) discrim_grads_and_vars = discrim_optim.compute_gradients(discrim_loss, var_list=discrim_tvars) discrim_train = discrim_optim.apply_gradients(discrim_grads_and_vars) with tf.name_scope("generator_train"): with tf.control_dependencies([discrim_train]): gen_tvars = [var for var in tf.trainable_variables() if var.name.startswith("generator")] gen_optim = tf.train.AdamOptimizer(a.lr, a.beta1) gen_grads_and_vars = gen_optim.compute_gradients(gen_loss, var_list=gen_tvars) gen_train = gen_optim.apply_gradients(gen_grads_and_vars) ema = tf.train.ExponentialMovingAverage(decay=0.99) update_losses = ema.apply([discrim_loss, gen_loss_GAN, gen_loss_L1]) global_step = tf.train.get_or_create_global_step() incr_global_step = tf.assign(global_step, global_step + 1) discrim_loss = ema.average(discrim_loss) gen_loss_GAN=ema.average(gen_loss_GAN) gen_loss_L1 = ema.average(gen_loss_L1) update_ops = [update_losses, incr_global_step, gen_train] train = tf.group(update_ops) elif (a.mode == "test"): predict_real = None predict_fake = None discrim_loss = None discrim_grads_and_vars = None gen_loss_GAN = None gen_loss_L1 = None gen_grads_and_vars = None train = None with tf.name_scope("patches_generator"): with tf.variable_scope("generator", reuse=tf.AUTO_REUSE): outputs = create_generator(inputs, out_channels) else: print("Give correct mode") exit(0) return Model( predict_real=predict_real, predict_fake=predict_fake, discrim_loss=discrim_loss, discrim_grads_and_vars=discrim_grads_and_vars, gen_loss_GAN=gen_loss_GAN, gen_loss_L1=gen_loss_L1, gen_grads_and_vars=gen_grads_and_vars, outputs=outputs, train=train ) def save_images(fetches, step=None): image_dir = os.path.join(a.output_dir, "images") if not os.path.exists(image_dir): os.makedirs(image_dir) filesets = [] for i, in_path in enumerate(fetches["paths"]): name, _ = os.path.splitext(os.path.basename(in_path.decode("utf8"))) fileset = {"name": name, "step": step} for kind in ["inputs", "outputs", "targets"]: filename = name + "-" + kind + ".png" if step is not None: filename = "%08d-%s" % (step, filename) fileset[kind] = filename out_path = os.path.join(image_dir, filename) contents = fetches[kind][i] with open(out_path, "wb") as f: f.write(contents) filesets.append(fileset) return filesets def append_index(filesets, step=False): index_path = os.path.join(a.output_dir, "index.html") if os.path.exists(index_path): index = open(index_path, "a") else: index = open(index_path, "w") index.write("
step | ") index.write("name | input | output | target |
---|---|---|---|---|
%d | " % fileset["step"]) index.write("%s | " % fileset["name"]) for kind in ["inputs", "outputs", "targets"]: index.write("" % fileset[kind]) index.write(" |