# coding=utf-8
import tensorflow as tf
import json
import numpy as np
import cv2
import os
import random

with open('./config_param.json') as config_file:
    config = json.load(config_file)
num_cls = int(config['num_cls'])
# The height of each image.
IMG_HEIGHT = int(config['img_height'])
# The width of each image.
IMG_WIDTH = int(config['img_width'])

def _decode_samples(image_list, shuffle=False):
    decomp_feature = {
        # image size, dimensions of 3 consecutive slices
        # dimension shape records
        'dsize_dim0': tf.FixedLenFeature([], tf.int64),
        'dsize_dim1': tf.FixedLenFeature([], tf.int64),
        'dsize_dim2': tf.FixedLenFeature([], tf.int64),
        # label size, dimension of the middle slice
        'lsize_dim0': tf.FixedLenFeature([], tf.int64),
        'lsize_dim1': tf.FixedLenFeature([], tf.int64),
        'lsize_dim2': tf.FixedLenFeature([], tf.int64),
        # image slices of size 
        # image data records
        'data_vol': tf.FixedLenFeature([], tf.string),
        # label slice of size
        # label data records
        'label_vol': tf.FixedLenFeature([], tf.string)}

    img_raw_size = [IMG_HEIGHT, IMG_WIDTH, 3]
    volume_size = [IMG_HEIGHT, IMG_WIDTH, 3]
    label_raw_size = [IMG_HEIGHT, IMG_WIDTH, 1]
    label_size = [IMG_HEIGHT, IMG_WIDTH, 1]

    data_queue = tf.train.string_input_producer(image_list, shuffle=shuffle)
    reader = tf.TFRecordReader()
    fid, serialized_example = reader.read(data_queue)
    parser = tf.parse_single_example(serialized_example, features=decomp_feature)

    data_vol = tf.decode_raw(parser['data_vol'], tf.float32)
    data_vol = tf.reshape(data_vol, img_raw_size)
    data_vol = tf.slice(data_vol, [0, 0, 0], volume_size)
    label_vol = tf.decode_raw(parser['label_vol'], tf.float32)
    #return data_vol, label_vol
    label_vol = tf.reshape(label_vol, label_raw_size)
    label_vol = tf.slice(label_vol, [0, 0, 0], label_size)

    batch_y = tf.one_hot(tf.cast(tf.squeeze(label_vol), tf.uint8), num_cls)
    return data_vol, batch_y
    # return data_vol,label_vol

def _load_samples(source_pth, target_pth,shuffle=False):

    with open(source_pth, 'r') as fp:
        rows = fp.readlines()
    imagea_list = [row[:-1] for row in rows]
    print(imagea_list[:10])

    with open(target_pth, 'r') as fp:
        rows = fp.readlines()
    imageb_list = [row[:-1] for row in rows]
    print(imageb_list[:10])
    data_vola, label_vola = _decode_samples(imagea_list, shuffle=shuffle)
    data_volb, label_volb = _decode_samples(imageb_list, shuffle=shuffle)

    return data_vola, data_volb, label_vola, label_volb


def load_data(source_pth, target_pth, batch_size, do_shuffle=True):

    image_i, image_j, gt_i, gt_j = _load_samples(source_pth, target_pth, shuffle=do_shuffle)

    # Batch
    print("****")
    #return image_i, image_j, gt_i, gt_j
    if do_shuffle is True:
        images_i, images_j, gt_i, gt_j = tf.train.shuffle_batch([image_i, image_j, gt_i, gt_j], batch_size, 500, 100)
    else:
        images_i, images_j, gt_i, gt_j = tf.train.batch([image_i, image_j, gt_i, gt_j], batch_size=batch_size, num_threads=32, capacity=500)
    print(np.shape(images_i), np.shape(images_j), np.shape(gt_i), np.shape(gt_j))
    return images_i, images_j, gt_i, gt_j
