import cv2
import numpy as np
import matplotlib.pyplot as plt
np.set_printoptions(suppress=True)



def plot_with_cv(image):
    cv2.imshow("", image)

    # wait until key was pressed
    cv2.waitKey(0)
    cv2.destroyAllWindows()


def plot_with_matplot(image):

    # gray image if only two dimensions
    if image.ndim == 2:
        plt.imshow(image, cmap='gray')
    else:
        plt.imshow(image)

    plt.xticks([])
    plt.yticks([])
    plt.show()


def gray_histogram(image, n_bins, from_to=None):

    if from_to is None:
        min_x = np.min(image)
        max_x = np.max(image)
    else:
        min_x = from_to[0]
        max_x = from_to[1]

    bin_width = (max_x - min_x) / n_bins

    histogram = np.zeros(n_bins)
    for ii in range(n_bins):
        lower = (ii * bin_width) + min_x
        upper = ((ii+1) * bin_width) + min_x
        cond1 = image >= lower
        cond2 = image < upper

        # last histogram bin must include the last element
        if ii == n_bins - 1:
            cond2 = image <= upper

        bin_count = np.sum(cond1 & cond2)
        histogram[ii] = bin_count

    return np.array(histogram, dtype=int)


def cumulative_histogram(histogram):
    return np.cumsum(histogram)


def thresholding_gray(image, threshold, value1, value2):
    new_image = np.copy(image)
    new_image[new_image < threshold] = value1
    new_image[new_image >= threshold]= value2
    return new_image


def thresholding_color(image, r_thresh, g_thresh, b_thresh, r_value, g_value,
                       b_value):
    new_image = np.copy(image)

    # these commands create a view not a copy
    r_channel = new_image[:, :, 0]
    g_channel = new_image[:, :, 1]
    b_channel = new_image[:, :, 2]

    r_channel[r_channel < r_thresh] = r_value
    g_channel[g_channel < g_thresh] = g_value
    b_channel[b_channel < b_thresh] = b_value

    return new_image


def plot_histogram(histogram):

    # fit the values of the cdf to the range of the histogram
    hist_max = np.max(histogram)
    cum_hist = np.cumsum(histogram)
    cum_max = np.max(cum_hist)
    cum_min = np.min(cum_hist)

    k = hist_max / (cum_max - cum_min)
    new_cum_hist = k * cum_hist

    # plot histogram and cdf
    n_bars = histogram.shape[0]
    ind = np.arange(n_bars)
    height = histogram

    fig, ax = plt.subplots()
    ax.bar(ind, height)
    ax.plot(ind, new_cum_hist, color='r', linewidth=2.0)

    plt.show()


def blend_over(img1, img2, color):

    blended_img = 0.6*img1 + 0.4*img2
    blended_img = np.round(blended_img)
    blended_img = np.array(blended_img, dtype= np.uint8)

    if color:
        plt.imshow(blended_img[:,:,::-1], cmap=None)
    else:
        plt.imshow(blended_img, cmap='gray')

    plt.show()

    #blended_img2 = cv2.addWeighted(img1, 0.8, img2, 0.2, 0)
    #cv2.imshow("Blend", blended_img2)
    #cv2.waitKey(0)
    #cv2.destroyAllWindows()


def pixel_value_normalisation_gray(image, output_max):
    input_max = np.max(image)
    input_min = np.min(image)

    k = output_max / (input_max - input_min)
    j = image - input_min
    new_image = k * j
    new_image = np.array(new_image, dtype=np.uint8)
    return new_image


def histogram_equalization_gray(image):
    # https: // en.wikipedia.org / wiki / Histogram_equalization
    hist = gray_histogram(image, n_bins=256, from_to=(0,256))
    cum_hist = cumulative_histogram(hist)
    width = image.shape[0]
    height = image.shape[1]
    n_pixel = width * height

    # get the first non-zero value of the cumulative hist
    non_zero_pos = np.where(cum_hist != 0)[0]
    cdf_min = cum_hist[non_zero_pos[0]]

    # subtract the minimum non-zero value from the cumulative histogram
    cum_hist[non_zero_pos] = cum_hist[non_zero_pos] - cdf_min

    # apply histogram equalization formula to each pixel of the image
    eq = (cum_hist[image] / (n_pixel - 1)) * 255
    eq = np.floor(eq)

    return np.array(eq, dtype=np.uint8)


def invert_gray(image):
    return np.array(-image + 255, dtype=np.uint8)


#hist = get_gray_histogram_fast(img, 50)
#print("Shape of image: ", img.shape)
#print("Expected: ", img.shape[0] * img.shape[1])
#print(np.sum(hist))

#img = thresholding(img, 127, 0)
#plotNormal(path)


#ind = np.arange(50)  # the x locations for the groups
#width = 0.35       # the width of the bars
#fig, ax = plt.subplots()
#rects1 = ax.bar(ind, hist, width)

#plt.show()

#plt.hist(img.ravel(), bins=50)
#plt.show()

def main(color):
    path = "/home/lukas/MedIP/MedIPProjects/PH2Dataset/PH2 Dataset images/" \
           "IMD002/IMD002_Dermoscopic_Image/IMD002.bmp"

    path_gt = "/home/lukas/MedIP/MedIPProjects/PH2Dataset/PH2 Dataset images/" \
              "IMD002/IMD002_lesion/IMD002_lesion.bmp"

    """2nd parameter of imread: 1=color, 0=gray, -1=unchanged"""
    if color:
        img = cv2.imread(path, 1)
        img2 = cv2.imread(path_gt, 1)
    else:
        img = cv2.imread(path, 0)
        img2 = cv2.imread(path_gt, 0)

    #img = img[:, :, ::-1]
    #plot_with_matplot(img)
    plot_with_cv(img)

    # ------------------------------------
    #Gray stuff
    # ------------------------------------
    # plot_with_matplot(img)

    #img_thresh = thresholding_gray(img, 127, 0, 255)
    #plot_with_matplot(img_thresh)


    #img_inverted = invert_gray(img)
    #plot_with_matplot(img_inverted)

    #img_norm = pixel_value_normalisation_gray(img, 2)
    #plot_with_matplot(img_norm)

    img_equal = histogram_equalization_gray(img)
    equal_hist = gray_histogram(img_equal, n_bins=256, from_to=(0, 256))
    plot_with_matplot(img_equal)
    plot_histogram(equal_hist)

    #histogram_equalization_gray(img)
#    hist = gray_histogram(img, n_bins=256, from_to=(0, 256))
#    print(hist)
#    print(np.sum(hist))
#    nphist = np.histogram(img, bins=256, range=(0, 256))[0]
#    print(nphist)
#    print(np.sum(nphist))
#    print(np.array_equal(hist, nphist))




    #img_color_thres = thresholding_color(img, 127, 127, 127, 0, 0, 0)
    #plot_with_matplot(img_color_thres)
    #plot_with_matplot(img)

    blend_over(img, img2, color=False)




if __name__ == "__main__":
    main(color=False)
