package org.deeplearning4j.arbiter.optimize.candidategenerator;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import org.apache.commons.math3.random.RandomAdaptor;
import org.deeplearning4j.arbiter.optimize.api.Candidate;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace;
import org.deeplearning4j.arbiter.optimize.runner.LocalOptimizationRunner;
import org.deeplearning4j.arbiter.util.CollectionUtils;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonProperty;

@JsonIgnoreProperties({"numValuesPerParam", "totalNumCandidates", "order", "candidateCounter", "rng"})
/* loaded from: input_file:org/deeplearning4j/arbiter/optimize/candidategenerator/GridSearchCandidateGenerator.class */
public class GridSearchCandidateGenerator<T> extends BaseCandidateGenerator<T> {
    private final int discretizationCount;
    private final Mode mode;
    private int[] numValuesPerParam;
    private int totalNumCandidates;
    private Queue<Integer> order;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.deeplearning4j.arbiter.optimize.candidategenerator.GridSearchCandidateGenerator$1, reason: invalid class name */
    /* loaded from: input_file:org/deeplearning4j/arbiter/optimize/candidategenerator/GridSearchCandidateGenerator$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$deeplearning4j$arbiter$optimize$candidategenerator$GridSearchCandidateGenerator$Mode = new int[Mode.values().length];

        static {
            try {
                $SwitchMap$org$deeplearning4j$arbiter$optimize$candidategenerator$GridSearchCandidateGenerator$Mode[Mode.Sequential.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$deeplearning4j$arbiter$optimize$candidategenerator$GridSearchCandidateGenerator$Mode[Mode.RandomOrder.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    /* loaded from: input_file:org/deeplearning4j/arbiter/optimize/candidategenerator/GridSearchCandidateGenerator$Mode.class */
    public enum Mode {
        Sequential,
        RandomOrder
    }

    public GridSearchCandidateGenerator(@JsonProperty("parameterSpace") ParameterSpace<T> parameterSpace, @JsonProperty("discretizationCount") int i, @JsonProperty("mode") Mode mode) {
        super(parameterSpace);
        this.discretizationCount = i;
        this.mode = mode;
        initialize();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.deeplearning4j.arbiter.optimize.candidategenerator.BaseCandidateGenerator
    public void initialize() {
        super.initialize();
        List unique = CollectionUtils.getUnique(this.parameterSpace.collectLeaves());
        int size = unique.size();
        this.numValuesPerParam = new int[size];
        long j = 1;
        for (int i = 0; i < size; i++) {
            ParameterSpace parameterSpace = (ParameterSpace) unique.get(i);
            if (parameterSpace instanceof DiscreteParameterSpace) {
                this.numValuesPerParam[i] = ((DiscreteParameterSpace) parameterSpace).numValues();
            } else if (parameterSpace instanceof IntegerParameterSpace) {
                IntegerParameterSpace integerParameterSpace = (IntegerParameterSpace) parameterSpace;
                this.numValuesPerParam[i] = Math.min((integerParameterSpace.getMax() - integerParameterSpace.getMin()) + 1, this.discretizationCount);
            } else {
                this.numValuesPerParam[i] = this.discretizationCount;
            }
            j *= this.numValuesPerParam[i];
        }
        if (j >= 2147483647L) {
            throw new IllegalStateException("Invalid search: cannot process search with " + j + " candidates > Integer.MAX_VALUE");
        }
        this.order = new ConcurrentLinkedQueue();
        this.totalNumCandidates = (int) j;
        switch (AnonymousClass1.$SwitchMap$org$deeplearning4j$arbiter$optimize$candidategenerator$GridSearchCandidateGenerator$Mode[this.mode.ordinal()]) {
            case LocalOptimizationRunner.DEFAULT_MAX_CONCURRENT_TASKS /* 1 */:
                for (int i2 = 0; i2 < this.totalNumCandidates; i2++) {
                    this.order.add(Integer.valueOf(i2));
                }
                return;
            case 2:
                ArrayList arrayList = new ArrayList(this.totalNumCandidates);
                for (int i3 = 0; i3 < this.totalNumCandidates; i3++) {
                    arrayList.add(Integer.valueOf(i3));
                }
                Collections.shuffle(arrayList, new RandomAdaptor(this.rng));
                this.order.addAll(arrayList);
                return;
            default:
                throw new RuntimeException();
        }
    }

    @Override // org.deeplearning4j.arbiter.optimize.api.CandidateGenerator
    public boolean hasMoreCandidates() {
        return !this.order.isEmpty();
    }

    @Override // org.deeplearning4j.arbiter.optimize.api.CandidateGenerator
    public Candidate<T> getCandidate() {
        double[] indexToValues = indexToValues(this.numValuesPerParam, this.order.remove().intValue(), this.totalNumCandidates);
        return new Candidate<>(this.parameterSpace.getValue(indexToValues), this.candidateCounter.getAndIncrement(), indexToValues);
    }

    public static double[] indexToValues(int[] iArr, int i, int i2) {
        int i3 = i2;
        int i4 = i;
        int[] iArr2 = new int[iArr.length];
        for (int length = iArr2.length - 1; length >= 0; length--) {
            i3 /= iArr[length];
            iArr2[length] = i4 / i3;
            i4 %= i3;
        }
        double[] dArr = new double[iArr.length];
        for (int i5 = 0; i5 < dArr.length; i5++) {
            if (iArr[i5] <= 1) {
                dArr[i5] = 0.0d;
            } else {
                dArr[i5] = iArr2[i5] / (iArr[i5] - 1);
            }
        }
        return dArr;
    }

    public String toString() {
        return "GridSearchCandidateGenerator(mode=" + this.mode + ")";
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof GridSearchCandidateGenerator)) {
            return false;
        }
        GridSearchCandidateGenerator gridSearchCandidateGenerator = (GridSearchCandidateGenerator) obj;
        if (!gridSearchCandidateGenerator.canEqual(this) || this.discretizationCount != gridSearchCandidateGenerator.discretizationCount) {
            return false;
        }
        Mode mode = this.mode;
        Mode mode2 = gridSearchCandidateGenerator.mode;
        if (mode == null) {
            if (mode2 != null) {
                return false;
            }
        } else if (!mode.equals(mode2)) {
            return false;
        }
        return Arrays.equals(this.numValuesPerParam, gridSearchCandidateGenerator.numValuesPerParam) && this.totalNumCandidates == gridSearchCandidateGenerator.totalNumCandidates;
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof GridSearchCandidateGenerator;
    }

    public int hashCode() {
        int i = (1 * 59) + this.discretizationCount;
        Mode mode = this.mode;
        return (((((i * 59) + (mode == null ? 43 : mode.hashCode())) * 59) + Arrays.hashCode(this.numValuesPerParam)) * 59) + this.totalNumCandidates;
    }
}
