% simulating categorisation data with 2 or more categories
clear;
close all;
clc;

rng(pi);

ncat = 6; % maximum categories involved in a simulation
ndim = ncat-1; % number of dimensions to accommodate ncat
sigma = 1; % of noise - assumed same for all categories
sam_per_cat = 1000000; % number of estimates for each category
method = 3; % way of computing dprime - details in Supp Mat

% set values of d to span
d_incs = 60;
maxd = 6;
ds = linspace(0,maxd,d_incs);

% set categorical agreement to span
As = [0.715416666666667,0.769166666666667,.822916666666667,1];
A_incs = length(As);

% hard-code the prior distributions. Order doesn't matter
psem = [.3736,.0708,.1542,.1056,.0736,.2222];
nsem = round(psem.*sam_per_cat);

pstr = [.2625,.3153,.2611,.1611];
nstr = round(pstr.*sam_per_cat);

dp = zeros(ncat-1,d_incs,A_incs);
dp2 = zeros(ncat-1,d_incs,A_incs);

for i = 1:d_incs
    
    for j = 1:A_incs

        disp(['Trying d=',num2str(i), '/',num2str(d_incs),', pA=',num2str(j), '/',num2str(A_incs),'...'])
        d = ds(i); % distance between categories
        pA = As(j);

        % set up category locations
        locn(1, :) = [-d/2 zeros(1, ndim-1)];
        locn(2, :) = [d/2 zeros(1, ndim-1)];
        dist(2) = norm(locn(1, :)-locn(2, :), 2);

        for catn=3:ncat

            % put next category in center of in preceding ones
            locn(catn, :) = mean(locn(1:(catn-1), :));

            % difference between current squared distance from preceding locations, and required distance (d*d)
            newdist_squared = d*d - sum((power(locn(1, :)-locn(catn, :), 2)));

            % add required amount to the catn-1 dimension to make all locations equally spaced
            locn(catn, catn-1) = sqrt(newdist_squared);

            % check the distance between new location and old ones
            dist(catn) = norm(locn(catn-1, :)-locn(catn, :), 2);

        end

        % set up samples for each category location plus gaussian noise
        for catn=1:ncat
            sam(catn, :, :) = locn(catn, :)' + sigma*randn(ndim, sam_per_cat);
        end

        % simulate tasks with different numbers of categories


        % we try tasks with 2->ncat categories
        for ncat_task=4:2:ncat

            disp(['Trying task with ', num2str(ncat_task),' Categories...'])
            dist = [];
            gt = [];
            resp = [];

            for catn=1:ncat_task
                
                % use prior gt distribution. does this have any effect? 
                if ncat_task == 4
                    nt = nstr(catn);
                    pd = pstr;
                elseif ncat_task == 6
                    nt = nsem(catn);
                    pd = psem;
                end

                % grab trials that we know belong to ground-truth category catn
                sam_subset=squeeze(sam(catn, :, 1:nt));

                % for this subset of samples, calculate absolute distance from each of the category locations
                dist_subset = [];
                for comp_cat=1:ncat_task
                    dist_vectors = locn(comp_cat, :)'-sam_subset;
                    dist_norm = sqrt(sum(power(dist_vectors, 2)));
                    dist_subset(comp_cat, :) = dist_norm;
                end

                % The category with the shortest distance is the responded category
                [~, resp_subset ]= min(dist_subset);
                gt_subset = catn*ones(1, nt);

                % if simulating pAgree<1, we tweak the appropriate proportion of gt_subset to be different from gt
                if pA<1
                    n_diff = round(nt*(1-pA));
                    GT_list = setdiff(1:ncat_task, catn);
                    REPLACE = 1;
                    gt_replace = randsample(GT_list,n_diff,REPLACE); % bit slow...
                    gt_subset(1:n_diff) = gt_replace;
                end   

                gt = [gt gt_subset];
                resp = [resp resp_subset];

            end
            
            dp(ncat_task-1,i,j) = kafc_dprime(resp,gt,ncat_task);

        end
        
    end

end

save('dprime_lookuptable_compressed_wprior2','dp')

% load dprime_lookuptable_compressed_wprior.mat 
% dp(:,1,:) = 0;
% d_incs = 60;
% maxd = 6;
% ds = linspace(0,maxd,d_incs);

cols = {[0.5660 0.7740 0.2880],[0.7350 0.1780 0.2840],[0.5940 0.2840 0.6560]};

figure; hold on;
plot(dp(3,:,4),ds,'Color',cols{1},'LineWidth',3);
plot(dp(3,:,1),ds,'Color',cols{1},'LineWidth',3,'LineStyle',':');
plot(dp(5,:,4),ds,'Color',cols{2},'LineWidth',3);
plot(dp(5,:,3),ds,'Color',cols{2},'LineWidth',3,'LineStyle',':');

xlabel('Empirical Sensitivity d_e_m_p^\prime');
ylabel('Simulated Sensitivity d_s_i_m^\prime');
leg = legend('Spatial, \alpha = 1','Spatial, \alpha = 0.7154',...
    'Semantic, \alpha = 1','Semantic, \alpha = 0.8229','Location','SE');
set(gca,'FontSize',18);
set(gca,'XLim',[0 6]);
xticks(0:6);
yticks(0:6);
legend boxoff
box off;
set(gca,'LineWidth',2)
axis square;

% correlate the two variables? 
corrcoef(dp(3,:,4),dp2(3,:,4))
corrcoef(dp(3,:,1),dp2(3,:,1))
corrcoef(dp(5,:,4),dp2(5,:,4))
corrcoef(dp(5,:,3),dp2(5,:,3))
