function [ grid,pos,disp ] = f_gridMethodImageProcessing_ID(refImagePath,refImageFile,...
    grid,gridMethodOpts)
% Author: Lloyd Fletcher
% PhotoDyn Group, University of Southampton
% Date: 10/8/2017
%
% Processes a sequence of grid images to obtain displacement fields
% This code uses the grid processing tool box that can be found at:
% www.thegridMethodOpts.net, developed by Grediac

if nargin < 5
    imageNoise.addNoise = false;
end

%--------------------------------------------------------------------------
% 1) Load the Reference Image and Find All Images in Folder
% Get the repeated part of the string assuming underscore termination
[~,~,imageExt] = fileparts(refImageFile);
usLoc = find(refImageFile == '_');
if ~isempty(usLoc)
    startStr = refImageFile(1:usLoc(end));
else
    startStr = 'DefGrid_';
end

% Ask for the filename starter and new filename
resp = inputdlg({'Enter the repeating part of the image file string:'},...
             'Input repeating file string', 1, {startStr} );        
fstart = resp{1};

% Get the names of all image files with the same ext in the directory
refImageFileStruct = dir([refImagePath,fstart,'*',imageExt]);
refImageFileCell = {refImageFileStruct.name}';

% Sort the data files in numerical order
sortedFiles{1} = refImageFile;
for i = 2:length(refImageFileCell)
    checkStr1 = [fstart,num2str(i),imageExt];   
    
    nn = length(num2str(i));
    if nn==1
        num = ['00' num2str(i)];
    elseif nn==2
        num = ['0' num2str(i)];
    else
        num = num2str(i);
    end
    checkStr2 = [fstart,num,imageExt];
    
    for j = 2:length(refImageFileCell)       
        if strcmp(checkStr1,refImageFileCell{j}) || strcmp(checkStr2,refImageFileCell{j})
            sortedFiles{i} = refImageFileCell{j};
            break
        end
    end
end
refImageFileCell = sortedFiles;

% Load the reference image 
refImage = imread([refImagePath,refImageFile]);
refImage = double(refImage);

if imageNoise.addNoise
    refImage = func_addNoiseToImagesStruct(refImage,imageNoise);
end
    
%--------------------------------------------------------------------------
% 2) Mask Images
maskImages = inputdlg('Select region of interest: y/n', 'ROI selection', 1, {'y'});

if strcmp(maskImages{1},'y')
    hf = figure; 
    imshow(refImage, []);
    title('Select window to keep after masking')
    [~,~,tempImage,Rect] = imcrop(hf);
    specLocBottomLeft = ceil([Rect(1),Rect(2)]);
    specLocTopRight = floor([Rect(1)+Rect(3),Rect(2)+Rect(4)]);   
elseif strcmp(maskImages{1},'c1')
    % Specimen is 1.8mm from the bottom left hand corner
    specLocBottomLeft = [12,1];
    specLocTopRight = [398,250];
elseif strcmp(maskImages{1},'c2')
    % Specimen is 2mm from the bottom left hand corner
    specLocBottomLeft = [12,1];
    specLocTopRight = [400,250];
else
    specLocBottomLeft = [1,1];
    specLocTopRight = [size(refImage,2),size(refImage,1)];   
end
close all
range.x = (specLocBottomLeft(1):specLocTopRight(1));
range.y = (specLocBottomLeft(2):specLocTopRight(2));

% Assign the position of the specimen to pos struct to return
pos.specLocBottomLeft = specLocBottomLeft;
pos.specLocTopright = specLocTopRight;

%--------------------------------------------------------------------------
% 3) Input Grid Parameters
fprintf('Input grid parameters.\n')
gridData = inputdlg({'Number of pixels per period:','Grid pitch: (m)'}, ...
             'Grid analysis parameters', 1, {num2str(grid.pxPerPeriod), num2str(grid.pitch)} );
grid.pxPerPeriod = str2double(gridData{1}); 
grid.pitch = str2double(gridData{2});
grid.mPerPx = grid.pitch/grid.pxPerPeriod;

%--------------------------------------------------------------------------
% 4) Build the Analysis Window
%Build Window
analysisWindow = build_window(gridMethodOpts.windowFlag,...
    gridMethodOpts.windowWidth*grid.pxPerPeriod); % See documentation
% 0 = Gaussian window, default
% 1 = Bi-triangular window, localised phenom, low noise

%--------------------------------------------------------------------------
% 5) Process Images
% Pre-alloc vars for speed
[sy,sx] = size(refImage);
st = length(refImageFileCell );
disp.x = zeros([sy,sx,st]); 
disp.y = zeros([sy,sx,st]);
disp.rot = zeros([sy,sx,st]);
phi.x = zeros([sy,sx,st]);
phi.y = zeros([sy,sx,st]);

