% simulating categorisation data with 2 or more categories
clear;
close all;
clc;

rng(pi);

ncat = 6; % maximum categories involved in a simulation
ndim = ncat-1; % number of dimensions to accommodate ncat
sigma = 1; % of noise - assumed same for all categories
sam_per_cat = 300000; % number of estimates for each category
method = 3; % way of computing dprime - details in Supp Mat

% set values of d to span
d_incs = 60;
maxd = 6;
ds = linspace(0,maxd,d_incs);

% set categorical agreement to span
As = [0.715416666666667,0.769166666666667,.822916666666667,1];
A_incs = length(As);

dp = zeros(ncat-1,d_incs,A_incs);

for i = 1:d_incs
    
    for j = 1:A_incs

        disp(['Trying d=',num2str(i), '/',num2str(d_incs),', pA=',num2str(j), '/',num2str(A_incs),'...'])
        d = ds(i); % distance between categories
        pA = As(j);

        % set up category locations
        locn(1, :) = [-d/2 zeros(1, ndim-1)];
        locn(2, :) = [d/2 zeros(1, ndim-1)];
        dist(2) = norm(locn(1, :)-locn(2, :), 2);

        for catn=3:ncat

            % put next category in center of in preceding ones
            locn(catn, :) = mean(locn(1:(catn-1), :));

            % difference between current squared distance from preceding locations, and required distance (d*d)
            newdist_squared = d*d - sum((power(locn(1, :)-locn(catn, :), 2)));

            % add required amount to the catn-1 dimension to make all locations equally spaced
            locn(catn, catn-1) = sqrt(newdist_squared);

            % check the distance between new location and old ones
            dist(catn) = norm(locn(catn-1, :)-locn(catn, :), 2);

        end

        % set up samples for each category location plus gaussian noise
        for catn=1:ncat
            sam(catn, :, :) = locn(catn, :)' + sigma*randn(ndim, sam_per_cat);
        end

        % simulate tasks with different numbers of categories


        % we try tasks with 2->ncat categories
        for ncat_task=4:2:ncat

            disp(['Trying task with ', num2str(ncat_task),' Categories...'])
            dist = [];
            gt = [];
            resp = [];

            for catn=1:ncat_task

                % grab trials that we know belong to ground-truth category catn
                sam_subset=squeeze(sam(catn, :, 1:sam_per_cat));

                % for this subset of samples, calculate absolute distance from each of the category locations
                dist_subset = [];
                for comp_cat=1:ncat_task
                    dist_vectors = locn(comp_cat, :)'-sam_subset;
                    dist_norm = sqrt(sum(power(dist_vectors, 2)));
                    dist_subset(comp_cat, :) = dist_norm;
                end

                % The category with the shortest distance is the responded category
                [~, resp_subset ]= min(dist_subset);
                gt_subset = catn*ones(1, sam_per_cat);

                % if simulating pAgree<1, we tweak the appropriate proportion of gt_subset to be different from gt
                if pA<1
                    n_diff = round(sam_per_cat*(1-pA));
                    GT_list = setdiff(1:ncat_task, catn);
                    REPLACE = 1;
                    gt_replace = randsample(GT_list,n_diff,REPLACE); % bit slow...
                    gt_subset(1:n_diff) = gt_replace;
                end   

                gt = [gt gt_subset];
                resp = [resp resp_subset];

            end

            dp(ncat_task-1,i,j) = mafc_dprime_new(resp,gt);
            % [dp_1(ncat_task), hit_1(ncat_task), fa_1(ncat_task), dp_2(ncat_task), hit_2(ncat_task), fa_2(ncat_task), ~] = wjafc_dprime(resp,gt);

        end
        
    end

end

save('dprime_lookuptable_compressed','dp')


figure; hold on;
plot(dp(3,:,4),ds,'-b','LineWidth',2);
plot(dp(3,:,1),ds,':b','LineWidth',2);
plot(dp(5,:,4),ds,'-r','LineWidth',2);
plot(dp(5,:,3),ds,':r','LineWidth',2);
xlabel('k-way sensitivity d_k^\prime');
ylabel('2-way sensitivity d_2^\prime');
leg = legend('Spatial, \alpha = 1','Spatial, \alpha = 0.7154',...
    'Semantic, \alpha = 1','Semantic, \alpha = 0.8229','Location','SE');
set(gca,'FontSize',18);
set(gca,'XLim',[0 6]);
xticks(0:6);
yticks(0:6);
legend boxoff
box off;
set(gca,'LineWidth',2)
axis square;

% load dprime_lookuptable.mat
% 
% diffs = squeeze(dp(3,:,:)-dp(5,:,:));
% % diffs = flipud(diffs);
% % figure;
% % imagesc(diffs);
% % colormap copper
% % xlabel('Inter-Observer Agreement (%)')
% % ylabel('d');
% % set(gca,'LineWidth',2);
% % box off
% % set(gca,'FontSize',19)
% % xticks(3:3:15)
% % xticklabels(40:15:100)
% % yticks(0:10:60)
% % yticklabels(6:-1:0)
% 
% figure; hold on;
% cmap = copper(15);
% for i = 1:15
%     plot(dp(5,:,i),'LineWidth',2,'Color',cmap(i,:))
% end
% ylim([0,6])
% xticks(0:20:60)
% xticklabels(0:2:6)
% xlabel('Simulated Category Distance')
% ylabel('6-Category d''')
% set(gca,'LineWidth',2)
% set(gca,'FontSize',19)
% h = colorbar;
% % h.Limits = [.30,1];
% colormap copper
% box off;
% axis square
% 
% figure; hold on;
% for i = 1:15
%     plot(dp(3,:,i),'LineWidth',2,'Color',cmap(i,:))
% end
% ylim([0,6])
% xticks(0:20:60)
% xticklabels(0:2:6)
% xlabel('Simulated Category Distance')
% ylabel('4-Category d''')
% set(gca,'LineWidth',2)
% set(gca,'FontSize',19)
% h = colorbar;
% % h.Limits = [.30,1];
% colormap copper
% box off;
% axis square
% 
% 
% figure; hold on;
% for i = 1:15
%     plot(diffs(:,i),'LineWidth',2,'Color',cmap(i,:))
% end
% ylim([-.4,.4])
% xticks(0:20:60)
% xticklabels(0:2:6)
% xlabel('Simulated Category Distance')
% ylabel('\Deltad''')
% set(gca,'LineWidth',2)
% set(gca,'FontSize',19)
% h = colorbar;
% % h.Limits = [.30,1];
% colormap copper
% box off;
% axis square