package com.aliasi.test.unit.stats;

import com.aliasi.stats.OnlineNormalEstimator;
import com.aliasi.util.Math;
import java.util.Random;
import junit.framework.Assert;
import org.junit.Test;

/* loaded from: input_file:com/aliasi/test/unit/stats/OnlineNormalEstimatorTest.class */
public class OnlineNormalEstimatorTest {
    public void testBadUnHandle1() {
        new OnlineNormalEstimator().unHandle(2.0d);
    }

    @Test(expected = IllegalStateException.class)
    public void testBadUnhandle2() {
        OnlineNormalEstimator onlineNormalEstimator = new OnlineNormalEstimator();
        onlineNormalEstimator.handle(2.0d);
        onlineNormalEstimator.unHandle(2.0d);
        onlineNormalEstimator.unHandle(2.0d);
    }

    public void testUnhandle() {
        OnlineNormalEstimator onlineNormalEstimator = new OnlineNormalEstimator();
        onlineNormalEstimator.handle(1.0d);
        Assert.assertEquals(1L, onlineNormalEstimator.numSamples());
        onlineNormalEstimator.unHandle(1.0d);
        Assert.assertEquals(0L, onlineNormalEstimator.numSamples());
        onlineNormalEstimator.handle(2.0d);
        Assert.assertEquals(1L, onlineNormalEstimator.numSamples());
        Assert.assertEquals(2.0d, onlineNormalEstimator.mean(), 1.0E-4d);
        Assert.assertEquals(0.0d, onlineNormalEstimator.variance(), 1.0E-4d);
        onlineNormalEstimator.handle(1.0d);
        Assert.assertEquals(2L, onlineNormalEstimator.numSamples());
        Assert.assertEquals(1.5d, onlineNormalEstimator.mean(), 1.0E-4d);
        Assert.assertEquals(0.25d, onlineNormalEstimator.variance(), 1.0E-4d);
        onlineNormalEstimator.unHandle(2.0d);
        Assert.assertEquals(1L, onlineNormalEstimator.numSamples());
        Assert.assertEquals(1.0d, onlineNormalEstimator.mean(), 1.0E-4d);
        Assert.assertEquals(0.0d, onlineNormalEstimator.variance(), 1.0E-4d);
        onlineNormalEstimator.handle(2.0d);
        onlineNormalEstimator.handle(3.0d);
        onlineNormalEstimator.unHandle(2.0d);
        Assert.assertEquals(2L, onlineNormalEstimator.numSamples());
        Assert.assertEquals(2.0d, onlineNormalEstimator.mean(), 1.0E-4d);
        Assert.assertEquals(1.0d, onlineNormalEstimator.variance(), 1.0E-4d);
    }

    @Test
    public void testNumSamples() {
        OnlineNormalEstimator estimator = estimator(new double[0]);
        Assert.assertEquals(0L, estimator.numSamples());
        estimator.handle(5.0d);
        Assert.assertEquals(1L, estimator.numSamples());
        estimator.handle(6.0d);
        Assert.assertEquals(2L, estimator.numSamples());
    }

    @Test
    public void testMean() {
        for (int i = 0; i < 10; i++) {
            double[] randomArray = randomArray(42L, 500);
            Assert.assertEquals(mean(randomArray), estimator(randomArray).mean(), 1.0E-4d);
        }
    }

    @Test
    public void testVariance() {
        for (int i = 0; i < 10; i++) {
            double[] randomArray = randomArray(42L, 500);
            Assert.assertEquals(variance(randomArray), estimator(randomArray).variance(), 1.0E-4d);
            Assert.assertEquals(Math.sqrt(variance(randomArray)), estimator(randomArray).standardDeviation(), 1.0E-4d);
            Assert.assertEquals(1.002004008016032d * variance(randomArray), estimator(randomArray).varianceUnbiased(), 1.0E-4d);
            Assert.assertEquals(Math.sqrt(1.002004008016032d * variance(randomArray)), estimator(randomArray).standardDeviationUnbiased(), 1.0E-4d);
        }
    }

    static double[] randomArray(long j, int i) {
        Random random = new Random(j);
        double[] dArr = new double[i];
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr[i2] = random.nextDouble();
        }
        return dArr;
    }

    static OnlineNormalEstimator estimator(double[] dArr) {
        OnlineNormalEstimator onlineNormalEstimator = new OnlineNormalEstimator();
        for (double d : dArr) {
            onlineNormalEstimator.handle(d);
        }
        return onlineNormalEstimator;
    }

    static double mean(double[] dArr) {
        return Math.sum(dArr) / dArr.length;
    }

    static double variance(double[] dArr) {
        return sumSquareDiffs(dArr, mean(dArr)) / dArr.length;
    }

    static double sumSquareDiffs(double[] dArr, double d) {
        double d2 = 0.0d;
        for (double d3 : dArr) {
            double d4 = d3 - d;
            d2 += d4 * d4;
        }
        return d2;
    }
}
