package org.nd4j.parameterserver.client;

import com.mashape.unirest.http.JsonNode;
import com.mashape.unirest.http.Unirest;
import io.aeron.Aeron;
import java.beans.ConstructorProperties;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import org.nd4j.aeron.ipc.AeronConnectionInformation;
import org.nd4j.aeron.ipc.AeronNDArrayPublisher;
import org.nd4j.aeron.ipc.AeronNDArraySubscriber;
import org.nd4j.aeron.ipc.AeronUtil;
import org.nd4j.aeron.ipc.NDArrayCallback;
import org.nd4j.aeron.ipc.NDArrayMessage;
import org.nd4j.aeron.ipc.response.HostPortPublisher;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.parameterserver.model.MasterStatus;
import org.nd4j.parameterserver.model.ServerTypeJson;
import org.nd4j.parameterserver.model.SubscriberState;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/parameterserver/client/ParameterServerClient.class */
public class ParameterServerClient implements NDArrayCallback {
    private static final Logger log = LoggerFactory.getLogger(ParameterServerClient.class);
    private String ndarraySendUrl;
    private String ndarrayRetrieveUrl;
    private AeronNDArraySubscriber subscriber;
    private String subscriberHost;
    private int subscriberPort;
    private int subscriberStream;
    private AtomicReference<INDArray> arr;
    private INDArray none;
    private AtomicBoolean running;
    private String masterStatusHost;
    private int masterStatusPort;
    private ObjectMapper objectMapper;
    private Aeron aeron;
    private boolean compressArray;

    /* loaded from: input_file:org/nd4j/parameterserver/client/ParameterServerClient$ParameterServerClientBuilder.class */
    public static class ParameterServerClientBuilder {
        private String ndarraySendUrl;
        private String ndarrayRetrieveUrl;
        private AeronNDArraySubscriber subscriber;
        private String subscriberHost;
        private int subscriberPort;
        private int subscriberStream;
        private AtomicReference<INDArray> arr;
        private INDArray none;
        private AtomicBoolean running;
        private String masterStatusHost;
        private int masterStatusPort;
        private ObjectMapper objectMapper;
        private Aeron aeron;
        private boolean compressArray;

        ParameterServerClientBuilder() {
        }

        public ParameterServerClientBuilder ndarraySendUrl(String str) {
            this.ndarraySendUrl = str;
            return this;
        }

        public ParameterServerClientBuilder ndarrayRetrieveUrl(String str) {
            this.ndarrayRetrieveUrl = str;
            return this;
        }

        public ParameterServerClientBuilder subscriber(AeronNDArraySubscriber aeronNDArraySubscriber) {
            this.subscriber = aeronNDArraySubscriber;
            return this;
        }

        public ParameterServerClientBuilder subscriberHost(String str) {
            this.subscriberHost = str;
            return this;
        }

        public ParameterServerClientBuilder subscriberPort(int i) {
            this.subscriberPort = i;
            return this;
        }

        public ParameterServerClientBuilder subscriberStream(int i) {
            this.subscriberStream = i;
            return this;
        }

        public ParameterServerClientBuilder arr(AtomicReference<INDArray> atomicReference) {
            this.arr = atomicReference;
            return this;
        }

        public ParameterServerClientBuilder none(INDArray iNDArray) {
            this.none = iNDArray;
            return this;
        }

        public ParameterServerClientBuilder running(AtomicBoolean atomicBoolean) {
            this.running = atomicBoolean;
            return this;
        }

        public ParameterServerClientBuilder masterStatusHost(String str) {
            this.masterStatusHost = str;
            return this;
        }

        public ParameterServerClientBuilder masterStatusPort(int i) {
            this.masterStatusPort = i;
            return this;
        }

        public ParameterServerClientBuilder objectMapper(ObjectMapper objectMapper) {
            this.objectMapper = objectMapper;
            return this;
        }

        public ParameterServerClientBuilder aeron(Aeron aeron) {
            this.aeron = aeron;
            return this;
        }

        public ParameterServerClientBuilder compressArray(boolean z) {
            this.compressArray = z;
            return this;
        }

        public ParameterServerClient build() {
            return new ParameterServerClient(this.ndarraySendUrl, this.ndarrayRetrieveUrl, this.subscriber, this.subscriberHost, this.subscriberPort, this.subscriberStream, this.arr, this.none, this.running, this.masterStatusHost, this.masterStatusPort, this.objectMapper, this.aeron, this.compressArray);
        }

