package org.deeplearning4j.nn.updater;

import java.util.Objects;
import org.deeplearning4j.eval.RegressionEvaluation;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.LearningRatePolicy;
import org.deeplearning4j.nn.conf.layers.BaseLayer;

/* loaded from: input_file:org/deeplearning4j/nn/updater/UpdaterUtils.class */
public class UpdaterUtils {

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.deeplearning4j.nn.updater.UpdaterUtils$1, reason: invalid class name */
    /* loaded from: input_file:org/deeplearning4j/nn/updater/UpdaterUtils$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$deeplearning4j$nn$conf$LearningRatePolicy = new int[LearningRatePolicy.values().length];

        static {
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$LearningRatePolicy[LearningRatePolicy.None.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$LearningRatePolicy[LearningRatePolicy.Exponential.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$LearningRatePolicy[LearningRatePolicy.Inverse.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$LearningRatePolicy[LearningRatePolicy.Poly.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$LearningRatePolicy[LearningRatePolicy.Sigmoid.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$LearningRatePolicy[LearningRatePolicy.Step.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$LearningRatePolicy[LearningRatePolicy.TorchStep.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$LearningRatePolicy[LearningRatePolicy.Schedule.ordinal()] = 8;
            } catch (NoSuchFieldError e8) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$LearningRatePolicy[LearningRatePolicy.Score.ordinal()] = 9;
            } catch (NoSuchFieldError e9) {
            }
        }
    }

    public static boolean updaterConfigurationsEquals(Layer layer, String str, Layer layer2, String str2) {
        if (!layer.conf().getLayer().getIUpdaterByParam(str).equals(layer2.conf().getLayer().getIUpdaterByParam(str2)) || !lrSchedulesEqual(layer, str, layer2, str2)) {
            return false;
        }
        boolean isPretrainParam = layer.conf().getLayer().isPretrainParam(str);
        boolean isPretrainParam2 = layer2.conf().getLayer().isPretrainParam(str2);
        if (isPretrainParam || isPretrainParam2) {
            return layer == layer2 && isPretrainParam && isPretrainParam2;
        }
        return true;
    }

    public static boolean lrSchedulesEqual(Layer layer, String str, Layer layer2, String str2) {
        boolean z;
        LearningRatePolicy learningRatePolicy = layer.conf().getLearningRatePolicy();
        if (learningRatePolicy != layer2.conf().getLearningRatePolicy() || layer.conf().getLearningRateByParam(str) != layer2.conf().getLearningRateByParam(str2)) {
            return false;
        }
        double lrPolicyDecayRate = layer.conf().getLrPolicyDecayRate();
        double lrPolicyDecayRate2 = layer2.conf().getLrPolicyDecayRate();
        switch (AnonymousClass1.$SwitchMap$org$deeplearning4j$nn$conf$LearningRatePolicy[learningRatePolicy.ordinal()]) {
            case 1:
                z = true;
                break;
            case 2:
                z = lrPolicyDecayRate == lrPolicyDecayRate2;
                break;
            case 3:
                z = lrPolicyDecayRate == lrPolicyDecayRate2 && layer.conf().getLrPolicyPower() == layer2.conf().getLrPolicyPower();
                break;
            case 4:
                z = layer.conf().getLrPolicyPower() == layer2.conf().getLrPolicyPower();
                break;
            case RegressionEvaluation.DEFAULT_PRECISION /* 5 */:
                z = lrPolicyDecayRate == lrPolicyDecayRate2 && layer.conf().getLrPolicySteps() == layer2.conf().getLrPolicySteps();
                break;
            case 6:
                z = lrPolicyDecayRate == lrPolicyDecayRate2 && layer.conf().getLrPolicySteps() == layer2.conf().getLrPolicySteps();
                break;
            case 7:
                z = layer.conf().getLrPolicyPower() == layer2.conf().getLrPolicyPower();
                break;
            case 8:
                z = Objects.equals(((BaseLayer) layer.conf().getLayer()).getLearningRateSchedule(), ((BaseLayer) layer2.conf().getLayer()).getLearningRateSchedule());
                break;
            case 9:
                z = false;
                break;
            default:
                throw new UnsupportedOperationException("Unknown learning rate schedule: " + learningRatePolicy);
        }
        return z;
    }
}
