/*
 * Decompiled with CFR 0.152.
 */
package infodynamics.measures.mixed.kraskov;

import infodynamics.measures.mixed.MutualInfoCalculatorMultiVariateWithDiscrete;
import infodynamics.utils.EmpiricalMeasurementDistribution;
import infodynamics.utils.EuclideanUtils;
import infodynamics.utils.MathsUtils;
import infodynamics.utils.MatrixUtils;
import infodynamics.utils.RandomGenerator;

public class MutualInfoCalculatorMultiVariateWithDiscreteKraskov
implements MutualInfoCalculatorMultiVariateWithDiscrete {
    protected static final double CUTOFF_MULTIPLIER = 1.5;
    protected int k = 4;
    protected double[][] continuousData;
    protected int[] discreteData;
    protected int[] counts;
    protected int base;
    protected int dimensions;
    protected boolean debug;
    protected double mi;
    protected boolean miComputed;
    protected EuclideanUtils normCalculator = new EuclideanUtils(2);
    protected double[][] xNorms;
    public static boolean tryKeepAllPairsNorms = true;
    public static int MAX_DATA_SIZE_FOR_KEEP_ALL_PAIRS_NORM = 2000;
    public static final String PROP_K = "k";
    public static final String PROP_NORM_TYPE = "NORM_TYPE";
    public static final String PROP_NORMALISE = "NORMALISE";
    protected boolean normalise = true;
    protected double[] means;
    protected double[] stds;

    @Override
    public void initialise(int n, int n2) {
        this.mi = 0.0;
        this.miComputed = false;
        this.xNorms = null;
        this.continuousData = null;
        this.means = null;
        this.stds = null;
        this.discreteData = null;
        this.dimensions = n;
        this.base = n2;
    }

    @Override
    public void setProperty(String string, String string2) {
        if (string.equalsIgnoreCase(PROP_K)) {
            this.k = Integer.parseInt(string2);
        } else if (string.equalsIgnoreCase(PROP_NORM_TYPE)) {
            this.normCalculator.setNormToUse(string2);
        } else if (string.equalsIgnoreCase(PROP_NORMALISE)) {
            this.normalise = Boolean.parseBoolean(string2);
        }
    }

    public void addObservations(double[][] dArray, double[][] dArray2) throws Exception {
        throw new RuntimeException("Not implemented yet");
    }

    public void addObservations(double[][] dArray, double[][] dArray2, int n, int n2) throws Exception {
        throw new RuntimeException("Not implemented yet");
    }

    public void setObservations(double[][] dArray, double[][] dArray2, boolean[] blArray, boolean[] blArray2) throws Exception {
        throw new RuntimeException("Not implemented yet");
    }

    public void setObservations(double[][] dArray, double[][] dArray2, boolean[][] blArray, boolean[][] blArray2) throws Exception {
        throw new RuntimeException("Not implemented yet");
    }

    public void startAddObservations() {
        throw new RuntimeException("Not implemented yet");
    }

    public void finaliseAddObservations() {
        throw new RuntimeException("Not implemented yet");
    }

    @Override
    public void setObservations(double[][] dArray, int[] nArray) throws Exception {
        int n;
        if (dArray.length != nArray.length) {
            throw new Exception("Time steps for observations2 " + nArray.length + " does not match the length " + "of observations1 " + dArray.length);
        }
        if (dArray[0].length == 0) {
            throw new Exception("Computing MI with a null set of data");
        }
        if (dArray[0].length != this.dimensions) {
            throw new Exception("The continuous observations do not have the expected number of variables (" + this.dimensions + ")");
        }
        this.continuousData = dArray;
        this.discreteData = nArray;
        if (this.normalise) {
            this.means = MatrixUtils.means(dArray);
            this.stds = MatrixUtils.stdDevs(dArray, this.means);
            this.continuousData = MatrixUtils.normaliseIntoNewArray(dArray, this.means, this.stds);
        }
        this.counts = new int[this.base];
        for (n = 0; n < this.discreteData.length; ++n) {
            int n2 = this.discreteData[n];
            this.counts[n2] = this.counts[n2] + 1;
        }
        for (n = 0; n < this.counts.length; ++n) {
            if (this.counts[n] >= this.k) continue;
            throw new RuntimeException("This implementation assumes there are at least k items in each discrete bin");
        }
    }

    protected void computeNorms() {
        int n = this.continuousData.length;
        this.xNorms = new double[n][n];
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                this.xNorms[i][j] = j == i ? Double.POSITIVE_INFINITY : this.normCalculator.norm(this.continuousData[i], this.continuousData[j]);
            }
        }
    }

    public double computeAverageLocalOfObservations(int[] nArray) throws Exception {
        int n = this.continuousData.length;
        if (!tryKeepAllPairsNorms || n > MAX_DATA_SIZE_FOR_KEEP_ALL_PAIRS_NORM) {
            int[] nArray2 = this.discreteData;
            this.discreteData = MatrixUtils.extractSelectedTimePoints(this.discreteData, nArray);
            double d = this.computeAverageLocalOfObservationsWhileComputingDistances();
            this.discreteData = nArray2;
            return d;
        }
        int[] nArray3 = MatrixUtils.extractSelectedTimePoints(this.discreteData, nArray);
        if (this.xNorms == null) {
            this.computeNorms();
        }
        double d = 0.0;
        double d2 = 0.0;
        double d3 = 0.0;
        for (int i = 0; i < n; ++i) {
            int n2;
            double d4 = MatrixUtils.kthMinSubjectTo(this.xNorms[i], this.k, nArray3, nArray3[i]);
            int n3 = 0;
            for (n2 = 0; n2 < n; ++n2) {
                if (!(this.xNorms[i][n2] <= d4)) continue;
                ++n3;
            }
            n2 = this.counts[nArray3[i]] - 1;
            d2 += (double)n3;
            d3 += (double)n2;
            d += MathsUtils.digamma(n3) + MathsUtils.digamma(n2);
        }
        d /= (double)n;
        if (this.debug) {
            System.out.println(String.format("Average n_x=%.3f, Average n_y=%.3f", d2 /= (double)n, d3 /= (double)n));
        }
        this.mi = MathsUtils.digamma(this.k) - 1.0 / (double)this.k - d + MathsUtils.digamma(n);
        this.miComputed = true;
        return this.mi;
    }

    @Override
    public double computeAverageLocalOfObservations() throws Exception {
        if (!tryKeepAllPairsNorms || this.continuousData.length > MAX_DATA_SIZE_FOR_KEEP_ALL_PAIRS_NORM) {
            return this.computeAverageLocalOfObservationsWhileComputingDistances();
        }
        if (this.xNorms == null) {
            this.computeNorms();
        }
        int n = this.continuousData.length;
        double d = 0.0;
        double d2 = 0.0;
        double d3 = 0.0;
        double d4 = 0.0;
        for (int i = 0; i < n; ++i) {
            int n2;
            double d5 = MatrixUtils.kthMinSubjectTo(this.xNorms[i], this.k, this.discreteData, this.discreteData[i]);
            int n3 = 0;
            for (n2 = 0; n2 < n; ++n2) {
                if (n2 == i || !(this.xNorms[i][n2] <= d5)) continue;
                ++n3;
            }
            n2 = this.counts[this.discreteData[i]] - 1;
            d2 += (double)n3;
            d3 += (double)n2;
            double d6 = MathsUtils.digamma(n3) + MathsUtils.digamma(n2);
            d += d6;
            if (!this.debug) continue;
            double d7 = MathsUtils.digamma(this.k) - 1.0 / (double)this.k - d6 + MathsUtils.digamma(n);
            d4 += d7;
            if (this.dimensions == 1) {
                System.out.printf("t=%d: x=%.3f, eps_x=%.3f, n_x=%d, n_y=%d, local=%.3f, running total = %.5f\n", i, this.continuousData[i][0], d5, n3, n2, d7, d4);
                continue;
            }
            System.out.printf("t=%d: eps_x=%.3f, n_x=%d, n_y=%d, local=%.3f, running total = %.5f\n", i, d5, n3, n2, d7, d4);
        }
        d /= (double)n;
        if (this.debug) {
            System.out.println(String.format("Average n_x=%.3f (-> digam=%.3f %.3f), Average n_y=%.3f (-> digam=%.3f)", d2 /= (double)n, MathsUtils.digamma((int)d2), MathsUtils.digamma((int)d2 - 1), d3 /= (double)n, MathsUtils.digamma((int)d3)));
            System.out.printf("Independent average num in joint box is %.3f\n", d2 * d3 / (double)n);
            System.out.println(String.format("digamma(k)=%.3f - 1/k=%.3f - averageDiGammas=%.3f + digamma(N)=%.3f\n", MathsUtils.digamma(this.k), 1.0 / (double)this.k, d, MathsUtils.digamma(n)));
        }
        this.mi = MathsUtils.digamma(this.k) - 1.0 / (double)this.k - d + MathsUtils.digamma(n);
        this.miComputed = true;
        return this.mi;
    }

    public double computeAverageLocalOfObservationsWhileComputingDistances() throws Exception {
        int n = this.continuousData.length;
        double d = 0.0;
        double d2 = 0.0;
        double d3 = 0.0;
        double d4 = 0.0;
        for (int i = 0; i < n; ++i) {
            int n2;
            double[] dArray = new double[n];
            for (int j = 0; j < n; ++j) {
                dArray[j] = j == i ? Double.POSITIVE_INFINITY : this.normCalculator.norm(this.continuousData[i], this.continuousData[j]);
            }
            double d5 = MatrixUtils.kthMinSubjectTo(dArray, this.k, this.discreteData, this.discreteData[i]);
            int n3 = 0;
            for (n2 = 0; n2 < n; ++n2) {
                if (!(dArray[n2] <= d5)) continue;
                ++n3;
            }
            n2 = this.counts[this.discreteData[i]] - 1;
            d2 += (double)n3;
            d3 += (double)n2;
            double d6 = MathsUtils.digamma(n3) + MathsUtils.digamma(n2);
            d += d6;
            if (!this.debug) continue;
            double d7 = MathsUtils.digamma(this.k) - 1.0 / (double)this.k - d6 + MathsUtils.digamma(n);
            d4 += d7;
            if (this.dimensions == 1) {
                System.out.printf("t=%d: x=%.3f, eps_x=%.3f, n_x=%d, n_y=%d, local=%.3f, running total = %.5f\n", i, this.continuousData[i][0], d5, n3, n2, d7, d4);
                continue;
            }
            System.out.printf("t=%d: eps_x=%.3f, n_x=%d, n_y=%d, local=%.3f, running total = %.5f\n", i, d5, n3, n2, d7, d4);
        }
        d /= (double)n;
        if (this.debug) {
            System.out.println(String.format("Average n_x=%.3f, Average n_y=%.3f", d2 /= (double)n, d3 /= (double)n));
        }
        this.mi = MathsUtils.digamma(this.k) - 1.0 / (double)this.k - d + MathsUtils.digamma(n);
        this.miComputed = true;
        return this.mi;
    }

    @Override
    public synchronized EmpiricalMeasurementDistribution computeSignificance(int n) throws Exception {
        RandomGenerator randomGenerator = new RandomGenerator();
        int[][] nArray = randomGenerator.generateRandomPerturbations(this.continuousData.length, n);
        return this.computeSignificance(nArray);
    }

    @Override
    public EmpiricalMeasurementDistribution computeSignificance(int[][] nArray) throws Exception {
        int n = nArray.length;
        if (!this.miComputed) {
            this.computeAverageLocalOfObservations();
        }
        double d = this.mi;
        EmpiricalMeasurementDistribution empiricalMeasurementDistribution = new EmpiricalMeasurementDistribution(n);
        int n2 = 0;
        for (int i = 0; i < n; ++i) {
            double d2;
            empiricalMeasurementDistribution.distribution[i] = d2 = this.computeAverageLocalOfObservations(nArray[i]);
            if (this.debug) {
                System.out.println("New MI was " + d2);
            }
            if (!(d2 >= d)) continue;
            ++n2;
        }
        this.mi = d;
        empiricalMeasurementDistribution.pValue = (double)n2 / (double)n;
        empiricalMeasurementDistribution.actualValue = this.mi;
        return empiricalMeasurementDistribution;
    }

    @Override
    public double[] computeLocalUsingPreviousObservations(double[][] dArray, int[] nArray) throws Exception {
        if (this.normalise) {
            dArray = MatrixUtils.normaliseIntoNewArray(dArray, this.means, this.stds);
        }
        int n = dArray.length;
        double[] dArray2 = new double[n];
        double d = MathsUtils.digamma(this.k) - 1.0 / (double)this.k + MathsUtils.digamma(n);
        double d2 = 0.0;
        if (this.debug) {
            System.out.printf("digamma(k)=%.3f - 1/k=%.3f + digamma(N)=%.3f\n", MathsUtils.digamma(this.k), 1.0 / (double)this.k, MathsUtils.digamma(n));
        }
        double d3 = 0.0;
        double d4 = 0.0;
        for (int i = 0; i < n; ++i) {
            int n2;
            double[] dArray3 = new double[this.continuousData.length];
            for (int j = 0; j < this.continuousData.length; ++j) {
                dArray3[j] = this.normCalculator.norm(dArray[i], this.continuousData[j]);
            }
            double d5 = MatrixUtils.kthMinSubjectTo(dArray3, this.k, this.discreteData, nArray[i]);
            int n3 = 0;
            for (n2 = 0; n2 < this.continuousData.length; ++n2) {
                if (!(dArray3[n2] <= d5)) continue;
                ++n3;
            }
            n2 = this.counts[this.discreteData[i]];
            d3 += (double)n3;
            d4 += (double)n2;
            dArray2[i] = d - MathsUtils.digamma(n3) - MathsUtils.digamma(n2);
            if (!this.debug) continue;
            d2 += dArray2[i];
            if (this.dimensions == 1) {
                System.out.printf("t=%d: x=%.3f, eps_x=%.3f, n_x=%d, n_y=%d, local=%.3f, running total = %.5f\n", i, dArray[i][0], d5, n3, n2, dArray2[i], d2);
                continue;
            }
            System.out.printf("t=%d: eps_x=%.3f, n_x=%d, n_y=%d, local=%.3f, running total = %.5f\n", i, d5, n3, n2, dArray2[i], d2);
        }
        if (this.debug) {
            System.out.printf("Average n_x=%.3f, Average n_y=%.3f\n", d3 /= (double)n, d4 /= (double)n);
        }
        return dArray2;
    }

    @Override
    public void setDebug(boolean bl) {
        this.debug = bl;
    }

    @Override
    public double getLastAverage() {
        return this.mi;
    }

    @Override
    public int getNumObservations() {
        return this.continuousData.length;
    }
}

