package org.deeplearning4j.optimize.listeners;

import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.IEvaluation;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.api.InvocationType;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.listeners.callbacks.EvaluationCallback;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/optimize/listeners/EvaluativeListener.class */
public class EvaluativeListener implements TrainingListener {
    private static final Logger log = LoggerFactory.getLogger(EvaluativeListener.class);
    protected transient ThreadLocal<AtomicLong> iterationCount;
    protected int frequency;
    protected AtomicLong invocationCount;
    protected transient DataSetIterator dsIterator;
    protected transient MultiDataSetIterator mdsIterator;
    protected DataSet ds;
    protected MultiDataSet mds;
    protected IEvaluation[] evaluations;
    protected InvocationType invocationType;
    protected transient EvaluationCallback callback;

    public EvaluativeListener(@NonNull DataSetIterator dataSetIterator, int i) {
        this(dataSetIterator, i, InvocationType.ITERATION_END, new Evaluation());
        if (dataSetIterator == null) {
            throw new NullPointerException("iterator");
        }
    }

    public EvaluativeListener(@NonNull DataSetIterator dataSetIterator, int i, @NonNull InvocationType invocationType) {
        this(dataSetIterator, i, invocationType, new Evaluation());
        if (dataSetIterator == null) {
            throw new NullPointerException("iterator");
        }
        if (invocationType == null) {
            throw new NullPointerException("type");
        }
    }

    public EvaluativeListener(@NonNull MultiDataSetIterator multiDataSetIterator, int i) {
        this(multiDataSetIterator, i, InvocationType.ITERATION_END, new Evaluation());
        if (multiDataSetIterator == null) {
            throw new NullPointerException("iterator");
        }
    }

    public EvaluativeListener(@NonNull MultiDataSetIterator multiDataSetIterator, int i, @NonNull InvocationType invocationType) {
        this(multiDataSetIterator, i, invocationType, new Evaluation());
        if (multiDataSetIterator == null) {
            throw new NullPointerException("iterator");
        }
        if (invocationType == null) {
            throw new NullPointerException("type");
        }
    }

    public EvaluativeListener(@NonNull DataSetIterator dataSetIterator, int i, IEvaluation... iEvaluationArr) {
        this(dataSetIterator, i, InvocationType.ITERATION_END, iEvaluationArr);
        if (dataSetIterator == null) {
            throw new NullPointerException("iterator");
        }
    }

    public EvaluativeListener(@NonNull DataSetIterator dataSetIterator, int i, @NonNull InvocationType invocationType, IEvaluation... iEvaluationArr) {
        this.iterationCount = new ThreadLocal<>();
        this.invocationCount = new AtomicLong(0L);
        if (dataSetIterator == null) {
            throw new NullPointerException("iterator");
        }
        if (invocationType == null) {
            throw new NullPointerException("type");
        }
        this.dsIterator = dataSetIterator;
        this.frequency = i;
        this.evaluations = iEvaluationArr;
        this.invocationType = invocationType;
    }

    public EvaluativeListener(@NonNull MultiDataSetIterator multiDataSetIterator, int i, IEvaluation... iEvaluationArr) {
        this(multiDataSetIterator, i, InvocationType.ITERATION_END, iEvaluationArr);
        if (multiDataSetIterator == null) {
            throw new NullPointerException("iterator");
        }
    }

    public EvaluativeListener(@NonNull MultiDataSetIterator multiDataSetIterator, int i, @NonNull InvocationType invocationType, IEvaluation... iEvaluationArr) {
        this.iterationCount = new ThreadLocal<>();
        this.invocationCount = new AtomicLong(0L);
        if (multiDataSetIterator == null) {
            throw new NullPointerException("iterator");
        }
        if (invocationType == null) {
            throw new NullPointerException("type");
        }
        this.mdsIterator = multiDataSetIterator;
        this.frequency = i;
        this.evaluations = iEvaluationArr;
        this.invocationType = invocationType;
    }

    public EvaluativeListener(@NonNull DataSet dataSet, int i, @NonNull InvocationType invocationType) {
        this(dataSet, i, invocationType, new Evaluation());
        if (dataSet == null) {
            throw new NullPointerException("dataSet");
        }
        if (invocationType == null) {
            throw new NullPointerException("type");
        }
    }

    public EvaluativeListener(@NonNull MultiDataSet multiDataSet, int i, @NonNull InvocationType invocationType) {
        this(multiDataSet, i, invocationType, new Evaluation());
        if (multiDataSet == null) {
            throw new NullPointerException("multiDataSet");
        }
        if (invocationType == null) {
            throw new NullPointerException("type");
        }
    }

    public EvaluativeListener(@NonNull DataSet dataSet, int i, @NonNull InvocationType invocationType, IEvaluation... iEvaluationArr) {
        this.iterationCount = new ThreadLocal<>();
        this.invocationCount = new AtomicLong(0L);
        if (dataSet == null) {
            throw new NullPointerException("dataSet");
        }
        if (invocationType == null) {
            throw new NullPointerException("type");
        }
        this.ds = dataSet;
        this.frequency = i;
        this.evaluations = iEvaluationArr;
        this.invocationType = invocationType;
    }

