classdef energy
    %energy class
    %properties:
    % funcHandle
    % gradHandle
    % 
    %
    %methods:
    % plus
    % mtimes
    % times
    % value
    % gradient
    % solve
    
    properties
        funcHandle
        gradHandle
    end
    %%
    methods
        function obj = energy(funcHandle,gradHandle)
            %ENERGY Construct an instance of this class
            %from function handles to function and its gradient
            obj.funcHandle = funcHandle;
            obj.gradHandle = gradHandle;
        end
        
        %% Overload basic operations:
        function obj1 = plus(obj1,obj2)
           % Add two energies together
           
           obj1.funcHandle = @(u) obj1.funcHandle(u)+obj2.funcHandle(u);
           obj1.gradHandle = @(u) obj1.gradHandle(u)+obj2.gradHandle(u);
        end
        
        function obj = mtimes(input1,input2)
            % Overload left multiplication to be scalar multiplication
            % and right multiplication to be linear operator composition
   
            if isa(input1,'energy') && ismatrix(input2) && ~isscalar(input2)
                % matrix concatenation of variable or
                % scalar multiplication of variable
                input1.funcHandle = @(u) input1.funcHandle(input2*u);
                input1.gradHandle = @(u) input2'*input1.gradHandle(input2*u);               
                obj = input1;
                
            elseif isscalar(input1) && isa(input2,'energy')
                % scalar multiplication
                input2.funcHandle = @(u) input1*input2.funcHandle(u);
                input2.gradHandle = @(u) input1*input2.gradHandle(u);     
                obj = input2;
            end
        end
        
        function obj = times(obj,numArray)
            % Overload left multiplication to be scalar multiplication
            % and right multiplication to be linear operator composition
            obj = mtimes(obj,numArray);
        end
        
        %% Evaluate 
        function val = value(obj,u)
            % Value of object at vector u
            val = obj.funcHandle(u);
        end
        function grad = gradient(obj,u)
            % Value of gradient at vector u
            grad = obj.gradHandle(u);
        end
        
        %% Solver
        
        function u_star = solve(obj,u0,alpha,beta,tau0,maxIts)
            % Solve gradient descent from starting vector u0
            % Optional input: alpha,beta,tau0,maxIts
            if nargin < 6
                maxIts = 2500;
            end
            if nargin < 5
                tau0 = 0.5;
            end
            if nargin < 4
                beta = 0.5; % beta defines the backtracking speed
            end
            if nargin < 3
                alpha = 0.4; % alpha defines the step acceptance
            end
            % Initialize
            tau_k = tau0;
            u = u0;

            
            % Iterate
            for ii = 1:maxIts
                Eu_k = obj.value(u);      % energy of previous iterate
                grad_u = obj.gradient(u); % gradient of previous iterate
                normGrad_u = sum(grad_u(:).^2); % norm of previous gradient
                
                % Test for convergence by gradient norm
                if normGrad_u < 1e-5
                    disp(['Gradient descent terminated with gradient ',...
                        num2str(normGrad_u),' at iteration ',num2str(ii)]);
                    break
                end
                
                u_test = u - tau_k*grad_u;
                count_bts = 0;           % count backtracks
                while obj.value(u_test) > Eu_k-alpha*tau_k*normGrad_u
                    tau_k = beta*tau_k;
                    u_test = u - tau_k*grad_u;
                    count_bts = count_bts +1;
                    if count_bts > 250
                        error('Too many backtracks');
                    end
                end
                
                if count_bts > 25
                    disp(['Warning: ',num2str(count_bts), ' backtracks in ',...
                          ' iteration ',num2str(ii)]);
                end
                
                % Test for significant improvement
                u_improved = sum(abs(u(:)-u_test(:)).^2);
                if u_improved < 1e-7                % This is actually quite a lot
                    disp(['Gradient descent terminated, distance of ',...
                           'improvement in last iteration was only ', ...
                           num2str(u_improved),' at iteration ',num2str(ii)]);
                    break
                end
                
                % Continue iteration after valid step has been found
                u = u_test;
                
                % Notify user of running algorithm
                % (Gradient descent might take some time on large images 
                %  on small computers, as we are not using the gpu at all)
                if mod(ii,25) == 0
                    disp(['Running iteration ',num2str(ii), ', improvement ',...
                          ' is  ',num2str(u_improved)]);
                end
                
            end
            if ii == maxIts
                disp('Algorithm did not convergence within maximum iterations');
            end
            % Set output
            u_star = u;
        end
    
    end
end

