# -*-coding:utf-8-*-
# pylint: disable=E1101,R,C
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.utils.data as data_utils
import gzip
import pickle
import numpy as np
from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import matplotlib.pyplot as plt
import math
import cv2
import os
import shutil
from compact_bilinear_pooling import CountSketch, CompactBilinearPooling
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
from cbam import ChannelGate,SpatialGate

NUM_EPOCHS = 20
BATCH_SIZE = 32
LEARNING_RATE = 1e-4
# count = 2800
count = 1005
# count = 2100
outputN = 8192


pretrained_model='/net_params.pkl'


def loadtestdata():

    path = "/images/test"
    testset = torchvision.datasets.ImageFolder(path,
                                                transform=transforms.Compose([
                                                    transforms.Resize((224, 224)),  

                                                    transforms.CenterCrop(224),
                                                    transforms.ToTensor()])
                                                )

    testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
                                              shuffle=False, num_workers=0)
    return testset,testloader


def loadquerydata():
    path = "/images/query"  
    queryset = torchvision.datasets.ImageFolder(path,
                                                transform=transforms.Compose([
                                                    transforms.Resize((224, 224)), 

                                                    transforms.CenterCrop(224),
                                                    transforms.ToTensor()])
                                                )

    queryloader = torch.utils.data.DataLoader(queryset, batch_size=1,
                                              shuffle=False, num_workers=0)
    return queryset,trainloader

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride


    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = conv1x1(inplanes, planes)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes, stride)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = conv1x1(planes, planes * self.expansion)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=38, zero_init_residual=False):
        super(ResNet, self).__init__()

        self.inplanes = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)

        for p in self.parameters():
            p.requires_grad = False

        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        self.input_size = 512
        self.f_output = 4096
        self.output_size = 8192
        self.gate_channels = 512
        self.ChannelGate = ChannelGate(self.gate_channels)
        self.SpatialGate = SpatialGate()
        self.mcb = CompactBilinearPooling(self.input_size, self.input_size, self.output_size).cuda()

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))          
        self.fc1 = nn.Linear(self.output_size * block.expansion, self.f_output)
        self.fc2 = nn.Linear(self.f_output, num_classes)


        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)


        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def signed_sqrt(self,x):
        y = torch.sqrt(F.relu(x)) - torch.sqrt(F.relu(-x))
        return y

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)


        x_channel = self.ChannelGate(x)
        x_channel = x_channel.permute(0, 2, 3, 1)
        x_spatial = self.SpatialGate(x)
        x_spatial = x_spatial.permute(0, 2, 3, 1)

        bilinear = self.mcb(x_channel, x_spatial)
        bilinear = bilinear.permute(0, 3, 1, 2)
        poolB = F.max_pool2d(bilinear, 7, 7)
        poolB = poolB.view(poolB.shape[0], -1)
        signed_sqrt = self.signed_sqrt(poolB)
        x = F.normalize(signed_sqrt, p=2, dim=1)

        # x = self.fc1(L2)
        # x = self.fc2(x)
        return x

