classdef HEstimatorTlbx
    % Collection of Methods (Functions) for measuring the response (Impulse Response and the
    % Transfer function) of a LTI system
    
    properties
        
    end
    
    methods(Static)
        
        %% H Estimator
        function [h, H, COH] = HEstimator(InputSignal, OutputSignal, N_fft, N_overlap, EstimatorTypeStr, AlignmentBoolean)
            % This function implements an H estimator
            % InputSignal and Outpusignal are wav files of the same lenght
            % OutputSignal can be a matrix of dimension
            % NumbOfChannels*Duration
            
            if nargin < 6, AlignmentBoolean = 1; end %perform the estimation with the alignment of the signals
            if nargin < 5, EstimatorTypeStr = 'H1'; end  %The H1 estimator is the default one
            if nargin < 4, N_overlap = round(N_fft/2); end % Default overlap is 50% of the length of the window
            
            %% Check OutputSignal dimensions
            [NumbOfChannels, DurationTime] = size(OutputSignal);
            
            % The assumption is made that the number of channel is way smaller
            % than the duration of the signals. If that is not the case,
            % probabply we have to transpose the matrix (as, for example,
            % the multichannel matlab tlbx  returns a matrix that is
            % LENGTHTIME * NCHANNELS
            if NumbOfChannels > DurationTime
                % then transpose...
                OutputSignal = transpose(OutputSignal);
                [NumbOfChannels, DurationTime] = size(OutputSignal);
            end
            
            % Creating parameters for the H estimation
            win = hanning(N_fft); % Window creation
            Pxx = HEstimatorTlbx.xspectrum(InputSignal, InputSignal, win, N_overlap);
            
            lengthfreqvect = N_fft/2 + 1;
            % Variable initialisation
            h = zeros(NumbOfChannels, N_fft);
            H = zeros(NumbOfChannels, lengthfreqvect);
            COH = zeros(NumbOfChannels, lengthfreqvect);
            % Actual estimation loop
            for ch_idx = 1:NumbOfChannels
                
                curr_Out_signal = squeeze(OutputSignal(ch_idx, :));
                
                if AlignmentBoolean
                    % Input/Output signal alignment
                    [OutputSignalProcessed, delay] = HEstimatorTlbx.AlignTwoSequences(InputSignal, curr_Out_signal);
                else% -> without compensation
                    OutputSignalProcessed = curr_Out_signal;
                end
                
                Pyx = HEstimatorTlbx.xspectrum(OutputSignalProcessed, InputSignal, win, N_overlap);
                Pyy = HEstimatorTlbx.xspectrum(OutputSignalProcessed, OutputSignalProcessed, win, N_overlap);
                
                if strcmpi(EstimatorTypeStr, 'H1') %EstimatorType == 1
                    temp_H = HEstimatorTlbx.H1(Pyx, Pxx);
                elseif strcmpi(EstimatorTypeStr, 'H2')
                    temp_H = Pyy./Pyx;       % H2 Estimator
                else
                    temp_H = Pyy - Pxx + sqrt((Pxx - Pyy).^2 + 4*abs(Pxy).^2)./(2*Pyx); %H3
                end
                
                COH_temp  = HEstimatorTlbx.coherencefunction(InputSignal, OutputSignalProcessed, win, N_overlap);
                
                temp_h = ifft(temp_H, N_fft, 'symmetric');
                
                if AlignmentBoolean
                    temp_h = circshift(temp_h, delay); %circular shift to preserve phase information
                end
                
                temp_H = fft(temp_h, N_fft); % Compute the H again after the compensation
                % Returns the H and the COH till the fs/2 value
                temp_H = temp_H(1:lengthfreqvect);
                
                % Storing values
                h(ch_idx, :) = temp_h;
                H(ch_idx, :) = temp_H;
                COH(ch_idx, :) = COH_temp;
            end% for ch_idx
        end%HEstimator
        
        function Plots(h, H, COH, fs, FigureNameTextStr)
            %PLOTS Summary of this function goes here
            %   Detailed explanation goes here
            if nargin < 4
            else
                figure('Name', [FigureNameTextStr ' WhiteNoise']);
            end
            
            N_fft = length(h);
            
            samplevect = (0:(N_fft - 1));
            timevect = samplevect/fs;
            freqvect = linspace(0, fs/2, N_fft/2+ 1);
            
            subplot(2,2,1); plot(timevect, h); title('Estimated IR');
            xlabel('Samples'); ylabel('Amplitude'); grid on;
            xlim([timevect(1), timevect(end)]); ylim([-8*10^-3, 8*10^-3]);
            
            %Ax1 = gca;
            %Ax2 = axes('Position', get(Ax1,'Position'), 'XAxisLocation','top');
            % plot(samplevect, h, 'color','k','parent', Ax2);
            %xlim([samplevect(1), samplevect(end)]);
            
            subplot(2,2,2); semilogx(freqvect, LogDB(abs(H))); title('Estimated Magnitude FRF'); xlabel('Freq, Hz'); ylabel('dB');
            grid on; xlim([50, fs/2]);
            
            subplot(2,2,3); semilogx(freqvect, unwrap(angle((H)))); title('PHase'); xlabel('Freq, Hz');xlim([50, fs/2]);  grid on;
            
            subplot(2,2,4); semilogx(freqvect, COH); title('Coherence'); xlabel('Freq, Hz');
            grid on; xlim([50, fs/2]);
        end
        
        %%
        
        function COH = coherencefunction(InputSignal, OutputSignalProcessed, win, N_overlap)
            %UNTITLED Summary of this function goes here
            %   Detailed explanation goes here
            
            Pxx = HEstimatorTlbx.xspectrum(InputSignal, InputSignal, win, N_overlap);
            Pyx = HEstimatorTlbx.xspectrum(OutputSignalProcessed, InputSignal, win, N_overlap);
            Pyy = HEstimatorTlbx.xspectrum(OutputSignalProcessed, OutputSignalProcessed, win, N_overlap);
            COH = (abs(Pyx).^2)./(Pxx.*Pyy);
            
        end
        
        %%
        
        function [Sxy,f] = csd(x, y, w, Fs)
            % Cross spectral density of vectors x and y
            %
            % Sxy is the csd of signal x and signal y sampled at rate Fs based on
            % FFTs of segments that are first multiplied by the window
            % given in w (the size of this also defines the FFT size).
            %
            % [Sxy,f] = csd(x,y,w,Fs)
            % INPUT:
            %    - x, y: input vectors
            %    - w: vector containing the window
            %    - Fs: sampling frequency
            % OUTPUT:
            %    - Sxy: CSD of signal x and signal y
            %    - f: set of corresponding frequencies
            
            x = x(:); % make sure x is a column vector
            y = y(:); % make sure y is a column vector
            w = w(:); % make sure w is a column vector
            
            N = length(w);
            %disp(['block length: ',int2str(N)])
            f = (0:N/2)*Fs/N;
            
            % determine number of averages
            m = length(x)/N;
            m = floor(m);
            
            if m == 0 % if N > length(x)
                m = 1;
            end
            %disp(['number of segments used: ',int2str(m)])
            
            % find FFTs and average
            Sxy = zeros(N/2 + 1, 1);
            for idx = 1:m
                X = fft(w.* ...
                    (x((idx - 1)*N+1:idx*N)));
                Y = fft(w.* ...
                    (y((idx - 1)*N+1:idx*N)));
                Sxy = Sxy + ...
                    2*(conj(X(1:N/2+1)).* ...
                    Y(1:N/2+1))/(N*Fs);
            end
            
            wp = sum(w.^2)/N;
            Sxy = Sxy/(m*wp);
            
        end
        
        
        %%
        
        function H = H1(Pyx, Pxx)
            % This function implements an H1 estimator
            
            H = Pyx./Pxx;       % H1 estimator
            
        end
        
        
        
        %%
        
        function Pxy = xspectrum(x, y, w, N_overlap)
            % Cross spectral density of vectors x and y
            %
            % Sxy is the csd of signal x and signal y sampled at rate Fs based on
            % FFTs of segments that are first multiplied by the window
            % given in w (the size of this also defines the FFT size).
            %
            % [Sxy,f] = csd(x,y,w,Fs)
            % INPUT:
            %    - x, y: input vectors
            %    - w: vector containing the window
            %    - Fs: sampling frequency
            % OUTPUT:
            %    - Sxy: CSD of signal x and signal y
            %    - f: set of corresponding frequencies
            
            x = x(:); % make sure x is a column vector
            y = y(:); % make sure y is a column vector
            w = w(:); % make sure w is a column vector
            
            L = min(length(x), length(y));
            N = length(w);
            
            if nargin < 4
                N_overlap = 0;
            end
            
            D = N - N_overlap;
            %D = N_overlap;
            
            % % determine number of averages
            % L = length(x)/N;
            % L = floor(L); % Truncation
            %
            % if L == 0 % if N > length(x)
            %     L = 1;
            % end
            % determine number of averages (from matlab cpsd function)
            K = (L - N_overlap)/D;
            K = floor(K);
            
            lengthfreqvect = N/2 + 1;
            % find FFTs and average
            Pxy = zeros(lengthfreqvect, 1);
            
            U = sum(w.^2)/N;
            
            for frame_idx = 0:K-1
                
                % Segmentation of the signals in frames
                lowlim = frame_idx*D + 1; % Plus one sample to adapt to the Matlab convention
                upperlim = frame_idx*D + N;
                
                x_idx = x(lowlim:upperlim);
                y_idx = y(lowlim:upperlim);
                
                % Calculates fft for the current frame
                X_idx = fft(w.*x_idx);
                Y_idx = fft(w.*y_idx);
                
                % Half of the spectrum
                X_idx = X_idx(1:lengthfreqvect);
                Y_idx = Y_idx(1:lengthfreqvect);
                
                % Current power spectrum
                curr_Pxy = (X_idx).*conj(Y_idx);
                curr_Pxy = curr_Pxy/(N*U); %Normalisation
                
                % Calculate xpower
                Pxy = Pxy + curr_Pxy;
            end
            
            Pxy = Pxy/K;
            
        end
        
        %%
        function [DelayedOutputSignal, delay_in_samples] = ...
                AlignTwoSequences(InputSignal, OutputSignal)
            % OutputSignal is the sequence that is shifted
            
            % if size(InputSignal) ~= size(OutputSignal)
            %     display('The input and output sequences do not have the same size. Check their dimensions.');
            %     return;
            % end
            
            if nargin < 4, FracDelayOptionFlag = 0; freqvect = 0; fs = 0; N_fft = 0; end
            InputSignal = InputSignal(:);
            OutputSignal = OutputSignal(:);
            delay_in_samples = HEstimatorTlbx.EstimateDelayInSamplesBtwTwoSequences(InputSignal, OutputSignal);
            % Shift in the time domain
            DelayedOutputSignal = circshift(OutputSignal, -delay_in_samples);
        end
        
        % [1] CLASSIFICATION AND EVALUATION OF DISCRETE SUBSAMPLE TIME DELAY ESTIMATION ALGORITHMS
        %%
        
        
        
        function [delay_in_samples,  CrossCorrelation] = EstimateDelayInSamplesBtwTwoSequences(InputSignal, OutputSignal)
            % Estimates the delay in samples between OutputSignal and the reference sequence InputSignal
            % .
            % Several methods are available, that are based on cross-correlation and
            % generalised cross correlation
            
            MethodStr = 'Cross-Correlation';
            
            
            lengthInput = length(InputSignal);
            lengthOutput = length(OutputSignal);
            CrossCorrelation = xcorr(OutputSignal, InputSignal); %compute cross-correlation between vectors InputSignal and OutputSignal
            
            [~, d] = max(CrossCorrelation); %find value and index of maximum value of cross-correlation amplitude
            delay_in_samples = d - max(lengthInput, lengthOutput) + 1; %shift index d, as length(X1)=2*N-1; where N is the length of the signals
            delay_in_samples = delay_in_samples - 1;
            
        end
        
    end%Static Methods
end