import tensorflow as tf
import layers
import json
from tensorflow import keras

with open('./config_param.json') as config_file:
    config = json.load(config_file)

BATCH_SIZE = int(config['batch_size'])
POOL_SIZE = int(config['pool_size'])

# The height of each image.
IMG_HEIGHT = int(config['img_height'])
# The width of each image.
IMG_WIDTH = int(config['img_width'])

num_cls = int(config['num_cls'])

ngf = 32
ndf = 64


def get_outputs(inputs, skip=False, is_training=True, keep_rate=0.75, rate_mafcn=0.2):
    images_a = inputs['images_a']
    images_b = inputs['images_b']
    fake_pool_a = inputs['fake_pool_a']
    fake_pool_b = inputs['fake_pool_b']

    with tf.variable_scope("Model", reuse=tf.AUTO_REUSE) as scope: # 

        current_discriminator = discriminator
        current_encoder = build_encoder
        current_decoder = build_decoder


        prob_real_a_is_real, prob_real_a_aux = discriminator_aux(images_a, "d_A")
        prob_real_b_is_real = current_discriminator(images_b, "d_B")

        fake_images_b = build_generator_resnet_9blocks(images_a, images_a, name='g_A',skip=skip)
        latent_b = current_encoder(images_b, name='e_B', skip=skip, is_training=is_training, keep_rate=keep_rate)

        fake_images_a = current_decoder(latent_b, images_b, name='de_B', skip=skip)

        pred_mask_b = feature_alignment(latent_b, name='s_B', keep_rate=keep_rate, is_training=is_training)


        prob_fake_a_is_real, prob_fake_a_aux_is_real = discriminator_aux(fake_images_a, "d_A")
        prob_fake_b_is_real = current_discriminator(fake_images_b, "d_B")

        latent_fake_b = current_encoder(fake_images_b, 'e_B', skip=skip, is_training=is_training, keep_rate=keep_rate)
        cycle_images_b = build_generator_resnet_9blocks(fake_images_a, fake_images_a, 'g_A', skip=skip)

        cycle_images_a = current_decoder(latent_fake_b, fake_images_b, 'de_B', skip=skip)

        pred_mask_fake_b = feature_alignment(latent_fake_b, 's_B', keep_rate=keep_rate, is_training=is_training)


        prob_fake_pool_a_is_real, prob_fake_pool_a_aux_is_real = discriminator_aux(fake_pool_a, "d_A")
        prob_fake_pool_b_is_real = current_discriminator(fake_pool_b, "d_B")

        prob_cycle_a_is_real, prob_cycle_a_aux_is_real = discriminator_aux(cycle_images_a, "d_A")

        prob_pred_mask_fake_b_is_real = current_discriminator(pred_mask_fake_b, name="d_P")
        prob_pred_mask_b_is_real = current_discriminator(pred_mask_b, 'd_P')

        output_mask_fake_b_1, output_mask_fake_b_2, output_mask_fake_b_3 = ma_fcn(fake_images_b, rate=rate_mafcn,  is_training=is_training)
        output_mask_b_1, output_mask_b_2, output_mask_b_3 = ma_fcn(images_b, rate=rate_mafcn,  is_training=is_training)
        output_mask_cycle_images_b_1, _, _ = ma_fcn(cycle_images_b, rate=rate_mafcn,  is_training=is_training)

        output_mask_fake_a_1, _, _ = ma_fcn(fake_images_a, rate=rate_mafcn,  is_training=is_training)
        output_mask_a_1, _, _ = ma_fcn(images_a, rate=rate_mafcn, is_training=is_training)
        output_mask_cycle_images_a_1, _, _ = ma_fcn(cycle_images_a, rate=rate_mafcn,  is_training=is_training)


        prob_output_mask_fake_b_1_is_real = current_discriminator(output_mask_fake_b_1, name="d_output1")
        prob_output_mask_b_1_is_real = current_discriminator(output_mask_b_1, 'd_output1')

        prob_output_mask_fake_b_2_is_real = current_discriminator(output_mask_fake_b_2, name="d_output2")
        prob_output_mask_b_2_is_real = current_discriminator(output_mask_b_2, 'd_output2')


        prob_output_mask_fake_b_3_is_real = current_discriminator(output_mask_fake_b_3, name="d_output3")
        prob_output_mask_b_3_is_real = current_discriminator(output_mask_b_3, 'd_output3')


    return {
        'prob_real_a_is_real': prob_real_a_is_real,
        'prob_real_b_is_real': prob_real_b_is_real,
        'prob_fake_a_is_real': prob_fake_a_is_real,
        'prob_fake_b_is_real': prob_fake_b_is_real,
        'prob_fake_pool_a_is_real': prob_fake_pool_a_is_real,
        'prob_fake_pool_b_is_real': prob_fake_pool_b_is_real,
        'cycle_images_a': cycle_images_a,
        'cycle_images_b': cycle_images_b,
        'fake_images_a': fake_images_a,
        'fake_images_b': fake_images_b,
        'pred_mask_a': pred_mask_b,
        'pred_mask_b': pred_mask_b,
        'pred_mask_fake_a': pred_mask_fake_b,
        'pred_mask_fake_b': pred_mask_fake_b,
        'prob_pred_mask_fake_b_is_real': prob_pred_mask_fake_b_is_real,
        'prob_pred_mask_b_is_real': prob_pred_mask_b_is_real,
        'prob_fake_a_aux_is_real': prob_fake_a_aux_is_real,
        'prob_fake_pool_a_aux_is_real': prob_fake_pool_a_aux_is_real,
        'prob_cycle_a_aux_is_real': prob_cycle_a_aux_is_real,
        'output_mask_fake_b_1': output_mask_fake_b_1,
        'output_mask_fake_b_2': output_mask_fake_b_2,
        'output_mask_fake_b_3': output_mask_fake_b_3,
        'output_mask_b_1': output_mask_b_1,
        'output_mask_b_2': output_mask_b_2,
        'output_mask_b_3': output_mask_b_3,
        'prob_output_mask_fake_b_1_is_real': prob_output_mask_fake_b_1_is_real,
        'prob_output_mask_fake_b_2_is_real': prob_output_mask_fake_b_2_is_real,
        'prob_output_mask_fake_b_3_is_real': prob_output_mask_fake_b_3_is_real,
        'prob_output_mask_b_1_is_real': prob_output_mask_b_1_is_real,
        'prob_output_mask_b_2_is_real': prob_output_mask_b_2_is_real,
        'prob_output_mask_b_3_is_real': prob_output_mask_b_3_is_real,
        'output_mask_cycle_images_b_1': output_mask_cycle_images_b_1,
        'output_mask_fake_a_1': output_mask_fake_a_1,
        'output_mask_a_1': output_mask_a_1,
        'output_mask_cycle_images_a_1': output_mask_cycle_images_a_1,
    }


