### K-Means Visualization ###

import numpy as np
import matplotlib.pyplot as plt
import scipy.ndimage.morphology as morph
from PIL import Image
from numpy import linalg as LA
from scipy.ndimage import gaussian_filter
import pdb

    
def kMeans(points, nrClasses, maxIter):
    
    # Initial Random Centroids
    centroids = np.random.rand(points.shape[1], nrClasses)
    #centroids = np.array([[0.0, 0.11846399, 0.27492782],
    #                      [0.0, 0.25205734, 0.56870648]])


    for i in range(maxIter):


        diff = np.zeros((points.shape[0], nrClasses))
        
        ## Step 1: Calculate distances of points and centroids
        for actClass in range(nrClasses):
            
            # centroid of current class
            cl_3d = centroids[:,actClass][np.newaxis, :]
            
            # repeat centroid: number of points
            cl_image = np.repeat(cl_3d, points.shape[0], axis=0)
            
            # save difference of all points and current centroid
            diff[:,actClass] = np.sum((cl_image - points)**2,axis = 1)
            
        # For every point: save index of nearest centroid
        indi = np.argmin(diff, axis=1)
        
        
        ## Plot result
        plot_points = np.concatenate((points, centroids.T))
        colors = np.concatenate((indi, [nrClasses, nrClasses + 1,  nrClasses + 2]))
        plt.scatter(plot_points[:,0],plot_points[:,1], c=colors)
        plt.show()
        
        
        ## Step 2: Update centroids
        for actClass in range(nrClasses):
            
            # True for all points belonging to the respective class
            ml = indi==actClass
            
            # Sum of all points belonging to the respective class
            count = np.max([np.sum(ml),1])
            
            # Reshape ml in 2d form
            ml_3d = np.repeat(ml[:,np.newaxis], 2, axis=1)
            
            # Calculate new Centroid - Mean of Points belonging to the respective class
            centroids[:,actClass] = 1/count*(ml_3d*points).sum(axis=0)
            


            

    
def main():


    
    # Create 2D Datapoints
    N = 50
    x = np.random.rand(N,2) * 0.3
    y = np.random.rand(N,2) * 0.2 + 0.5
    z = np.random.rand(N,2) * 0.2 + [0.5, 0]
    points = np.concatenate((x, y, z))
    
    # Plot Points
    plt.scatter(points[:,0],points[:,1])
    plt.show()

    # Kmeans
    nrClasses = 3
    maxIter = 10
    kMeans(points, nrClasses, maxIter)


    
if __name__ == "__main__": 
    main()