"""Code for training SIFA."""
from datetime import datetime
import json
from skimage import transform
import random
import os
import cv2
import data_loader, losses, model
from PIL import Image
import tensorflow as tf
from scipy.misc import imsave
import numpy as np
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

save_interval = 300

class SIFA:
    """The SIFA module."""

    def __init__(self, config):
        current_time = datetime.now().strftime("%Y%m%d-%H%M%S")
        self._source_train_pth = config['source_train_pth']
        self._target_train_pth = config['target_train_pth']
        self._source_val_pth = config['source_val_pth']
        self._target_val_pth = config['target_val_pth']
        self._output_root_dir = config['output_root_dir']
        if not os.path.isdir(self._output_root_dir):
            os.makedirs(self._output_root_dir)
        self._output_dir = os.path.join(self._output_root_dir, current_time)
        self._images_dir = os.path.join(self._output_dir, 'imgs')
        self.images_test_dir = os.path.join(self._output_dir, 'test')
        self._num_imgs_to_val_save = 649
        self._num_imgs_test_to_save = 649
        self._pool_size = int(config['pool_size'])
        self._lambda_a = float(config['_LAMBDA_A'])
        self._lambda_b = float(config['_LAMBDA_B'])
        self._skip = bool(config['skip'])
        self._num_cls = int(config['num_cls'])
        self._base_lr = float(config['base_lr'])
        self._max_step = int(config['max_step'])
        self._keep_rate_value = float(config['keep_rate_value'])
        self._rate_mafcn_value = float(config['rate_mafcn_value'])
        self._is_training_value = bool(config['is_training_value'])
        self._batch_size = int(config['batch_size'])
        self._lr_gan_decay = bool(config['lr_gan_decay'])
        self._lsgan_loss_p_scheduler = bool(config['lsgan_loss_p_scheduler'])
        self._to_restore = bool(config['to_restore'])
        self._checkpoint_dir = config['checkpoint_dir']
        self.label_weight = config['label_weight']

        self.fake_images_A = np.zeros(
            (self._pool_size, self._batch_size, model.IMG_HEIGHT, model.IMG_WIDTH, 3))  # 1->3
        self.fake_images_B = np.zeros(
            (self._pool_size, self._batch_size, model.IMG_HEIGHT, model.IMG_WIDTH, 3))  # 1->3

    def model_setup(self):

        self.input_a = tf.placeholder(
            tf.float32, [
                self._batch_size,
                model.IMG_WIDTH,
                model.IMG_HEIGHT,
                3
            ], name="input_A")
        self.input_b = tf.placeholder(
            tf.float32, [
                self._batch_size,
                model.IMG_WIDTH,
                model.IMG_HEIGHT,
                3
            ], name="input_B")
        self.fake_pool_A = tf.placeholder(
            tf.float32, [
                self._batch_size,
                model.IMG_WIDTH,
                model.IMG_HEIGHT,
                3
            ], name="fake_pool_A")
        self.fake_pool_B = tf.placeholder(
            tf.float32, [
                self._batch_size,
                model.IMG_WIDTH,
                model.IMG_HEIGHT,
                3
            ], name="fake_pool_B")
        self.gt_a = tf.placeholder(
            tf.float32, [
                self._batch_size,
                model.IMG_WIDTH,
                model.IMG_HEIGHT,
                self._num_cls
            ], name="gt_A")
        self.gt_b = tf.placeholder(
            tf.float32, [
                self._batch_size,
                model.IMG_WIDTH,
                model.IMG_HEIGHT,
                self._num_cls
            ], name="gt_B")
        self.gt_a_2 = tf.placeholder(
            tf.float32, [
                self._batch_size,
                model.IMG_WIDTH // 2,
                model.IMG_HEIGHT // 2,
                self._num_cls
            ], name="gt_A_2")
        self.gt_a_3 = tf.placeholder(
            tf.float32, [
                self._batch_size,
                model.IMG_WIDTH // 4,
                model.IMG_HEIGHT // 4,
                self._num_cls
            ], name="gt_A_3")


        self.keep_rate = tf.placeholder(tf.float32, shape=[]) 
        self.rate_mafcn = tf.placeholder(tf.float32, shape=[])
        self.is_training = tf.placeholder(tf.bool, shape=())

        self.global_step = tf.train.get_or_create_global_step()
        self.num_fake_inputs = 0

        self.learning_rate_gan = tf.placeholder(tf.float32, shape=[], name="lr_gan")
        self.learning_rate_seg = tf.placeholder(tf.float32, shape=[], name="lr_seg")


        self.train_output_loss = tf.placeholder(tf.float32, shape=[], name="train_loss")
        self.train_output_acc = tf.placeholder(tf.float32, shape=[], name="train_acc")
        self.val_output_loss = tf.placeholder(tf.float32, shape=[], name="val_loss")
        self.val_output_acc = tf.placeholder(tf.float32, shape=[], name="val_acc")

        self.train_output_loss_epoch = tf.placeholder(tf.float32, shape=[], name="train_loss_epoch")
        self.train_output_acc_epoch = tf.placeholder(tf.float32, shape=[], name="train_acc_epoch")
        self.val_output_loss_epoch = tf.placeholder(tf.float32, shape=[], name="val_loss_epoch")
        self.val_output_acc_epoch = tf.placeholder(tf.float32, shape=[], name="val_acc_epoch")

        inputs = {
            'images_a': self.input_a,
            'images_b': self.input_b,
            'fake_pool_a': self.fake_pool_A,
            'fake_pool_b': self.fake_pool_B,
        }

        outputs = model.get_outputs(inputs, skip=self._skip, is_training=self.is_training, keep_rate=self.keep_rate, rate_mafcn=self.rate_mafcn)

        self.prob_real_a_is_real = outputs['prob_real_a_is_real']
        self.prob_real_b_is_real = outputs['prob_real_b_is_real']
        self.fake_images_a = outputs['fake_images_a']
        self.fake_images_b = outputs['fake_images_b']
        self.prob_fake_a_is_real = outputs['prob_fake_a_is_real']
        self.prob_fake_b_is_real = outputs['prob_fake_b_is_real']

        self.cycle_images_a = outputs['cycle_images_a']
        self.cycle_images_b = outputs['cycle_images_b']

        self.prob_fake_pool_a_is_real = outputs['prob_fake_pool_a_is_real']
        self.prob_fake_pool_b_is_real = outputs['prob_fake_pool_b_is_real']
        self.pred_mask_a = outputs['pred_mask_a']
        self.pred_mask_b = outputs['pred_mask_b']
        self.pred_mask_b_argmax = tf.argmax(self.pred_mask_b, axis=3)
        self.pred_mask_fake_a = outputs['pred_mask_fake_a']
        self.pred_mask_fake_b = outputs['pred_mask_fake_b']
        self.pred_mask_fake_b_argmax = tf.argmax(self.pred_mask_fake_b, axis=3)
        self.prob_pred_mask_fake_b_is_real = outputs['prob_pred_mask_fake_b_is_real']
        self.prob_pred_mask_b_is_real = outputs['prob_pred_mask_b_is_real']

        self.prob_fake_a_aux_is_real = outputs['prob_fake_a_aux_is_real']
        self.prob_fake_pool_a_aux_is_real = outputs['prob_fake_pool_a_aux_is_real']
        self.prob_cycle_a_aux_is_real = outputs['prob_cycle_a_aux_is_real']

        self.output_mask_b_1 = outputs['output_mask_b_1']
        self.output_mask_fake_b_1 = outputs['output_mask_fake_b_1']
        self.output_mask_cycle_images_b_1 = outputs['output_mask_cycle_images_b_1']
        self.output_mask_fake_a_1 = outputs['output_mask_fake_a_1']
        self.output_mask_a_1 = outputs['output_mask_a_1']
        self.output_mask_cycle_images_a_1 = outputs['output_mask_cycle_images_a_1']

        self.output_mask_b_1_argmax = tf.argmax(self.output_mask_b_1, axis=3)
        self.output_mask_a_1_argmax = tf.argmax(self.output_mask_a_1, axis=3)
        self.output_mask_cycle_images_a_1_argmax = tf.argmax(self.output_mask_cycle_images_a_1, axis=3)
        self.output_mask_fake_b_1_argmax = tf.argmax(self.output_mask_fake_b_1, axis=3)
        self.gt_a_argmax = tf.argmax(self.gt_a, axis=3)
        self.gt_b_argmax = tf.argmax(self.gt_b, axis=3)

        self.output_mask_fake_b_2 = outputs['output_mask_fake_b_2']
        self.output_mask_fake_b_3 = outputs['output_mask_fake_b_3']
        self.prob_output_mask_fake_b_1_is_real = outputs['prob_output_mask_fake_b_1_is_real']
        self.prob_output_mask_fake_b_2_is_real = outputs['prob_output_mask_fake_b_2_is_real']
        self.prob_output_mask_fake_b_3_is_real = outputs['prob_output_mask_fake_b_3_is_real']

        self.prob_output_mask_b_1_is_real = outputs['prob_output_mask_b_1_is_real']
        self.prob_output_mask_b_2_is_real = outputs['prob_output_mask_b_2_is_real']
        self.prob_output_mask_b_3_is_real = outputs['prob_output_mask_b_3_is_real']

    def compute_losses(self):

        cycle_consistency_loss_a = \
            self._lambda_a * losses.cycle_consistency_loss(
                real_images=self.input_a, generated_images=self.cycle_images_a,
            )
        cycle_consistency_loss_b = \
            self._lambda_b * losses.cycle_consistency_loss(
                real_images=self.input_b, generated_images=self.cycle_images_b,
            )

        lsgan_loss_a = losses.lsgan_loss_generator(self.prob_fake_a_is_real)
        lsgan_loss_b = losses.lsgan_loss_generator(self.prob_fake_b_is_real)
        lsgan_loss_p = losses.lsgan_loss_generator(self.prob_pred_mask_b_is_real)
        lsgan_loss_a_aux = losses.lsgan_loss_generator(self.prob_fake_a_aux_is_real)

        lsgan_loss_output1 = losses.lsgan_loss_generator(self.prob_output_mask_b_1_is_real)
        lsgan_loss_output2 = losses.lsgan_loss_generator(self.prob_output_mask_b_2_is_real)
        lsgan_loss_output3 = losses.lsgan_loss_generator(self.prob_output_mask_b_3_is_real)


        d_loss_output1 = losses.lsgan_loss_discriminator(
            prob_real_is_real=self.prob_output_mask_fake_b_1_is_real,
            prob_fake_is_real=self.prob_output_mask_b_1_is_real,
        )
        d_loss_output2 = losses.lsgan_loss_discriminator(
            prob_real_is_real=self.prob_output_mask_fake_b_2_is_real,
            prob_fake_is_real=self.prob_output_mask_b_2_is_real,
        )

        d_loss_output3 = losses.lsgan_loss_discriminator(
            prob_real_is_real=self.prob_output_mask_fake_b_3_is_real,
            prob_fake_is_real=self.prob_output_mask_b_3_is_real,
        )


        per_loss_t = losses.cycle_consistency_loss(real_images=self.output_mask_b_1, generated_images=self.output_mask_fake_b_1)
        per_loss_t_recon = losses.cycle_consistency_loss(real_images=self.output_mask_b_1, generated_images=self.output_mask_cycle_images_b_1)

        per_loss_s = losses.cycle_consistency_loss(real_images=self.output_mask_a_1, generated_images=self.output_mask_fake_a_1)
        per_loss_s_recon = losses.cycle_consistency_loss(real_images=self.output_mask_a_1, generated_images=self.output_mask_cycle_images_a_1)


        softmax_loss_output = losses._softmax_loss(self.output_mask_fake_b_1, self.gt_a) * 2 + \
                            losses._softmax_loss(self.output_mask_fake_b_2, self.gt_a_2) + \
                            losses._softmax_loss(self.output_mask_fake_b_3, self.gt_a_3) 

        self.softmax_loss_output_val = losses._softmax_loss(self.output_mask_b_1, self.gt_b)
        self.softmax_loss_output = softmax_loss_output
        l2_loss_output = tf.add_n([0.0001 * tf.nn.l2_loss(v) for v in tf.trainable_variables() if '/u_B/' in v.name])

        self.lsgan_loss_p_weight = tf.placeholder(tf.float32, shape=[], name="lsgan_loss_p_weight")

        seg_loss_output = softmax_loss_output + l2_loss_output + self.lsgan_loss_p_weight * lsgan_loss_output1 + \
                        self.lsgan_loss_p_weight * lsgan_loss_output2 \
                        + self.lsgan_loss_p_weight * lsgan_loss_output3

        softmax_loss_feature = losses._softmax_losses(self.pred_mask_fake_b, self.gt_a)
        l2_loss_feature = tf.add_n([0.0001 * tf.nn.l2_loss(v) for v in tf.trainable_variables() if '/s_B/' in v.name or '/e_B/' in v.name])

        g_loss_A = cycle_consistency_loss_a + cycle_consistency_loss_b + lsgan_loss_b +\
                   (lsgan_loss_output1 + lsgan_loss_output2 + lsgan_loss_output3) * self.lsgan_loss_p_weight + \
                   0.1 * per_loss_t + 10 * per_loss_t_recon + 10 * per_loss_s_recon
        g_loss_B = cycle_consistency_loss_b + cycle_consistency_loss_a + lsgan_loss_a + \
                   0.1 * per_loss_s + 10 * per_loss_t_recon + 10 * per_loss_s_recon

        seg_loss_feature = softmax_loss_feature + l2_loss_feature + 0.1 * g_loss_B + self.lsgan_loss_p_weight * lsgan_loss_p + 0.1 * lsgan_loss_a_aux

        d_loss_A = losses.lsgan_loss_discriminator(
            prob_real_is_real=self.prob_real_a_is_real,
            prob_fake_is_real=self.prob_fake_pool_a_is_real,
        )
        d_loss_A_aux = losses.lsgan_loss_discriminator(
            prob_real_is_real=self.prob_cycle_a_aux_is_real,
            prob_fake_is_real=self.prob_fake_pool_a_aux_is_real,
        )
        d_loss_A = d_loss_A + d_loss_A_aux
        d_loss_B = losses.lsgan_loss_discriminator(
            prob_real_is_real=self.prob_real_b_is_real,
            prob_fake_is_real=self.prob_fake_pool_b_is_real,
        )
        d_loss_P = losses.lsgan_loss_discriminator(
            prob_real_is_real=self.prob_pred_mask_fake_b_is_real,
            prob_fake_is_real=self.prob_pred_mask_b_is_real,
        )

        optimizer_gan = tf.train.AdamOptimizer(self.learning_rate_gan, beta1=0.5)
        optimizer_seg = tf.train.AdamOptimizer(self.learning_rate_seg)

        self.model_vars = tf.trainable_variables()

        d_A_vars = [var for var in self.model_vars if '/d_A/' in var.name]
        d_B_vars = [var for var in self.model_vars if '/d_B/' in var.name]
        g_A_vars = [var for var in self.model_vars if '/g_A/' in var.name]
        e_B_vars = [var for var in self.model_vars if '/e_B/' in var.name]
        de_B_vars = [var for var in self.model_vars if '/de_B/' in var.name]
        s_B_vars = [var for var in self.model_vars if '/s_B/' in var.name]
        d_P_vars = [var for var in self.model_vars if '/d_P/' in var.name]
        output_vars = [var for var in self.model_vars if '/u_B/' in var.name]
        d_U1_vars = [var for var in self.model_vars if '/d_output1/' in var.name]
        d_U2_vars = [var for var in self.model_vars if '/d_output2/' in var.name]
        d_U3_vars = [var for var in self.model_vars if '/d_output3/' in var.name]

        self.d_A_trainer = optimizer_gan.minimize(d_loss_A, var_list=d_A_vars)
        self.d_B_trainer = optimizer_gan.minimize(d_loss_B, var_list=d_B_vars)
        self.g_A_trainer = optimizer_gan.minimize(g_loss_A, var_list=g_A_vars) 
        self.g_B_trainer = optimizer_gan.minimize(g_loss_B, var_list=de_B_vars) 
        self.d_P_trainer = optimizer_gan.minimize(d_loss_P, var_list=d_P_vars)
        self.s_B_trainer = optimizer_seg.minimize(seg_loss_feature, var_list=e_B_vars + s_B_vars)
        self.output_trainer = optimizer_seg.minimize(seg_loss_output, var_list=output_vars)
        self.d_U1_trainer = optimizer_gan.minimize(d_loss_output1, var_list=d_U1_vars)
        self.d_U2_trainer = optimizer_gan.minimize(d_loss_output2, var_list=d_U2_vars)
        self.d_U3_trainer = optimizer_gan.minimize(d_loss_output3, var_list=d_U3_vars)


        for var in self.model_vars:
            print(var.name)

        # Summary variables for tensorboard
        with tf.name_scope('feature'):
            self.g_A_loss_summ = tf.summary.scalar("g_A_loss", g_loss_A)
            self.g_B_loss_summ = tf.summary.scalar("g_B_loss", g_loss_B)
            self.d_A_loss_summ = tf.summary.scalar("d_A_loss", d_loss_A)
            self.d_B_loss_summ = tf.summary.scalar("d_B_loss", d_loss_B)
            self.softmax_feature_loss_summ = tf.summary.scalar('softmax_feature_loss', softmax_loss_feature)
            self.l2_feature_loss_summ = tf.summary.scalar("l2_feature_loss", l2_loss_feature)
            self.s_feature_loss_summ = tf.summary.scalar("s_feature_loss", seg_loss_feature)
            self.s_feature_loss_merge_summ = tf.summary.merge(
                [self.softmax_feature_loss_summ, self.l2_feature_loss_summ, self.s_feature_loss_summ])
            self.d_P_loss_summ = tf.summary.scalar("d_P_loss", d_loss_P)
            self.lsgan_loss_p_weight_summ = tf.summary.scalar("lsgan_loss_p_weight", self.lsgan_loss_p_weight)
            self.lr_gan_summ = tf.summary.scalar("lr_gan", self.learning_rate_gan)

        with tf.name_scope('output'):
            self.d_loss_output1_summ = tf.summary.scalar("d_loss_output1", d_loss_output1)
            self.d_loss_output2_summ = tf.summary.scalar("d_loss_output2", d_loss_output2)
            self.d_loss_output3_summ = tf.summary.scalar("d_loss_output3", d_loss_output3)
            self.softmax_output_loss_summ = tf.summary.scalar('softmax_loss_output', softmax_loss_output)
            self.seg_loss_output_summ = tf.summary.scalar("seg_loss_output", seg_loss_output)
            self.l2_loss_output_summ = tf.summary.scalar("l2_loss_output", l2_loss_output)
            self.output_loss_merge_summ = tf.summary.merge(
                [self.softmax_output_loss_summ, self.l2_loss_output_summ, self.seg_loss_output_summ])
            self.lr_seg_summ = tf.summary.scalar("lr_seg", self.learning_rate_seg)

        with tf.name_scope('Loss_Acc'):
            self.train_loss_summ = tf.summary.scalar("train_loss", self.train_output_loss)
            self.train_acc_summ = tf.summary.scalar("train_acc", self.train_output_acc)
            self.val_loss_summ = tf.summary.scalar("val_loss", self.val_output_loss)
            self.val_acc_summ = tf.summary.scalar("val_acc", self.val_output_acc)

            self.train_loss_summ_epoch = tf.summary.scalar("train_loss_epoch", self.train_output_loss_epoch)
            self.train_acc_summ_epoch = tf.summary.scalar("train_acc_epoch", self.train_output_acc_epoch)
            self.val_loss_summ_epoch = tf.summary.scalar("val_loss_epoch", self.val_output_loss_epoch)
            self.val_acc_summ_epoch = tf.summary.scalar("val_acc_epoch", self.val_output_acc_epoch)


    def save_images(self, sess, epoch):

        if not os.path.exists(self.images_test_dir):
            os.makedirs(self.images_test_dir)

        names = ['fake_image_b']
        with open(os.path.join(self._output_dir, 'epoch_' + str(epoch) + '.html'), 'w') as v_html:
            for i in range(0, self._num_imgs_to_val_save):
                images_i, images_j, gts_i, gts_j = sess.run(self.inputs)
                inputs = {
                    'images_i': images_i,
                    'images_j': images_j,
                    'gts_i': gts_i,
                    'gts_j': gts_j,
                }

                fake_image_b = sess.run([
                    self.fake_images_b,
                ], feed_dict={
                    self.input_a: inputs['images_i'],
                    self.input_b: inputs['images_j'],
                    self.gt_a: inputs['gts_i'],
                    self.is_training: False,
                    self.keep_rate: 1,
                    self.rate_mafcn: 0,
                })

                tensors = [fake_image_b]

                for name, tensor in zip(names, tensors):
                    image_name = name + str(epoch) + "_" + str(i) + ".tif"
                    if 'pred_mask'  == name or 'output_mask' == name or 'label' == name:
                        cv2.imwrite(os.path.join(self.images_test_dir, image_name), (tensor[0]).squeeze())
                    else:
                        imsave(os.path.join(self.images_test_dir, image_name), ((tensor[0] + 1) * 127.5).astype(np.uint8).squeeze())
                    v_html.write("<img src=\"" + os.path.join('imgs', image_name) + "\">")
                v_html.write("<br>")


    def save_images1(self, sess, epoch):
        if not os.path.exists(self.images_test_dir):
            os.makedirs(self.images_test_dir)

        names = ['output_mask_b_1_argmax', 'label']
        with open(os.path.join(self._output_dir, 'epoch_' + str(epoch) + '.html'), 'w') as v_html:
            for i in range(0, self._num_imgs_test_to_save):
                print("Saving image {}/{}".format(i, self._num_imgs_test_to_save))
                images_i, images_j, gts_i, gts_j = sess.run(self.inputs) 
                inputs = {
                    'images_j': images_j,
                    'gts_j': gts_j,
                }
                output_mask_B_1_argmax, gt_B_argmax= sess.run(
                    [self.output_mask_b_1_argmax , self.gt_b_argmax],
                    feed_dict={
                        self.input_ｂ: inputs['images_j'],
                        self.gt_b: inputs['gts_j'],
                        self.is_training: False,
                        self.keep_rate: 1,
                        self.rate_mafcn: 0,
                    })
                output_mask_BB_1_argmax = np.squeeze(output_mask_B_1_argmax)
                gt_BB_argmax = np.squeeze(gt_B_argmax)
                tensors = [output_mask_BB_1_argmax, gt_BB_argmax]

                for name, tensor in zip(names, tensors):
                    image_name = name + "_" + str(i) + ".PNG"
                    if name == 'label' or name == 'output_mask_b_1_argmax':
                        cv2.imwrite(os.path.join(self.images_test_dir, image_name),(np.expand_dims(tensor, axis=-1)))
                    else:
                        imsave(os.path.join(self.images_test_dir, image_name), ((tensor[0] + 1) * 127.5).astype(np.uint8).squeeze())


    def fake_image_pool(self, num_fakes, fake, fake_pool):
        if num_fakes < self._pool_size:
            fake_pool[num_fakes] = fake
            return fake
        else:
            p = random.random()
            if p > 0.5:
                random_id = random.randint(0, self._pool_size - 1)
                temp = fake_pool[random_id]
                fake_pool[random_id] = fake
                return temp
            else:
                return fake

    def train(self):

        # Load Dataset
        self.inputs = data_loader.load_data(self._source_train_pth, self._target_train_pth,self._batch_size, True)
        self.inputs_val = data_loader.load_data(self._source_val_pth, self._target_val_pth,self._batch_size, True)

        # Build the network
        self.model_setup()

        # Loss function calculations
        self.compute_losses() 

        # Initializing the global variables
        init = (tf.global_variables_initializer(),
                tf.local_variables_initializer())
        saver = tf.train.Saver(max_to_keep=70)

        with open(self._source_train_pth, 'r') as fp:
            rows_s = fp.readlines()
        with open(self._target_train_pth, 'r') as fp:
            rows_t = fp.readlines()
        with open(self._source_val_pth, 'r') as fp:
            rows_val_s = fp.readlines()
        with open(self._target_val_pth, 'r') as fp:
            rows_val_t = fp.readlines()

        max_images = max(len(rows_s), len(rows_t))

        max_val_images = len(rows_val_t)

        gpu_options = tf.GPUOptions(allow_growth=True)

        with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
            sess.run(init)


            # Restore the model to run the model from last checkpoint
            if self._to_restore:
                chkpt_fname = tf.train.latest_checkpoint(self._checkpoint_dir)
                saver.restore(sess, chkpt_fname) 

            writer = tf.summary.FileWriter(self._output_dir + '/train')
            writer_val = tf.summary.FileWriter(self._output_dir + '/val')

            if not os.path.exists(self._output_dir):
                os.makedirs(self._output_dir)

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)

            # Training Loop
            curr_lr_seg = 0.001
            curr_lr_output = 1e-3
            cnt = -1

            gts_i_2_i = np.zeros((self._batch_size, model.IMG_WIDTH // 2, model.IMG_HEIGHT // 2, self._num_cls))
            gts_i_3_i = np.zeros((self._batch_size, model.IMG_WIDTH // 4, model.IMG_HEIGHT // 4, self._num_cls))

            for epoch in range(sess.run(self.global_step), self._max_step):
                print("In the epoch ", epoch)
                if self._lr_gan_decay:
                    if epoch < (self._max_step/2):
                        curr_lr = self._base_lr
                    else:
                        curr_lr = self._base_lr - self._base_lr * (epoch - self._max_step/2) / (self._max_step/2)
                else:
                    curr_lr = self._base_lr

                if self._lsgan_loss_p_scheduler:
                    if epoch < 5:
                        lsgan_loss_p_weight_value = 0
                    elif epoch < 7:
                        lsgan_loss_p_weight_value = 0.1 * (epoch - 4.0) / (7.0 - 4.0)
                    else:
                        lsgan_loss_p_weight_value = 0.1
                else:
                    lsgan_loss_p_weight_value = 0.1

                if epoch > 0 and epoch % 2 == 0:
                    curr_lr_seg = np.multiply(curr_lr_seg, 0.9)
                    curr_lr_output = np.multiply(curr_lr_output, 0.9)

                max_inter = np.uint16(np.floor(max_images/self._batch_size))
                max_val_inter = np.uint16(np.floor(max_val_images/self._batch_size))

                Train_loss = 0
                Train_acc = 0
                Val_loss = 0
                Val_acc = 0
                cnt += 1

                for i in range(0, max_inter):
                    print("Processing batch {}/{}".format(i, max_inter))

                    images_i, images_j, gts_i, gts_j = sess.run(self.inputs)

                    batch_size, h, w = np.shape(gts_i)[0], np.shape(gts_i)[1], np.shape(gts_i)[2]

                    for j in range(batch_size):
                        gts_i_2_i[j] = cv2.resize(gts_i[j], (h // 2, w // 2), interpolation=cv2.INTER_NEAREST)
                        gts_i_3_i[j] = cv2.resize(gts_i[j], (h // 4, w // 4), interpolation=cv2.INTER_NEAREST)

                    inputs = {
                        'images_i': images_i,
                        'images_j': images_j,
                        'gts_i': gts_i,
                        'gts_i_2': gts_i_2_i,
                        'gts_i_3': gts_i_3_i,
                        'gts_j': gts_j,
                    }

                    # Optimizing the G_A network
                    _, fake_B_temp, summary_str = sess.run(
                        [self.g_A_trainer,
                         self.fake_images_b,
                         self.g_A_loss_summ],
                        feed_dict={
                            self.input_a:
                                inputs['images_i'],
                            self.input_b:
                                inputs['images_j'],
                            self.gt_a:
                                inputs['gts_i'],
                            self.learning_rate_gan: curr_lr,
                            self.keep_rate: self._keep_rate_value,
                            self.rate_mafcn: self._rate_mafcn_value,
                            self.is_training:self._is_training_value,
                            self.lsgan_loss_p_weight: lsgan_loss_p_weight_value,
                        }
                    )
                    writer.add_summary(summary_str, epoch * max_inter + i)

                    fake_B_temp1 = self.fake_image_pool(
                        self.num_fake_inputs, fake_B_temp, self.fake_images_B)

                    # Optimizing the D_B network
                    _, summary_str = sess.run(
                        [self.d_B_trainer, self.d_B_loss_summ],
                        feed_dict={
                            self.input_a:
                                inputs['images_i'],
                            self.input_b:
                                inputs['images_j'],
                            self.learning_rate_gan: curr_lr,
                            self.fake_pool_B: fake_B_temp1,
                            self.keep_rate: self._keep_rate_value,
                            self.rate_mafcn: self._rate_mafcn_value,
                            self.is_training: self._is_training_value,
                            self.lsgan_loss_p_weight: lsgan_loss_p_weight_value,
                        }
                    )
                    writer.add_summary(summary_str, epoch * max_inter + i)

                    # Optimizing the S_B network
                    _, summary_str = sess.run(
                        [self.s_B_trainer, self.s_feature_loss_merge_summ],
                        feed_dict={
                            self.input_a:
                                inputs['images_i'],
                            self.input_b:
                                inputs['images_j'],
                            self.gt_a:
                                inputs['gts_i'],
                            self.learning_rate_seg: curr_lr_seg,
                            self.keep_rate: self._keep_rate_value,
                            self.rate_mafcn: self._rate_mafcn_value,
                            self.is_training: self._is_training_value,
                            self.lsgan_loss_p_weight: lsgan_loss_p_weight_value,
                        }
                    )
                    writer.add_summary(summary_str, epoch * max_inter + i)

                    # Optimizing the G_B network
                    _, fake_A_temp, summary_str = sess.run(
                        [self.g_B_trainer,
                         self.fake_images_a,
                         self.g_B_loss_summ],
                        feed_dict={
                            self.input_a:
                                inputs['images_i'],
                            self.input_b:
                                inputs['images_j'],
                            self.learning_rate_gan: curr_lr,
                            self.gt_a: inputs['gts_i'],
                            self.keep_rate: self._keep_rate_value,
                            self.rate_mafcn: self._rate_mafcn_value,
                            self.is_training: self._is_training_value,
                            self.lsgan_loss_p_weight: lsgan_loss_p_weight_value,
                        }
                    )
                    writer.add_summary(summary_str, epoch * max_inter + i)

                    fake_A_temp1 = self.fake_image_pool(
                        self.num_fake_inputs, fake_A_temp, self.fake_images_A)

                    # Optimizing the D_A network
                    _, summary_str = sess.run(
                        [self.d_A_trainer, self.d_A_loss_summ],
                        feed_dict={
                            self.input_a:
                                inputs['images_i'],
                            self.input_b:
                                inputs['images_j'],
                            self.learning_rate_gan: curr_lr,
                            self.fake_pool_A: fake_A_temp1,
                            self.keep_rate: self._keep_rate_value,
                            self.rate_mafcn: self._rate_mafcn_value,
                            self.is_training: self._is_training_value,
                            self.lsgan_loss_p_weight: lsgan_loss_p_weight_value,
                        }
                    )
                    writer.add_summary(summary_str, epoch * max_inter + i)

                    # Optimizing the D_P network
                    _, summary_str = sess.run(
                        [self.d_P_trainer, self.d_P_loss_summ],
                        feed_dict={
                            self.input_a:
                                inputs['images_i'],
                            self.input_b:
                                inputs['images_j'],
                            self.learning_rate_gan: curr_lr,
                            self.keep_rate: self._keep_rate_value,
                            self.rate_mafcn: self._rate_mafcn_value,
                            self.is_training: self._is_training_value,
                            self.lsgan_loss_p_weight: lsgan_loss_p_weight_value,
                        }
                    )
                    writer.add_summary(summary_str, epoch * max_inter + i)

                    # Optimizing the U_net network
                    # unet_loss_merge_summ
                    _, summary_str, train_loss, output_mask_fake_B_1_argmax, gt_A_argmax = sess.run(
                        [self.output_trainer, self.output_loss_merge_summ, self.softmax_loss_output, self.output_mask_fake_b_1_argmax, self.gt_a_argmax ],
                        feed_dict={
                            self.input_a: inputs['images_i'],
                            self.input_b: inputs['images_j'],
                            self.gt_a: inputs['gts_i'],
                            self.gt_a_2: inputs['gts_i_2'],
                            self.gt_a_3: inputs['gts_i_3'],
                            self.learning_rate_seg: curr_lr_seg,
                            self.keep_rate: self._keep_rate_value,
                            self.rate_mafcn: self._rate_mafcn_value,
                            self.is_training: self._is_training_value,
                            self.lsgan_loss_p_weight: lsgan_loss_p_weight_value,
                        }
                    )
                    writer.add_summary(summary_str, epoch * max_inter + i)
                    train_acc = len(np.where(output_mask_fake_B_1_argmax == gt_A_argmax)[0]) / (model.IMG_WIDTH * model.IMG_HEIGHT * self._batch_size)
                    Train_loss += train_loss
                    Train_acc += train_acc

                    # Optimizing the d_U1 network
                    _, summary_str = sess.run(
                        [self.d_U1_trainer, self.d_loss_output1_summ],
                        feed_dict={
                            self.input_a:
                                inputs['images_i'],
                            self.input_b:
                                inputs['images_j'],
                            self.learning_rate_gan: curr_lr,
                            self.keep_rate: self._keep_rate_value,
                            self.rate_mafcn: self._rate_mafcn_value,
                            self.is_training: self._is_training_value,
                            self.lsgan_loss_p_weight: lsgan_loss_p_weight_value,
                        }
                    )
                    writer.add_summary(summary_str, epoch * max_inter + i)

                    # Optimizing the d_U2 network
                    _, summary_str = sess.run(
                        [self.d_U2_trainer, self.d_loss_output2_summ],
                        feed_dict={
                            self.input_a:
                                inputs['images_i'],
                            self.input_b:
                                inputs['images_j'],
                            self.learning_rate_gan: curr_lr,
                            self.keep_rate: self._keep_rate_value,
                            self.rate_mafcn: self._rate_mafcn_value,
                            self.is_training: self._is_training_value,
                            self.lsgan_loss_p_weight: lsgan_loss_p_weight_value,
                        }
                    )
                    writer.add_summary(summary_str, epoch * max_inter + i)

                    # Optimizing the d_U3 network
                    _, summary_str = sess.run(
                        [self.d_U3_trainer, self.d_loss_output3_summ],
                        feed_dict={
                            self.input_a:
                                inputs['images_i'],
                            self.input_b:
                                inputs['images_j'],
                            self.learning_rate_gan: curr_lr,
                            self.keep_rate: self._keep_rate_value,
                            self.rate_mafcn: self._rate_mafcn_value,
                            self.is_training: self._is_training_value,
                            self.lsgan_loss_p_weight: lsgan_loss_p_weight_value,
                        }
                    )
                    writer.add_summary(summary_str, epoch * max_inter + i)

                    
                    summary_str_gan, summary_str_seg, summary_str_lossp , summary_str_t_l,  summary_str_t_a = \
                        sess.run([self.lr_gan_summ, self.lr_seg_summ, self.lsgan_loss_p_weight_summ, self.train_loss_summ, self.train_acc_summ],
                                 feed_dict={
                                    self.learning_rate_gan: curr_lr,
                                    self.learning_rate_seg: curr_lr_seg,
                                    self.lsgan_loss_p_weight: lsgan_loss_p_weight_value,
                                    self.train_output_loss: train_loss,
                                    self.train_output_acc: train_acc,
                             })

                    writer.add_summary(summary_str_gan, epoch * max_inter + i)
                    writer.add_summary(summary_str_seg, epoch * max_inter + i)
                    writer.add_summary(summary_str_lossp, epoch * max_inter + i)
                    writer.add_summary(summary_str_t_l, epoch * max_inter + i)
                    writer.add_summary(summary_str_t_a, epoch * max_inter + i)

                    writer.flush()
                    self.num_fake_inputs += 1


                print('train loss:{}'.format(Train_loss / max_inter))
                print('train acc:{}'.format(Train_acc / max_inter))

                # val
                for i in range(max_val_inter):
                    images_i_val, images_j_val, gts_i_val, gts_j_val = sess.run(self.inputs_val)
                    inputs_val = {
                        'images_i_val': images_i_val,
                        'images_j_val': images_j_val,
                        'gts_i_val': gts_i_val,
                        'gts_j_val': gts_j_val,
                    }

                    val_loss, output_mask_B_1_argmax, gt_B_argmax = sess.run(
                        [self.softmax_loss_output_val, self.output_mask_b_1_argmax, self.gt_b_argmax],
                        feed_dict={
                            self.input_a: inputs_val['images_i_val'],
                            self.input_b: inputs_val['images_j_val'],
                            self.gt_b: inputs_val['gts_j_val'],
                            self.is_training: False,
                            self.keep_rate: 1,
                            self.rate_mafcn: 0,
                        }
                    )
                    val_acc = len(np.where(output_mask_B_1_argmax == gt_B_argmax)[0]) / (model.IMG_WIDTH * model.IMG_HEIGHT * self._batch_size)
                    Val_loss += val_loss
                    Val_acc += val_acc

                    summary_str_v_l, summary_str_v_a = sess.run([self.val_loss_summ, self.val_acc_summ],
                                 feed_dict={
                                     self.val_output_loss: val_loss,
                                     self.val_output_acc: val_acc,
                             })
                    writer.add_summary(summary_str_v_l, epoch * max_val_inter + i)
                    writer.add_summary(summary_str_v_a, epoch * max_val_inter + i)

                    # 验证集指标

                print('val loss:{}'.format(Val_loss / max_val_inter))
                print('val acc:{}'.format(Val_acc / max_val_inter))

                summary_str_t_ll, summary_str_t_aa, summary_str_v_ll, summary_str_v_aa =\
                    sess.run([self.train_loss_summ_epoch, self.train_acc_summ_epoch, self.val_loss_summ_epoch, self.val_acc_summ_epoch],
                                                            feed_dict={
                                                                self.train_output_loss_epoch: Train_loss / max_inter,
                                                                self.train_output_acc_epoch: Train_acc / max_inter,
                                                                self.val_output_loss_epoch: Val_loss / max_val_inter,
                                                                self.val_output_acc_epoch: Val_acc / max_val_inter,
                                                            })
                writer.add_summary(summary_str_t_ll, epoch)
                writer.add_summary(summary_str_t_aa, epoch)
                writer.add_summary(summary_str_v_ll, epoch)
                writer.add_summary(summary_str_v_aa, epoch)

                saver.save(sess, os.path.join(
                    self._output_dir, "sifa_val_loss_" + str(Val_loss / max_val_inter)[: 6] + '_val_acc_' + str(Val_acc / max_val_inter)[: 6]), global_step=cnt)

                sess.run(tf.assign(self.global_step, epoch + 1))

            coord.request_stop()
            coord.join(threads)
            writer.add_graph(sess.graph)

    def test(self):
        print('Test the results')
        self.inputs = data_loader.load_data(self._source_val_pth, self._target_val_pth,self._batch_size, do_shuffle=False)
        self.model_setup()
        saver = tf.train.Saver()
        init = tf.global_variables_initializer()

        gpu_options = tf.GPUOptions(allow_growth=True)
        with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
            sess.run(init)

            chkpt_fname = tf.train.latest_checkpoint(self._checkpoint_dir)
            print(chkpt_fname)
            saver.restore(sess, chkpt_fname)
            # saver = tf.train.import_meta_graph(self._checkpoint_dir + 'sifa_val_loss_0.9239_val_acc_0.6762-27.meta')
            # saver.restore(sess, self._checkpoint_dir + 'sifa_val_loss_0.9239_val_acc_0.6762-27')

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            self.save_images1(sess, sess.run(self.global_step))

            coord.request_stop()
            coord.join(threads)


def main(config_filename):

    with open(config_filename) as config_file:
        config = json.load(config_file)    

    sifa_model = SIFA(config)
    is_training = config["is_training_value"]
    print(is_training)
    if is_training > 0:
        sifa_model.train()
    else:
        sifa_model.test()


if __name__ == '__main__':
    main(config_filename='./config_param.json')