def build_resnet_block(inputres, dim, name="resnet", padding="REFLECT", norm_type=None, is_training=True, keep_rate=0.75):
    with tf.variable_scope(name):
        out_res = tf.pad(inputres, [[0, 0], [1, 1], [1, 1], [0, 0]], padding)
        out_res = layers.general_conv2d(out_res, dim, 3, 3, 1, 1, 0.01, "VALID", "c1", norm_type=norm_type, is_training=is_training, keep_rate=keep_rate)
        out_res = tf.pad(out_res, [[0, 0], [1, 1], [1, 1], [0, 0]], padding)
        out_res = layers.general_conv2d(out_res, dim, 3, 3, 1, 1, 0.01, "VALID", "c2", do_relu=False, norm_type=norm_type, is_training=is_training, keep_rate=keep_rate)

        return tf.nn.relu(out_res + inputres)

def build_resnet_block_ins(inputres, dim, name="resnet", padding="REFLECT"):
    with tf.variable_scope(name):
        out_res = tf.pad(inputres, [[0, 0], [1, 1], [1, 1], [0, 0]], padding)
        out_res = layers.general_conv2d_ga(out_res, dim, 3, 3, 1, 1, 0.02, "VALID", "c1", norm_type='Ins')
        out_res = tf.pad(out_res, [[0, 0], [1, 1], [1, 1], [0, 0]], padding)
        out_res = layers.general_conv2d_ga(out_res, dim, 3, 3, 1, 1, 0.02, "VALID", "c2", do_relu=False, norm_type='Ins')

        return tf.nn.relu(out_res + inputres)