% Spatial Unwrapping
fprintf('Spatial unwrapping.\n')
for i = 1:length(refImageFileCell)
    % Load the current image file
    currImage = imread([refImagePath,refImageFileCell{i}]);
    currImage = double(currImage);
    
    % Add noise to the image if required
    if imageNoise.addNoise
        currImage = func_addNoiseToImagesStruct(currImage,imageNoise);
    end

    % Calculation of the phase and spatial phase unwrapping
    [phi.x(:,:,i),phi.y(:,:,i),phaseMod.x(:,:,i),phaseMod.y(:,:,i)] = ...
        LSA(currImage, analysisWindow, grid.pxPerPeriod);
    phi.x(:,:,i) = unwrap2D(single(phi.x(:,:,i)));
    phi.y(:,:,i) = unwrap2D(single(phi.y(:,:,i)));
end

% Temporal Unwrapping
% CAUTION: this only works for a sequence of images with small increments in
% the phase, uses average phase over the FOV to calc rigid body motion
if gridMethodOpts.temporalUnwrap
    range.x = (specLocBottomLeft(1):specLocTopRight(1));
    range.y = (specLocBottomLeft(2):specLocTopRight(2));
    phi.x_nuw = phi.x;
    phi.y_nuw = phi.y;
    
    fprintf('Temporal unwrappping.\n')
    threshold = pi;
    phase = func_temporalUnwrap(phi,threshold,'field',range);

    for i = 1:size(phi.x,3)
        phi.x(:,:,i) = phi.x(:,:,i) + 2*pi*phase.x(i); 
        phi.y(:,:,i) = phi.y(:,:,i) + 2*pi*phase.y(i);
    end
end

if gridMethodOpts.debug
    figure;
    hold on
    plot(squeeze(mean(mean(phi.x(range.y,range.x,:)))),'-+b')
    plot(squeeze(mean(mean(phi.x_nuw(range.y,range.x,:)))),'-xr')
    xlabel('Frame')
    ylabel('Mean Phase Over ROI')
    legend('Temp Unwrapped','No Unwrap')
    hold off    
end

% Calculate Displacements and Strains
fprintf('Calculating displacement and strain components.\n')
disp.x = zeros(size(phi.x));
disp.y = zeros(size(phi.x));
disp.rot = zeros(size(phi.x));

% Calculate a window over which the specimen exists 
if (max(range.x)+grid.pxPerPeriod) > size(phi.x,2)
    endInd = size(phi.x,2);
else
    endInd = max(range.x)+grid.pxPerPeriod;
end
% startInd = min(range.x)-grid.pxPerPeriod-3;
startInd = 1;

dispRangeX1 = startInd:endInd; % update ROI iteratively
% Y Range
if (min(range.y)-grid.pxPerPeriod) < 1
    startInd = 1;
else
    startInd = min(range.y)-grid.pxPerPeriod;
end
if (max(range.y)+grid.pxPerPeriod) > size(phi.x,1)
    endInd = size(phi.x,1);
else
    endInd = max(range.y)+grid.pxPerPeriod;
end

dispRangeY1 = startInd:endInd; % update ROI iteratively
% vector storing rigid-body translation of the sample over time (used for
% updated ROI approach)
indIncX = zeros(1,size(disp.x,3));

