package com.aliasi.test.unit.stats;

import com.aliasi.stats.RegressionPrior;
import com.aliasi.test.unit.Asserts;
import com.aliasi.util.AbstractExternalizable;
import java.io.IOException;
import junit.framework.Assert;
import org.junit.Test;

/* loaded from: input_file:com/aliasi/test/unit/stats/RegressionPriorTest.class */
public class RegressionPriorTest {
    @Test
    public void testMeans() {
        RegressionPrior gaussian = RegressionPrior.gaussian(1.0d, true);
        Assert.assertEquals(Double.valueOf(0.0d), Double.valueOf(gaussian.mode(0)));
        Assert.assertEquals(Double.valueOf(0.0d), Double.valueOf(gaussian.mode(1)));
        RegressionPrior shiftMeans = RegressionPrior.shiftMeans(new double[]{1.0d, 2.0d, -3.0d}, gaussian);
        Assert.assertEquals(Double.valueOf(1.0d), Double.valueOf(shiftMeans.mode(0)));
        Assert.assertEquals(Double.valueOf(2.0d), Double.valueOf(shiftMeans.mode(1)));
        Assert.assertEquals(Double.valueOf(-3.0d), Double.valueOf(shiftMeans.mode(2)));
        Assert.assertEquals(0.0d, shiftMeans.gradient(1.0d, 0), 1.0E-4d);
        Assert.assertEquals(0.0d, shiftMeans.gradient(2.0d, 1), 1.0E-4d);
        Assert.assertEquals(0.0d, shiftMeans.gradient(-3.0d, 2), 1.0E-4d);
        RegressionPrior shiftMeans2 = RegressionPrior.shiftMeans(new double[]{2.0d, 1.0d, 3.0d}, shiftMeans);
        Assert.assertEquals(Double.valueOf(3.0d), Double.valueOf(shiftMeans2.mode(0)));
        Assert.assertEquals(Double.valueOf(3.0d), Double.valueOf(shiftMeans2.mode(1)));
        Assert.assertEquals(Double.valueOf(0.0d), Double.valueOf(shiftMeans2.mode(2)));
        Assert.assertEquals(0.0d, shiftMeans2.gradient(3.0d, 0), 1.0E-4d);
        Assert.assertEquals(0.0d, shiftMeans2.gradient(3.0d, 1), 1.0E-4d);
        Assert.assertEquals(0.0d, shiftMeans2.gradient(0.0d, 2), 1.0E-4d);
    }

    @Test
    public void testElasticNet() {
        RegressionPrior elasticNet = RegressionPrior.elasticNet(0.3d, 2.0d, true);
        RegressionPrior laplace = RegressionPrior.laplace(1.0d / Math.sqrt(2.0d), true);
        RegressionPrior gaussian = RegressionPrior.gaussian(Math.sqrt(2.0d) / 2.0d, true);
        for (int i = -5; i < 5; i++) {
            Assert.assertEquals((0.3d * laplace.log2Prior(i, 2)) + (0.7d * gaussian.log2Prior(i, 2)), elasticNet.log2Prior(i, 2), 1.0E-4d);
            Assert.assertEquals((0.3d * laplace.log2Prior(i, 0)) + (0.7d * gaussian.log2Prior(i, 0)), elasticNet.log2Prior(i, 0), 1.0E-4d);
            Assert.assertEquals((0.3d * laplace.gradient(i, 1)) + (0.7d * gaussian.gradient(i, 1)), elasticNet.gradient(i, 1), 1.0E-4d);
            Assert.assertEquals((0.3d * laplace.gradient(i, 0)) + (0.7d * gaussian.gradient(i, 0)), elasticNet.gradient(i, 0), 1.0E-4d);
        }
        RegressionPrior.elasticNet(0.3d, 2.0d, false);
        for (int i2 = -5; i2 < 5; i2++) {
            Assert.assertEquals(Double.valueOf((0.3d * laplace.log2Prior(i2, 2)) + (0.7d * gaussian.log2Prior(i2, 2))), Double.valueOf(elasticNet.log2Prior(i2, 2)));
            Assert.assertEquals(0.0d, elasticNet.log2Prior(i2, 0), 1.0E-4d);
            Assert.assertEquals((0.3d * laplace.gradient(i2, 1)) + (0.7d * gaussian.gradient(i2, 1)), elasticNet.gradient(i2, 1), 1.0E-4d);
            Assert.assertEquals(0.0d, elasticNet.gradient(5.0d, 0), 1.0E-4d);
        }
    }

