package com.aliasi.test.unit.lm;

import com.aliasi.lm.LanguageModel;
import com.aliasi.lm.NGramProcessLM;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.Math;
import com.aliasi.util.Strings;
import com.aliasi.xml.XHtmlWriter;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import junit.framework.Assert;
import org.junit.Test;

/* loaded from: input_file:com/aliasi/test/unit/lm/NGramProcessLMTest.class */
public class NGramProcessLMTest {
    static double lambdaFactor = 4.0d;
    static int alphabetSize = 255;
    static char[] ABRACADABRA = "abracadabra".toCharArray();
    static double count = 0.0d;
    static char[] A = XHtmlWriter.A.toCharArray();
    static double numOutcomesNull = 5.0d;
    static double aCount = 5.0d;
    static double numEventsNull = 11.0d;
    static double mlEstimateA = aCount / numEventsNull;
    static double uniformEstimate = 1.0d / alphabetSize;
    static double lambdaNull = numEventsNull / (numEventsNull + (lambdaFactor * numOutcomesNull));
    static double estimateA = (lambdaNull * mlEstimateA) + ((1.0d - lambdaNull) * uniformEstimate);
    static char[] B = XHtmlWriter.B.toCharArray();
    static double bCount = 2.0d;
    static double mlEstimateB = bCount / numEventsNull;
    static double estimateB = (lambdaNull * mlEstimateB) + ((1.0d - lambdaNull) * uniformEstimate);
    static char[] AB = "ab".toCharArray();
    static double aContextCount = 4.0d;
    static double abCount = 2.0d;
    static double numOutcomesA = 3.0d;
    static double lambdaA = aContextCount / (aContextCount + (lambdaFactor * numOutcomesA));
    static double mlEstimateAB = abCount / aContextCount;
    static double estimateAB = (lambdaA * mlEstimateAB) + ((1.0d - lambdaA) * estimateB);
    static char[] DAB = "dab".toCharArray();
    static double daContextCount = 1.0d;
    static double dabCount = 1.0d;
    static double numOutcomesDA = 1.0d;
    static double lambdaDA = daContextCount / (daContextCount + (lambdaFactor * numOutcomesDA));
    static double mlEstimateDAB = 1.0d;
    static double estimateDAB = (lambdaDA * mlEstimateDAB) + ((1.0d - lambdaDA) * estimateAB);
    static char[] ZAB = "zab".toCharArray();
    static char[] XDAB = "xdab".toCharArray();

    @Test
    public void testExs() {
        try {
            new NGramProcessLM(3, 128).log2ConditionalEstimate(Strings.EMPTY_STRING);
            Assert.fail();
        } catch (IllegalArgumentException e) {
            Assert.assertTrue(true);
        }
    }

    @Test
    public void testOne() throws ClassNotFoundException, IOException {
        NGramProcessLM nGramProcessLM = new NGramProcessLM(3, alphabetSize, lambdaFactor);
        nGramProcessLM.train(ABRACADABRA, 0, ABRACADABRA.length);
        assertModel(nGramProcessLM);
    }

    @Test
    public void testA() throws ClassNotFoundException, IOException {
        NGramProcessLM nGramProcessLM = new NGramProcessLM(4, 128, 4.0d);
        nGramProcessLM.train(XHtmlWriter.A);
        Assert.assertEquals(Math.log2((0.2d * 1.0d) + (((1.0d - 0.2d) * 1.0d) / 128.0d)), nGramProcessLM.log2ConditionalEstimate(XHtmlWriter.A), 0.005d);
        nGramProcessLM.train(XHtmlWriter.A);
        Assert.assertEquals(Math.log2((0.3333333333333333d * 1.0d) + (((1.0d - 0.3333333333333333d) * 1.0d) / 128.0d)), nGramProcessLM.log2ConditionalEstimate(XHtmlWriter.A), 0.005d);
    }

