import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import numpy as np
from keras.models import *
from keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, Conv2DTranspose, Concatenate, Activation
from keras.callbacks import ModelCheckpoint, TensorBoard
from keras.preprocessing.image import array_to_img, img_to_array, load_img
from keras.optimizers import Adam
from keras.initializers import Constant
import keras as k
import glob, cv2, imageio
# import data
from Conv import GroupedConv2D, MixNetConvInitializer
import warnings
warnings.filterwarnings('ignore')


def Iou(y_true, y_pred):
    # extract the label values using the argmax operator then
    # calculate equality of the predictions and truths to the label

    y_pred = K.cast(K.greater(y_pred, 0.5), K.floatx())
    # calculate the |intersection| (AND) of the labels
    intersection = K.sum(K.cast(y_true, K.floatx()) * y_pred)

    # calculate the |union| (OR) of the labels
    union = K.sum(y_true) + K.sum(y_pred) - intersection
    # avoid divide by zero - if the union is zero, return 1
    # otherwise, return the intersection over union
    return K.switch(K.equal(union, 0.0), 1.0, intersection / union)


class myUnet(object):
    def __init__(self,img_rows=2048,img_cols=2048):

        self.img_rows=img_rows
        self.img_cols=img_cols

    def funet_mix(self):

        inputs = Input((self.img_rows, self.img_cols, 4))

        conv1 = Conv2D(48, 3, activation='relu', padding='same')(inputs)
        conv1 = GroupedConv2D(48, kernel_size=[3, 5, 7], strides=[1, 1], activation='relu',
                              kernel_initializer=MixNetConvInitializer(), padding='same', use_bias=True)(conv1)
        pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

        conv2 = Conv2D(96, 3, activation='relu', padding='same')(pool1)
        conv2 = GroupedConv2D(96, kernel_size=[3, 5, 7], strides=[1, 1], activation='relu',
                              kernel_initializer=MixNetConvInitializer(), padding='same', use_bias=True)(conv2)
        pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

        conv3 = Conv2D(192, 3, activation='relu', padding='same')(pool2)
        conv3 = GroupedConv2D(192, kernel_size=[3, 5, 7], strides=[1, 1], activation='relu',
                              kernel_initializer=MixNetConvInitializer(), padding='same', use_bias=True)(conv3)
        conv3 = GroupedConv2D(192, kernel_size=[3, 5, 7], strides=[1, 1], activation='relu',
                              kernel_initializer=MixNetConvInitializer(), padding='same', use_bias=True)(conv3)
        pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

        conv4 = Conv2D(384, 3, activation='relu', padding='same')(pool3)
        conv4 = GroupedConv2D(384, kernel_size=[3, 5, 7], strides=[1, 1], activation='relu',
                              kernel_initializer=MixNetConvInitializer(), padding='same', use_bias=True)(conv4)
        conv4 = GroupedConv2D(384, kernel_size=[3, 5, 7], strides=[1, 1], activation='relu',
                              kernel_initializer=MixNetConvInitializer(), padding='same', use_bias=True)(conv4)
        pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

        conv5 = GroupedConv2D(576, kernel_size=[3, 5], strides=[1, 1], activation='relu',
                              kernel_initializer=MixNetConvInitializer(), padding='same', use_bias=True)(pool4)
        conv5 = GroupedConv2D(576, kernel_size=[3, 5], strides=[1, 1], activation='relu',
                              kernel_initializer=MixNetConvInitializer(), padding='same', use_bias=True)(conv5)
        conv5 = GroupedConv2D(576, kernel_size=[3, 5], strides=[1, 1], activation='relu',
                              kernel_initializer=MixNetConvInitializer(), padding='same', use_bias=True)(conv5)

        up6 = Conv2DTranspose(384, kernel_size=(4, 4), strides=(2, 2), activation='relu', padding='same')(conv5)
        merge6 = k.layers.concatenate([conv4, up6], axis=3)
        conv6 = GroupedConv2D(384, kernel_size=[3, 5], strides=[1, 1], activation='relu',
                              kernel_initializer=MixNetConvInitializer(), padding='same', use_bias=True)(merge6)
        conv6 = GroupedConv2D(192, kernel_size=[3, 5], strides=[1, 1], activation='relu',
                              kernel_initializer=MixNetConvInitializer(), padding='same', use_bias=True)(conv6)
        conv61 = Conv2D(32, 3, activation='relu', padding='same')(conv6)
        b1 = Conv2D(1, 1, activation=None, use_bias=False, padding='same')(conv61)

        up7 = Conv2DTranspose(192, kernel_size=(4, 4), strides=(2, 2), activation='relu', padding='same')(conv6)
        merge7 = k.layers.concatenate([conv3, up7], axis=3)
        conv7 = GroupedConv2D(192, kernel_size=[3, 5], strides=[1, 1], activation='relu',
                              kernel_initializer=MixNetConvInitializer(), padding='same', use_bias=True)(merge7)
        conv7 = GroupedConv2D(96, kernel_size=[3, 5], strides=[1, 1], activation='relu',
                              kernel_initializer=MixNetConvInitializer(), padding='same', use_bias=True)(conv7)
        conv71 = Conv2D(32, 3, activation='relu', padding='same')(conv7)
        b2 = Conv2D(1, 1, activation=None, use_bias=False, padding='same')(conv71)

        up8 = Conv2DTranspose(96, kernel_size=(4, 4), strides=(2, 2), activation='relu', padding='same')(conv7)
        merge8 = k.layers.concatenate([conv2, up8], axis=3)
        conv8 = GroupedConv2D(96, kernel_size=[3, 5], strides=[1, 1], activation='relu',
                              kernel_initializer=MixNetConvInitializer(), padding='same', use_bias=True)(merge8)
        conv8 = GroupedConv2D(48, kernel_size=[3, 5], strides=[1, 1], activation='relu',
                              kernel_initializer=MixNetConvInitializer(), padding='same', use_bias=True)(conv8)
        conv81 = Conv2D(32, 3, activation='relu', padding='same')(conv8)
        b3 = Conv2D(1, 1, activation=None, use_bias=False, padding='same')(conv81)

        up9 = Conv2DTranspose(48, kernel_size=(4, 4), strides=(2, 2), activation='relu', padding='same')(conv8)
        merge9 = k.layers.concatenate([conv1, up9], axis=3)
        conv9 = GroupedConv2D(48, kernel_size=[3, 5], strides=[1, 1], activation='relu',
                              kernel_initializer=MixNetConvInitializer(), padding='same', use_bias=True)(merge9)
        conv9 = Conv2D(32, 3, activation='relu', padding='same')(conv9)
        b4 = Conv2D(1, 1, activation=None, use_bias=False, padding='same', name='b4')(conv9)

        ob1 = UpSampling2D(size=(8, 8), data_format=None)(b1)
        ob2 = UpSampling2D(size=(4, 4), data_format=None)(b2)
        ob3 = UpSampling2D(size=(2, 2), data_format=None)(b3)

        fuse = Concatenate(axis=-1)([ob1, ob2, ob3, b4])
        fuse = Conv2D(1, (1, 1), padding='same', use_bias=False, activation=None,
                      kernel_initializer=Constant(value=0.25))(fuse)

        o1 = Activation('sigmoid', name='o1')(b1)
        o2 = Activation('sigmoid', name='o2')(b2)
        o3 = Activation('sigmoid', name='o3')(b3)
        o4 = Activation('sigmoid', name='o4')(b4)
        ofuse = Activation('sigmoid', name='ofuse')(fuse)

        model = Model(inputs=[inputs], outputs=[o1, o2, o3, o4, ofuse])

        model.compile(loss={'o1': "binary_crossentropy",
                            'o2': "binary_crossentropy",
                            'o3': "binary_crossentropy",
                            'o4': "binary_crossentropy",
                            'ofuse': "binary_crossentropy",
                            },
                      metrics={'ofuse': Iou},
                      optimizer='adam')
        return model

if __name__ == '__main__':
    myunet = myUnet(2048,2048)
