%
%  Sample Solution for exercise 2
% 
clearvars
addpath(genpath('subfunctions'));

%%
% Read image
img = im2double(imread('peppers.png'));
img = imresize(img,[240,240]);
[m,n,k] = size(img);
figure(1),imagesc(img);title('High-Res Image');

%% Get downsampling matrix
factor = 2; % Downsampling factor

Sx = xDirectionDownsamplingMatrix(n/factor,n);
Sy = (xDirectionDownsamplingMatrix(m/factor,m))';
DownSamp = kron(speye(3),kron(Sx',Sy));


%% Generate data
f = DownSamp*img(:);
f = f+rand(size(f))*1e-6;
figure(1),imagesc(imresize(reshape(f,[m/factor,n/factor,k]),factor,'nearest')),title('Low-Res input Data');
drawnow
%% Write energy function:
alpha = 0.15;

% Data term
data_term = energy(@(u) 0.5*sum((DownSamp*u-f).^2),@(u) DownSamp'*(DownSamp*u-f));

% Regularizers 
% see subfunctions for details
reg_huber_normal = huber_reg([m,n,k],0.01,'normal');
reg_huber_double = huber_reg([m,n,k],0.01,'double opponent');

reg_NL_1         = nl_reg(rgb2gray(reshape(DownSamp'*f,m,n,k)),3,7,0.1,'D-W');
reg_NL_2         = nl_reg(rgb2gray(reshape(DownSamp'*f,m,n,k)),3,7,0.1,'normed');

reg_TV           = TV_reg_smooth([m,n,k],0.01);


% Start checking everything:
%% Huber 1
Eu = data_term + 2*alpha*reg_huber_normal;
u_out = Eu.solve(DownSamp'*f);
figure(1),imagesc(reshape(u_out,[m,n,k])),title(['PSNR: ',num2str(psnr(u_out,img(:)))]);
pause();
%% Huber 2
Eu = data_term + 2*alpha*reg_huber_double;
u_out = Eu.solve(DownSamp'*f);
figure(1),imagesc(reshape(u_out,[m,n,k])),title(['PSNR: ',num2str(psnr(u_out,img(:)))]);
pause();
%% Nonlocal 1
Eu = data_term + 0.005*alpha*reg_NL_1;
u_out = Eu.solve(DownSamp'*f);
figure(1),imagesc(reshape(u_out,[m,n,k])),title(['PSNR: ',num2str(psnr(u_out,img(:)))]);
pause();
%% Nonlocal 2
Eu = data_term + alpha*0.02*reg_NL_2;
u_out = Eu.solve(DownSamp'*f);
figure(1),imagesc(reshape(u_out,[m,n,k])),title(['PSNR: ',num2str(psnr(u_out,img(:)))]);
pause();
%% Smoothed Total Variation
Eu = data_term + alpha*0.1*reg_TV;
u_out = Eu.solve(DownSamp'*f);
figure(1),imagesc(reshape(u_out,[m,n,k])),title(['PSNR: ',num2str(psnr(u_out,img(:)))]);