    public EvaluativeListener(@NonNull MultiDataSet multiDataSet, int i, @NonNull InvocationType invocationType, IEvaluation... iEvaluationArr) {
        this.iterationCount = new ThreadLocal<>();
        this.invocationCount = new AtomicLong(0L);
        if (multiDataSet == null) {
            throw new NullPointerException("multiDataSet");
        }
        if (invocationType == null) {
            throw new NullPointerException("type");
        }
        this.mds = multiDataSet;
        this.frequency = i;
        this.evaluations = iEvaluationArr;
        this.invocationType = invocationType;
    }

    @Override // org.deeplearning4j.optimize.api.IterationListener
    public boolean invoked() {
        return false;
    }

    @Override // org.deeplearning4j.optimize.api.IterationListener
    public void invoke() {
    }

    @Override // org.deeplearning4j.optimize.api.IterationListener
    public void iterationDone(Model model, int i) {
    }

    @Override // org.deeplearning4j.optimize.api.TrainingListener
    public void onEpochStart(Model model) {
        if (this.invocationType == InvocationType.EPOCH_START) {
            invokeListener(model);
        }
    }

    @Override // org.deeplearning4j.optimize.api.TrainingListener
    public void onEpochEnd(Model model) {
        if (this.invocationType == InvocationType.EPOCH_END) {
            invokeListener(model);
        }
    }

    @Override // org.deeplearning4j.optimize.api.TrainingListener
    public void onForwardPass(Model model, List<INDArray> list) {
    }

    @Override // org.deeplearning4j.optimize.api.TrainingListener
    public void onForwardPass(Model model, Map<String, INDArray> map) {
    }

    @Override // org.deeplearning4j.optimize.api.TrainingListener
    public void onGradientCalculation(Model model) {
    }

    @Override // org.deeplearning4j.optimize.api.TrainingListener
    public void onBackwardPass(Model model) {
        if (this.invocationType == InvocationType.ITERATION_END) {
            invokeListener(model);
        }
    }

    protected void invokeListener(Model model) {
        if (this.iterationCount.get() == null) {
            this.iterationCount.set(new AtomicLong(0L));
        }
        if (this.iterationCount.get().getAndIncrement() % this.frequency != 0) {
            return;
        }
        for (IEvaluation iEvaluation : this.evaluations) {
            iEvaluation.reset();
        }
        if (this.dsIterator != null && this.dsIterator.resetSupported()) {
            this.dsIterator.reset();
        } else if (this.mdsIterator != null && this.mdsIterator.resetSupported()) {
            this.mdsIterator.reset();
        }
        log.info("Starting evaluation nr. {}", Long.valueOf(this.invocationCount.incrementAndGet()));
        if (!(model instanceof MultiLayerNetwork)) {
            if (!(model instanceof ComputationGraph)) {
                throw new DL4JInvalidInputException("Model is unknown: " + model.getClass().getCanonicalName());
            }
            if (this.dsIterator != null) {
                ((ComputationGraph) model).doEvaluation(this.dsIterator, this.evaluations);
            } else if (this.mdsIterator != null) {
                ((ComputationGraph) model).doEvaluation(this.mdsIterator, this.evaluations);
            } else if (this.ds != null) {
                for (IEvaluation iEvaluation2 : this.evaluations) {
                    evalAtIndex(iEvaluation2, new INDArray[]{this.ds.getLabels()}, ((ComputationGraph) model).output(this.ds.getFeatureMatrix()), 0);
                }
            } else if (this.mds != null) {
                for (IEvaluation iEvaluation3 : this.evaluations) {
                    evalAtIndex(iEvaluation3, this.mds.getLabels(), ((ComputationGraph) model).output(this.mds.getFeatures()), 0);
                }
            }
        } else if (this.dsIterator != null) {
            ((MultiLayerNetwork) model).doEvaluation(this.dsIterator, this.evaluations);
        } else if (this.ds != null) {
            for (IEvaluation iEvaluation4 : this.evaluations) {
                iEvaluation4.eval(this.ds.getLabels(), ((MultiLayerNetwork) model).output(this.ds.getFeatureMatrix()));
            }
        }
        log.info("Reporting evaluation results:");
        for (IEvaluation iEvaluation5 : this.evaluations) {
            log.info("{}:\n{}", iEvaluation5.getClass().getSimpleName(), iEvaluation5.stats());
        }
        if (this.callback != null) {
            this.callback.call(this, model, this.invocationCount.get(), this.evaluations);
        }
    }

    protected void evalAtIndex(IEvaluation iEvaluation, INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2, int i) {
        iEvaluation.eval(iNDArrayArr[i], iNDArrayArr2[i]);
    }

    public IEvaluation[] getEvaluations() {
        return this.evaluations;
    }

    public InvocationType getInvocationType() {
        return this.invocationType;
    }

    public EvaluationCallback getCallback() {
        return this.callback;
    }

    public void setCallback(EvaluationCallback evaluationCallback) {
        this.callback = evaluationCallback;
    }
}
