package org.deeplearning4j.ui.module.histogram;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.function.Supplier;
import org.deeplearning4j.api.storage.Persistable;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.api.storage.StatsStorageEvent;
import org.deeplearning4j.ui.api.FunctionType;
import org.deeplearning4j.ui.api.HttpMethod;
import org.deeplearning4j.ui.api.Route;
import org.deeplearning4j.ui.api.UIModule;
import org.deeplearning4j.ui.module.train.TrainModule;
import org.deeplearning4j.ui.stats.api.StatsInitializationReport;
import org.deeplearning4j.ui.stats.api.StatsReport;
import org.deeplearning4j.ui.stats.api.StatsType;
import org.deeplearning4j.ui.stats.api.SummaryType;
import org.deeplearning4j.ui.views.html.histogram.Histogram;
import org.deeplearning4j.ui.weights.beans.CompactModelAndGradient;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import play.libs.Json;
import play.mvc.Result;
import play.mvc.Results;

/* loaded from: input_file:org/deeplearning4j/ui/module/histogram/HistogramModule.class */
public class HistogramModule implements UIModule {
    private static final Logger log = LoggerFactory.getLogger(HistogramModule.class);
    private Map<String, StatsStorage> knownSessionIDs = Collections.synchronizedMap(new LinkedHashMap());

    @Override // org.deeplearning4j.ui.api.UIModule
    public List<String> getCallbackTypeIDs() {
        return Collections.singletonList("StatsListener");
    }

    @Override // org.deeplearning4j.ui.api.UIModule
    public List<Route> getRoutes() {
        return Arrays.asList(new Route("/weights", HttpMethod.GET, FunctionType.Supplier, (Supplier<Result>) () -> {
            return Results.ok(Histogram.apply());
        }), new Route("/weights/listSessions", HttpMethod.GET, FunctionType.Supplier, (Supplier<Result>) () -> {
            return Results.ok(Json.toJson(this.knownSessionIDs.keySet()));
        }), new Route("/weights/updated/:sid", HttpMethod.GET, FunctionType.Function, (Function<String, Result>) this::getLastUpdateTime), new Route("/weights/data/:sid", HttpMethod.GET, FunctionType.Function, (Function<String, Result>) this::processRequest));
    }

    @Override // org.deeplearning4j.ui.api.UIModule
    public void reportStorageEvents(Collection<StatsStorageEvent> collection) {
        log.trace("Received events: {}", collection);
        for (StatsStorageEvent statsStorageEvent : collection) {
            if (!this.knownSessionIDs.containsKey(statsStorageEvent.getSessionID())) {
                this.knownSessionIDs.put(statsStorageEvent.getSessionID(), statsStorageEvent.getStatsStorage());
            }
        }
    }

    @Override // org.deeplearning4j.ui.api.UIModule
    public void onAttach(StatsStorage statsStorage) {
        for (String str : statsStorage.listSessionIDs()) {
            Iterator it = statsStorage.listTypeIDsForSession(str).iterator();
            while (it.hasNext()) {
                if ("StatsListener".equals((String) it.next())) {
                    this.knownSessionIDs.put(str, statsStorage);
                }
            }
        }
    }

    @Override // org.deeplearning4j.ui.api.UIModule
    public void onDetach(StatsStorage statsStorage) {
        Iterator it = statsStorage.listSessionIDs().iterator();
        while (it.hasNext()) {
            this.knownSessionIDs.remove((String) it.next());
        }
    }

    private Result getLastUpdateTime(String str) {
        return Results.ok(Json.toJson(Long.valueOf(System.currentTimeMillis())));
    }

