classdef energyClass %< handle
    % general energy class, energies can be given as function handles
    % properties: energyHandle,gradHandle,probSize
    % methods: constructor, evaluate energy, evaluate gradient, add energy, 
    %          multiply with scalar, solve with gradient descent
    
    properties
        energyHandle
        gradHandle
        probSize
        
    end
    
    methods
        %% Construct energy
        function obj = energyClass(energyHandle,gradHandle,probSize)
            
            
            if nargin < 3
                obj.probSize = 0;
            else
                obj.probSize = probSize;
            end
            
            if isa(energyHandle,'function_handle') && isa(gradHandle,'function_handle')
                obj.energyHandle = energyHandle;
                obj.gradHandle = gradHandle;
                
            elseif isa(energyHandle,'string') && isa(gradHandle,'function_handle')
                
            else
                error('invalid input')
            end
            
        end
        
        %% evaluate gradient at u
        function  gradU = evalGrad(obj,u)
            gradU = obj.gradHandle(u);
        end
        
        %% evaluate energy at u
        function energy_val = evalE(obj,u)
            energy_val =  obj.energyHandle(u);
            
        end
        
        %% Add two energies
        function obj_out = plus(obj1,obj2)
            if obj1.probSize > 0 && obj2.probSize > 0
                if  obj1.probSize ~= obj2.probSize
                    error('invalid data sizes');
                else
                    probSize_out = obj1.probSize;
                end
            else
                probSize_out = max(obj1.probSize,obj2.probSize);
            end
            energyHandle_out = @(u) obj1.energyHandle(u) + obj2.energyHandle(u);
            gradHandle_out = @(u) obj1.gradHandle(u) + obj2.gradHandle(u);
            
            obj_out = energyClass(energyHandle_out,gradHandle_out,probSize_out);
            
        end
        %% multiply with scalar
        function obj2 = mtimes(scalar,obj)
            if ~isscalar(scalar)
                error('only multiplication with scalars');
            end
            energyHandle2 = @(u) scalar*obj.energyHandle(u);
            gradHandle2 = @(u) scalar*obj.gradHandle(u);
            obj2 = energyClass(energyHandle2,gradHandle2,obj.probSize);
        end
        
        %% Solve energy
        function [u_out,log] = solve(obj,algType,backend,u0)
            
            % start at zero if no starting iterate is given
            if nargin < 4
                u = zeros(obj.probSize,1);
            else
                u = u0;
                if length(u) ~= obj.probSize;
                    warning('invalid starting vector disregarded');
                    u = zeros(obj.probSize,1);
                end
            end
            
            % use standard backend if no input is given
            if nargin < 3
                backend.alpha = 0.2;
                backend.beta = 0.25;
                backend.maxiters = 10000;
                backend.startStep = 0.2;
                backend.stopCrit = 1e-4;
            end
            if ~isfield(backend,'callback')
               backend.callback = 25;
            end
            
            
            if strcmp(algType,'gradDescent')
                
                % init 
                alph = backend.alpha;
                nGrad = norm(obj.evalGrad(u));
                disp(['Grad descent initialized, starting residual is ',num2str(nGrad)]);
                % iterate
                for i = 1:backend.maxiters 
                    
                    % compute values at iteration i
                    Eu = obj.evalE(u);
                    GradU = obj.evalGrad(u);
                    normGrad  = sum(GradU.^2); % this is norm^2
                    tau = backend.startStep;
                    
                    % backtrack
                    count  =1;
                    while obj.evalE(u - tau*GradU) > (Eu - alph*tau*normGrad) 
                        tau = backend.beta*tau;
                        count = count+1;
                    end
                    % log tau value and counter
                    log.tau(i) = tau;
                    log.count(i) = count;
                    
                    % update
                    u = u-tau*GradU;
                    
                    % callback
                    if mod(i,backend.callback) == 0
                        nGrad = norm(obj.evalGrad(u));
                        disp(['iteration ',num2str(i),' - residual is ',num2str(nGrad)]);
                    end
                
                    % check convergence
                    if norm(obj.evalGrad(u)) < backend.stopCrit
                        break % break for-loop
                    end
                    


                end
                % return last iterate
                u_out = u;
                
                % write number of iterations into logfile
                log.numIts = i;
                
                disp(['gradient descent finished after ',num2str(i), ' iterations']);
            else
                error('algorithm not implemented');
            end
            
            
        end
        
    end
    
end