    @Test
    public void testMeanOffsets() {
        RegressionPrior gaussian = RegressionPrior.gaussian(1.0d, false);
        RegressionPrior shiftMeans = RegressionPrior.shiftMeans(new double[]{1.0d, -2.0d, 3.0d}, gaussian);
        Assert.assertEquals(Double.valueOf(gaussian.log2Prior(0.0d, 0)), Double.valueOf(shiftMeans.log2Prior(1.0d, 0)));
        Assert.assertEquals(Double.valueOf(gaussian.log2Prior(1.0d, 0)), Double.valueOf(shiftMeans.log2Prior(2.0d, 0)));
        Assert.assertEquals(Double.valueOf(gaussian.log2Prior(-1.0d, 0)), Double.valueOf(shiftMeans.log2Prior(0.0d, 0)));
        Assert.assertEquals(Double.valueOf(gaussian.gradient(0.0d, 0)), Double.valueOf(shiftMeans.gradient(1.0d, 0)));
        Assert.assertEquals(Double.valueOf(gaussian.gradient(1.0d, 0)), Double.valueOf(shiftMeans.gradient(2.0d, 0)));
        Assert.assertEquals(Double.valueOf(gaussian.gradient(-2.0d, 0)), Double.valueOf(shiftMeans.gradient(-1.0d, 0)));
        Assert.assertEquals(Double.valueOf(gaussian.log2Prior(3.0d, 1)), Double.valueOf(shiftMeans.log2Prior(1.0d, 1)));
        Assert.assertEquals(Double.valueOf(gaussian.log2Prior(7.0d, 2)), Double.valueOf(shiftMeans.log2Prior(10.0d, 2)));
        Assert.assertEquals(Double.valueOf(gaussian.gradient(7.0d, 2)), Double.valueOf(shiftMeans.gradient(10.0d, 2)));
    }

    @Test(expected = IllegalArgumentException.class)
    public void testElasticNetEx1() {
        RegressionPrior.elasticNet(-1.0d, 2.0d, true);
    }

    @Test(expected = IllegalArgumentException.class)
    public void testElasticNetEx2() {
        RegressionPrior.elasticNet(Double.NaN, 2.0d, true);
    }

    @Test(expected = IllegalArgumentException.class)
    public void testElasticNetEx3() {
        RegressionPrior.elasticNet(Double.POSITIVE_INFINITY, 2.0d, true);
    }

    @Test(expected = IllegalArgumentException.class)
    public void testElasticNetEx4() {
        RegressionPrior.elasticNet(0.5d, -1.0d, true);
    }

    @Test(expected = IllegalArgumentException.class)
    public void testElasticNetEx5() {
        RegressionPrior.elasticNet(0.5d, Double.NaN, true);
    }

    @Test(expected = IllegalArgumentException.class)
    public void testElasticNetEx6() {
        RegressionPrior.elasticNet(0.5d, Double.POSITIVE_INFINITY, true);
    }

    @Test(expected = IllegalArgumentException.class)
    public void testElasticNetEx7() {
        RegressionPrior.elasticNet(0.5d, 0.0d, true);
    }

    @Test
    public void testSerialization() throws IOException, ClassNotFoundException {
        double[] dArr = {1.0d, 2.0d, 3.0d};
        assertSerialization(RegressionPrior.shiftMeans(new double[]{1.0d, -2.0d, 3.0d}, RegressionPrior.gaussian(1.0d, false)), 3);
        assertSerialization(RegressionPrior.elasticNet(0.95d, 2.0d, false), -1);
        assertSerialization(RegressionPrior.cauchy(dArr), 3);
        assertSerialization(RegressionPrior.cauchy(1.0d, true), -1);
        assertSerialization(RegressionPrior.cauchy(1.0d, false), -1);
        assertSerialization(RegressionPrior.gaussian(dArr), 3);
        assertSerialization(RegressionPrior.gaussian(1.0d, true), -1);
        assertSerialization(RegressionPrior.gaussian(1.0d, false), -1);
        assertSerialization(RegressionPrior.laplace(dArr), 3);
        assertSerialization(RegressionPrior.laplace(1.0d, true), -1);
        assertSerialization(RegressionPrior.laplace(1.0d, false), -1);
        assertSerialization(RegressionPrior.noninformative(), -1);
    }

    void assertSerialization(RegressionPrior regressionPrior, int i) throws IOException, ClassNotFoundException {
        RegressionPrior regressionPrior2 = (RegressionPrior) AbstractExternalizable.serializeDeserialize(regressionPrior);
        int i2 = 0;
        while (true) {
            if (i2 < i || (i == -1 && i2 < 10)) {
                Assert.assertEquals(regressionPrior.log2Prior(2.0d, i2), regressionPrior2.log2Prior(2.0d, i2), 1.0E-5d);
                Assert.assertEquals(regressionPrior.log2Prior(-1.0d, i2), regressionPrior2.log2Prior(-1.0d, i2), 1.0E-5d);
                Assert.assertEquals(regressionPrior.gradient(5.0d, i2), regressionPrior2.gradient(5.0d, i2), 1.0E-5d);
                Assert.assertEquals(regressionPrior.gradient(-2.0d, i2), regressionPrior2.gradient(-2.0d, i2), 1.0E-5d);
                i2++;
            }
        }
        if (i > 0) {
            try {
                regressionPrior.gradient(2.0d, i + 1);
                Assert.fail();
            } catch (ArrayIndexOutOfBoundsException e) {
                Asserts.succeed();
            }
        }
    }
}
