package com.aliasi.test.unit.classify;

import com.aliasi.classify.ConfusionMatrix;
import com.aliasi.classify.PrecisionRecallEvaluation;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:com/aliasi/test/unit/classify/ConfusionMatrixTest.class */
public class ConfusionMatrixTest {
    private static String[] BINARY_CATS = {"0", "1"};
    private static String[] WINE_CATS = {"Cab", "Syr", "Pin"};

    @Test(expected = IllegalArgumentException.class)
    public void testExcIncrement() {
        new ConfusionMatrix(WINE_CATS).incrementByN(0, 0, -1);
    }

    @Test
    public void testInit() {
        ConfusionMatrix confusionMatrix = new ConfusionMatrix(BINARY_CATS);
        Assert.assertArrayEquals(BINARY_CATS, confusionMatrix.categories());
        int[][] matrix = confusionMatrix.matrix();
        junit.framework.Assert.assertEquals(2, matrix.length);
        junit.framework.Assert.assertEquals(2, matrix[0].length);
        for (int i = 0; i < 2; i++) {
            for (int i2 = 0; i2 < 2; i2++) {
                junit.framework.Assert.assertEquals(0, matrix[i][i2]);
            }
        }
        junit.framework.Assert.assertEquals(0, confusionMatrix.getIndex("0"));
        junit.framework.Assert.assertEquals(1, confusionMatrix.getIndex("1"));
        junit.framework.Assert.assertEquals(-1, confusionMatrix.getIndex("2"));
        junit.framework.Assert.assertEquals(1, confusionMatrix.chiSquaredDegreesOfFreedom());
        junit.framework.Assert.assertEquals(2, confusionMatrix.numCategories());
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Test
    public void testIncrement() {
        ConfusionMatrix confusionMatrix = new ConfusionMatrix(WINE_CATS);
        confusionMatrix.increment("Cab", "Cab");
        junit.framework.Assert.assertEquals(1, confusionMatrix.count(0, 0));
        confusionMatrix.increment(0, 0);
        junit.framework.Assert.assertEquals(2, confusionMatrix.count(0, 0));
        confusionMatrix.increment(1, 2);
        confusionMatrix.increment("Syr", "Pin");
        junit.framework.Assert.assertEquals(2, confusionMatrix.count(1, 2));
        junit.framework.Assert.assertEquals(0, confusionMatrix.count(2, 1));
        int[] iArr = {new int[]{2}, new int[]{0, 0, 2}, new int[3]};
        int[][] matrix = confusionMatrix.matrix();
        junit.framework.Assert.assertEquals(iArr.length, matrix.length);
        for (int i = 0; i < matrix.length; i++) {
            Assert.assertArrayEquals(iArr[i], matrix[i]);
        }
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [int[], int[][]] */
    @Test
    public void testStats() {
        ConfusionMatrix confusionMatrix = new ConfusionMatrix(WINE_CATS, new int[]{new int[]{9, 3}, new int[]{3, 5, 1}, new int[]{1, 1, 4}});
        junit.framework.Assert.assertEquals(9, confusionMatrix.count(0, 0));
        junit.framework.Assert.assertEquals(3, confusionMatrix.count(0, 1));
        junit.framework.Assert.assertEquals(1, confusionMatrix.count(2, 0));
        junit.framework.Assert.assertEquals(27, confusionMatrix.totalCount());
        junit.framework.Assert.assertEquals(18, confusionMatrix.totalCorrect());
        junit.framework.Assert.assertEquals(0.6667d, confusionMatrix.totalAccuracy(), 0.005d);
        junit.framework.Assert.assertEquals(0.1778d, confusionMatrix.confidence95(), 0.005d);
        junit.framework.Assert.assertEquals(0.2341d, confusionMatrix.confidence99(), 0.005d);
        junit.framework.Assert.assertEquals(0.3663d, confusionMatrix.randomAccuracy(), 0.005d);
        junit.framework.Assert.assertEquals(0.3669d, confusionMatrix.randomAccuracyUnbiased(), 0.005d);
        junit.framework.Assert.assertEquals(0.474d, confusionMatrix.kappa(), 0.005d);
        junit.framework.Assert.assertEquals(0.4735d, confusionMatrix.kappaUnbiased(), 0.005d);
        junit.framework.Assert.assertEquals(0.3333d, confusionMatrix.kappaNoPrevalence(), 0.005d);
        junit.framework.Assert.assertEquals(1.5305d, confusionMatrix.referenceEntropy(), 0.005d);
        junit.framework.Assert.assertEquals(1.4865d, confusionMatrix.responseEntropy(), 0.005d);
        junit.framework.Assert.assertEquals(1.5376d, confusionMatrix.crossEntropy(), 0.005d);
        PrecisionRecallEvaluation oneVsAll = confusionMatrix.oneVsAll(0);
        PrecisionRecallEvaluation oneVsAll2 = confusionMatrix.oneVsAll(1);
        PrecisionRecallEvaluation oneVsAll3 = confusionMatrix.oneVsAll(2);
        junit.framework.Assert.assertEquals(12L, oneVsAll.positiveReference());
        junit.framework.Assert.assertEquals(9L, oneVsAll2.positiveReference());
        junit.framework.Assert.assertEquals(6L, oneVsAll3.positiveReference());
        junit.framework.Assert.assertEquals(0.4414d, oneVsAll.referenceLikelihood(), 0.005d);
        junit.framework.Assert.assertEquals(0.3333d, oneVsAll2.referenceLikelihood(), 0.005d);
        junit.framework.Assert.assertEquals(0.2222d, oneVsAll3.referenceLikelihood(), 0.005d);
        junit.framework.Assert.assertEquals(13L, oneVsAll.positiveResponse());
        junit.framework.Assert.assertEquals(9L, oneVsAll2.positiveResponse());
        junit.framework.Assert.assertEquals(5L, oneVsAll3.positiveResponse());
        junit.framework.Assert.assertEquals(0.4815d, oneVsAll.responseLikelihood(), 0.005d);
        junit.framework.Assert.assertEquals(0.3333d, oneVsAll2.responseLikelihood(), 0.005d);
        junit.framework.Assert.assertEquals(0.1852d, oneVsAll3.responseLikelihood(), 0.005d);
        junit.framework.Assert.assertEquals(0.6923d, oneVsAll.precision(), 0.005d);
        junit.framework.Assert.assertEquals(0.5555d, oneVsAll2.precision(), 0.005d);
        junit.framework.Assert.assertEquals(0.8d, oneVsAll3.precision(), 0.005d);
        junit.framework.Assert.assertEquals(0.75d, oneVsAll.recall(), 0.005d);
        junit.framework.Assert.assertEquals(0.5555d, oneVsAll2.recall(), 0.005d);
        junit.framework.Assert.assertEquals(0.6666d, oneVsAll3.recall(), 0.005d);
        junit.framework.Assert.assertEquals(0.72d, oneVsAll.fMeasure(), 0.005d);
        junit.framework.Assert.assertEquals(0.5555d, oneVsAll2.fMeasure(), 0.005d);
        junit.framework.Assert.assertEquals(0.7273d, oneVsAll3.fMeasure(), 0.005d);
        junit.framework.Assert.assertEquals(0.7333d, oneVsAll.rejectionRecall(), 0.005d);
        junit.framework.Assert.assertEquals(0.7778d, oneVsAll2.rejectionRecall(), 1.0E-4d);
        junit.framework.Assert.assertEquals(0.9524d, oneVsAll3.rejectionRecall(), 1.0E-4d);
        junit.framework.Assert.assertEquals(0.7857d, oneVsAll.rejectionPrecision(), 1.0E-4d);
        junit.framework.Assert.assertEquals(0.7778d, oneVsAll2.rejectionPrecision(), 1.0E-4d);
        junit.framework.Assert.assertEquals(0.9091d, oneVsAll3.rejectionPrecision(), 1.0E-4d);
        junit.framework.Assert.assertEquals(0.5625d, oneVsAll.jaccardCoefficient(), 1.0E-4d);
        junit.framework.Assert.assertEquals(0.3846d, oneVsAll2.jaccardCoefficient(), 1.0E-4d);
        junit.framework.Assert.assertEquals(0.5714d, oneVsAll3.jaccardCoefficient(), 1.0E-4d);
        junit.framework.Assert.assertEquals(0.7407d, oneVsAll.accuracy(), 1.0E-4d);
        junit.framework.Assert.assertEquals(0.7037d, oneVsAll2.accuracy(), 1.0E-4d);
        junit.framework.Assert.assertEquals(0.8889d, oneVsAll3.accuracy(), 1.0E-4d);
        junit.framework.Assert.assertEquals(4, confusionMatrix.chiSquaredDegreesOfFreedom());
        junit.framework.Assert.assertEquals(3, confusionMatrix.numCategories());
        junit.framework.Assert.assertEquals(15.5256d, confusionMatrix.chiSquared(), 0.005d);
        junit.framework.Assert.assertEquals(6.2382d, oneVsAll.chiSquared(), 0.005d);
        junit.framework.Assert.assertEquals(3.0d, oneVsAll2.chiSquared(), 0.005d);
        junit.framework.Assert.assertEquals(11.8519d, oneVsAll3.chiSquared(), 0.005d);
        junit.framework.Assert.assertEquals(0.6826d, confusionMatrix.macroAvgPrecision(), 0.005d);
        junit.framework.Assert.assertEquals(0.6574d, confusionMatrix.macroAvgRecall(), 0.005d);
        junit.framework.Assert.assertEquals(0.6676d, confusionMatrix.macroAvgFMeasure(), 0.005d);
        PrecisionRecallEvaluation microAverage = confusionMatrix.microAverage();
        junit.framework.Assert.assertEquals(0.6666d, microAverage.precision(), 0.005d);
        junit.framework.Assert.assertEquals(0.6666d, microAverage.recall(), 0.005d);
        junit.framework.Assert.assertEquals(0.6666d, microAverage.fMeasure(), 0.005d);
        junit.framework.Assert.assertEquals(2.6197d, confusionMatrix.jointEntropy(), 0.005d);
        junit.framework.Assert.assertEquals(0.8113d, confusionMatrix.conditionalEntropy(0), 0.005d);
        junit.framework.Assert.assertEquals(1.3516d, confusionMatrix.conditionalEntropy(1), 0.005d);
        junit.framework.Assert.assertEquals(1.2516d, confusionMatrix.conditionalEntropy(2), 0.005d);
        junit.framework.Assert.assertEquals(1.0892d, confusionMatrix.conditionalEntropy(), 0.005d);
        junit.framework.Assert.assertEquals(0.575d, confusionMatrix.phiSquared(), 0.005d);
        junit.framework.Assert.assertEquals(0.5362d, confusionMatrix.cramersV(), 0.005d);
        junit.framework.Assert.assertEquals(0.7838d, oneVsAll.yulesQ(), 0.005d);
        junit.framework.Assert.assertEquals(0.6279d, oneVsAll2.yulesQ(), 0.005d);
        junit.framework.Assert.assertEquals(0.9512d, oneVsAll3.yulesQ(), 0.005d);
        junit.framework.Assert.assertEquals(0.4835d, oneVsAll.yulesY(), 0.005d);
        junit.framework.Assert.assertEquals(0.3531d, oneVsAll2.yulesY(), 0.005d);
        junit.framework.Assert.assertEquals(0.7269d, oneVsAll3.yulesY(), 0.005d);
        junit.framework.Assert.assertEquals(12.49d, oneVsAll.fowlkesMallows(), 0.05d);
        junit.framework.Assert.assertEquals(9.0d, oneVsAll2.fowlkesMallows(), 0.05d);
        junit.framework.Assert.assertEquals(5.48d, oneVsAll3.fowlkesMallows(), 0.05d);
        junit.framework.Assert.assertEquals(0.4d, confusionMatrix.lambdaA(), 0.005d);
        junit.framework.Assert.assertEquals(0.3571d, confusionMatrix.lambdaB(), 0.005d);
        junit.framework.Assert.assertEquals(confusionMatrix.responseEntropy() - confusionMatrix.conditionalEntropy(), confusionMatrix.mutualInformation(), 0.005d);
        junit.framework.Assert.assertEquals(0.007129d, confusionMatrix.klDivergence(), 5.0E-5d);
    }
}