        public String toString() {
            return "ParameterServerClient.ParameterServerClientBuilder(ndarraySendUrl=" + this.ndarraySendUrl + ", ndarrayRetrieveUrl=" + this.ndarrayRetrieveUrl + ", subscriber=" + this.subscriber + ", subscriberHost=" + this.subscriberHost + ", subscriberPort=" + this.subscriberPort + ", subscriberStream=" + this.subscriberStream + ", arr=" + this.arr + ", none=" + this.none + ", running=" + this.running + ", masterStatusHost=" + this.masterStatusHost + ", masterStatusPort=" + this.masterStatusPort + ", objectMapper=" + this.objectMapper + ", aeron=" + this.aeron + ", compressArray=" + this.compressArray + ")";
        }
    }

    public int arraysSentToResponder() {
        if (this.objectMapper == null) {
            this.objectMapper = new ObjectMapper();
        }
        try {
            String type = ((ServerTypeJson) this.objectMapper.readValue(((JsonNode) Unirest.get(String.format("http://%s:%d/type", this.masterStatusHost, Integer.valueOf(this.masterStatusPort))).asJson().getBody()).toString(), ServerTypeJson.class)).getType();
            if (!type.equals("master")) {
                throw new IllegalStateException("Wrong type " + type);
            }
            Unirest.get(String.format("http://%s:%d/started", this.masterStatusHost, Integer.valueOf(this.masterStatusPort))).asJson().getBody();
            return ((MasterStatus) this.objectMapper.readValue(((JsonNode) Unirest.get(String.format("http://%s:%d/started", this.masterStatusHost, Integer.valueOf(this.masterStatusPort))).asJson().getBody()).toString(), MasterStatus.class)).getResponderN();
        } catch (Exception e) {
            e.printStackTrace();
            return 0;
        }
    }

