%%% modelling the predicted centre of mass, under imperfect estimate of
%%% relative density of two parts
%%% exponential learning of density ratio
%%% fitting individual observers
%%% updated July 2024 to test LOG density ratio
%%% version 2 21st July - incorporating possibility of noise in estimate of
%%%  density ratio (or LDR). Adapted from FitExpDensityPPT_WJA.m

clear all;
close all;

% final model in paper is 3
for(ModelType=[5 6]) %%% could be any of many models tested here
    for(Expt_Num=[1 2 3]) % some combo of [1 2 2.5 3 3.5])
    Define_Vars;

    FIT_BOTH = 0;
    nIterations = 10;

    switch Expt_Num
        case 1, E_label = 'Exp1';
        case 2, E_label = 'Exp2a';
        case 2.5, E_label = 'Exp2b';
        case 3, E_label = 'Exp3a';
        case 3.5, E_label = 'Exp3b';
    end
    % load data
    load(strcat('../Data/Data_', E_label, '.mat'));
    DataA = Data;
    if(FIT_BOTH)
        if Expt_Num==2
            load('../Data/Data_Exp2b.mat'); DataB = Data;
        elseif Expt_Num==3
            load('../Data/Data_Exp3b.mat'); DataB = Data;
        else
            fprintf('not a two-parter\n');
        end
    end

    %%%% load stored density end points from Part A
    if(ModelType==19)
        Model_AN = 1; % model to get last density from
    elseif(ModelType==20)
        Model_AN = 3;
    else
        Model_AN = 0;
    end
    if(Model_AN)
        DensityFile = strcat('EndD_Expt', num2str(Expt_Num-0.5), '_Model', num2str(Model_AN), '.mat');
        load(DensityFile, 'FinalDensity');
    else
        FinalDensity = [];
    end

    MODEL = model_setup(ModelType);
    results_File = get_resultsFileName(Expt_Num, FIT_BOTH, MODEL);

    subList = unique(Data(:,1));
 
    %%% load existing best params and error, for pooled data
    if(exist(results_File))
        load(results_File, 'repsDone', 'bestParams', 'bestError', 'iterationParams', 'iterationError');
        fprintf('found results file\n');
    else
        fprintf('no results file\n');
        bestError = inf(1,length(subList));
        repsDone= zeros(1, length(subList));
        iterationError = Inf*ones(length(subList), nIterations);
        clear iterationParams;
        clear bestParams
    end

    %%% information about physical objects
    MetalLength = 0:4:32;
    PlasticLength = 32 - MetalLength;
    nOb = 9;
    %%% calculate centre of mass, assuming object and metal part starts at x=0
    for(ii=1:nOb)
        %%% Calculate pos and mass of two parts
        MetalPos(ii) = (MetalLength(ii)/2);
        PlasticPos(ii) =  MetalLength(ii) + (PlasticLength(ii)/2);
        PlasticMass(ii) = PlasticLength(ii);
    end

    for(ss=1:length(subList))
        if(Model_AN)
            DensStartSub = FinalDensity(ss);
        end
        sub_ID = subList(ss);
        if(repsDone(ss)<nIterations)
            %%% some reps remain, so prep ss data
            data_temp = DataA(:,1)==sub_ID; %%%% find all data for current ppt
            data_ss = DataA(data_temp, 1:5);

            %%% organise the data in best way to find likelihood
            %%% separate grasping data into trials for 1 or 9 vs the rest
            %%% just retain object number and grasp location
            index19_A = find(data_ss(:, 3)==1 | data_ss(:, 3)==9);
            data19_A = data_ss(index19_A, [3 5]);
            index28_A = find(data_ss(:, 3)>1 & data_ss(:, 3)<9);
            data28_A = data_ss(index28_A, [3 5]);
            %%% also save grasping location as index (as extracted by
            %%% histcount)
            for(tt=1:length(data28_A))
                [RespCount, ~] = histcounts(data28_A(tt, 2), Pos_edges);
                bins28_A(tt) = find(RespCount>0);
            end
            data28_A = [data28_A bins28_A'];
            %%%% set up the values needed to convert estimated density to estimated
            %%%% CoM (i.e. expected grasp location), for the trial order for this
            %%%% ppt

            ObInf.Metal_Len = MetalLength(data28_A(:, 1));
            ObInf.Metal_Pos = MetalPos(data28_A(:, 1));
            ObInf.Plas_Pos = PlasticPos(data28_A(:, 1));
            ObInf.Plas_Mass = PlasticMass(data28_A(:, 1));

            if(FIT_BOTH==1)
                data_temp = DataB(:,1)==sub_ID; %%%% find all data for current ppt
                data_ss = DataB(data_temp, 1:5);
                index19_B = find(data_ss(:, 3)==1 | data_ss(:, 3)==9);
                data19_B = data_ss(index19_B, [3 5]);
                index28_B = find(data_ss(:, 3)>1 & data_ss(:, 3)<9);
                data28_B = data_ss(index28_B, [3 5]);
                for(tt=1:length(data28_B))
                    [RespCount, ~] = histcounts(data28_B(tt, 2), Pos_edges);
                    bins28_B(tt) = find(RespCount>0);
                end
                data28_B = [data28_B bins28_B'];

                %%% append the ObInf info for part B
                ObInf.Metal_Len = [ObInf.Metal_Len MetalLength(data28_B(:, 1))];
                ObInf.Metal_Pos = [ObInf.Metal_Pos MetalPos(data28_B(:, 1))];
                ObInf.Plas_Pos = [ObInf.Plas_Pos PlasticPos(data28_B(:, 1))];
                ObInf.Plas_Mass = [ObInf.Plas_Mass PlasticMass(data28_B(:, 1))];
               
            end

            %%%%% FIND BEST MODEL
            errorhand = @likelihood_PPT3;

            options = optimset('MaxFunEvals', 100*3); % default is 200 x n params

            for(jj=1:nIterations)
                if(jj<repsDone(ss))
                    %%% do nothing
                    fprintf('Subject: %d, Iteration: %d already done\n', ss, jj);
                else
                    fprintf('Subject: %d, Iteration: %d doing now\n', ss, jj);
                    clear ParamGuess;
                    if(MODEL.DensStart(1) == FREE)
                        ParamI = MODEL.DensStart(2); % which parameter it is
                        if(MODEL.LSpace==1)
                            % regular space
                            if(Expt_Num==2.5)
                                ParamGuess(ParamI) = 3;
                            elseif(Expt_Num==3.5)
                                ParamGuess(ParamI) = 1/3;
                            else
                                ParamGuess(ParamI) = 1;
                            end
                        else
                            % log space
                            if(Expt_Num==2.5)
                                ParamGuess(ParamI) = 1.1;
                            elseif(Expt_Num==3.5)
                                ParamGuess(ParamI) = -1.1;
                            else
                                ParamGuess(ParamI) = 0;
                            end
                        end
                    end
                    if(MODEL.DensEnd(1) == FREE)
                        ParamI = MODEL.DensEnd(2); % which parameter it is
                        if(Expt_Num==1 | Expt_Num==2 | Expt_Num==3.5)
                            if(MODEL.LSpace==1) % regular
                                ParamGuess(ParamI) = 3;
                            else % log space
                                ParamGuess(ParamI) = 1.1;
                            end
                        else
                            if(MODEL.LSpace==1) % regular
                                ParamGuess(ParamI) = 1/3;
                            else
                                ParamGuess(ParamI) = -1.1;
                            end
                        end
                    end
                    % learning rate
                    if(MODEL.LRate(1) == FREE)
                        ParamI = MODEL.LRate(2); % which parameter it is
                        ParamGuess(ParamI) = 1;
                    end
                    if(MODEL.RatSigma(1) == FREE)
                        ParamI = MODEL.RatSigma(2); % which parameter it is
                        ParamGuess(ParamI) = 2;
                    end
                    % grasping noise
                    if(MODEL.MotSigma(1) == FREE)
                        ParamI = MODEL.MotSigma(2); % which parameter it is
                        ParamGuess(ParamI) = 3;
                    else
                        if(FIT_BOTH)
                            MODEL.MotSigma(2) = std([data19_A(:, 2); data19_B(:, 2)]);
                        else
                            MODEL.MotSigma(2) = std(data19_A(:, 2));
                        end
                    end
                    if(MODEL.OrB(1) == FREE)
                        ParamI = MODEL.OrB(2); % which parameter it is
                        ParamGuess(ParamI) = 0.3; % bias in cm
                    else
                        MODEL.OrB(2) = 0;
                    end
                    %%% for part B
                    if(FIT_BOTH)
                        if(MODEL.DensStart_B(1) == FREE)
                            ParamI = MODEL.DensStart_B(2); % which parameter it is
                            if(MODEL.LSpace==1)
                                % regular space
                                ParamGuess(ParamI) = 1;
                            else
                                % log space
                                ParamGuess(ParamI) = 0;
                            end
                        end
                        if(MODEL.DensEnd_B(1) == FREE)
                            ParamI = MODEL.DensEnd_B(2); % which parameter it is
                            if(Expt_Num==2.5 | Expt_Num==3)
                                if(MODEL.LSpace==1) % regular
                                    ParamGuess(ParamI) = 3;
                                else % log space
                                    ParamGuess(ParamI) = 1.1;
                                end
                            else
                                if(MODEL.LSpace==1) % regular
                                    ParamGuess(ParamI) = 1/3;
                                else
                                    ParamGuess(ParamI) = -1.1;
                                end
                            end
                        end
                        if(MODEL.LearnRate_B(1) == FREE)
                            ParamI = MODEL.LearnRate_B(2); % which parameter it is
                            ParamGuess(ParamI) = 0.3;
                        end
                        if(MODEL.MotSigma_B(1) == FREE)
                            ParamI = MODEL.MotSigma_B(2); % which parameter it is
                            ParamGuess(ParamI) = 3;
                        else
                            if(FIT_BOTH)
                                MODEL.MotSigma_B(2) = std([data19_A(:, 2); data19_B(:, 2)]);
                            else
                                MODEL.MotSigma_B(2) = std(data19_A(:, 2));
                            end
                        end
                    end
                    ParamGuess1 = ParamGuess;
                    %%% need to add noise to param guesses
                    if(jj>1)
                        ParamGuess = ParamGuess1*0.0 + 2.0*rand*ParamGuess1;
                        index0 = find(ParamGuess1==0);
                        if(~isempty(index0))
                            ParamGuess(index0) = -0.5 + rand(1, length(index0));
                        end
                     
                    end
                    if(FIT_BOTH==0)
                        [ParamsTemp, ErrorTemp] = fminsearch(errorhand, ParamGuess, options, data28_A, data19_A, ObInf, MODEL, FIT_BOTH);
                    else
                        [ParamsTemp, ErrorTemp] = fminsearch(errorhand, ParamGuess, options, [data28_A; data28_B], [data19_A; data19_B], ObInf, MODEL, FIT_BOTH);
                    end
                    iterationParams(ss,jj,:) = ParamsTemp';
                    iterationError(ss,jj) = ErrorTemp';
                    if(ErrorTemp<bestError(ss))
                        fprintf('YAY improved fit: old: %f, new: %f\n', bestError(ss), ErrorTemp);
                        bestParams(ss,:) = ParamsTemp;%change to iterationParams, iterationError
                        bestError(ss) = ErrorTemp;
                    else
                        fprintf('DOH poor fit: old: %f, new: %f\n', bestError(ss), ErrorTemp);
                    end
                    repsDone(ss) = repsDone(ss)+1;
                    if(exist('bestParams'))
                        save(results_File, 'bestParams', 'bestError', 'iterationParams', 'iterationError', 'repsDone');
                    end
                end % if rep needs doing
            end % jj reps
        end % subs
    end

    % Plotting the Data
    if(Expt_Num==1 || Expt_Num==2 || Expt_Num==3.5)
        COND_A = 1; COND_B = -1;
    else
        COND_A = -1; COND_B = 1;
    end

    for (ss=1:length(subList))
        sub_ID = subList(ss);
        [data19, data28] = getSubData(DataA, sub_ID);
        %%%% set parameter values
        ParSet = setParams(bestParams(ss,:), MODEL, data19, 1);
        if(MODEL.DensStart(1)==STORED) % from file
            if(MODEL.LSpace==1)
                ParSet.DensStart = FinalDensity(ss);
            else
                ParSet.DensStart = log(FinalDensity(ss));
            end
        end
        %%%% get fit from parameters
        [Est_D] = getExpo(Ntrials28, ParSet, MODEL);
        Fit_D(ss, 1:Ntrials28) = Est_D;
        Fit_G(ss,1:Ntrials28) = dens2grasp2(Fit_D(ss,1:Ntrials28),data28(:,1),ObsI, COND_A);
        normFit_G(ss,1:Ntrials28) = grasp2norm2(Fit_G(ss,1:Ntrials28)',data28(:,1), ObsI, COND_A);
        norm_data28(ss,1:Ntrials28) = grasp2norm2(data28(:,2),data28(:,1), ObsI, COND_A);
        raw_data28(ss,1:Ntrials28) = data28(:,2);

        % expo fit must start with x=0
        X = 1:Ntrials28; % just for plotting
        if(FIT_BOTH)
            [data19_B, data28_B] = getSubData(DataB, sub_ID);
            
            %%% for stats for paper %%%
            raw_data28(ss, Ntrials28+1:2*Ntrials28) = data28_B(:, 2);
            
            ParSet_B = setParams(bestParams(ss,:), MODEL, data19_B, 2);
            if(MODEL.DensStart_B(1)==FIXED)
                if(MODEL.LSpace==1) % normal
                    ParSet_B.DensStart=Est_D(Ntrials28);
                else
                    ParSet_B.DensStart=log(Est_D(Ntrials28));
                end
            end
            [Est_D] = getExpo(Ntrials28, ParSet_B, MODEL);
            Fit_D(ss, Ntrials28+1:2*Ntrials28) = Est_D;

            Fit_G(ss,Ntrials28+1:2*Ntrials28) = dens2grasp2(Fit_D(ss,Ntrials28+1:2*Ntrials28),data28_B(:,1),ObsI, COND_B);
            normFit_G(ss,Ntrials28+1:2*Ntrials28) = -grasp2norm2(Fit_G(ss,Ntrials28+1:2*Ntrials28)',data28_B(:,1), ObsI, COND_B);
            norm_data28(ss,Ntrials28+1:2*Ntrials28) = -grasp2norm2(data28_B(:,2),data28_B(:,1), ObsI, COND_B);
            X = 1:2*Ntrials28; % just for plotting
        end
        SUBS_PLOTS = 0;
        if(SUBS_PLOTS)
            figure('Name', strcat('sub', num2str(ss)));
            %%% plot in two halves with different colours
            X1 = 1:Ntrials28; X2 = Ntrials28+1:Ntrials28*2;
  
            plot(X1,normFit_G(ss,X1),'r-');hold on;
            plot(X1,norm_data28(ss,X1),'*');
            if(FIT_BOTH)
                plot(X2,normFit_G(ss,X2),'g-');
                plot(X2,norm_data28(ss,X2),'o');
            end
            xlabel('trial number');
            ylabel('Normalised Grasping Pos');
            yline(0, 'r--');
            yline(1, 'k--');
        end
    end

    %%% t-test on first 28 grasp (not fit?)
    [H,P] = ttest(raw_data28(:, 1), 16);
    fprintf('Experiment: %d, mean start G: %f, t-test against 1 (or 0): %f\n', Expt_Num, mean(raw_data28(:, 1)), P);
    [H,P] = ttest(raw_data28(:, 2), 16);
    fprintf('Experiment: %d, mean 2nd trial G: %f, t-test against 1 (or 0): %f\n', Expt_Num, mean(raw_data28(:, 2)), P);
    
    if(FIT_BOTH || Expt_Num==2.5 || Expt_Num==3.5)
    %%% t-test for part B
    % now have saved raw data for part B 
    [H,P] = ttest(raw_data28(:, 36), 16);
    fprintf('Experiment: %d, **part B** mean start G: %f, t-test against 16: %f\n', Expt_Num, mean(raw_data28(:, 36)), P);
    end

    %%% t-test normalised  first 28 grasp
    [H,P] = ttest(norm_data28(:, 1), 0);
    fprintf('Experiment: %d, mean start G norm: %f, t-test against 1 (or 0): %f\n', Expt_Num, mean(norm_data28(:, 1)), P);

    if(FIT_BOTH || Expt_Num==2.5 || Expt_Num==3.5)
    %%% t-test for part B
    [H,P] = ttest(norm_data28(:, 36), 0);
    fprintf('Experiment: %d, **part B** mean start G norm: %f, t-test against 0: %f\n', Expt_Num, mean(norm_data28(:, 36)), P);
    end

    figure('Name', strcat('Expt', num2str(Expt_Num), 'Model', num2str(ModelType)));
    meanFit_G = mean(normFit_G(:,:),1); stdmeanFit_G = std(normFit_G(:,:),1)/sqrt(length(subList));
    meanData_G = mean(norm_data28(:,:),1);stdmeanData_G = std(norm_data28(:,:),1)/sqrt(length(subList));
    errorbar(X,meanFit_G, stdmeanFit_G,'-'); hold on;
    errorbar(X,meanData_G, stdmeanData_G,'*'); hold on;
    yline(0, 'k--');
    yline(1, 'r--');
    if(FIT_BOTH)
        axis([0 71 -1.5 1.5]);
        yline(-1, 'g--');
    else
        axis([0 36 -1 2]);
    end

    figure('Name', strcat('DR, Expt', num2str(Expt_Num), 'Model', num2str(ModelType)));
    hold on;
    for(ss=1:20)
        plot(X, squeeze(Fit_D(ss, :)));
    end

end
end