%
% Deblurring with L2 data and 
%
clearvars

%% Load data
f_image = im2double(imread('ein_ei.jpg'));
f_image = f_image(101:300,101:300,:);

[nx,ny,nc] = size(f_image);

figure(1),imshow(f_image),title('ground truth');
%% Construct blur matrix

%kernel = fspecial('gaussian',11,2);
kernel = fspecial('motion',11);

blurMat = convmtx2(kernel,nx,ny); %output size is:  size(H)+[M N]-1)
matrixCrop = padarray(ones(nx,ny),floor(size(kernel)/2));
sizeCrop = prod(size(kernel) + [nx,ny] -1); % size(matrixCrop)

matrixCrop = spdiags(matrixCrop(:),0,sizeCrop,sizeCrop);
isValid = (1-full(sum(matrixCrop,2))).*(1:sizeCrop)';
isValid(isValid==0) = [];
matrixCrop(isValid,:) = [];


blurMat = kron(speye(3),matrixCrop*blurMat);


%% Add noise
f= blurMat*f_image(:) + 0.01*randn(nx*ny*nc,1);

figure(2), imshow(reshape(f,nx,ny,nc)),title('noisy image');
drawnow

%% Set Huber parameters

eps_val = 0.01;
alpha   = 0.25;

%% Build energy
% data object
l2 = @(u) 0.5*sum((blurMat*u-f).^2);
l2Grad = @(u) blurMat'*(blurMat*u-f);

E_l2 = energyClass(l2,l2Grad,nx*ny*nc);

% regularization functions
Huber = @(u) sum(0.5*u.^2.*(abs(u) <= eps_val) + eps_val*(abs(u)-0.5*eps_val).*(1-(abs(u) <= eps_val))); 
HuberGrad = @(u) u.*(abs(u) <= eps_val) + eps_val*sign(u).*(1-(abs(u) <= eps_val));

% Gradient matrix:
dy = spdiags([[-ones(ny - 1, 1); 0], ones(ny, 1)], [0, 1], ny, ny);
dy = kron(speye(nx), dy);
dx = spdiags([[-ones(ny*(nx-1),1); zeros(ny, 1)], ones(nx*ny,1)],[0, ny], nx*ny,nx*ny);
D = cat(1, kron(speye(nc), dx), kron(speye(nc), dy));

% Regularization object
E_huber = energyClass(@(u) Huber(D*u),@(u) D'*HuberGrad(D*u));



% final variational model
E_final = E_l2 + alpha * E_huber;


%% Run gradient descent
backend.alpha = 0.25;       % backtracking strictness, variable in ]0,0.5[
backend.beta = 0.5;         % backtracking reduction factor, in ]0,1[
backend.maxiters = 250;     % maximum number of iterations
backend.startStep = 1;      % starting tau for each iteration
backend.stopCrit = 1e-4;    % stopping criterion on gradient norm
backend.callback = 25;      % frequency of console prints

tic
[u_out,logfile] = E_final.solve('gradDescent',backend,f); %% Call solve method for gradient descent, with some options and starting u
toc

%% Show result

u_image = reshape(u_out,nx,ny,nc);

figure(3), imshow(u_image), title('algorithm result');