% simulating categorisation data with 2 or more categories
clear;
close all;
clc;

ncat = 6; % maximum categories involved in a simulation
ndim = ncat-1; % number of dimensions to accommodate ncat
sigma = 1; % of noise - assumed same for all categories
sam_per_cat = 1000000; % number of estimates for each category
method = 3; % way of computing dprime - details in Supp Mat

% set values of d to span
d_incs = 60;
maxd = 6;
ds = linspace(0,maxd,d_incs);

% set categorical agreement to span
minA=.3;
maxA=1;
A_incs = 15;
As = linspace(minA,maxA,A_incs);

load dprime_lookuptable.mat
dp2 = zeros(ncat-1,d_incs,A_incs);
dp2_corr = zeros(ncat-1,d_incs,A_incs);

for i = 1:5:d_incs
    
    for j = 1:5:A_incs

        disp(['Trying d=',num2str(i), '/',num2str(d_incs),', pA=',num2str(j), '/',num2str(A_incs),'...'])
        d = ds(i); % distance between categories
        pA = As(j);

        % set up category locations
        locn(1, :) = [-d/2 zeros(1, ndim-1)];
        locn(2, :) = [d/2 zeros(1, ndim-1)];
        dist(2) = norm(locn(1, :)-locn(2, :), 2);

        for catn=3:ncat

            % put next category in center of in preceding ones
            locn(catn, :) = mean(locn(1:(catn-1), :));

            % difference between current squared distance from preceding locations, and required distance (d*d)
            newdist_squared = d*d - sum((power(locn(1, :)-locn(catn, :), 2)));

            % add required amount to the catn-1 dimension to make all locations equally spaced
            locn(catn, catn-1) = sqrt(newdist_squared);

            % check the distance between new location and old ones
            dist(catn) = norm(locn(catn-1, :)-locn(catn, :), 2);

        end

        % set up samples for each category location plus gaussian noise
        for catn=1:ncat
            sam(catn, :, :) = locn(catn, :)' + sigma*randn(ndim, sam_per_cat);
        end


        % we try tasks with 6 categories
        for ncat_task=4:2:6

            disp(['Trying task with ', num2str(ncat_task),' Categories...'])
            dist = [];
            gt = [];
            resp = [];

            for catn=1:ncat_task

                % grab trials that we know belong to ground-truth category catn
                sam_subset=squeeze(sam(catn, :, 1:sam_per_cat));

                % for this subset of samples, calculate absolute distance from each of the category locations
                dist_subset = [];
                for comp_cat=1:ncat_task
                    dist_vectors = locn(comp_cat, :)'-sam_subset;
                    dist_norm = sqrt(sum(power(dist_vectors, 2)));
                    dist_subset(comp_cat, :) = dist_norm;
                end

                % The category with the shortest distance is the responded category
                [~, resp_subset ]= min(dist_subset);
                gt_subset = catn*ones(1, sam_per_cat);

                % if simulating pAgree<1, we tweak the appropriate proportion of gt_subset to be different from gt
                if pA<1
                    n_diff = round(sam_per_cat*(1-pA));
                    GT_list = setdiff(1:ncat_task, catn);
                    REPLACE = 1;
                    gt_replace = randsample(GT_list,n_diff,REPLACE); % bit slow...
                    gt_subset(1:n_diff) = gt_replace;
                end   

                gt = [gt gt_subset];
                resp = [resp resp_subset];

            end
            
            if ncat_task == 4
                catcmp = 6;
            elseif ncat_task == 6
                catcmp = 4;
            end

            dp2(ncat_task-1,i,j) = mafc_dprime(resp, gt, method);
            dp2(ncat_task-1,i,j) = adj_dprime(dp2(ncat_task-1,i,j), ncat_task, catcmp, dp, As,pA);

        end
        
    end

end

% values should be more or less identical when d and agreement is held constant

d4 = squeeze(dp2(3,:,:));
d4 = d4(1:5:end,:);
d4(isnan(d4)) = 0;

d6 = squeeze(dp2(5,:,:));
d6 = d6(1:5:end,:);
d6(isnan(d6)) = 0;

maxval = max(d6(:));

figure; hold on;
plot([0,maxval],[0,maxval],'Color','k','LineWidth',1,'LineStyle','--')
scatter(d4(:,1),d6(:,1),70,'filled')
scatter(d4(:,6),d6(:,6),70,'filled')
scatter(d4(:,11),d6(:,11),70,'filled')
plot([0,maxval],[0,maxval],'Color','k','LineWidth',1,'LineStyle','--')

% perfekt!