for i = 1:size(disp.x,3) 
    fprintf(strcat('Computing displacement for frame ',num2str(i),'\n'));
    if gridMethodOpts.padPhaseMaps
 % ---------- inerpolate back to undeformed reference frame-----------------
        % compute displacement in deformed image (i)  
         if i == 1
            dispRangeX = dispRangeX1; dispRangeY = dispRangeY1;
            UXO = zeros(size(phi.x,1),size(phi.x,2));
            UYO = UXO;
         else
             % currently not used in grid method code - find that errors
             % can propagate into later images
            UXO = squeeze(disp.x(dispRangeY,dispRangeX,i-1));
            UYO = squeeze(disp.x(dispRangeY,dispRangeX,i-1));
         end
         
         % for first image, set FE to 1, otherwise, use previously tracked
         % position
         if i ==1
             rangeFE = 1;
             rangeIE = size(phi.x,2);
         else
            rangeIE = indIncX(i-1);
         end
                
        % pad with constant value
        if strcmp(gridMethodOpts.gmPadMethod,'constant')
            phi.x(:,1:rangeFE,i) = repmat(phi.x(:,rangeFE+1,i),1,rangeFE);
            phi.x(:,rangeIE:end,i) = repmat(phi.x(:,rangeIE-grid.pxPerPeriod,i),1,size(phi.x,2)-rangeIE+1);
            
        elseif strcmp(gridMethodOpts.gmPadMethod,'linear')
            % ------------ could be built into function
            % function to linear interpolate data over poor data regions in
            % dispT fields
            xToFit1 = rangeFE+1:rangeFE+1+grid.pxPerPeriod;
            xToFit2 = rangeIE-2*grid.pxPerPeriod:rangeIE-grid.pxPerPeriod;

            xToExtrap1 = 1:rangeFE;
            xToExtrap2 = rangeIE-grid.pxPerPeriod+1:size(disp.x,2);

            [sy,~,~] = size(phi.x(dispRangeY,dispRangeX,:));

            for p = 1:sy   
                % create vectors of displacements to fit
                % fit smoothed map from edges inward by fitWindow plus one pitch
                varXToFit1 = phi.x(p,xToFit1,i);
                varXToFit2 = phi.x(p,xToFit2,i);
                % - interpolate
                x1 = [ones(length(xToFit1),1) xToFit1'];
                x2 = [ones(length(xToFit2),1) xToFit2'];
                r1 = x1\varXToFit1';
                r2 = x2\varXToFit2';

                diff1 = phi.x(p,max(xToExtrap1)+1,i) - (r1(2)*xToFit1(1) + r1(1));
                diff2 = phi.x(p,min(xToExtrap2)-1,i) - (r2(2)*xToFit2(end) + r2(1));

                r1(1) = r1(1) + diff1;
                r2(1) = r2(1) + diff2;

                % evaluate function at padded coordinates
                phi.x(p,xToExtrap1,i) = r1(2)*xToExtrap1 + r1(1);
                phi.x(p,xToExtrap2,i) = r2(2)*xToExtrap2 + r2(1);
            end
            % ---------------- end function    
        end
         
         [disp.x(dispRangeY,dispRangeX,i), disp.y(dispRangeY,dispRangeX,i),~,~,~, disp.rot(dispRangeY,dispRangeX,i)]...
            = calculate_U_EPS(grid.pxPerPeriod,squeeze(phi.x(dispRangeY,dispRangeX,1)),squeeze(phi.y(dispRangeY,dispRangeX,1)),...
            squeeze(phi.x(dispRangeY,dispRangeX,i)),squeeze(phi.y(dispRangeY,dispRangeX,i)),...
            UXO,UYO,gridMethodOpts.dispCalcMethod,100,i);
        
        % track free edge and impact edge for phase padding
        if i == 1
            % no movement of impact edge or free edge
            meanXIE(i) = 0; indIncX(i) = max(dispRangeX);
            meanXFE(i) = 0; indIncFX(i) = min(range.x);
        else
           % find motion of impact edge within one pitch of edge of sample
            meanXIE(i) = nanmean(disp.x(:,indIncX(i-1)-grid.pxPerPeriod-1,i));
            meanXFE(i) = nanmean(disp.x(:,min(range.x),i));
           % round up to nearest integer when displacement exceeds one pixel
            indIncX(i) = max(dispRangeX) + floor(meanXIE(i));
            indIncFX(i) = min(range.x)+ floor(meanXFE(i));%
           % catch incase poor data suggests displacment moves out of
           % field of view (shouldn't be needed)
            if indIncX(i) > size(disp.x,2)
                indIncX(i) = size(disp.x,2);
            end
        end
        
        % define deformed region of interest (indIncFX (index of free
        % edge), indIncX (index of impact edge)
        dispRangeXM = indIncFX(i):indIncX(i);
    
        % number of columns to replac at the impact edge 
        numColsIE = size(disp.x,2)-indIncX(i)+grid.pxPerPeriod;     
                
        dispT.x(isnan(disp.x))=0; disp.y(isnan(disp.y))=0;
        rangeFE = min(dispRangeXM);
        if rangeFE < 1
            rangeFE = 1;
        end
    else
         % ------------------------ fixed ROI approach -----------------------------
        [disp.x(dispRangeY1,dispRangeX1,i), disp.y(dispRangeY1,dispRangeX1,i),~,~,~, disp.rot(dispRangeY1,dispRangeX1,i)]...
            = calculate_U_EPS(grid.pxPerPeriod,squeeze(phi.x(dispRangeY1,dispRangeX1,1)),squeeze(phi.y(dispRangeY1,dispRangeX1,1)),...
            squeeze(phi.x(dispRangeY1,dispRangeX1,i)),squeeze(phi.y(dispRangeY1,dispRangeX1,i)),gridMethodOpts.dispCalcMethod,100);
    end
end

%--------------------------------------------------------------------------
% 6) Convert Displacement and Crop to ROI 
% Convert the displacement from pixels to mm
disp.x = disp.x*(grid.pitch/grid.pxPerPeriod); 
disp.y = disp.y*(grid.pitch/grid.pxPerPeriod);

% Crop the image to the ROI
if strcmp(gridMethodOpts.dispROIMethod,'updateROI')
    % update ROI based on moving sample
    disp.x = disp.x(range.y,range.x+min(indIncX),:);
    disp.y = disp.y(range.y,range.x+min(indIncX),:);
else
    disp.x = disp.x(range.y,range.x,:);
    disp.y = disp.y(range.y,range.x,:);
end

% Create the position mesh grid
pos.x = grid.mPerPx/2:grid.mPerPx:(grid.mPerPx*size(disp.x,2));
pos.y = grid.mPerPx/2:grid.mPerPx:(grid.mPerPx*size(disp.x,1));
[pos.xGrid,pos.yGrid] = meshgrid(pos.x,pos.y);
pos.xStep = grid.mPerPx;
pos.yStep = grid.mPerPx;
end