%%%% for a model or models find the final fitted density
clear all;
close all;

Expt_List = [2 3];
FIT_BOTH = 1;
%ModelTypes = [1 17]; % both regular space
ModelTypes = [1 3]; % both regular space

FIT_BOTH_LIST = [0 0]; % whether fit both for model 1 or 2

PLOT_IND = 0;
PLOT_AV = 1;

%subs = [1 2 3]; % sub numbers 1 to 20
subs = 1:20; % sub numbers 1 to 20
Define_Vars;

%%% load data and organise
for(Expt_Num = Expt_List)
    [DataA] = load_data(Expt_Num);
    subList = unique(DataA(:,1));
    %if(FIT_BOTH)
    switch Expt_Num
        case 1, COND_A=1;
        case 2, load('../Data/Data_Exp2b.mat'); DataB = Data; COND_A=1;  COND_B=-1;
        case 3, load('../Data/Data_Exp3b.mat'); DataB = Data; COND_A=-1;  COND_B=1;
        otherwise, fprintf('not a two-parter\n');
    end
    %end


    for(mm=1:length(ModelTypes))
        FIT_BOTH=FIT_BOTH_LIST(mm);
        ModelN = ModelTypes(mm);
        MODEL = model_setup(ModelN);
        FileName = get_resultsFileName(Expt_Num, FIT_BOTH, MODEL);
        if(FIT_BOTH==0 && Expt_Num>1)
            FileNameB = get_resultsFileName(Expt_Num+0.5, FIT_BOTH, MODEL);
            load(FileNameB, 'bestError', 'bestParams');
            bestParamsB = bestParams;
            bestErrorB = bestError;
        end
        load(FileName, 'bestError', 'bestParams');
        bestParamsA = bestParams;
        bestErrorA = bestError;
        for(ss = subs)
            [data19_A, data28_A] = getSubData(DataA, subList(subs(ss)));
            norm_data28(ss,1:Ntrials28) = grasp2norm2(data28_A(:,2),data28_A(:,1), ObsI, COND_A);

            %%% create fits
            X = 0:(Ntrials28-1); % for fits
            ParSet_A = setParams(bestParamsA(ss, :), MODEL, data19_A, 1);
            if(MODEL.LSpace==1)
                Est_D(mm, ss, 1:Ntrials28) = (ParSet_A.DensStart - ParSet_A.DensEnd) * exp(-ParSet_A.LearnRate * X) + ParSet_A.DensEnd;
            else
                Est_LogD = (ParSet_A.DensStart - ParSet_A.DensEnd) * exp(-ParSet_A.LearnRate * X) + ParSet_A.DensEnd;
                Est_D(mm, ss, 1:Ntrials28) = exp(Est_LogD);
            end
            %%% store learning rates %%%%
            %%% save final density value
            FinalDensity(ss) = Est_D(mm, ss, Ntrials28);

            LRate_A(mm, ss) = ParSet_A.LearnRate;
            Fit_G(mm, ss,1:Ntrials28) = dens2grasp2(squeeze(Est_D(mm, ss,1:Ntrials28))',data28_A(:,1),ObsI, COND_A);
            normFit_G(mm, ss, 1:Ntrials28) = grasp2norm2(squeeze(Fit_G(mm, ss,1:Ntrials28)),data28_A(:,1), ObsI, COND_A);

            % X = 1:Ntrials28; % just for plotting
            if(Expt_Num>1)
                [data19_B, data28_B] = getSubData(DataB, subList(subs(ss)));
                norm_data28(ss,Ntrials28+1:2*Ntrials28) = grasp2norm2(data28_B(:,2),data28_B(:,1), ObsI, COND_B);
                if(FIT_BOTH)
                    ParSet_B = setParams(bestParams(ss, :), MODEL, data19_B, 2);
                else
                    ParSet_B = setParams(bestParamsB(ss, :), MODEL, data19_B, 1);
                end


                LRate_B(mm, ss) = ParSet_B.LearnRate;

                %%%% only doing this carry over if specified by model
                if(FIT_BOTH && isnan(ParSet_B.DensStart))
                    if(MODEL.LSpace==1) % regular
                        ParSet_B.DensStart = Est_D(mm, ss, Ntrials28);
                    else % log space
                        ParSet_B.DensStart = log(Est_D(mm, ss, Ntrials28));
                    end
                end
                if(MODEL.LSpace==1)
                    Est_D(mm, ss, Ntrials28+1:2*Ntrials28) = (ParSet_B.DensStart - ParSet_B.DensEnd) * exp(-ParSet_B.LearnRate * X) + ParSet_B.DensEnd;
                else
                    Est_LogD = (ParSet_B.DensStart - ParSet_B.DensEnd) * exp(-ParSet_B.LearnRate * X) + ParSet_B.DensEnd;
                    Est_D(mm, ss, Ntrials28+1:2*Ntrials28) = exp(Est_LogD);
                end
                Fit_G(mm, ss,Ntrials28+1:2*Ntrials28) = dens2grasp2(squeeze(Est_D(mm, ss,Ntrials28+1:2*Ntrials28))',data28_B(:,1),ObsI, COND_B);
                normFit_G(mm, ss, Ntrials28+1:2*Ntrials28) = grasp2norm2(squeeze(Fit_G(mm, ss, Ntrials28+1:2*Ntrials28)),data28_B(:,1), ObsI, COND_B);
            end
            X1 = 1:Ntrials28; X2 = Ntrials28+1:Ntrials28*2;
            if(PLOT_IND)
                if(mm==1)
                    figHand(ss) = figure('Name', strcat('sub', num2str(ss)));
                else
                    figure(figHand(ss))
                end
                %%% plot in two halves with different colours

                hold on;
                plot(X1,norm_data28(ss,X1),'*');
                switch mm
                    case 1, plot(X1, squeeze(normFit_G(mm, ss, X1)), 'r-');
                    case 2, plot(X1, squeeze(normFit_G(mm, ss, X1)), 'g-');
                end
                if(Expt_Num>1)
                    plot(X2,norm_data28(ss,X2),'o');
                    switch mm
                        case 1, plot(X2, squeeze(normFit_G(mm, ss, X2)), 'r-');
                        case 2, plot(X2, squeeze(normFit_G(mm, ss, X2)), 'g-');
                    end
                end

                xlabel('trial number');
                ylabel('Normalised Grasping Pos');
                yline(0, 'r--');
                yline(1, 'k--');

            end
        end
        if(PLOT_AV)
            %%% get mean and std across subs
            MeanData = mean(norm_data28(subs,:), 1);
            SEData = std(norm_data28(subs,:), [], 1)/sqrt(length(subs));

            MeanFit = mean(normFit_G(:, subs, :), 2);
            SE_Fit = std(normFit_G(:, subs, :), [], 2)/sqrt(length(subs));

            if(mm==1)
                MeanHand = figure('Name', 'Average'); hold on;
                errorbar(X1,  MeanData(X1), SEData(X1), 'k');
            else
                figure(MeanHand);
            end
            switch mm
                case 1, errorbar(X1, MeanFit(mm, X1), SE_Fit(mm, X1), 'r')
                case 2, errorbar(X1, MeanFit(mm, X1), SE_Fit(mm, X1), 'g')
            end
            if(Expt_Num>1)
                errorbar(X2,  MeanData(X2), SEData(X2), 'k');
                switch mm
                    case 1, errorbar(X2, MeanFit(mm, X2), SE_Fit(mm, X2), 'r')
                    case 2, errorbar(X2, MeanFit(mm, X2), SE_Fit(mm, X2), 'g')
                end
            end

        end
        DensityFile = strcat('EndD_Expt', num2str(Expt_Num), '_Model', num2str(ModelN), '.mat');
        FinalDensity
        save(DensityFile, 'FinalDensity');
    end
end

figure('Name', 'Learning Stats');
subplot(1, 2, 1)
plot(Est_D(1, :, 35), Est_D(2, :, 35), 'r*');
hold on;
plot([0 1], [0 1], 'k:')
xlabel('final density, model 1');
ylabel('final density, model 2');
subplot(1, 2, 2)
plot(LRate_A(1, :), LRate_A(2, :), 'r*');
hold on;
plot([0 3], [0 3], 'k:')
xlabel('learn rate, model 1');
ylabel('learn rate, model 2');