package com.aliasi.cluster;

import com.aliasi.classify.PrecisionRecallEvaluation;
import com.aliasi.util.Distance;
import com.aliasi.util.Tuple;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;

/* loaded from: input_file:com/aliasi/cluster/ClusterScore.class */
public class ClusterScore<E> {
    private final PrecisionRecallEvaluation mPrEval;
    private final Set<? extends Set<? extends E>> mReferencePartition;
    private final Set<? extends Set<? extends E>> mResponsePartition;

    public ClusterScore(Set<? extends Set<? extends E>> set, Set<? extends Set<? extends E>> set2) {
        assertPartitionSameSets(set, set2);
        this.mReferencePartition = set;
        this.mResponsePartition = set2;
        this.mPrEval = calculateConfusionMatrix();
    }

    public PrecisionRecallEvaluation equivalenceEvaluation() {
        return this.mPrEval;
    }

    public double mucPrecision() {
        return mucRecall(this.mResponsePartition, this.mReferencePartition);
    }

    public double mucRecall() {
        return mucRecall(this.mReferencePartition, this.mResponsePartition);
    }

    public double mucF() {
        return f(mucPrecision(), mucRecall());
    }

    public double b3ClusterPrecision() {
        return b3ClusterRecall(this.mResponsePartition, this.mReferencePartition);
    }

    public double b3ClusterRecall() {
        return b3ClusterRecall(this.mReferencePartition, this.mResponsePartition);
    }

    public double b3ClusterF() {
        return f(b3ClusterPrecision(), b3ClusterRecall());
    }

    public double b3ElementPrecision() {
        return b3ElementRecall(this.mResponsePartition, this.mReferencePartition);
    }

    public double b3ElementRecall() {
        return b3ElementRecall(this.mReferencePartition, this.mResponsePartition);
    }

    public double b3ElementF() {
        return f(b3ElementPrecision(), b3ElementRecall());
    }

    public Set<Tuple<E>> truePositives() {
        Set<Tuple<E>> equivalences = toEquivalences(this.mReferencePartition);
        equivalences.retainAll(toEquivalences(this.mResponsePartition));
        return equivalences;
    }

    public Set<Tuple<E>> falsePositives() {
        Set<Tuple<E>> equivalences = toEquivalences(this.mReferencePartition);
        Set<Tuple<E>> equivalences2 = toEquivalences(this.mResponsePartition);
        equivalences2.removeAll(equivalences);
        return equivalences2;
    }

    public Set<Tuple<E>> falseNegatives() {
        Set<Tuple<E>> equivalences = toEquivalences(this.mReferencePartition);
        equivalences.removeAll(toEquivalences(this.mResponsePartition));
        return equivalences;
    }

    private PrecisionRecallEvaluation calculateConfusionMatrix() {
        Set<Tuple<E>> equivalences = toEquivalences(this.mReferencePartition);
        Set<Tuple<E>> equivalences2 = toEquivalences(this.mResponsePartition);
        long j = 0;
        long j2 = 0;
        Iterator<Tuple<E>> it = equivalences.iterator();
        while (it.hasNext()) {
            if (equivalences2.remove(it.next())) {
                j++;
            } else {
                j2++;
            }
        }
        long size = elementsOf(this.mReferencePartition).size();
        long j3 = size * size;
        long size2 = equivalences2.size();
        return new PrecisionRecallEvaluation(j, j2, size2, ((j3 - j) - j2) - size2);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("CLUSTER SCORE");
        sb.append("\nEquivalence Evaluation\n");
        sb.append(this.mPrEval.toString());
        sb.append("\nMUC Evaluation");
        sb.append("\n  MUC Precision = " + mucPrecision());
        sb.append("\n  MUC Recall = " + mucRecall());
        sb.append("\n  MUC F(1) = " + mucF());
        sb.append("\nB-Cubed Evaluation");
        sb.append("\n  B3 Cluster Averaged Precision = " + b3ClusterPrecision());
        sb.append("\n  B3 Cluster Averaged Recall = " + b3ClusterRecall());
        sb.append("\n  B3 Cluster Averaged F(1) = " + b3ClusterF());
        sb.append("\n  B3 Element Averaged Precision = " + b3ElementPrecision());
        sb.append("\n  B3 Element Averaged Recall = " + b3ElementRecall());
        sb.append("\n  B3 Element Averaged F(1) = " + b3ElementF());
        return sb.toString();
    }

    public static <E> double withinClusterScatter(Set<? extends Set<? extends E>> set, Distance<? super E> distance) {
        double d = 0.0d;
        Iterator<? extends Set<? extends E>> it = set.iterator();
        while (it.hasNext()) {
            d += scatter(it.next(), distance);
        }
        return d;
    }