    @Test
    public void testA_AB() {
        NGramProcessLM nGramProcessLM = new NGramProcessLM(4, 128, 4.0d);
        nGramProcessLM.train(XHtmlWriter.A);
        nGramProcessLM.train("ab");
        Assert.assertEquals(Math.log2(((0.2727272727272727d * 2.0d) / 3.0d) + (((1.0d - 0.2727272727272727d) * 1.0d) / 128.0d)), nGramProcessLM.log2ConditionalEstimate(XHtmlWriter.A), 5.0E-4d);
        double d = ((0.2727272727272727d * 1.0d) / 3.0d) + (((1.0d - 0.2727272727272727d) * 1.0d) / 128.0d);
        Assert.assertEquals(Math.log2(d), nGramProcessLM.log2ConditionalEstimate(XHtmlWriter.B), 5.0E-4d);
        Assert.assertEquals(Math.log2((0.2d * 1.0d) + ((1.0d - 0.2d) * d)), nGramProcessLM.log2ConditionalEstimate("ab"), 5.0E-4d);
    }

    public void assertModel(NGramProcessLM nGramProcessLM) throws IOException, ClassNotFoundException {
        assertConditionalLM(nGramProcessLM);
        try {
            assertConditionalLM((LanguageModel.Conditional) AbstractExternalizable.compile(nGramProcessLM));
        } catch (IOException e) {
            e.printStackTrace(System.err);
            Assert.fail(e.toString());
        }
        try {
            assertConditionalLM(readWrite(nGramProcessLM));
        } catch (IOException e2) {
            e2.printStackTrace(System.err);
            Assert.fail(e2.toString());
        }
        try {
            assertConditionalLM((LanguageModel.Conditional) AbstractExternalizable.serializeDeserialize(nGramProcessLM));
        } catch (IOException e3) {
            e3.printStackTrace(System.err);
            Assert.fail(e3.toString());
        }
    }

    public static NGramProcessLM readWrite(NGramProcessLM nGramProcessLM) throws IOException {
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        nGramProcessLM.writeTo(byteArrayOutputStream);
        return NGramProcessLM.readFrom(new ByteArrayInputStream(byteArrayOutputStream.toByteArray()));
    }

    public void assertConditionalLM(LanguageModel.Conditional conditional) throws IOException {
        Assert.assertEquals(Math.log2(estimateA), conditional.log2ConditionalEstimate(A, 0, 1), 5.0E-4d);
        Assert.assertEquals(Math.log2(estimateA), conditional.log2Estimate(A, 0, 1), 5.0E-4d);
        Assert.assertEquals(Math.log2(estimateB), conditional.log2ConditionalEstimate(B, 0, 1), 5.0E-4d);
        Assert.assertEquals("AB", Math.log2(estimateAB), conditional.log2ConditionalEstimate(AB, 0, 2), 5.0E-4d);
        Assert.assertEquals(conditional.log2ConditionalEstimate(ZAB, 0, 3), conditional.log2ConditionalEstimate(AB, 0, 2), 5.0E-4d);
        Assert.assertEquals("DAB", Math.log2(estimateDAB), conditional.log2ConditionalEstimate(DAB, 0, 3), 5.0E-5d);
        Assert.assertEquals(Math.log2(estimateDAB), conditional.log2ConditionalEstimate(XDAB, 0, 4), 5.0E-4d);
        Assert.assertEquals(conditional.log2ConditionalEstimate(A, 0, 1), conditional.log2Estimate(A, 0, 1), 5.0E-4d);
        Assert.assertEquals(conditional.log2ConditionalEstimate(AB, 0, 1) + conditional.log2ConditionalEstimate(AB, 0, 2), conditional.log2Estimate(AB, 0, 2), 5.0E-4d);
        Assert.assertEquals(conditional.log2ConditionalEstimate(DAB, 0, 1) + conditional.log2ConditionalEstimate(DAB, 0, 2) + conditional.log2ConditionalEstimate(DAB, 0, 3), conditional.log2Estimate(DAB, 0, 3), 5.0E-4d);
        Assert.assertEquals(conditional.log2ConditionalEstimate(DAB, 1, 2) + conditional.log2ConditionalEstimate(DAB, 1, 3), conditional.log2Estimate(DAB, 1, 3), 5.0E-4d);
    }
}
