package org.deeplearning4j.nn.layers.convolution;

import java.util.Arrays;
import java.util.Map;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.eval.EvaluationBinary;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.deeplearning4j.util.ConvolutionUtils;
import org.deeplearning4j.util.Dropout;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.convolution.Convolution;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.class */
public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.ConvolutionLayer> {
    protected static final Logger log = LoggerFactory.getLogger(ConvolutionLayer.class);
    protected INDArray i2d;
    protected ConvolutionHelper helper;
    protected ConvolutionMode convolutionMode;

    public ConvolutionLayer(NeuralNetConfiguration neuralNetConfiguration) {
        super(neuralNetConfiguration);
        this.helper = null;
        initializeHelper();
        this.convolutionMode = ((org.deeplearning4j.nn.conf.layers.ConvolutionLayer) conf().getLayer()).getConvolutionMode();
    }

    public ConvolutionLayer(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        super(neuralNetConfiguration, iNDArray);
        this.helper = null;
        initializeHelper();
    }

    void initializeHelper() {
        try {
            this.helper = (ConvolutionHelper) Class.forName("org.deeplearning4j.nn.layers.convolution.CudnnConvolutionHelper").asSubclass(ConvolutionHelper.class).newInstance();
            log.debug("CudnnConvolutionHelper successfully initialized");
            if (!this.helper.checkSupported()) {
                this.helper = null;
            }
        } catch (Throwable th) {
            if (th instanceof ClassNotFoundException) {
                return;
            }
            log.warn("Could not initialize CudnnConvolutionHelper", th);
        }
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public double calcL2(boolean z) {
        if (!this.conf.isUseRegularization()) {
            return EvaluationBinary.DEFAULT_EDGE_VALUE;
        }
        double d = 0.0d;
        for (Map.Entry<String, INDArray> entry : paramTable().entrySet()) {
            double l2ByParam = this.conf.getL2ByParam(entry.getKey());
            if (l2ByParam > EvaluationBinary.DEFAULT_EDGE_VALUE) {
                double doubleValue = getParam(entry.getKey()).norm2Number().doubleValue();
                d += 0.5d * l2ByParam * doubleValue * doubleValue;
            }
        }
        return d;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public double calcL1(boolean z) {
        if (!this.conf.isUseRegularization()) {
            return EvaluationBinary.DEFAULT_EDGE_VALUE;
        }
        double d = 0.0d;
        for (Map.Entry<String, INDArray> entry : paramTable().entrySet()) {
            double l1ByParam = this.conf.getL1ByParam(entry.getKey());
            if (l1ByParam > EvaluationBinary.DEFAULT_EDGE_VALUE) {
                d += l1ByParam * getParam(entry.getKey()).norm1Number().doubleValue();
            }
        }
        return d;
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public Layer.Type type() {
        return Layer.Type.CONVOLUTIONAL;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray) {
        int[] padding;
        int[] outputSize;
        INDArray create;
        Pair<Gradient, INDArray> backpropGradient;
        INDArray param = getParam("W");
        int size = this.input.size(0);
        int size2 = this.input.size(2);
        int size3 = this.input.size(3);
        int size4 = param.size(0);
        int size5 = param.size(1);
        int size6 = param.size(2);
        int size7 = param.size(3);
        int[] kernelSize = layerConf().getKernelSize();
        int[] stride = layerConf().getStride();
        if (this.convolutionMode == ConvolutionMode.Same) {
            outputSize = ConvolutionUtils.getOutputSize(this.input, kernelSize, stride, null, this.convolutionMode);
            padding = ConvolutionUtils.getSameModeTopLeftPadding(outputSize, new int[]{size2, size3}, kernelSize, stride);
        } else {
            padding = layerConf().getPadding();
            outputSize = ConvolutionUtils.getOutputSize(this.input, kernelSize, stride, padding, this.convolutionMode);
        }
        int i = outputSize[0];
        int i2 = outputSize[1];
        INDArray iNDArray2 = this.gradientViews.get("b");
        INDArray iNDArray3 = this.gradientViews.get("W");
        INDArray transpose = Shape.newShapeNoCopy(iNDArray3, new int[]{size4, size5 * size6 * size7}, false).transpose();
        IActivation activationFn = layerConf().getActivationFn();
        Pair<INDArray, INDArray> preOutput4d = preOutput4d(true, true);
        INDArray iNDArray4 = (INDArray) activationFn.backprop(preOutput4d.getFirst(), iNDArray).getFirst();
        if (this.helper != null && (backpropGradient = this.helper.backpropGradient(this.input, param, iNDArray4, kernelSize, stride, padding, iNDArray2, iNDArray3, activationFn, layerConf().getCudnnAlgoMode(), layerConf().getCudnnBwdFilterAlgo(), layerConf().getCudnnBwdDataAlgo(), this.convolutionMode)) != null) {
            return backpropGradient;
        }
        INDArray reshape = iNDArray4.permute(new int[]{1, 0, 2, 3}).reshape('c', new int[]{size4, size * i * i2});
        INDArray second = preOutput4d.getSecond();
        if (second == null) {
            INDArray createUninitialized = Nd4j.createUninitialized(new int[]{size, i, i2, size5, size6, size7}, 'c');
            Convolution.im2col(this.input, size6, size7, stride[0], stride[1], padding[0], padding[1], this.convolutionMode == ConvolutionMode.Same, createUninitialized.permute(new int[]{0, 3, 4, 5, 1, 2}));
            second = createUninitialized.reshape('c', size * i * i2, size5 * size6 * size7);
        }
        Nd4j.gemm(second, reshape, transpose, true, true, 1.0d, EvaluationBinary.DEFAULT_EDGE_VALUE);
        INDArray permute = Shape.newShapeNoCopy(param.permute(new int[]{3, 2, 1, 0}).reshape('f', size5 * size6 * size7, size4).mmul(reshape), new int[]{size7, size6, size5, i2, i, size}, true).permute(new int[]{5, 2, 1, 0, 4, 3});
        if (!Nd4j.getWorkspaceManager().checkIfWorkspaceExists(ComputationGraph.workspaceExternal) || Nd4j.getMemoryManager().getCurrentWorkspace() == Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceExternal)) {
            create = Nd4j.create(new int[]{size5, size, size2, size3}, 'c');
        } else {
            MemoryWorkspace notifyScopeBorrowed = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceExternal).notifyScopeBorrowed();
            Throwable th = null;
            try {
                try {
                    create = Nd4j.create(new int[]{size5, size, size2, size3}, 'c');
                    if (notifyScopeBorrowed != null) {
                        if (0 != 0) {
                            try {
                                notifyScopeBorrowed.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            notifyScopeBorrowed.close();
                        }
                    }
                } finally {
                }
            } catch (Throwable th3) {
                if (notifyScopeBorrowed != null) {
                    if (th != null) {
                        try {
                            notifyScopeBorrowed.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        notifyScopeBorrowed.close();
                    }
                }
                throw th3;
            }
        }
        INDArray permute2 = create.permute(new int[]{1, 0, 2, 3});
        Convolution.col2im(permute, permute2, stride[0], stride[1], padding[0], padding[1], size2, size3);
        DefaultGradient defaultGradient = new DefaultGradient();
        reshape.sum(iNDArray2, new int[]{1});
        defaultGradient.setGradientFor("b", iNDArray2);
        defaultGradient.setGradientFor("W", iNDArray3, 'c');
        return new Pair<>(defaultGradient, permute2);
    }

    protected Pair<INDArray, INDArray> preOutput4d(boolean z, boolean z2) {
        return preOutput(z, z2);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.layers.AbstractLayer
    public INDArray preOutput(boolean z) {
        return preOutput(z, false).getFirst();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Pair<INDArray, INDArray> preOutput(boolean z, boolean z2) {
        int[] padding;
        int[] outputSize;
        INDArray mmul;
        INDArray param = getParam("W");
        INDArray param2 = getParam("b");
        if (this.conf.isUseDropConnect() && z && this.conf.getLayer().getDropOut() > EvaluationBinary.DEFAULT_EDGE_VALUE) {
            param = Dropout.applyDropConnect(this, "W");
        }
        if (this.input.rank() != 4) {
            String layerName = this.conf.getLayer().getLayerName();
            if (layerName == null) {
                layerName = "(not named)";
            }
            throw new DL4JInvalidInputException("Got rank " + this.input.rank() + " array as input to ConvolutionLayer (layer name = " + layerName + ", layer index = " + this.index + ") with shape " + Arrays.toString(this.input.shape()) + ". Expected rank 4 array with shape [minibatchSize, layerInputDepth, inputHeight, inputWidth]." + (this.input.rank() == 2 ? " (Wrong input type (see InputType.convolutionalFlat()) or wrong data type?)" : "") + " " + layerId());
        }
        int size = this.input.size(0);
        int size2 = param.size(0);
        int size3 = param.size(1);
        if (this.input.size(1) != size3) {
            String layerName2 = this.conf.getLayer().getLayerName();
            if (layerName2 == null) {
                layerName2 = "(not named)";
            }
            throw new DL4JInvalidInputException("Cannot do forward pass in Convolution layer (layer name = " + layerName2 + ", layer index = " + this.index + "): input array depth does not match CNN layer configuration (data input depth = " + this.input.size(1) + ", [minibatch,inputDepth,height,width]=" + Arrays.toString(this.input.shape()) + "; expected input depth = " + size3 + ") " + layerId());
        }
        int size4 = param.size(2);
        int size5 = param.size(3);
        int[] kernelSize = layerConf().getKernelSize();
        int[] stride = layerConf().getStride();
        if (this.convolutionMode == ConvolutionMode.Same) {
            outputSize = ConvolutionUtils.getOutputSize(this.input, kernelSize, stride, null, this.convolutionMode);
            padding = ConvolutionUtils.getSameModeTopLeftPadding(outputSize, new int[]{this.input.size(2), this.input.size(3)}, kernelSize, stride);
        } else {
            padding = layerConf().getPadding();
            outputSize = ConvolutionUtils.getOutputSize(this.input, kernelSize, stride, padding, this.convolutionMode);
        }
        int i = outputSize[0];
        int i2 = outputSize[1];
        if (this.helper != null) {
            if (this.preOutput != null && z2) {
                return new Pair<>(this.preOutput, null);
            }
            INDArray preOutput = this.helper.preOutput(this.input, param, param2, kernelSize, stride, padding, layerConf().getCudnnAlgoMode(), layerConf().getCudnnFwdAlgo(), this.convolutionMode);
            if (preOutput != null) {
                return new Pair<>(preOutput, null);
            }
        }
        if (this.preOutput != null && this.i2d != null && z2) {
            return new Pair<>(this.preOutput, this.i2d);
        }
        INDArray createUninitialized = Nd4j.createUninitialized(new int[]{size, i, i2, size3, size4, size5}, 'c');
        Convolution.im2col(this.input, size4, size5, stride[0], stride[1], padding[0], padding[1], this.convolutionMode == ConvolutionMode.Same, createUninitialized.permute(new int[]{0, 3, 4, 5, 1, 2}));
        INDArray newShapeNoCopy = Shape.newShapeNoCopy(createUninitialized, new int[]{size * i * i2, size3 * size4 * size5}, false);
        INDArray reshape = param.permute(new int[]{3, 2, 1, 0}).reshape('f', size5 * size4 * size3, size2);
        if (!Nd4j.getWorkspaceManager().checkIfWorkspaceExists(ComputationGraph.workspaceExternal) || Nd4j.getMemoryManager().getCurrentWorkspace() == Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceExternal)) {
            mmul = newShapeNoCopy.mmul(reshape);
        } else {
            MemoryWorkspace notifyScopeBorrowed = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceExternal).notifyScopeBorrowed();
            Throwable th = null;
            try {
                mmul = newShapeNoCopy.mmul(reshape);
                if (notifyScopeBorrowed != null) {
                    if (0 != 0) {
                        try {
                            notifyScopeBorrowed.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        notifyScopeBorrowed.close();
                    }
                }
            } catch (Throwable th3) {
                if (notifyScopeBorrowed != null) {
                    if (0 != 0) {
                        try {
                            notifyScopeBorrowed.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        notifyScopeBorrowed.close();
                    }
                }
                throw th3;
            }
        }
        mmul.addiRowVector(param2);
        INDArray permute = Shape.newShapeNoCopy(mmul, new int[]{i2, i, size, size2}, true).permute(new int[]{2, 3, 1, 0});
        if (this.cacheMode != CacheMode.NONE && Nd4j.getWorkspaceManager().checkIfWorkspaceExists(ComputationGraph.workspaceCache)) {
            MemoryWorkspace notifyScopeBorrowed2 = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceCache).notifyScopeBorrowed();
            Throwable th5 = null;
            try {
                try {
                    this.i2d = newShapeNoCopy.unsafeDuplication();
                    if (notifyScopeBorrowed2 != null) {
                        if (0 != 0) {
                            try {
                                notifyScopeBorrowed2.close();
                            } catch (Throwable th6) {
                                th5.addSuppressed(th6);
                            }
                        } else {
                            notifyScopeBorrowed2.close();
                        }
                    }
                } finally {
                }
            } catch (Throwable th7) {
                if (notifyScopeBorrowed2 != null) {
                    if (th5 != null) {
                        try {
                            notifyScopeBorrowed2.close();
                        } catch (Throwable th8) {
                            th5.addSuppressed(th8);
                        }
                    } else {
                        notifyScopeBorrowed2.close();
                    }
                }
                throw th7;
            }
        }
        return new Pair<>(permute, z2 ? newShapeNoCopy : null);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z) {
        INDArray activate;
        if (this.input == null) {
            throw new IllegalArgumentException("Cannot perform forward pass with null input " + layerId());
        }
        if (this.cacheMode == null) {
            this.cacheMode = CacheMode.NONE;
        }
        applyDropOutIfNecessary(z);
        INDArray preOutput = preOutput(z);
        if (z && this.cacheMode != CacheMode.NONE && Nd4j.getWorkspaceManager().checkIfWorkspaceExists(ComputationGraph.workspaceCache)) {
            MemoryWorkspace notifyScopeBorrowed = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceCache).notifyScopeBorrowed();
            Throwable th = null;
            try {
                try {
                    this.preOutput = preOutput.unsafeDuplication();
                    if (notifyScopeBorrowed != null) {
                        if (0 != 0) {
                            try {
                                notifyScopeBorrowed.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            notifyScopeBorrowed.close();
                        }
                    }
                } finally {
                }
            } catch (Throwable th3) {
                if (notifyScopeBorrowed != null) {
                    if (th != null) {
                        try {
                            notifyScopeBorrowed.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        notifyScopeBorrowed.close();
                    }
                }
                throw th3;
            }
        }
        return (this.helper == null || (activate = this.helper.activate(preOutput, layerConf().getActivationFn())) == null) ? layerConf().getActivationFn().getActivation(preOutput, z) : activate;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public Layer transpose() {
        throw new UnsupportedOperationException("Not supported - " + layerId());
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public boolean isPretrainLayer() {
        return false;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public Gradient calcGradient(Gradient gradient, INDArray iNDArray) {
        throw new UnsupportedOperationException("Not supported " + layerId());
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public void fit(INDArray iNDArray) {
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public void merge(Layer layer, int i) {
        throw new UnsupportedOperationException(layerId());
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model, org.deeplearning4j.nn.api.NeuralNetwork
    public INDArray params() {
        return Nd4j.toFlattened('c', this.params.values());
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public void setParams(INDArray iNDArray) {
        setParams(iNDArray, 'c');
    }
}
