import numpy as np
from numpy import linalg as LA
import matplotlib.pylab as plt
from PIL import Image

def hut(x):
    b = np.array([[0 , -x[2], x[1]],
              [x[2], 0, -x[0]],
              [-x[1], x[0],0]])
    return b
    
def konstruiere_bi(a,x):
    temp = hut(x)
    return np.append(np.append(a[0]*temp, a[1]*temp, axis=1), a[2]*temp, axis=1)

def konstruiere_b(A,X):
    B1 = konstruiere_bi(np.append(A[:,0],[1]),np.append(X[:,0],[1]))
    B2 = konstruiere_bi(np.append(A[:,1],[1]),np.append(X[:,1],[1]))
    B3 = konstruiere_bi(np.append(A[:,2],[1]),np.append(X[:,2],[1]))
    B4 = konstruiere_bi(np.append(A[:,3],[1]),np.append(X[:,3],[1]))
    return np.vstack((B1,B2,B3,B4))
    

def kleinster_ev(B):
    w, v = LA.eig(B)
    ind = np.argmin(w)
    return v[:,ind]
    
def interpolateBilin(img,coord):
    i=np.int(np.floor(coord[0]))
    j=np.int(np.floor(coord[1]))
    # as soon as one index is out of bounds, return 0
    if i<0:
        return 0.0
    if i>img.shape[0]-2:
        return 0.0
    if j<0:
        return 0.0
    if j>img.shape[1]-2:
        return 0.0
    
    # do bilinear interpolation
    temp1 =  (coord[0]-i)*img[i+1,j] + (i+1-coord[0])*img[i,j]        
    temp2 =  (coord[0]-i)*img[i+1,j+1] + (i+1-coord[0])*img[i,j+1]
    temp = (coord[1]-j)*temp2 + (j+1-coord[1])*temp1
    return temp
    
def main():

    # coordinates
    A = np.array([[100,100,300,300],[100,300,100,300]], dtype=float)
    X = np.array([[110,100,290,300],[100,300,100,300]], dtype=float)

    # create transformation matrix
    B = konstruiere_b(A,X)
    BtB = np.matmul(B.T,B);
    H = np.linalg.inv(kleinster_ev(BtB).reshape(3,3).T)


    # load and show image
    img_dir="giraffe.jpg"
    img_pil=Image.open(img_dir)
    img_grey_pil = img_pil.convert('L')
    img = np.array(img_grey_pil) / 255.0
    plt.figure(figsize=(5,5))
    plt.imshow(img_grey_pil)
    plt.show()


    # perform bilinear interpolation to correct the perspective distortion
    shiftedImg = np.copy(img)
    it = np.nditer(img, flags=['multi_index'])
    while not it.finished:
        z = np.append(np.array(it.multi_index),[1])
        x_temp = np.matmul(H, z)
        x_temp = x_temp/x_temp[2]
        shiftedImg[it.multi_index] = interpolateBilin(img,x_temp)
        it.iternext()
    
    # visualize the result
    plt.figure(figsize=(5,5))
    plt.imshow(shiftedImg)
    plt.show()
    
    
if __name__ == "__main__":
    main()