%%%% for a model or models plot the fit and data, as normalised grasping
%%%% position
%%% updated Feb 2025 to do different coloured markers for reviewer 1 %%%
clear all;
close all;

%Expt_Num = 1;
FIT_BOTH = 0;
%ModelTypes = [1 17]; % both regular space
%ModelTypes = [1 7]; % both regular space
ModelTypes = 3; % log space

FIT_BOTH_LIST = [0 0]; % whether fit both for model 1 or 2

PLOT_IND = 0;
PLOT_AV = 1;
lightgrey = [0.9 0.9 0.9];
darkgrey = [0.8 0.8 0.8];
d2grey = [0.3 0.3 0.3];

%subs = [1 2 3]; % sub numbers 1 to 20
subs = 1:20; % sub numbers 1 to 20
Define_Vars;

for(Expt_Num=1:3)
%for(Expt_Num=2:3)
    %%% load data and organise
    [DataA] = load_data(Expt_Num);
    subList = unique(DataA(:,1));

    switch Expt_Num
        case 1, COND_A=1;
        case 2, load('../Data/Data_Exp2b.mat'); DataB = Data; COND_A=1;  COND_B=-1;
        case 3, load('../Data/Data_Exp3b.mat'); DataB = Data; COND_A=-1;  COND_B=1;
        otherwise, fprintf('not a two-parter\n');
    end

    MeanHand(Expt_Num) = figure('Name', strcat('E', num2str(Expt_Num), 'Average'));
   % make_violin_plot
    if(Expt_Num==1)     
        NPARTS=1;
    else
        NPARTS=2;
    end
    for(pp=1:NPARTS)
        if(Expt_Num==1)
            subplot(4, 5, [1 2 6 7 11 12]);
        else
            subplot(3, 2, pp);
        end
        if(pp==1)
            DataP = DataA;
        else
            DataP = DataB;
        end
    hold on;
    bh = bar(1:9, [ObsI.Metal_Len; 32 - ObsI.Metal_Len], 'stacked');
    set(bh, 'FaceColor', 'Flat')
    bh(1).CData = lightgrey;  % Change color to first level
    bh(2).CData = darkgrey; 
    xlabel('Object', 'FontSize', 14);
    ylabel('Grasp Location', 'FontSize', 14);
    %%% scatter on top
    nt = 900;

    c = zeros(nt, 3); % array of colour triples per dot
    t1s = 1:45:(nt-44); %% indices of initial trials, per object
    tends = 45:45:nt; %% indices of final trials, per object

    c(t1s, :) = repmat([1 0 0], 20, 1); %%% assign colour of first trials
    c(tends, :) = repmat([0 1 0], 20, 1); %%% assign colour of last trials
    col1 = [1 0 0]; % first 
    coln = [0 1 0]; % last
    %%% actually set the order of plotting for all trials
    for(tt=1:45)
        %%% get indices of nth trial (1 to 45) for all subs in full array
        t_indices = tt:45:(nt-(45-tt));
        %%%% sort order so that first and last points are plotted last
        if(tt==1) % subjects' first trial
            place(t_indices)=44; % plot in penultimate place
            c(t_indices, :)=repmat(col1, 20, 1); % initial points red
        elseif(tt==45)
            place(t_indices)=45; % plot last
            c(t_indices, :)=repmat(coln, 20, 1); % initial points red
        else
            %place(t_indices)=tt; % for now, leave these ordered according to trial number? or could randomise
            place(t_indices) = randi(43, 1, 20);
            %%% colour on continuum from first to last trial colours %%
            col_tt = col1 + (coln-col1)*((tt-1)/45);
            % mix colour with grey...?
            col_tt = (col_tt + d2grey)/2;
            c(t_indices, :)=repmat(col_tt, 20, 1);
        end
    end

    %%% set size of each dot %%%%
    sz = 50*ones(1, nt);
    sz([t1s tends]) = 60;
    %%% get the correct order to plot all the points
    [sorted, order] = sort(place);
    sc = swarmchart(DataP(order, 3),DataP(order, 5),sz(order),c(order, :), "filled");
    sc.MarkerFaceAlpha=0.75;
    sc.MarkerEdgeAlpha=1;
    xlim([0.4 9.6]);
    ylim([-0.1 32.1]);
    yline(16, 'k--');
    axis square;
        %%% true CoM
    for(oo=1:9)
        if(Expt_Num==2 && pp==2 || Expt_Num==3 && pp==1)
            plot([oo-0.4 oo+0.4], [ObsI.TrueCoM_R(oo) ObsI.TrueCoM_R(oo)], 'Color',  [0.1 0.1 1], 'LineWidth', 3);
        else
            plot([oo-0.4 oo+0.4], [ObsI.TrueCoM(oo) ObsI.TrueCoM(oo)], 'Color',  [0.1 0.1 1], 'LineWidth', 3);
        end
    end
    end
    %legend({'', '', 'First', 'Last', 'Intermediate', 'Centre of Mass'})
    legend({'', '', 'Intermediate', 'Geometric Centre', 'Centre of Mass'})


    ModelN = 3;
    MODEL = model_setup(ModelN);
    FileName = get_resultsFileName(Expt_Num, FIT_BOTH, MODEL);

    load(FileName, 'bestError', 'bestParams');
    bestParamsA = bestParams;
    bestErrorA = bestError;
    for(ss = subs)
        [data19_A, data28_A] = getSubData(DataA, subList(subs(ss)));
        norm_data28(ss,1:Ntrials28) = grasp2norm2(data28_A(:,2),data28_A(:,1), ObsI, COND_A);
        raw_data28(ss, 1:Ntrials28) = data28_A(:,2);
        %%% create fits
        X = 0:(Ntrials28-1); % for fits
        ParSet_A = setParams(bestParamsA(ss, :), MODEL, data19_A, 1);
        if(MODEL.LSpace==1)
            Est_D(ss, 1:Ntrials28) = (ParSet_A.DensStart - ParSet_A.DensEnd) * exp(-ParSet_A.LearnRate * X) + ParSet_A.DensEnd;
        else
            %%% in here for final model
            Est_LogD = (ParSet_A.DensStart - ParSet_A.DensEnd) * exp(-ParSet_A.LearnRate * X) + ParSet_A.DensEnd;
            Est_D(ss, 1:Ntrials28) = exp(Est_LogD);
        end
        %%% store learning rates %%%%
        LRate_A(ss) = ParSet_A.LearnRate;
        Fit_G(ss,1:Ntrials28) = dens2grasp2(squeeze(Est_D(ss,1:Ntrials28)),data28_A(:,1),ObsI, COND_A);
        normFit_G(ss, 1:Ntrials28) = grasp2norm2(squeeze(Fit_G(ss,1:Ntrials28))',data28_A(:,1), ObsI, COND_A);

        % X = 1:Ntrials28; % just for plotting
        %%% part 2B and 3B
        if(Expt_Num>1)
            %%%% load best params and set up model
            ModelN = 20;
            MODEL = model_setup(ModelN);
            FileName = get_resultsFileName(Expt_Num+0.5, FIT_BOTH, MODEL);
            load(FileName, 'bestError', 'bestParams');
            bestParamsB = bestParams;
            bestErrorB = bestError;

            [data19_B, data28_B] = getSubData(DataB, subList(subs(ss)));
            norm_data28(ss,Ntrials28+1:2*Ntrials28) = grasp2norm2(data28_B(:,2),data28_B(:,1), ObsI, COND_B);
            raw_data28(ss, Ntrials28+1:2*Ntrials28) = data28_B(:,2);

            ParSet_B = setParams(bestParamsB(ss, :), MODEL, data19_B, 1);

            %%% set carried over startpoint for density
            if(MODEL.LSpace==1) % regular
                ParSet_B.DensStart = Est_D(ss, Ntrials28);
            else % log space
                ParSet_B.DensStart = log(Est_D(ss, Ntrials28));
            end
            LRate_B(ss) = ParSet_B.LearnRate;
            if(MODEL.LSpace==1)
                Est_D(ss, Ntrials28+1:2*Ntrials28) = (ParSet_B.DensStart - ParSet_B.DensEnd) * exp(-ParSet_B.LearnRate * X) + ParSet_B.DensEnd;
            else
                Est_LogD = (ParSet_B.DensStart - ParSet_B.DensEnd) * exp(-ParSet_B.LearnRate * X) + ParSet_B.DensEnd;
                Est_D(ss, Ntrials28+1:2*Ntrials28) = exp(Est_LogD);
            end
            Fit_G(ss,Ntrials28+1:2*Ntrials28) = dens2grasp2(squeeze(Est_D(ss,Ntrials28+1:2*Ntrials28)),data28_B(:,1),ObsI, COND_B);
            normFit_G(ss, Ntrials28+1:2*Ntrials28) = grasp2norm2(squeeze(Fit_G(ss, Ntrials28+1:2*Ntrials28))',data28_B(:,1), ObsI, COND_B);
        end
        X1 = 1:Ntrials28; X2 = Ntrials28+1:Ntrials28*2;
        if(PLOT_IND)
            if(mm==1)
                figHand(ss) = figure('Name', strcat('sub', num2str(ss)));
            else
                figure(figHand(ss))
            end
            %%% plot in two halves with different colours
            hold on;
            plot(X1,norm_data28(ss,X1),'*');

            plot(X1, squeeze(normFit_G(ss, X1)), 'r-');

            if(Expt_Num>1)
                plot(X2,norm_data28(ss,X2),'o');
                switch mm
                    case 1, plot(X2, squeeze(normFit_G(ss, X2)), 'r-');
                    case 2, plot(X2, squeeze(normFit_G(ss, X2)), 'g-');
                end
            end

            xlabel('trial number', 'FontSize', 14);
            ylabel('Normalised Grasping Pos', 'FontSize', 14);
            yline(0, 'r--', 'LineWidth', 2);
            yline(1, 'k--', 'LineWidth', 2);

        end
    end

    %%% put t-tests in here
     %%% t-test on first 28 grasp (not fit?)
    [H,P] = ttest(raw_data28(:, 1), 16);
    fprintf('Experiment: %d, mean start G: %f, t-test against 1 (or 0): %f\n', Expt_Num, mean(raw_data28(:, 1)), P);
    [H,P] = ttest(raw_data28(:, 2), 16);
    fprintf('Experiment: %d, mean 2nd trial G: %f, t-test against 1 (or 0): %f\n', Expt_Num, mean(raw_data28(:, 2)), P);
    
    if(Expt_Num>1)
    %%% t-test for part B
    % now have saved raw data for part B 
    [H,P] = ttest(raw_data28(:, 36), 16);
    fprintf('Experiment: %d, **part B** mean start G: %f, t-test against 16: %f\n', Expt_Num, mean(raw_data28(:, 36)), P);
    end

    %%% t-test normalised  first 28 grasp
    [H,P] = ttest(norm_data28(:, 1), 0);
    fprintf('Experiment: %d, mean start G norm: %f, t-test against 1 (or 0): %f\n', Expt_Num, mean(norm_data28(:, 1)), P);

    if(Expt_Num>1)
    %%% t-test for part B
    [H,P] = ttest(norm_data28(:, 36), 0);
    fprintf('Experiment: %d, **part B** mean start G norm: %f, t-test against 0: %f\n', Expt_Num, mean(norm_data28(:, 36)), P);
    end


    if(PLOT_AV)
        %%% get mean and std across subs
        MeanData = mean(norm_data28(subs,:), 1);
        SEData = std(norm_data28(subs,:), [], 1)/sqrt(length(subs));

        MeanFit = mean(normFit_G(subs, :), 1);
        SE_Fit = std(normFit_G(subs, :), [], 1)/sqrt(length(subs));

        lower = MeanFit - SE_Fit;
        upper = MeanFit + SE_Fit;

        figure(MeanHand(Expt_Num)); 
   
        if(Expt_Num==1)
            subplot(4, 5, [3 4 5 8 9 10])
        else
            subplot(3, 2, 3:4);
        end
        hold on;

        if(Expt_Num>1)
            NPARTS = 2;
        else
            NPARTS = 1;
        end

        xlabel('Trial Number', 'FontSize',14);
        ylabel('Normalised Grasp Location', 'FontSize',14)
        yline(0, '--', 'LineWidth', 2);
        yline(1, '--', 'LineWidth', 2);
        % if(Expt_Num==1)
        %     axis([0 36 -0.2 1.4]);
        % end

        for(pp=1:NPARTS)
            switch pp
                case 1, X = X1;
                case 2, X = X2;
            end
            %%%% plot fit
            %%%% shaded region for +/- 1SE

            plot(X, lower(X), 'r', 'LineWidth', 0.8);
            plot(X, upper(X), 'b', 'LineWidth', 0.8);
            X_loop = [X, fliplr(X)];
            inBetween = [lower(X), fliplr(upper(X))];
            fill(X_loop, inBetween, lightgrey);
            %%% mean fit
            plot(X, MeanFit(X), 'k', 'LineWidth', 3)

            %%%% plot data
            errorbar(X,  MeanData(X), SEData(X), 'r*', 'LineWidth', 2);
            if(Expt_Num==1)
                axis([0 36 -0.2 1.3]);
                txt = 'Geometric Centre';
                text(15,0.01,txt, 'FontSize',14);
                txt = 'Centre of Mass';
                text(22,1.1,txt, 'FontSize',14);
            elseif(Expt_Num==2)
                axis([0 71 -0.9 1.6]);
                   txt = 'Geometric Centre';
                text(9,0.01,txt, 'FontSize',14);
                txt = 'Centre of Mass';
                text(28,1.1,txt, 'FontSize',14);
            else
                axis([0 71 -0.9 1.6]);
                   txt = 'Geometric Centre';
                text(18,0.1,txt, 'FontSize',14);
                txt = 'Centre of Mass';
                text(1,1.1,txt, 'FontSize',14);
            end
        end


        %%%%% plot density fits
        if(Expt_Num==1)
            subplot(4, 5, [13 14 15 18 19 20])
        else
            subplot(3, 2, 5:6);
        end
        hold on;
        meanD = mean(log(Est_D), 1);

        xlabel('Trial Number', 'FontSize',14);
        ylabel('Log Density Ratio', 'FontSize',14);

        yline(0, '--', 'LineWidth', 2);
        yline(log(3), '--','LineWidth', 2);


        for(pp=1:NPARTS)
            switch pp
                case 1, X = X1;
                case 2, X = X2;
            end
            %%%% plot fit line for each sub
            plot(X, log(Est_D(:, X)), 'LineWidth', 0.8, 'Color', 0.5*[1 1 1]);
            %%% mean fit
            plot(X, meanD(X), 'k', 'LineWidth', 2.5, 'Color', [1 0 0])
        end
        if(Expt_Num==1)
                axis([0 36 -0.5 2.2]);
                txt = 'True Density Ratio';
                text(17.1, 1.5, txt, 'FontSize',14); 
                plot([15 17], [log(3) 1.5], 'k-', 'LineWidth', 2);
                txt = 'Equal Density';
                text(15, -0.1, txt, 'FontSize',14);   
        elseif(Expt_Num==2)
                axis([0 71 -2.5 2.5]);
                yline(log(1/3), '--');
                txt = 'Natural Density Ratio';
                text(45, 1.2, txt, 'FontSize',14);
                txt = 'Inverted Density Ratio';
                text(18, -1.2, txt, 'FontSize',14);
               % plot([15 17], [log(3) 1.5], 'k-', 'LineWidth', 2);
                txt = 'Equal Density';
                text(5, -0.2, txt, 'FontSize',14);
        elseif(Expt_Num==3)
             axis([0 71 -2.5 2.5]);
                yline(log(1/3), '--');
                txt = 'Natural Density Ratio';
                text(20, 1.2, txt, 'FontSize',14);
                txt = 'Inverted Density Ratio';
                text(40, -1.2, txt, 'FontSize',14);
               % plot([15 17], [log(3) 1.5], 'k-', 'LineWidth', 2);
                txt = 'Equal Density';
                text(5, 0.2, txt, 'FontSize',14);
            end
    end
end
