package com.aliasi.test.unit.stats;

import com.aliasi.matrix.DenseVector;
import com.aliasi.matrix.SparseFloatVector;
import com.aliasi.matrix.Vector;
import com.aliasi.stats.AnnealingSchedule;
import com.aliasi.stats.LogisticRegression;
import com.aliasi.stats.RegressionPrior;
import com.aliasi.util.AbstractExternalizable;
import java.io.IOException;
import junit.framework.Assert;
import org.junit.Test;

/* loaded from: input_file:com/aliasi/test/unit/stats/LogisticRegressionTest.class */
public class LogisticRegressionTest {
    static final int[] WALLET_OUTCOME_VECTOR = {1, 1, 2, 2, 0, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 1, 0, 1, 1, 2, 2, 2, 2, 1, 1, 0, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 0, 2, 2, 0, 2, 1, 0, 0, 2, 2, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 0, 0, 1, 0, 1, 0, 1, 0, 2, 2, 1, 2, 0, 2, 1, 2, 2, 1, 2, 2, 0, 1, 1, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 2, 1, 2, 1, 2, 2, 0, 2, 2, 2, 2, 1, 2, 1, 2, 1, 2, 2, 2, 2, 1, 2, 2, 1, 2, 2, 1, 2, 1, 2, 0, 2, 1, 0, 1, 2, 1, 2, 1, 1, 0, 1, 1, 0, 1, 1, 2, 2, 1, 0, 1, 2, 1, 2, 0, 1, 2, 1, 2, 2, 2, 2, 2, 1};
    static final double[][] WALLET_DATA_MATRIX = {new double[]{1.0d, 0.0d, 0.0d, 2.0d, 0.0d}, new double[]{1.0d, 0.0d, 0.0d, 2.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 2.0d, 0.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 2.0d, 1.0d}, new double[]{1.0d, 0.0d, 1.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 1.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 0.0d}, new double[]{1.0d, 0.0d, 0.0d, 2.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 3.0d, 0.0d}, new double[]{1.0d, 1.0d, 1.0d, 3.0d, 0.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 2.0d, 0.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 1.0d, 1.0d, 0.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 0.0d}, new double[]{1.0d, 1.0d, 0.0d, 3.0d, 0.0d}, new double[]{1.0d, 1.0d, 0.0d, 2.0d, 0.0d}, new double[]{1.0d, 1.0d, 0.0d, 2.0d, 0.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 2.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 2.0d, 0.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 0.0d}, new double[]{1.0d, 1.0d, 1.0d, 2.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 2.0d, 0.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 2.0d, 0.0d}, new double[]{1.0d, 1.0d, 1.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 0.0d}, new double[]{1.0d, 0.0d, 1.0d, 3.0d, 0.0d}, new double[]{1.0d, 1.0d, 0.0d, 2.0d, 0.0d}, new double[]{1.0d, 0.0d, 0.0d, 2.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 0.0d}, new double[]{1.0d, 1.0d, 1.0d, 1.0d, 0.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 0.0d}, new double[]{1.0d, 1.0d, 1.0d, 3.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 3.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 0.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 1.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 0.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 0.0d}, new double[]{1.0d, 1.0d, 0.0d, 3.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 3.0d, 1.0d}, new double[]{1.0d, 1.0d, 1.0d, 2.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 2.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 3.0d, 0.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 0.0d}, new double[]{1.0d, 1.0d, 1.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 1.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 2.0d, 1.0d}, new double[]{1.0d, 1.0d, 1.0d, 1.0d, 0.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 1.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 1.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 1.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 2.0d, 1.0d}, new double[]{1.0d, 1.0d, 1.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 2.0d, 0.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 1.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 1.0d, 3.0d, 0.0d}, new double[]{1.0d, 1.0d, 1.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 1.0d, 3.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 3.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 1.0d, 3.0d, 0.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 0.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 3.0d, 0.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 1.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 3.0d, 0.0d}, new double[]{1.0d, 0.0d, 1.0d, 2.0d, 0.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 1.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 0.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 1.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 1.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 0.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 1.0d, 2.0d, 0.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 0.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 0.0d}, new double[]{1.0d, 0.0d, 0.0d, 2.0d, 1.0d}, new double[]{1.0d, 1.0d, 1.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 3.0d, 0.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 1.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 0.0d}, new double[]{1.0d, 0.0d, 1.0d, 1.0d, 0.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 0.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 0.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 2.0d, 0.0d}, new double[]{1.0d, 1.0d, 1.0d, 2.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 0.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 1.0d, 1.0d, 0.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 0.0d}, new double[]{1.0d, 0.0d, 1.0d, 2.0d, 1.0d}, new double[]{1.0d, 1.0d, 1.0d, 2.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 1.0d, 3.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 0.0d}, new double[]{1.0d, 1.0d, 0.0d, 3.0d, 0.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 1.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 2.0d, 0.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 1.0d, 2.0d, 0.0d}, new double[]{1.0d, 1.0d, 0.0d, 2.0d, 0.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 1.0d, 3.0d, 0.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 2.0d, 0.0d}, new double[]{1.0d, 0.0d, 1.0d, 2.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 2.0d, 0.0d}, new double[]{1.0d, 1.0d, 1.0d, 1.0d, 1.0d}, new double[]{1.0d, 0.0d, 1.0d, 2.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 3.0d, 0.0d}, new double[]{1.0d, 1.0d, 1.0d, 1.0d, 0.0d}, new double[]{1.0d, 0.0d, 0.0d, 3.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 2.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 3.0d, 1.0d}, new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 1.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d, 1.0d, 1.0d}};
    static final double[][] WALLET_EXPECTED_FEATURES = {new double[]{-3.4712d, 1.2673d, 1.1804d, 1.0817d, -1.6006d}, new double[]{-1.2917d, 1.1699d, 0.4179d, 0.1957d, -0.804d}, new double[]{0.0d, 0.0d, 0.0d, 0.0d, 0.0d}};

    @Test
    public void testClass() {
        LogisticRegression logisticRegression = new LogisticRegression(new Vector[]{new DenseVector(new double[]{1.0d, 2.0d, 3.0d}), new DenseVector(new double[]{-2.0d, 1.0d, -1.0d})});
        DenseVector denseVector = new DenseVector(new double[]{1.0d, -1.0d, 2.0d});
        double exp = Math.exp(5.0d);
        double exp2 = Math.exp(-5.0d);
        double exp3 = Math.exp(0.0d);
        Assert.assertEquals(1.0d, exp3, 1.0E-4d);
        double[] dArr = {exp / ((exp + exp2) + exp3), exp2 / ((exp + exp2) + exp3), exp3 / ((exp + exp2) + exp3)};
        double[] classify = logisticRegression.classify(denseVector);
        Assert.assertEquals(dArr.length, classify.length);
        for (int i = 0; i < dArr.length; i++) {
            Assert.assertEquals(dArr[i], classify[i], 1.0E-7d);
        }
    }

    static Vector[] sparseCopy(Vector[] vectorArr) {
        Vector[] vectorArr2 = new Vector[vectorArr.length];
        for (int i = 0; i < vectorArr.length; i++) {
            vectorArr2[i] = sparseCopy(vectorArr[i]);
        }
        return vectorArr2;
    }

    static Vector sparseCopy(Vector vector) {
        int[] iArr = new int[vector.numDimensions()];
        float[] fArr = new float[vector.numDimensions()];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = i;
            fArr[i] = (float) vector.value(i);
        }
        return new SparseFloatVector(iArr, fArr, vector.numDimensions());
    }

    @Test
    public void testEstimation() throws IOException, ClassNotFoundException {
        Vector[] vectorArr = new Vector[WALLET_DATA_MATRIX.length];
        for (int i = 0; i < vectorArr.length; i++) {
            vectorArr[i] = new DenseVector(WALLET_DATA_MATRIX[i]);
        }
        Vector[] sparseCopy = sparseCopy(vectorArr);
        assertCorrectRegression(vectorArr);
        assertCorrectRegression(sparseCopy);
    }

    void assertCorrectRegression(Vector[] vectorArr) throws IOException, ClassNotFoundException {
        LogisticRegression estimate = LogisticRegression.estimate(vectorArr, WALLET_OUTCOME_VECTOR, RegressionPrior.noninformative(), 3, null, AnnealingSchedule.inverse(0.05d, 100.0d), 1.0E-5d, 5, 10, 500000, null, null);
        Vector[] weightVectors = estimate.weightVectors();
        for (int i = 0; i < weightVectors.length; i++) {
            for (int i2 = 0; i2 < weightVectors[i].numDimensions(); i2++) {
                Assert.assertEquals(WALLET_EXPECTED_FEATURES[i][i2], weightVectors[i].value(i2), 0.1d);
            }
        }
        LogisticRegression logisticRegression = (LogisticRegression) AbstractExternalizable.compile(estimate);
        Assert.assertEquals(estimate.numOutcomes(), logisticRegression.numOutcomes());
        Assert.assertEquals(estimate.numInputDimensions(), estimate.numInputDimensions());
        Vector[] weightVectors2 = estimate.weightVectors();
        Vector[] weightVectors3 = logisticRegression.weightVectors();
        Assert.assertEquals(weightVectors2.length, weightVectors3.length);
        Assert.assertEquals(weightVectors2.length, weightVectors3.length);
        for (int i3 = 0; i3 < weightVectors2.length; i3++) {
            Assert.assertEquals(weightVectors2[i3], weightVectors3[i3]);
        }
        Vector[] weightVectors4 = LogisticRegression.estimate(vectorArr, WALLET_OUTCOME_VECTOR, RegressionPrior.noninformative(), 2, estimate, AnnealingSchedule.inverse(0.05d, 100.0d), 1.0E-7d, 5, 10, 500000, null, null).weightVectors();
        for (int i4 = 0; i4 < weightVectors4.length; i4++) {
            for (int i5 = 0; i5 < weightVectors4[i4].numDimensions(); i5++) {
                Assert.assertEquals(WALLET_EXPECTED_FEATURES[i4][i5], weightVectors4[i4].value(i5), 0.1d);
            }
        }
    }
}
