package com.aliasi.test.unit.classify;

import com.aliasi.classify.Classification;
import com.aliasi.classify.Classified;
import com.aliasi.classify.LogisticRegressionClassifier;
import com.aliasi.corpus.XValidatingObjectCorpus;
import com.aliasi.stats.AnnealingSchedule;
import com.aliasi.stats.RegressionPrior;
import com.aliasi.tokenizer.RegExTokenizerFactory;
import com.aliasi.tokenizer.TokenFeatureExtractor;
import java.io.IOException;
import java.util.Random;
import junit.framework.Assert;
import org.junit.Test;

/* loaded from: input_file:com/aliasi/test/unit/classify/LogisticRegressionClassifierTest.class */
public class LogisticRegressionClassifierTest {
    @Test
    public void test1() throws IOException {
        Random random = new Random();
        XValidatingObjectCorpus xValidatingObjectCorpus = new XValidatingObjectCorpus(10);
        for (int i = 0; i < 4; i++) {
            Classification classification = new Classification("cat_" + ((char) (97 + i)));
            for (int i2 = 0; i2 < 100; i2++) {
                xValidatingObjectCorpus.handle(new Classified(generateExample(i), classification));
            }
        }
        xValidatingObjectCorpus.permuteCorpus(random);
        TokenFeatureExtractor tokenFeatureExtractor = new TokenFeatureExtractor(new RegExTokenizerFactory("\\S+"));
        RegressionPrior noninformative = RegressionPrior.noninformative();
        AnnealingSchedule inverse = AnnealingSchedule.inverse(0.01d, 500.0d);
        LogisticRegressionClassifier train = LogisticRegressionClassifier.train(xValidatingObjectCorpus, tokenFeatureExtractor, 2, true, noninformative, 4, null, inverse, 0.001d, 5, 2, 10000, null, null);
        for (int i3 = 0; i3 < 4; i3++) {
            Classification classification2 = new Classification("cat_" + ((char) (97 + i3)));
            for (int i4 = 0; i4 < 10; i4++) {
                Assert.assertEquals(classification2.bestCategory(), train.classify((LogisticRegressionClassifier) generateExample(i3)).bestCategory());
            }
        }
        LogisticRegressionClassifier.train(xValidatingObjectCorpus, tokenFeatureExtractor, 2, true, noninformative, 2, train, inverse, 0.001d / 1000.0d, 5, 2, 10000, null, null);
    }

    static StringBuilder generateExample(int i) {
        Random random = new Random();
        StringBuilder sb = new StringBuilder();
        for (int i2 = 0; i2 < 100; i2++) {
            if (i2 > 0) {
                sb.append(' ');
            }
            if (random.nextBoolean()) {
                sb.append((char) (97 + i));
            } else {
                sb.append((char) (97 + random.nextInt(10)));
            }
        }
        return sb;
    }
}