def resnet34(pretrained=False, **kwargs):
    """Constructs a ResNet-34 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
    return model


def pca(dataMat):
    meanVals = np.mean(dataMat, axis=0)  
    meanRemoved = dataMat - meanVals
    covMat = np.cov(meanRemoved, rowvar=0)  
    eigVals, eigVects = np.linalg.eig(np.mat(covMat))  
    k = 32  
    eigValInd = np.argsort(eigVals)  
    eigValInd = eigValInd[:-(k + 1):-1]  
    redEigVects = eigVects[:, eigValInd]  
    lowDDataMat = meanRemoved * redEigVects  
    reconMat = (lowDDataMat * redEigVects.T) + meanVals  
    return lowDDataMat, reconMat
   

def maxminnorm(array):
    maxcols=array.max(axis=0)
    mincols=array.min(axis=0)
    data_shape = array.shape
    data_rows = data_shape[0]
    data_cols = data_shape[1]
    t=np.empty((data_rows,data_cols))
    for j in range(data_rows):
     for i in xrange(data_cols):
        t[j,i]=(array[j,i]-mincols[i])/(maxcols[i]-mincols[i])
    return t,maxcols,mincols


	
def calcMean(x,y):
    x_mean = mean(x)
    y_mean = mean(y)

    return x_mean,y_mean

    
def calcPearson(x, y):
    x_mean, y_mean = calcMean(x, y)  


    sumTop = 0.0
    sumBottom = 0.0
    x_pow = 0.0
    y_pow = 0.0
    n = x.shape[1]


    for i in range(n):
        sumTop += (x[0, i] - x_mean) * (y[0, i] - y_mean)
    for i in range(n):
        x_pow += math.pow(x[0, i] - x_mean, 2)
    for i in range(n):
        y_pow += math.pow(y[0, i] - y_mean, 2)

    sumBottom = math.sqrt(x_pow * y_pow)

    p = sumTop / sumBottom
    return p


def main():

    test_dataset, test_loader = loadtestdata()
	query_dataset, query_loader = loadquerydata()

    model = resnet34()
    # params = model.state_dict()
    # for k, v in params.items():
    #     print(k)  


    state_dict=torch.load(pretrained_model)
    model_dict=model.state_dict()
    pretrained_dict=state_dict
    pretrained_dict=({k: v for k, v in pretrained_dict.items() if k in model_dict})
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)



    print("",model)
    model.to(DEVICE)

    array = np.zeros(shape=(count, outputN))

    num = 0;
    for i, data in enumerate(test_loader):

        model.eval()
        image_test, label_test = data
        with torch.no_grad():
            image_test = image_test.to(DEVICE)
            output_test = model(image_test)
            batch = output_test.cpu().detach().numpy()

            array[num : num+batch.shape[0], :] = batch

            num = num+batch.shape[0]

    array,maxcols,mincols = maxminnorm(array)

    lowDDataMat, meanVals, redEigVects = pcaydw(array,38)
    lowDDataMat = lowDDataMat.astype(np.float64)


    f = file("result.txt", "w")

    temp = "Desert"
    c_class = 0
    c_map = 0
    s_class = 0
    s_map = 0
    DIR = 'images/test'  
    l = len([name for name in os.listdir(DIR) if os.path.isfile(os.path.join(DIR, name))])
    print l
    ll = 0

    mk_5 = 0
    mk_10 = 0
    mk_50 = 0
    mk_100 = 0


    for i, data in enumerate(query_loader):
        ll = ll + 1
        model.eval()
        image_query, label_query = data
        with torch.no_grad():
            image_query = image_query.to(DEVICE)
            output_query = model(image_test)
            output_query = output_query.cpu().detach().numpy()
            data_cols = output_query.shape[1]

            for j in xrange(data_cols):
                if maxcols[j] == mincols[j]:
                    output_query[0][j] = 0
                else:
                    output_query[0][j] = (output_query[0][j] - mincols[j]) / (maxcols[j] - mincols[j])

            meanRemoved = output_query - meanVals
            low_query = meanRemoved * redEigVects
            low_query = low_query.astype(np.float64)


        score = np.zeros(((len(test_dataset)), 2))


        s = str(query_dataset.imgs[i][0])
        l1 = len("images/test/")
        l2 = len("_01.jpg")
        label_1 = s[l1:-1 * l2]

        if label_1 == temp:
            c_class = c_class + 1
        else:
            if ll == 1:
                c_class = c_class + 1
                temp = label_1
            else:
                t = c_map * 1.0 / c_class
                t = round(t * 1.0, 4)
                f.write(str(temp) + " " + str(t) + " \n\n")
                s_class = s_class + 1
                print s_class
                s_map = s_map + t

                c_class = 1
                c_map = 0
                temp = label_1

        print s[l1:]
        f.write(str(s[l1:]) + " ")

        for j in range(count):
            pearson = calcPearson(low_query, lowDDataMat[j])
            score[j, 0] = pearson
            score[j, 1] = j

        index = np.argsort(score, 0)[:, 0]


        j = 1
        k_5 = 0
        k_10 = 0
        k_50 = 0
        k_100 = 0
        k_all = 0
        ap = 0

        for i in range(len(index) - 2, 0, -1):
            sort = index[i]


            s = str(test_dataset.imgs[sort][0])
            l3 = len("images/test/")
            l4 = len("_01.jpg")

            label_2 = s[l3:-1 * l4]
            if i > len(index) - 7:
                if label_2 == label_1:
                    k_5 = k_5 + 1

            if i > len(index) - 12:


                if label_2 == label_1:
                    k_10 = k_10 + 1
                    ap = ap + k_10 * 1.0 / j

                j = j + 1

            if i > len(index) - 52:
                if label_2 == label_1:
                    k_50 = k_50 + 1

            if i > len(index) - 102:
                if label_2 == label_1:
                    k_100 = k_100 + 1


        mk_5 = mk_5 + k_5 / 5.0
        mk_10 = mk_10 + k_10 / 10.0
        mk_50 = mk_50 + k_50 / 50.0
        mk_100 = mk_100 + k_100 / 100.0

        if k_10 == 0:

            f.write(str(k_10 * 1.0 / (j - 1)) + " ")
            map = 0
            f.write(str(map) + " \n")
        else:
            map = ap / k_10
            f.write(str(k_10 * 1.0 / (j - 1)) + " ")
            c_map = c_map + map
            map = round(map * 1.0, 4)
            f.write(str(map) + " \n")

        if ll == l:
            t = c_map * 1.0 / c_class
            t = round(t * 1.0, 4)
            f.write(str(temp) + " " + str(t) + " \n\n")
            s_class = s_class + 1
            s_map = s_map + t
            print s_class

    mk_5 = mk_5 / l
    mk_10 = mk_10 / l
    mk_50 = mk_50 / l
    mk_100 = mk_100 / l

    mk_5 = round(mk_5 * 1.0, 4)
    mk_10 = round(mk_10 * 1.0, 4)
    mk_50 = round(mk_50 * 1.0, 4)
    mk_100 = round(mk_100 * 1.0, 4)

    q = s_map * 1.0 / s_class

    q = round(q * 1.0, 4)
    f.write("total_map" + " " + str(q) + " \n")
    print "total_map   ", q
    print "mk_5   ", mk_5
    print "mk_10   ", mk_10
    print "mk_50   ", mk_50
    print "mk_100   ", mk_100

    f.close()


if __name__ == '__main__':
    main()
