package org.deeplearning4j.nn.gradient;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/nn/gradient/DefaultGradient.class */
public class DefaultGradient implements Gradient {
    public static final char DEFAULT_FLATTENING_ORDER = 'f';
    private Map<String, INDArray> gradients = new LinkedHashMap();
    private Map<String, Character> flatteningOrders;
    private INDArray flattenedGradient;

    public DefaultGradient() {
    }

    public DefaultGradient(INDArray iNDArray) {
        this.flattenedGradient = iNDArray;
    }

    @Override // org.deeplearning4j.nn.gradient.Gradient
    public Map<String, INDArray> gradientForVariable() {
        return this.gradients;
    }

    @Override // org.deeplearning4j.nn.gradient.Gradient
    public INDArray gradient(List<String> list) {
        ArrayList arrayList = new ArrayList();
        if (this.flatteningOrders == null) {
            for (String str : list) {
                if (this.gradients.containsKey(str)) {
                    arrayList.add(this.gradients.get(str));
                }
            }
        } else {
            for (String str2 : list) {
                if (this.gradients.containsKey(str2)) {
                    if (!this.flatteningOrders.containsKey(str2) || this.flatteningOrders.get(str2).charValue() == 'f') {
                        arrayList.add(this.gradients.get(str2));
                    } else {
                        arrayList.add(Nd4j.toFlattened(this.flatteningOrders.get(str2).charValue(), new INDArray[]{this.gradients.get(str2)}));
                    }
                }
            }
        }
        return Nd4j.toFlattened('f', arrayList);
    }

    private void flattenGradient() {
        if (this.flatteningOrders == null) {
            this.flattenedGradient = Nd4j.toFlattened('f', this.gradients.values());
            return;
        }
        ArrayList arrayList = new ArrayList();
        for (Map.Entry<String, INDArray> entry : this.gradients.entrySet()) {
            if (!this.flatteningOrders.containsKey(entry.getKey()) || this.flatteningOrders.get(entry.getKey()).charValue() == 'f') {
                arrayList.add(entry.getValue());
            } else {
                arrayList.add(Nd4j.toFlattened(this.flatteningOrders.get(entry.getKey()).charValue(), new INDArray[]{entry.getValue()}));
            }
        }
        this.flattenedGradient = Nd4j.toFlattened('f', arrayList);
    }

    @Override // org.deeplearning4j.nn.gradient.Gradient
    public INDArray gradient() {
        if (this.flattenedGradient != null) {
            return this.flattenedGradient;
        }
        flattenGradient();
        return this.flattenedGradient;
    }

    @Override // org.deeplearning4j.nn.gradient.Gradient
    public void clear() {
        this.gradients.clear();
    }

    @Override // org.deeplearning4j.nn.gradient.Gradient
    public INDArray getGradientFor(String str) {
        return this.gradients.get(str);
    }

    @Override // org.deeplearning4j.nn.gradient.Gradient
    public INDArray setGradientFor(String str, INDArray iNDArray) {
        return this.gradients.put(str, iNDArray);
    }

    @Override // org.deeplearning4j.nn.gradient.Gradient
    public INDArray setGradientFor(String str, INDArray iNDArray, Character ch) {
        INDArray gradientFor = setGradientFor(str, iNDArray);
        if (ch != null) {
            if (this.flatteningOrders == null) {
                this.flatteningOrders = new LinkedHashMap();
            }
            this.flatteningOrders.put(str, ch);
        }
        return gradientFor;
    }

    @Override // org.deeplearning4j.nn.gradient.Gradient
    public Character flatteningOrderForVariable(String str) {
        if (this.flatteningOrders == null) {
            return null;
        }
        return this.flatteningOrders.get(str);
    }

    public String toString() {
        return "DefaultGradient{gradients=" + this.gradients + (this.flatteningOrders != null ? this.flatteningOrders : "") + '}';
    }

    public void setFlattenedGradient(INDArray iNDArray) {
        this.flattenedGradient = iNDArray;
    }
}