def build_resnet_block_ds(inputres, dim_in, dim_out, name="resnet", padding="REFLECT", norm_type=None, is_training=True, keep_rate=0.75):

    with tf.variable_scope(name):
        out_res = tf.pad(inputres, [[0, 0], [1, 1], [1, 1], [0, 0]], padding)
        out_res = layers.general_conv2d(out_res, dim_out, 3, 3, 1, 1, 0.01, "VALID", "c1", norm_type=norm_type, is_training=is_training, keep_rate=keep_rate)
        out_res = tf.pad(out_res, [[0, 0], [1, 1], [1, 1], [0, 0]], padding)
        out_res = layers.general_conv2d(out_res, dim_out, 3, 3, 1, 1, 0.01, "VALID", "c2", do_relu=False, norm_type=norm_type, is_training=is_training, keep_rate=keep_rate)

        inputres = tf.pad(inputres, [[0, 0], [0, 0], [0, 0], [(dim_out - dim_in) // 2, (dim_out - dim_in) // 2]], padding)

        return tf.nn.relu(out_res + inputres)


def build_drn_block(inputdrn, dim, name="drn", padding="REFLECT", norm_type=None, is_training=True, keep_rate=0.75):

    with tf.variable_scope(name):
        out_drn = tf.pad(inputdrn, [[0, 0], [2, 2], [2, 2], [0, 0]], padding)
        out_drn = layers.dilate_conv2d(out_drn, dim, dim, 3, 3, 2, 0.01, "VALID", "c1", norm_type=norm_type, is_training=is_training, keep_rate=keep_rate)
        out_drn = tf.pad(out_drn, [[0, 0], [2, 2], [2, 2], [0, 0]], padding)
        out_drn = layers.dilate_conv2d(out_drn, dim, dim, 3, 3, 2, 0.01, "VALID", "c2", do_relu=False, norm_type=norm_type, is_training=is_training, keep_rate=keep_rate)

        return tf.nn.relu(out_drn + inputdrn)



def build_drn_block_ds(inputdrn, dim_in, dim_out, name='drn_ds', padding="REFLECT", norm_type=None, is_training=True, keep_rate=0.75):
    with tf.variable_scope(name):
        out_drn = tf.pad(inputdrn, [[0,0], [2,2], [2,2], [0,0]], padding)
        out_drn = layers.dilate_conv2d(out_drn, dim_in, dim_out, 3, 3, 2, 0.01, 'VALID', "c1", norm_type=norm_type, is_training=is_training, keep_rate=keep_rate)
        out_drn = tf.pad(out_drn, [[0,0], [2,2], [2,2], [0,0]], padding)
        out_drn = layers.dilate_conv2d(out_drn, dim_out, dim_out, 3, 3, 2, 0.01, 'VALID', "c2", do_relu=False, norm_type=norm_type, is_training=is_training, keep_rate=keep_rate)

        inputdrn = tf.pad(inputdrn, [[0,0], [0,0], [0, 0], [(dim_out-dim_in)//2,(dim_out-dim_in)//2]], padding)

        return tf.nn.relu(out_drn + inputdrn)


def build_generator_resnet_9blocks(inputgen, inputimg, name="generator", skip=False):
    with tf.variable_scope(name):
        f = 7
        ks = 3
        padding = "CONSTANT"

        pad_input = tf.pad(inputgen, [[0, 0], [ks, ks], [ks, ks], [0, 0]], padding)
        o_c1 = layers.general_conv2d_ga(pad_input, ngf, f, f, 1, 1, 0.02, name="c1", norm_type='Ins')
        o_c2 = layers.general_conv2d_ga(o_c1, ngf * 2, ks, ks, 2, 2, 0.02, "SAME", "c2", norm_type='Ins')
        o_c3 = layers.general_conv2d_ga(o_c2, ngf * 4, ks, ks, 2, 2, 0.02, "SAME", "c3", norm_type='Ins')

        o_r1 = build_resnet_block_ins(o_c3, ngf * 4, "r1", padding)
        o_r2 = build_resnet_block_ins(o_r1, ngf * 4, "r2", padding)
        o_r3 = build_resnet_block_ins(o_r2, ngf * 4, "r3", padding)
        o_r4 = build_resnet_block_ins(o_r3, ngf * 4, "r4", padding)
        o_r5 = build_resnet_block_ins(o_r4, ngf * 4, "r5", padding)
        o_r6 = build_resnet_block_ins(o_r5, ngf * 4, "r6", padding)
        o_r7 = build_resnet_block_ins(o_r6, ngf * 4, "r7", padding)
        o_r8 = build_resnet_block_ins(o_r7, ngf * 4, "r8", padding)
        o_r9 = build_resnet_block_ins(o_r8, ngf * 4, "r9", padding)

        o_c4 = layers.general_deconv2d(o_r9, [BATCH_SIZE, 128, 128, ngf * 2], ngf * 2, ks, ks, 2, 2, 0.02, "SAME", "c4", norm_type='Ins')
        o_c5 = layers.general_deconv2d(o_c4, [BATCH_SIZE, 256, 256, ngf], ngf, ks, ks, 2, 2, 0.02, "SAME", "c5", norm_type='Ins')
        o_c6 = layers.general_conv2d_ga(o_c5, 3, f, f, 1, 1, 0.02, "SAME", "c6", do_norm=False, do_relu=False)

        if skip is True:
            out_gen = tf.nn.tanh(inputimg + o_c6, "t1")
        else:
            out_gen = tf.nn.tanh(o_c6, "t1")

        return out_gen

def build_generator_resnet_9blocks1(inputgen, inputimg, name="generator"):
    with tf.variable_scope(name):
        f = 7
        ks = 3
        padding = "CONSTANT"

        pad_input = tf.pad(inputgen, [[0, 0], [ks, ks], [ks, ks], [0, 0]], padding)
        o_c1 = layers.general_conv2d_ga(pad_input, ngf, f, f, 1, 1, 0.02, name="c1", norm_type='Ins')
        o_c2 = layers.general_conv2d_ga(o_c1, ngf * 2, ks, ks, 2, 2, 0.02, "SAME", "c2", norm_type='Ins')
        o_c3 = layers.general_conv2d_ga(o_c2, ngf * 4, ks, ks, 2, 2, 0.02, "SAME", "c3", norm_type='Ins')

        o_r1 = build_resnet_block_ins(o_c3, ngf * 4, "r1", padding)
        o_r2 = build_resnet_block_ins(o_r1, ngf * 4, "r2", padding)
        o_r3 = build_resnet_block_ins(o_r2, ngf * 4, "r3", padding)
        o_r4 = build_resnet_block_ins(o_r3, ngf * 4, "r4", padding)
        o_r5 = build_resnet_block_ins(o_r4, ngf * 4, "r5", padding)
        o_r6 = build_resnet_block_ins(o_r5, ngf * 4, "r6", padding)
        o_r7 = build_resnet_block_ins(o_r6, ngf * 4, "r7", padding)
        o_r8 = build_resnet_block_ins(o_r7, ngf * 4, "r8", padding)
        o_r9 = build_resnet_block_ins(o_r8, ngf * 4, "r9", padding)

        o_c4 = layers.general_deconv2d(o_r9, [BATCH_SIZE, 128, 128, ngf * 2], ngf * 2, ks, ks, 2, 2, 0.02, "SAME", "c4", norm_type='Ins')
        o_c5 = layers.general_deconv2d(o_c4, [BATCH_SIZE, 256, 256, ngf], ngf, ks, ks, 2, 2, 0.02, "SAME", "c5", norm_type='Ins')
        o_c6 = layers.general_conv2d_ga(o_c5, 3, f, f, 1, 1, 0.02, "SAME", "c6", do_norm=False, do_relu=False)


        out_gen = tf.nn.tanh(o_c6, "t1")

        return out_gen


def build_encoder(inputen, name='encoder', skip=False, is_training=True, keep_rate=0.75):
    with tf.variable_scope(name):
        fb = 16
        k1 = 3
        padding = "CONSTANT"

        o_c1 = layers.general_conv2d(inputen, fb, 7, 7, 1, 1, 0.01, 'SAME', name="c1", norm_type="Batch", is_training=is_training, keep_rate=keep_rate)
        o_r1 = build_resnet_block(o_c1, fb, "r1", padding, norm_type='Batch', is_training=is_training, keep_rate=keep_rate)
        out1 = tf.nn.max_pool(o_r1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

        o_r2 = build_resnet_block_ds(out1, fb, fb*2, "r2", padding, norm_type='Batch', is_training=is_training, keep_rate=keep_rate)
        out2 = tf.nn.max_pool(o_r2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

        o_r3 = build_resnet_block_ds(out2, fb*2, fb*4, 'r3', padding, norm_type='Batch', is_training=is_training, keep_rate=keep_rate)
        o_r4 = build_resnet_block(o_r3, fb*4, 'r4', padding, norm_type='Batch', is_training=is_training, keep_rate=keep_rate)
        out3 = tf.nn.max_pool(o_r4, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

        o_r5 = build_resnet_block_ds(out3, fb*4, fb*8, 'r5', padding, norm_type='Batch', is_training=is_training, keep_rate=keep_rate)
        o_r6 = build_resnet_block(o_r5, fb*8, 'r6', padding, norm_type='Batch', is_training=is_training, keep_rate=keep_rate)

        o_r7 = build_resnet_block_ds(o_r6, fb*8, fb*16, 'r7', padding, norm_type='Batch', is_training=is_training, keep_rate=keep_rate)
        o_r8 = build_resnet_block(o_r7, fb*16, 'r8', padding, norm_type='Batch', is_training=is_training, keep_rate=keep_rate)

        o_r9 = build_resnet_block(o_r8, fb*16, 'r9', padding, norm_type='Batch', is_training=is_training, keep_rate=keep_rate)
        o_r10 = build_resnet_block(o_r9, fb * 16, 'r10', padding, norm_type='Batch', is_training=is_training, keep_rate=keep_rate)

        o_r11 = build_resnet_block_ds(o_r10, fb * 16, fb * 32, 'r11', padding, norm_type='Batch', is_training=is_training, keep_rate=keep_rate)
        o_r12 = build_resnet_block(o_r11, fb * 32, 'r12', padding, norm_type='Batch', is_training=is_training, keep_rate=keep_rate)

        o_d1 = build_drn_block(o_r12, fb*32, 'd1', padding, norm_type='Batch', is_training=is_training, keep_rate=keep_rate)
        o_d2 = build_drn_block(o_d1, fb*32, 'd2', padding, norm_type='Batch', is_training=is_training, keep_rate=keep_rate)

        o_c2 = layers.general_conv2d(o_d2, fb*32, k1, k1, 1, 1, 0.01, 'SAME', 'c2', norm_type='Batch', is_training=is_training,keep_rate=keep_rate)
        o_c3 = layers.general_conv2d(o_c2, fb*32, k1, k1, 1, 1, 0.01, 'SAME', 'c3', norm_type='Batch', is_training=is_training, keep_rate=keep_rate)


        return o_c3


def build_decoder(inputde, inputimg, name='decoder', skip=False):
    with tf.variable_scope(name):
        f = 7
        ks = 3
        padding = "CONSTANT"

        o_c1 = layers.general_conv2d(inputde, ngf * 4, ks, ks, 1, 1, 0.02, "SAME", "c1", norm_type='Ins')
        o_r1 = build_resnet_block(o_c1, ngf * 4, "r1", padding, norm_type='Ins')
        o_r2 = build_resnet_block(o_r1, ngf * 4, "r2", padding, norm_type='Ins')
        o_r3 = build_resnet_block(o_r2, ngf * 4, "r3", padding, norm_type='Ins')
        o_r4 = build_resnet_block(o_r3, ngf * 4, "r4", padding, norm_type='Ins')
        o_c3 = layers.general_deconv2d(o_r4, [BATCH_SIZE, 64, 64, ngf * 2], ngf * 2, ks, ks, 2, 2, 0.02, "SAME", "c3", norm_type='Ins')
        o_c4 = layers.general_deconv2d(o_c3, [BATCH_SIZE, 128, 128, ngf * 2], ngf * 2, ks, ks, 2, 2, 0.02, "SAME", "c4", norm_type='Ins')
        o_c5 = layers.general_deconv2d(o_c4, [BATCH_SIZE, 256, 256, ngf], ngf, ks, ks, 2, 2, 0.02, "SAME", "c5", norm_type='Ins')
        o_c6 = layers.general_conv2d(o_c5, 3, f, f, 1, 1, 0.02, "SAME", "c6", do_norm=False, do_relu=False)

        if skip is True:
            out_gen = tf.nn.tanh(inputimg + o_c6, "t1")
        else:
            out_gen = tf.nn.tanh(o_c6, "t1")

        return out_gen

def feature_alignment(inputse, name='f_alignment', keep_rate=0.75, is_training=False):
    with tf.variable_scope(name):

        f = 3
        k1 = 1

        up7 = conv2d_relu((keras.layers.UpSampling2D(size=(2, 2))(inputse)), 128, 2, name='c7_1')
        conv7_1 = conv2d_bn_relu(up7, 128, f, name='c7_2', is_training=is_training)
        conv7 = conv2d_bn_relu(conv7_1, 128, f, name='c7_3', is_training=is_training)

        up8 = conv2d_relu((keras.layers.UpSampling2D(size=(2, 2))(conv7)), 64, 2, name='c8_1')
        conv8_1 = conv2d_bn_relu(up8, 64, f, name='c8_2', is_training=is_training)
        conv8 = conv2d_bn_relu(conv8_1, 64, f, name='c8_3', is_training=is_training)


        up9 = conv2d_relu((keras.layers.UpSampling2D(size=(2, 2))(conv8)), 32, 2, name='c9_1')
        conv9_1 = conv2d_bn_relu(up9, 32, f, name='c9_2', is_training=is_training)
        conv9 = conv2d_bn_relu(conv9_1, 32, f, name='c9_3', is_training=is_training)
        feature_result = conv2d_relu(conv9, num_cls, 1, name='c9_4', padding='VALID', is_relu=False)
        return feature_result


def discriminator(inputdisc, name="discriminator"):
    with tf.variable_scope(name):
        f = 4
        padw = 2

        pad_input = tf.pad(inputdisc, [[0, 0], [padw, padw], [padw, padw], [0, 0]], "CONSTANT")
        o_c1 = layers.general_conv2d(pad_input, ndf, f, f, 2, 2, 0.02, "VALID", "c1", do_norm=False, relufactor=0.2, norm_type='Ins')

        pad_o_c1 = tf.pad(o_c1, [[0, 0], [padw, padw], [padw, padw], [0, 0]], "CONSTANT")
        o_c2 = layers.general_conv2d(pad_o_c1, ndf * 2, f, f, 2, 2, 0.02, "VALID", "c2", relufactor=0.2, norm_type='Ins')

        pad_o_c2 = tf.pad(o_c2, [[0, 0], [padw, padw], [padw, padw], [0, 0]], "CONSTANT")
        o_c3 = layers.general_conv2d(pad_o_c2, ndf * 4, f, f, 2, 2, 0.02, "VALID", "c3", relufactor=0.2, norm_type='Ins')

        pad_o_c3 = tf.pad(o_c3, [[0, 0], [padw, padw], [padw, padw], [0, 0]], "CONSTANT")
        o_c4 = layers.general_conv2d(pad_o_c3, ndf * 8, f, f, 1, 1, 0.02, "VALID", "c4", relufactor=0.2, norm_type='Ins')

        pad_o_c4 = tf.pad(o_c4, [[0, 0], [padw, padw], [padw, padw], [0, 0]], "CONSTANT")
        o_c5 = layers.general_conv2d(pad_o_c4, 1, f, f, 1, 1, 0.02, "VALID", "c5", do_norm=False, do_relu=False)

        return o_c5


def discriminator_aux(inputdisc, name="discriminator"):
    with tf.variable_scope(name):
        f = 4
        padw = 2
        pad_input = tf.pad(inputdisc, [[0, 0], [padw, padw], [padw, padw], [0, 0]], "CONSTANT")
        o_c1 = layers.general_conv2d(pad_input, ndf, f, f, 2, 2, 0.02, "VALID", "c1", do_norm=False, relufactor=0.2, norm_type='Ins')

        pad_o_c1 = tf.pad(o_c1, [[0, 0], [padw, padw], [padw, padw], [0, 0]], "CONSTANT")
        o_c2 = layers.general_conv2d(pad_o_c1, ndf * 2, f, f, 2, 2, 0.02, "VALID", "c2", relufactor=0.2, norm_type='Ins')

        pad_o_c2 = tf.pad(o_c2, [[0, 0], [padw, padw], [padw, padw], [0, 0]], "CONSTANT")
        o_c3 = layers.general_conv2d(pad_o_c2, ndf * 4, f, f, 2, 2, 0.02, "VALID", "c3", relufactor=0.2, norm_type='Ins')

        pad_o_c3 = tf.pad(o_c3, [[0, 0], [padw, padw], [padw, padw], [0, 0]], "CONSTANT")
        o_c4 = layers.general_conv2d(pad_o_c3, ndf * 8, f, f, 1, 1, 0.02, "VALID", "c4", relufactor=0.2, norm_type='Ins')

        pad_o_c4 = tf.pad(o_c4, [[0, 0], [padw, padw], [padw, padw], [0, 0]], "CONSTANT")
        o_c5 = layers.general_conv2d(pad_o_c4, 2, f, f, 1, 1, 0.02, "VALID", "c5", do_norm=False, do_relu=False)

        return tf.expand_dims(o_c5[...,0], axis=3), tf.expand_dims(o_c5[...,1], axis=3)


def ma_fcn(inputdisc, name="u_B/ma_fcn", rate=0.2, is_training=False):
    with tf.variable_scope(name):
        f = 3
        class_num = num_cls

        conv1 = conv2d_relu(inputdisc, 32, f, name='c1_1', is_training=is_training)
        conv1 = conv2d_bn_relu(conv1, 32, f, name='c1_2', is_training=is_training)
        conv1 = keras.layers.Dropout(rate)(conv1)
        pool1 = keras.layers.MaxPooling2D(pool_size=(2, 2), strides=2)(conv1)

        conv2 = conv2d_relu(pool1, 64, f, name='c2_1', is_training=is_training)
        conv2 = conv2d_bn_relu(conv2, 64, f, name='c2_2', is_training=is_training)
        conv2 = keras.layers.Dropout(rate)(conv2)
        pool2 = keras.layers.MaxPooling2D(pool_size=(2, 2), strides=2)(conv2)

        conv3 = conv2d_relu(pool2, 128, f, name='c3_1', is_training=is_training)
        conv3 = conv2d_bn_relu(conv3, 128, f, name='c3_2', is_training=is_training)
        conv3 = conv2d_bn_relu(conv3, 128, f, name='c3_3', is_training=is_training)
        conv3 = keras.layers.Dropout(rate)(conv3)
        pool3 = keras.layers.MaxPooling2D(pool_size=(2, 2), strides=2)(conv3)

        conv4 = conv2d_relu(pool3, 256, f, name='c4_1', is_training=is_training)
        conv4 = conv2d_bn_relu(conv4, 256, f, name='c4_2', is_training=is_training)
        conv4 = conv2d_bn_relu(conv4, 256, f, name='c4_3', is_training=is_training)
        conv4 = keras.layers.Dropout(rate)(conv4)
        pool4 = keras.layers.MaxPooling2D(pool_size=(2, 2), strides=2)(conv4)

        conv5 = conv2d_relu(pool4, 512, f, name='c5_1', is_training=is_training)
        conv5 = conv2d_bn_relu(conv5, 512, f, name='c5_2', is_training=is_training)
        conv5 = conv2d_bn_relu(conv5, 512, f, name='c5_3', is_training=is_training)
        conv5 = keras.layers.Dropout(rate)(conv5)

        #tam = layers.tam(conv5, 512, 'tam_layer', is_training=is_training)

        up6 = conv2d_relu(keras.layers.UpSampling2D(size=(2, 2))(conv5), 256, 2, name='c6_1')
        merge6 = keras.layers.concatenate(([conv4, up6]), axis=-1)
        conv6 = conv2d_relu(merge6, 256, f, name='c6_2', is_training=is_training)
        conv6 = conv2d_relu(conv6, 256, f, name='c6_3', is_training=is_training)
        conv6 = keras.layers.Dropout(rate)(conv6)

        up7 = conv2d_relu((keras.layers.UpSampling2D(size=(2, 2))(conv6)), 128, 2, name='c7_1')
        merge7 = keras.layers.concatenate(([conv3, up7]), axis=-1)
        conv7 = conv2d_relu(merge7, 128, f, name='c7_2', is_training=is_training)
        conv7 = conv2d_relu(conv7, 128, f, name='c7_3', is_training=is_training)
        conv7 = keras.layers.Dropout(rate)(conv7)

        out3_1 = conv2d_relu(conv7, 64, 1, name='out3_1')
        out3_2 = conv2d_relu(out3_1, class_num, 1, name='out3_2', padding='VALID', is_relu=False)
        output3 = keras.activations.softmax(out3_2, axis=-1)

        up8 = conv2d_relu((keras.layers.UpSampling2D(size=(2, 2))(conv7)), 64, 2, name='c8_1')
        merge8 = keras.layers.concatenate(([conv2, up8]), axis=-1)
        conv8 = conv2d_relu(merge8, 64, f, name='c8_2', is_training=is_training)
        conv8 = conv2d_relu(conv8, 64, f, name='c8_3', is_training=is_training)

        out2_1 = conv2d_relu(conv8, 32, 1, name='out2_1')
        out2_2 = conv2d_relu(out2_1, class_num, 1, name='out2_2', padding='VALID', is_relu=False)
        output2 = keras.activations.softmax(out2_2, axis=-1)


        up9 = conv2d_relu((keras.layers.UpSampling2D(size=(2, 2))(conv8)), 32, 2, name='c9_1')
        merge9 = keras.layers.concatenate(([conv1, up9]), axis=-1)
        conv9 = conv2d_relu(merge9, 32, f, name='c9_2', is_training=is_training)
        conv9 = conv2d_relu(conv9, 32, f, name='c9_3', is_training=is_training)
        out1_1 = conv2d_relu(conv9, class_num, 1, name='c9_4', padding='VALID', is_relu=False)
        output1 = keras.activations.softmax(out1_1, axis=-1)


        return output1, output2, output3


def conv2d_relu(input, out_ch, k, name='conv2d_relu', padding='SAME', is_relu=True, is_training=False):
    with tf.variable_scope(name):
        conv = tf.contrib.layers.conv2d(input, out_ch, k, padding=padding, activation_fn=None,
                                        weights_initializer=tf.contrib.layers.xavier_initializer(),
                                        biases_initializer=None)
        if is_relu:
            conv = tf.nn.relu(conv, "relu")
    return conv


def conv2d_bn_relu(input, out_ch, k, name='conv2d_bn_relu', is_training=False):
    with tf.variable_scope(name):
        conv = tf.contrib.layers.conv2d(input, out_ch, k, padding='SAME', activation_fn=None,
                                        weights_initializer=tf.contrib.layers.xavier_initializer(),
                                        biases_initializer=None)
        conv = layers.batch_norm_unet(conv, is_training)
        conv = tf.nn.relu(conv, "relu")

    return conv
