|
@@ -0,0 +1,249 @@
|
|
|
|
+from __future__ import print_function
|
|
|
|
+
|
|
|
|
+import os
|
|
|
|
+import time
|
|
|
|
+import random
|
|
|
|
+
|
|
|
|
+from PIL import Image
|
|
|
|
+import tensorflow as tf
|
|
|
|
+import numpy as np
|
|
|
|
+
|
|
|
|
+from utils import *
|
|
|
|
+
|
|
|
|
+def concat(layers):
|
|
|
|
+ return tf.concat(layers, axis=3)
|
|
|
|
+
|
|
|
|
+def DecomNet(input_im, layer_num, channel=64, kernel_size=3):
|
|
|
|
+ input_max = tf.reduce_max(input_im, axis=3, keepdims=True)
|
|
|
|
+ input_im = concat([input_max, input_im])
|
|
|
|
+ with tf.variable_scope('DecomNet', reuse=tf.AUTO_REUSE):
|
|
|
|
+ conv = tf.layers.conv2d(input_im, channel, kernel_size * 3, padding='same', activation=None, name="shallow_feature_extraction")
|
|
|
|
+ for idx in range(layer_num):
|
|
|
|
+ conv = tf.layers.conv2d(conv, channel, kernel_size, padding='same', activation=tf.nn.relu, name='activated_layer_%d' % idx)
|
|
|
|
+ conv = tf.layers.conv2d(conv, 4, kernel_size, padding='same', activation=None, name='recon_layer')
|
|
|
|
+
|
|
|
|
+ R = tf.sigmoid(conv[:,:,:,0:3])
|
|
|
|
+ L = tf.sigmoid(conv[:,:,:,3:4])
|
|
|
|
+
|
|
|
|
+ return R, L
|
|
|
|
+
|
|
|
|
+def RelightNet(input_L, input_R, channel=64, kernel_size=3):
|
|
|
|
+ input_im = concat([input_R, input_L])
|
|
|
|
+ with tf.variable_scope('RelightNet'):
|
|
|
|
+ conv0 = tf.layers.conv2d(input_im, channel, kernel_size, padding='same', activation=None)
|
|
|
|
+ conv1 = tf.layers.conv2d(conv0, channel, kernel_size, strides=2, padding='same', activation=tf.nn.relu)
|
|
|
|
+ conv2 = tf.layers.conv2d(conv1, channel, kernel_size, strides=2, padding='same', activation=tf.nn.relu)
|
|
|
|
+ conv3 = tf.layers.conv2d(conv2, channel, kernel_size, strides=2, padding='same', activation=tf.nn.relu)
|
|
|
|
+
|
|
|
|
+ up1 = tf.image.resize_nearest_neighbor(conv3, (tf.shape(conv2)[1], tf.shape(conv2)[2]))
|
|
|
|
+ deconv1 = tf.layers.conv2d(up1, channel, kernel_size, padding='same', activation=tf.nn.relu) + conv2
|
|
|
|
+ up2 = tf.image.resize_nearest_neighbor(deconv1, (tf.shape(conv1)[1], tf.shape(conv1)[2]))
|
|
|
|
+ deconv2= tf.layers.conv2d(up2, channel, kernel_size, padding='same', activation=tf.nn.relu) + conv1
|
|
|
|
+ up3 = tf.image.resize_nearest_neighbor(deconv2, (tf.shape(conv0)[1], tf.shape(conv0)[2]))
|
|
|
|
+ deconv3 = tf.layers.conv2d(up3, channel, kernel_size, padding='same', activation=tf.nn.relu) + conv0
|
|
|
|
+
|
|
|
|
+ deconv1_resize = tf.image.resize_nearest_neighbor(deconv1, (tf.shape(deconv3)[1], tf.shape(deconv3)[2]))
|
|
|
|
+ deconv2_resize = tf.image.resize_nearest_neighbor(deconv2, (tf.shape(deconv3)[1], tf.shape(deconv3)[2]))
|
|
|
|
+ feature_gather = concat([deconv1_resize, deconv2_resize, deconv3])
|
|
|
|
+ feature_fusion = tf.layers.conv2d(feature_gather, channel, 1, padding='same', activation=None)
|
|
|
|
+ output = tf.layers.conv2d(feature_fusion, 1, 3, padding='same', activation=None)
|
|
|
|
+ return output
|
|
|
|
+
|
|
|
|
+class lowlight_enhance(object):
|
|
|
|
+ def __init__(self, sess):
|
|
|
|
+ self.sess = sess
|
|
|
|
+ self.DecomNet_layer_num = 5
|
|
|
|
+
|
|
|
|
+ # build the model
|
|
|
|
+ self.input_low = tf.placeholder(tf.float32, [None, None, None, 3], name='input_low')
|
|
|
|
+ self.input_high = tf.placeholder(tf.float32, [None, None, None, 3], name='input_high')
|
|
|
|
+
|
|
|
|
+ [R_low, I_low] = DecomNet(self.input_low, layer_num=self.DecomNet_layer_num)
|
|
|
|
+ [R_high, I_high] = DecomNet(self.input_high, layer_num=self.DecomNet_layer_num)
|
|
|
|
+
|
|
|
|
+ I_delta = RelightNet(I_low, R_low)
|
|
|
|
+
|
|
|
|
+ I_low_3 = concat([I_low, I_low, I_low])
|
|
|
|
+ I_high_3 = concat([I_high, I_high, I_high])
|
|
|
|
+ I_delta_3 = concat([I_delta, I_delta, I_delta])
|
|
|
|
+
|
|
|
|
+ self.output_R_low = R_low
|
|
|
|
+ self.output_I_low = I_low_3
|
|
|
|
+ self.output_I_delta = I_delta_3
|
|
|
|
+ self.output_S = R_low * I_delta_3
|
|
|
|
+
|
|
|
|
+ # loss
|
|
|
|
+ self.recon_loss_low = tf.reduce_mean(tf.abs(R_low * I_low_3 - self.input_low))
|
|
|
|
+ self.recon_loss_high = tf.reduce_mean(tf.abs(R_high * I_high_3 - self.input_high))
|
|
|
|
+ self.recon_loss_mutal_low = tf.reduce_mean(tf.abs(R_high * I_low_3 - self.input_low))
|
|
|
|
+ self.recon_loss_mutal_high = tf.reduce_mean(tf.abs(R_low * I_high_3 - self.input_high))
|
|
|
|
+ self.equal_R_loss = tf.reduce_mean(tf.abs(R_low - R_high))
|
|
|
|
+ self.relight_loss = tf.reduce_mean(tf.abs(R_low * I_delta_3 - self.input_high))
|
|
|
|
+
|
|
|
|
+ self.Ismooth_loss_low = self.smooth(I_low, R_low)
|
|
|
|
+ self.Ismooth_loss_high = self.smooth(I_high, R_high)
|
|
|
|
+ self.Ismooth_loss_delta = self.smooth(I_delta, R_low)
|
|
|
|
+
|
|
|
|
+ self.loss_Decom = self.recon_loss_low + self.recon_loss_high + 0.001 * self.recon_loss_mutal_low + 0.001 * self.recon_loss_mutal_high + 0.1 * self.Ismooth_loss_low + 0.1 * self.Ismooth_loss_high + 0.01 * self.equal_R_loss
|
|
|
|
+ self.loss_Relight = self.relight_loss + 3 * self.Ismooth_loss_delta
|
|
|
|
+
|
|
|
|
+ self.lr = tf.placeholder(tf.float32, name='learning_rate')
|
|
|
|
+ optimizer = tf.train.AdamOptimizer(self.lr, name='AdamOptimizer')
|
|
|
|
+
|
|
|
|
+ self.var_Decom = [var for var in tf.trainable_variables() if 'DecomNet' in var.name]
|
|
|
|
+ self.var_Relight = [var for var in tf.trainable_variables() if 'RelightNet' in var.name]
|
|
|
|
+
|
|
|
|
+ self.train_op_Decom = optimizer.minimize(self.loss_Decom, var_list = self.var_Decom)
|
|
|
|
+ self.train_op_Relight = optimizer.minimize(self.loss_Relight, var_list = self.var_Relight)
|
|
|
|
+
|
|
|
|
+ self.sess.run(tf.global_variables_initializer())
|
|
|
|
+
|
|
|
|
+ self.saver_Decom = tf.train.Saver(var_list = self.var_Decom)
|
|
|
|
+ self.saver_Relight = tf.train.Saver(var_list = self.var_Relight)
|
|
|
|
+
|
|
|
|
+ print("[*] Initialize model successfully...")
|
|
|
|
+
|
|
|
|
+ def gradient(self, input_tensor, direction):
|
|
|
|
+ self.smooth_kernel_x = tf.reshape(tf.constant([[0, 0], [-1, 1]], tf.float32), [2, 2, 1, 1])
|
|
|
|
+ self.smooth_kernel_y = tf.transpose(self.smooth_kernel_x, [1, 0, 2, 3])
|
|
|
|
+
|
|
|
|
+ if direction == "x":
|
|
|
|
+ kernel = self.smooth_kernel_x
|
|
|
|
+ elif direction == "y":
|
|
|
|
+ kernel = self.smooth_kernel_y
|
|
|
|
+ return tf.abs(tf.nn.conv2d(input_tensor, kernel, strides=[1, 1, 1, 1], padding='SAME'))
|
|
|
|
+
|
|
|
|
+ def ave_gradient(self, input_tensor, direction):
|
|
|
|
+ return tf.layers.average_pooling2d(self.gradient(input_tensor, direction), pool_size=3, strides=1, padding='SAME')
|
|
|
|
+
|
|
|
|
+ def smooth(self, input_I, input_R):
|
|
|
|
+ input_R = tf.image.rgb_to_grayscale(input_R)
|
|
|
|
+ return tf.reduce_mean(self.gradient(input_I, "x") * tf.exp(-10 * self.ave_gradient(input_R, "x")) + self.gradient(input_I, "y") * tf.exp(-10 * self.ave_gradient(input_R, "y")))
|
|
|
|
+
|
|
|
|
+ def evaluate(self, epoch_num, eval_low_data, sample_dir, train_phase):
|
|
|
|
+ print("[*] Evaluating for phase %s / epoch %d..." % (train_phase, epoch_num))
|
|
|
|
+
|
|
|
|
+ for idx in range(len(eval_low_data)):
|
|
|
|
+ input_low_eval = np.expand_dims(eval_low_data[idx], axis=0)
|
|
|
|
+
|
|
|
|
+ if train_phase == "Decom":
|
|
|
|
+ result_1, result_2 = self.sess.run([self.output_R_low, self.output_I_low], feed_dict={self.input_low: input_low_eval})
|
|
|
|
+ if train_phase == "Relight":
|
|
|
|
+ result_1, result_2 = self.sess.run([self.output_S, self.output_I_delta], feed_dict={self.input_low: input_low_eval})
|
|
|
|
+
|
|
|
|
+ save_images(os.path.join(sample_dir, 'eval_%s_%d_%d.png' % (train_phase, idx + 1, epoch_num)), result_1, result_2)
|
|
|
|
+
|
|
|
|
+ def train(self, train_low_data, train_high_data, eval_low_data, batch_size, patch_size, epoch, lr, sample_dir, ckpt_dir, eval_every_epoch, train_phase):
|
|
|
|
+ assert len(train_low_data) == len(train_high_data)
|
|
|
|
+ numBatch = len(train_low_data) // int(batch_size)
|
|
|
|
+
|
|
|
|
+ # load pretrained model
|
|
|
|
+ if train_phase == "Decom":
|
|
|
|
+ train_op = self.train_op_Decom
|
|
|
|
+ train_loss = self.loss_Decom
|
|
|
|
+ saver = self.saver_Decom
|
|
|
|
+ elif train_phase == "Relight":
|
|
|
|
+ train_op = self.train_op_Relight
|
|
|
|
+ train_loss = self.loss_Relight
|
|
|
|
+ saver = self.saver_Relight
|
|
|
|
+
|
|
|
|
+ load_model_status, global_step = self.load(saver, ckpt_dir)
|
|
|
|
+ if load_model_status:
|
|
|
|
+ iter_num = global_step
|
|
|
|
+ start_epoch = global_step // numBatch
|
|
|
|
+ start_step = global_step % numBatch
|
|
|
|
+ print("[*] Model restore success!")
|
|
|
|
+ else:
|
|
|
|
+ iter_num = 0
|
|
|
|
+ start_epoch = 0
|
|
|
|
+ start_step = 0
|
|
|
|
+ print("[*] Not find pretrained model!")
|
|
|
|
+
|
|
|
|
+ print("[*] Start training for phase %s, with start epoch %d start iter %d : " % (train_phase, start_epoch, iter_num))
|
|
|
|
+
|
|
|
|
+ start_time = time.time()
|
|
|
|
+ image_id = 0
|
|
|
|
+
|
|
|
|
+ for epoch in range(start_epoch, epoch):
|
|
|
|
+ for batch_id in range(start_step, numBatch):
|
|
|
|
+ # generate data for a batch
|
|
|
|
+ batch_input_low = np.zeros((batch_size, patch_size, patch_size, 3), dtype="float32")
|
|
|
|
+ batch_input_high = np.zeros((batch_size, patch_size, patch_size, 3), dtype="float32")
|
|
|
|
+ for patch_id in range(batch_size):
|
|
|
|
+ h, w, _ = train_low_data[image_id].shape
|
|
|
|
+ x = random.randint(0, h - patch_size)
|
|
|
|
+ y = random.randint(0, w - patch_size)
|
|
|
|
+
|
|
|
|
+ rand_mode = random.randint(0, 7)
|
|
|
|
+ batch_input_low[patch_id, :, :, :] = data_augmentation(train_low_data[image_id][x : x+patch_size, y : y+patch_size, :], rand_mode)
|
|
|
|
+ batch_input_high[patch_id, :, :, :] = data_augmentation(train_high_data[image_id][x : x+patch_size, y : y+patch_size, :], rand_mode)
|
|
|
|
+
|
|
|
|
+ image_id = (image_id + 1) % len(train_low_data)
|
|
|
|
+ if image_id == 0:
|
|
|
|
+ tmp = list(zip(train_low_data, train_high_data))
|
|
|
|
+ random.shuffle(list(tmp))
|
|
|
|
+ train_low_data, train_high_data = zip(*tmp)
|
|
|
|
+
|
|
|
|
+ # train
|
|
|
|
+ _, loss = self.sess.run([train_op, train_loss], feed_dict={self.input_low: batch_input_low, \
|
|
|
|
+ self.input_high: batch_input_high, \
|
|
|
|
+ self.lr: lr[epoch]})
|
|
|
|
+
|
|
|
|
+ print("%s Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.6f" \
|
|
|
|
+ % (train_phase, epoch + 1, batch_id + 1, numBatch, time.time() - start_time, loss))
|
|
|
|
+ iter_num += 1
|
|
|
|
+
|
|
|
|
+ # evalutate the model and save a checkpoint file for it
|
|
|
|
+ if (epoch + 1) % eval_every_epoch == 0:
|
|
|
|
+ self.evaluate(epoch + 1, eval_low_data, sample_dir=sample_dir, train_phase=train_phase)
|
|
|
|
+ self.save(saver, iter_num, ckpt_dir, "RetinexNet-%s" % train_phase)
|
|
|
|
+
|
|
|
|
+ print("[*] Finish training for phase %s." % train_phase)
|
|
|
|
+
|
|
|
|
+ def save(self, saver, iter_num, ckpt_dir, model_name):
|
|
|
|
+ if not os.path.exists(ckpt_dir):
|
|
|
|
+ os.makedirs(ckpt_dir)
|
|
|
|
+ print("[*] Saving model %s" % model_name)
|
|
|
|
+ saver.save(self.sess, \
|
|
|
|
+ os.path.join(ckpt_dir, model_name), \
|
|
|
|
+ global_step=iter_num)
|
|
|
|
+
|
|
|
|
+ def load(self, saver, ckpt_dir):
|
|
|
|
+ ckpt = tf.train.get_checkpoint_state(ckpt_dir)
|
|
|
|
+ if ckpt and ckpt.model_checkpoint_path:
|
|
|
|
+ full_path = tf.train.latest_checkpoint(ckpt_dir)
|
|
|
|
+ try:
|
|
|
|
+ global_step = int(full_path.split('/')[-1].split('-')[-1])
|
|
|
|
+ except ValueError:
|
|
|
|
+ global_step = None
|
|
|
|
+ saver.restore(self.sess, full_path)
|
|
|
|
+ return True, global_step
|
|
|
|
+ else:
|
|
|
|
+ print("[*] Failed to load model from %s" % ckpt_dir)
|
|
|
|
+ return False, 0
|
|
|
|
+
|
|
|
|
+ def test(self, test_low_data, test_high_data, test_low_data_names, save_dir, decom_flag):
|
|
|
|
+ tf.global_variables_initializer().run()
|
|
|
|
+
|
|
|
|
+ print("[*] Reading checkpoint...")
|
|
|
|
+ load_model_status_Decom, _ = self.load(self.saver_Decom, './model/Decom')
|
|
|
|
+ load_model_status_Relight, _ = self.load(self.saver_Relight, './model/Relight')
|
|
|
|
+ if load_model_status_Decom and load_model_status_Relight:
|
|
|
|
+ print("[*] Load weights successfully...")
|
|
|
|
+
|
|
|
|
+ print("[*] Testing...")
|
|
|
|
+ for idx in range(len(test_low_data)):
|
|
|
|
+ print(test_low_data_names[idx])
|
|
|
|
+ [_, name] = os.path.split(test_low_data_names[idx])
|
|
|
|
+ suffix = name[name.find('.') + 1:]
|
|
|
|
+ name = name[:name.find('.')]
|
|
|
|
+
|
|
|
|
+ input_low_test = np.expand_dims(test_low_data[idx], axis=0)
|
|
|
|
+ [R_low, I_low, I_delta, S] = self.sess.run([self.output_R_low, self.output_I_low, self.output_I_delta, self.output_S], feed_dict = {self.input_low: input_low_test})
|
|
|
|
+
|
|
|
|
+ if decom_flag == 1:
|
|
|
|
+ save_images(os.path.join(save_dir, name + "_R_low." + suffix), R_low)
|
|
|
|
+ save_images(os.path.join(save_dir, name + "_I_low." + suffix), I_low)
|
|
|
|
+ save_images(os.path.join(save_dir, name + "_I_delta." + suffix), I_delta)
|
|
|
|
+ save_images(os.path.join(save_dir, name + "_S." + suffix), S)
|
|
|
|
+
|