    public void blockTillReady() {
        while (!isReadyForNext()) {
            try {
                Thread.sleep(1000L);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
        }
    }

    public boolean isReadyForNext() {
        if (this.objectMapper == null) {
            this.objectMapper = new ObjectMapper();
        }
        try {
            return ((SubscriberState) this.objectMapper.readValue(((JsonNode) Unirest.get(String.format("http://%s:%d/state/%d", this.masterStatusHost, Integer.valueOf(this.masterStatusPort), Integer.valueOf(Integer.parseInt(this.ndarraySendUrl.split(":")[2])))).asJson().getBody()).toString(), SubscriberState.class)).isReady();
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    public boolean masterStarted() {
        if (this.objectMapper == null) {
            this.objectMapper = new ObjectMapper();
        }
        try {
            String type = ((ServerTypeJson) this.objectMapper.readValue(((JsonNode) Unirest.get(String.format("http://%s:%d/type", this.masterStatusHost, Integer.valueOf(this.masterStatusPort))).asJson().getBody()).toString(), ServerTypeJson.class)).getType();
            if (!type.equals("master")) {
                throw new IllegalStateException("Wrong type " + type);
            }
            Unirest.get(String.format("http://%s:%d/started", this.masterStatusHost, Integer.valueOf(this.masterStatusPort))).asJson().getBody();
            return ((MasterStatus) this.objectMapper.readValue(((JsonNode) Unirest.get(String.format("http://%s:%d/started", this.masterStatusHost, Integer.valueOf(this.masterStatusPort))).asJson().getBody()).toString(), MasterStatus.class)).started();
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    public void pushNDArrayMessage(NDArrayMessage nDArrayMessage) {
        if (this.subscriber == null) {
            this.running = new AtomicBoolean(true);
            this.subscriber = AeronNDArraySubscriber.startSubscriber(this.aeron, this.subscriberHost, this.subscriberPort, this, this.subscriberStream, this.running);
            log.debug("Started parameter server client on " + this.subscriber.connectionUrl());
        }
        String[] split = this.ndarraySendUrl.split(":");
        int parseInt = Integer.parseInt(split[1]);
        int parseInt2 = Integer.parseInt(split[2]);
        String aeronChannel = AeronUtil.aeronChannel(split[0], parseInt);
        log.debug("Parameter server client publishing to " + this.ndarraySendUrl);
        try {
            AeronNDArrayPublisher build = AeronNDArrayPublisher.builder().streamId(parseInt2).compress(isCompressArray()).aeron(this.aeron).channel(aeronChannel).build();
            Throwable th = null;
            try {
                try {
                    build.publish(nDArrayMessage);
                    if (build != null) {
                        if (0 != 0) {
                            try {
                                build.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            build.close();
                        }
                    }
                } finally {
                }
            } finally {
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public void pushNDArray(INDArray iNDArray) {
        pushNDArrayMessage(NDArrayMessage.wholeArrayUpdate(iNDArray));
    }

    public String connectionUrl() {
        return AeronConnectionInformation.of(this.subscriberHost, this.subscriberPort, this.subscriberStream).toString();
    }

    public INDArray getArray() {
        if (this.subscriber == null) {
            this.running = new AtomicBoolean(true);
            this.subscriber = AeronNDArraySubscriber.startSubscriber(this.aeron, this.subscriberHost, this.subscriberPort, this, this.subscriberStream, this.running);
            log.debug("Started parameter server client on " + this.subscriber.connectionUrl());
        }
        if (this.arr == null) {
            this.arr = new AtomicReference<>(this.none);
        }
        log.debug("Parameter server client retrieving url from " + this.ndarrayRetrieveUrl);
        String[] split = this.ndarrayRetrieveUrl.split(":");
        try {
            HostPortPublisher build = HostPortPublisher.builder().channel(AeronUtil.aeronChannel(split[0], Integer.parseInt(split[1]))).aeron(this.aeron).streamId(Integer.parseInt(split[2])).uriToSend(AeronConnectionInformation.of(this.subscriberHost, this.subscriberPort, this.subscriberStream).toString()).build();
            Throwable th = null;
            try {
                try {
                    build.send();
                    log.debug("Sent subscriber information " + AeronConnectionInformation.of(this.subscriberHost, this.subscriberPort, this.subscriberStream).toString());
                    while (this.arr.get() == this.none) {
                        Thread.sleep(1000L);
                        log.info("Waiting on array to be updated.");
                    }
                    if (build != null) {
                        if (0 != 0) {
                            try {
                                build.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            build.close();
                        }
                    }
                } finally {
                }
            } finally {
            }
        } catch (Exception e) {
            log.error("Error with publishing", e);
        }
        INDArray iNDArray = this.arr.get();
        this.arr.set(this.none);
        return iNDArray;
    }

    public void onNDArrayMessage(NDArrayMessage nDArrayMessage) {
        INDArray arr = nDArrayMessage.getArr();
        int[] dimensions = nDArrayMessage.getDimensions();
        if (dimensions.length == 1 && dimensions[0] == -1) {
            onNDArray(arr);
        } else {
            onNDArrayPartial(arr, nDArrayMessage.getIndex(), dimensions);
        }
    }

    public void onNDArrayPartial(INDArray iNDArray, long j, int... iArr) {
        this.arr.get().tensorAlongDimension((int) j, iArr).assign(iNDArray);
    }

    public void onNDArray(INDArray iNDArray) {
        log.info("Received array");
        this.arr.set(iNDArray);
    }

    public static ParameterServerClientBuilder builder() {
        return new ParameterServerClientBuilder();
    }

    public String getNdarraySendUrl() {
        return this.ndarraySendUrl;
    }

    public String getNdarrayRetrieveUrl() {
        return this.ndarrayRetrieveUrl;
    }

    public AeronNDArraySubscriber getSubscriber() {
        return this.subscriber;
    }

    public String getSubscriberHost() {
        return this.subscriberHost;
    }

    public int getSubscriberPort() {
        return this.subscriberPort;
    }

    public int getSubscriberStream() {
        return this.subscriberStream;
    }

    public AtomicReference<INDArray> getArr() {
        return this.arr;
    }

    public INDArray getNone() {
        return this.none;
    }

    public AtomicBoolean getRunning() {
        return this.running;
    }

    public String getMasterStatusHost() {
        return this.masterStatusHost;
    }

    public int getMasterStatusPort() {
        return this.masterStatusPort;
    }

    public ObjectMapper getObjectMapper() {
        return this.objectMapper;
    }

    public Aeron getAeron() {
        return this.aeron;
    }

    public boolean isCompressArray() {
        return this.compressArray;
    }

    public void setNdarraySendUrl(String str) {
        this.ndarraySendUrl = str;
    }

    public void setNdarrayRetrieveUrl(String str) {
        this.ndarrayRetrieveUrl = str;
    }

    public void setSubscriber(AeronNDArraySubscriber aeronNDArraySubscriber) {
        this.subscriber = aeronNDArraySubscriber;
    }

    public void setSubscriberHost(String str) {
        this.subscriberHost = str;
    }

    public void setSubscriberPort(int i) {
        this.subscriberPort = i;
    }

    public void setSubscriberStream(int i) {
        this.subscriberStream = i;
    }

    public void setArr(AtomicReference<INDArray> atomicReference) {
        this.arr = atomicReference;
    }

    public void setNone(INDArray iNDArray) {
        this.none = iNDArray;
    }

    public void setRunning(AtomicBoolean atomicBoolean) {
        this.running = atomicBoolean;
    }

    public void setMasterStatusHost(String str) {
        this.masterStatusHost = str;
    }

    public void setMasterStatusPort(int i) {
        this.masterStatusPort = i;
    }

    public void setObjectMapper(ObjectMapper objectMapper) {
        this.objectMapper = objectMapper;
    }

    public void setAeron(Aeron aeron) {
        this.aeron = aeron;
    }

    public void setCompressArray(boolean z) {
        this.compressArray = z;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof ParameterServerClient)) {
            return false;
        }
        ParameterServerClient parameterServerClient = (ParameterServerClient) obj;
        if (!parameterServerClient.canEqual(this)) {
            return false;
        }
        String ndarraySendUrl = getNdarraySendUrl();
        String ndarraySendUrl2 = parameterServerClient.getNdarraySendUrl();
        if (ndarraySendUrl == null) {
            if (ndarraySendUrl2 != null) {
                return false;
            }
        } else if (!ndarraySendUrl.equals(ndarraySendUrl2)) {
            return false;
        }
        String ndarrayRetrieveUrl = getNdarrayRetrieveUrl();
        String ndarrayRetrieveUrl2 = parameterServerClient.getNdarrayRetrieveUrl();
        if (ndarrayRetrieveUrl == null) {
            if (ndarrayRetrieveUrl2 != null) {
                return false;
            }
        } else if (!ndarrayRetrieveUrl.equals(ndarrayRetrieveUrl2)) {
            return false;
        }
        AeronNDArraySubscriber subscriber = getSubscriber();
        AeronNDArraySubscriber subscriber2 = parameterServerClient.getSubscriber();
        if (subscriber == null) {
            if (subscriber2 != null) {
                return false;
            }
        } else if (!subscriber.equals(subscriber2)) {
            return false;
        }
        String subscriberHost = getSubscriberHost();
        String subscriberHost2 = parameterServerClient.getSubscriberHost();
        if (subscriberHost == null) {
            if (subscriberHost2 != null) {
                return false;
            }
        } else if (!subscriberHost.equals(subscriberHost2)) {
            return false;
        }
        if (getSubscriberPort() != parameterServerClient.getSubscriberPort() || getSubscriberStream() != parameterServerClient.getSubscriberStream()) {
            return false;
        }
        AtomicReference<INDArray> arr = getArr();
        AtomicReference<INDArray> arr2 = parameterServerClient.getArr();
        if (arr == null) {
            if (arr2 != null) {
                return false;
            }
        } else if (!arr.equals(arr2)) {
            return false;
        }
        INDArray none = getNone();
        INDArray none2 = parameterServerClient.getNone();
        if (none == null) {
            if (none2 != null) {
                return false;
            }
        } else if (!none.equals(none2)) {
            return false;
        }
        AtomicBoolean running = getRunning();
        AtomicBoolean running2 = parameterServerClient.getRunning();
        if (running == null) {
            if (running2 != null) {
                return false;
            }
        } else if (!running.equals(running2)) {
            return false;
        }
        String masterStatusHost = getMasterStatusHost();
        String masterStatusHost2 = parameterServerClient.getMasterStatusHost();
        if (masterStatusHost == null) {
            if (masterStatusHost2 != null) {
                return false;
            }
        } else if (!masterStatusHost.equals(masterStatusHost2)) {
            return false;
        }
        if (getMasterStatusPort() != parameterServerClient.getMasterStatusPort()) {
            return false;
        }
        ObjectMapper objectMapper = getObjectMapper();
        ObjectMapper objectMapper2 = parameterServerClient.getObjectMapper();
        if (objectMapper == null) {
            if (objectMapper2 != null) {
                return false;
            }
        } else if (!objectMapper.equals(objectMapper2)) {
            return false;
        }
        Aeron aeron = getAeron();
        Aeron aeron2 = parameterServerClient.getAeron();
        if (aeron == null) {
            if (aeron2 != null) {
                return false;
            }
        } else if (!aeron.equals(aeron2)) {
            return false;
        }
        return isCompressArray() == parameterServerClient.isCompressArray();
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof ParameterServerClient;
    }

    public int hashCode() {
        String ndarraySendUrl = getNdarraySendUrl();
        int hashCode = (1 * 59) + (ndarraySendUrl == null ? 43 : ndarraySendUrl.hashCode());
        String ndarrayRetrieveUrl = getNdarrayRetrieveUrl();
        int hashCode2 = (hashCode * 59) + (ndarrayRetrieveUrl == null ? 43 : ndarrayRetrieveUrl.hashCode());
        AeronNDArraySubscriber subscriber = getSubscriber();
        int hashCode3 = (hashCode2 * 59) + (subscriber == null ? 43 : subscriber.hashCode());
        String subscriberHost = getSubscriberHost();
        int hashCode4 = (((((hashCode3 * 59) + (subscriberHost == null ? 43 : subscriberHost.hashCode())) * 59) + getSubscriberPort()) * 59) + getSubscriberStream();
        AtomicReference<INDArray> arr = getArr();
        int hashCode5 = (hashCode4 * 59) + (arr == null ? 43 : arr.hashCode());
        INDArray none = getNone();
        int hashCode6 = (hashCode5 * 59) + (none == null ? 43 : none.hashCode());
        AtomicBoolean running = getRunning();
        int hashCode7 = (hashCode6 * 59) + (running == null ? 43 : running.hashCode());
        String masterStatusHost = getMasterStatusHost();
        int hashCode8 = (((hashCode7 * 59) + (masterStatusHost == null ? 43 : masterStatusHost.hashCode())) * 59) + getMasterStatusPort();
        ObjectMapper objectMapper = getObjectMapper();
        int hashCode9 = (hashCode8 * 59) + (objectMapper == null ? 43 : objectMapper.hashCode());
        Aeron aeron = getAeron();
        return (((hashCode9 * 59) + (aeron == null ? 43 : aeron.hashCode())) * 59) + (isCompressArray() ? 79 : 97);
    }

    public String toString() {
        return "ParameterServerClient(ndarraySendUrl=" + getNdarraySendUrl() + ", ndarrayRetrieveUrl=" + getNdarrayRetrieveUrl() + ", subscriber=" + getSubscriber() + ", subscriberHost=" + getSubscriberHost() + ", subscriberPort=" + getSubscriberPort() + ", subscriberStream=" + getSubscriberStream() + ", arr=" + getArr() + ", none=" + getNone() + ", running=" + getRunning() + ", masterStatusHost=" + getMasterStatusHost() + ", masterStatusPort=" + getMasterStatusPort() + ", objectMapper=" + getObjectMapper() + ", aeron=" + getAeron() + ", compressArray=" + isCompressArray() + ")";
    }

    @ConstructorProperties({"ndarraySendUrl", "ndarrayRetrieveUrl", "subscriber", "subscriberHost", "subscriberPort", "subscriberStream", "arr", "none", "running", "masterStatusHost", "masterStatusPort", "objectMapper", "aeron", "compressArray"})
    public ParameterServerClient(String str, String str2, AeronNDArraySubscriber aeronNDArraySubscriber, String str3, int i, int i2, AtomicReference<INDArray> atomicReference, INDArray iNDArray, AtomicBoolean atomicBoolean, String str4, int i3, ObjectMapper objectMapper, Aeron aeron, boolean z) {
        this.subscriberStream = 11;
        this.none = Nd4j.scalar(1.0d);
        this.objectMapper = new ObjectMapper();
        this.compressArray = true;
        this.ndarraySendUrl = str;
        this.ndarrayRetrieveUrl = str2;
        this.subscriber = aeronNDArraySubscriber;
        this.subscriberHost = str3;
        this.subscriberPort = i;
        this.subscriberStream = i2;
        this.arr = atomicReference;
        this.none = iNDArray;
        this.running = atomicBoolean;
        this.masterStatusHost = str4;
        this.masterStatusPort = i3;
        this.objectMapper = objectMapper;
        this.aeron = aeron;
        this.compressArray = z;
    }
}