    private Result processRequest(String str) {
        StatsStorage statsStorage = this.knownSessionIDs.get(str);
        if (statsStorage == null) {
            return Results.notFound("Unknown session ID: " + str);
        }
        List listWorkerIDsForSession = statsStorage.listWorkerIDsForSession(str);
        StatsInitializationReport staticInfo = statsStorage.getStaticInfo(str, "StatsListener", (String) listWorkerIDsForSession.get(0));
        if (staticInfo == null) {
            return Results.ok(Json.toJson(Collections.EMPTY_MAP));
        }
        String[] modelParamNames = staticInfo.getModelParamNames();
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        for (String str2 : modelParamNames) {
            String[] split = str2.split("_");
            if (!linkedHashSet.contains(split[0])) {
                linkedHashSet.add(split[0]);
            }
        }
        ArrayList arrayList = new ArrayList(linkedHashSet);
        List<Persistable> allUpdatesAfter = statsStorage.getAllUpdatesAfter(str, "StatsListener", (String) listWorkerIDsForSession.get(0), 0L);
        Collections.sort(allUpdatesAfter, (persistable, persistable2) -> {
            return Long.compare(persistable.getTimeStamp(), persistable2.getTimeStamp());
        });
        ArrayList arrayList2 = new ArrayList(allUpdatesAfter.size());
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        for (int i = 0; i < arrayList.size(); i++) {
            arrayList3.add(new HashMap());
            arrayList4.add(new HashMap());
        }
        StatsReport statsReport = null;
        for (Persistable persistable3 : allUpdatesAfter) {
            if (persistable3 instanceof StatsReport) {
                StatsReport statsReport2 = (StatsReport) persistable3;
                arrayList2.add(Double.valueOf(statsReport2.getScore()));
                if (statsReport2.hasSummaryStats(StatsType.Parameters, SummaryType.MeanMagnitudes)) {
                    updateMeanMagnitudeMaps(statsReport2.getMeanMagnitudes(StatsType.Parameters), arrayList, arrayList3);
                }
                if (statsReport2.hasSummaryStats(StatsType.Updates, SummaryType.MeanMagnitudes)) {
                    updateMeanMagnitudeMaps(statsReport2.getMeanMagnitudes(StatsType.Updates), arrayList, arrayList4);
                }
                statsReport = statsReport2;
            } else {
                log.debug("Encountered unexpected type: {}", persistable3);
            }
        }
        Map<String, Map> histogram = getHistogram(statsReport.getHistograms(StatsType.Parameters));
        Map<String, Map> histogram2 = getHistogram(statsReport.getHistograms(StatsType.Updates));
        double doubleValue = arrayList2.size() == 0 ? TrainModule.NAN_REPLACEMENT_VALUE : ((Double) arrayList2.get(arrayList2.size() - 1)).doubleValue();
        CompactModelAndGradient compactModelAndGradient = new CompactModelAndGradient();
        compactModelAndGradient.setGradients(histogram2);
        compactModelAndGradient.setParameters(histogram);
        compactModelAndGradient.setScore(doubleValue);
        compactModelAndGradient.setScores(arrayList2);
        compactModelAndGradient.setUpdateMagnitudes(arrayList4);
        compactModelAndGradient.setParamMagnitudes(arrayList3);
        compactModelAndGradient.setLastUpdateTime(statsReport.getTimeStamp());
        return Results.ok(Json.toJson(compactModelAndGradient));
    }

    private void updateMeanMagnitudeMaps(Map<String, Double> map, List<String> list, List<Map<String, List<Double>>> list2) {
        for (Map.Entry<String, Double> entry : map.entrySet()) {
            String key = entry.getKey();
            Map<String, List<Double>> map2 = list2.get(list.indexOf(key.split("_")[0]));
            List<Double> list3 = map2.get(key);
            if (list3 == null) {
                list3 = new ArrayList();
                map2.put(key, list3);
            }
            list3.add(entry.getValue());
        }
    }

    private Map<String, Map> getHistogram(Map<String, org.deeplearning4j.ui.stats.api.Histogram> map) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (String str : map.keySet()) {
            org.deeplearning4j.ui.stats.api.Histogram histogram = map.get(str);
            String str2 = Character.isDigit(str.charAt(0)) ? "param_" + str : str;
            LinkedHashMap linkedHashMap2 = new LinkedHashMap();
            double min = histogram.getMin();
            double max = histogram.getMax();
            int nBins = histogram.getNBins();
            double d = (max - min) / nBins;
            int[] binCounts = histogram.getBinCounts();
            for (int i = 0; i < nBins; i++) {
                linkedHashMap2.put(Double.valueOf(min + (i * d) + (d / 2.0d)), Integer.valueOf(binCounts[i]));
            }
            linkedHashMap.put(str2, linkedHashMap2);
        }
        return linkedHashMap;
    }
}