    public static <E> double scatter(Set<? extends E> set, Distance<? super E> distance) {
        Object[] array = set.toArray();
        double d = 0.0d;
        for (int i = 0; i < array.length; i++) {
            for (int i2 = i + 1; i2 < array.length; i2++) {
                d += distance.distance(array[i], array[i2]);
            }
        }
        return d;
    }

    Set<Tuple<E>> toEquivalences(Set<? extends Set<? extends E>> set) {
        HashSet hashSet = new HashSet();
        for (Set<? extends E> set2 : set) {
            Object[] objArr = new Object[set2.size()];
            set2.toArray(objArr);
            for (Object obj : objArr) {
                for (Object obj2 : objArr) {
                    hashSet.add(Tuple.create(obj, obj2));
                }
            }
        }
        return hashSet;
    }

    private static <F> double b3ElementRecall(Set<? extends Set<? extends F>> set, Set<? extends Set<? extends F>> set2) {
        double d = 0.0d;
        Set elementsOf = elementsOf(set);
        for (Set<? extends F> set3 : set) {
            Iterator<? extends F> it = set3.iterator();
            while (it.hasNext()) {
                d += uniformElementWeight(elementsOf) * b3Recall(it.next(), set3, set2);
            }
        }
        return d;
    }

    private static <F> double uniformElementWeight(Set<? extends F> set) {
        return 1.0d / set.size();
    }

    private static <F> double uniformClusterWeight(Set<? extends F> set, Set<? extends Set<? extends F>> set2) {
        return 1.0d / (set.size() * set2.size());
    }

    private static <F> double b3ClusterRecall(Set<? extends Set<? extends F>> set, Set<? extends Set<? extends F>> set2) {
        double d = 0.0d;
        for (Set<? extends F> set3 : set) {
            Iterator<? extends F> it = set3.iterator();
            while (it.hasNext()) {
                d += uniformClusterWeight(set3, set) * b3Recall(it.next(), set3, set2);
            }
        }
        return d;
    }

    private static <F> double b3Recall(F f, Set<? extends F> set, Set<? extends Set<? extends F>> set2) {
        return recallSets(set, getEquivalenceClass(f, set2));
    }

    private static <F> double recallSets(Set<? extends F> set, Set<? extends F> set2) {
        if (set.size() == 0) {
            return 1.0d;
        }
        return intersectionSize(set, set2) / set.size();
    }

    private static <F> long intersectionSize(Set<? extends F> set, Set<? extends F> set2) {
        long j = 0;
        Iterator<? extends F> it = set.iterator();
        while (it.hasNext()) {
            if (set2.contains(it.next())) {
                j++;
            }
        }
        return j;
    }

    private static <F> void assertPartitionSameSets(Set<? extends Set<? extends F>> set, Set<? extends Set<? extends F>> set2) {
        assertValidPartition(set);
        assertValidPartition(set2);
        if (!elementsOf(set).equals(elementsOf(set2))) {
            throw new IllegalArgumentException("Partitions must be of same sets.");
        }
    }

    private static <F> void assertValidPartition(Set<? extends Set<? extends F>> set) {
        HashSet hashSet = new HashSet();
        Iterator<? extends Set<? extends F>> it = set.iterator();
        while (it.hasNext()) {
            for (F f : it.next()) {
                if (!hashSet.add(f)) {
                    throw new IllegalArgumentException("Partitions must not contain overlapping members. Found overlapping element=" + f);
                }
            }
        }
    }

    private static <F> Set<? extends F> getEquivalenceClass(F f, Set<? extends Set<? extends F>> set) {
        for (Set<? extends F> set2 : set) {
            if (set2.contains(f)) {
                return set2;
            }
        }
        throw new IllegalArgumentException("Element must be in an equivalence class in partition.");
    }

    private static <F> Set<F> elementsOf(Set<? extends Set<? extends F>> set) {
        HashSet hashSet = new HashSet();
        Iterator<? extends Set<? extends F>> it = set.iterator();
        while (it.hasNext()) {
            hashSet.addAll(it.next());
        }
        return hashSet;
    }

    private static double f(double d, double d2) {
        return ((2.0d * d) * d2) / (d + d2);
    }

    private static <F> double mucRecall(Set<? extends Set<? extends F>> set, Set<? extends Set<? extends F>> set2) {
        long j = 0;
        long j2 = 0;
        for (Set<? extends F> set3 : set) {
            long j3 = 0;
            Iterator<? extends Set<? extends F>> it = set2.iterator();
            while (it.hasNext()) {
                if (!Collections.disjoint(set3, it.next())) {
                    j3++;
                }
            }
            j += set3.size() - j3;
            j2 += set3.size() - 1;
        }
        if (j2 == 0) {
            return 1.0d;
        }
        return j / j2;
    }
}
