/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.rest;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.neuralsearch.settings.NeuralSearchSettingsAccessor;
import org.opensearch.neuralsearch.stats.NeuralStatsInput;
import org.opensearch.neuralsearch.stats.events.EventStatName;
import org.opensearch.neuralsearch.stats.info.InfoStatName;
import org.opensearch.neuralsearch.transport.NeuralStatsAction;
import org.opensearch.neuralsearch.transport.NeuralStatsRequest;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.RestResponse;
import org.opensearch.rest.action.RestActions;
import org.opensearch.transport.client.node.NodeClient;

public class RestNeuralStatsAction
extends BaseRestHandler {
    @Generated
    private static final Logger log = LogManager.getLogger(RestNeuralStatsAction.class);
    public static final String FLATTEN_PARAM = "flat_stat_paths";
    public static final String INCLUDE_METADATA_PARAM = "include_metadata";
    public static final String PARAM_REGEX = "^[A-Za-z0-9-_]+$";
    public static final int MAX_PARAM_LENGTH = 255;
    private static final String NAME = "neural_stats_action";
    private static final Set<String> EVENT_STAT_NAMES = EnumSet.allOf(EventStatName.class).stream().map(EventStatName::getNameString).map(str -> str.toLowerCase(Locale.ROOT)).collect(Collectors.toSet());
    private static final Set<String> STATE_STAT_NAMES = EnumSet.allOf(InfoStatName.class).stream().map(InfoStatName::getNameString).map(str -> str.toLowerCase(Locale.ROOT)).collect(Collectors.toSet());
    private static final List<RestHandler.Route> ROUTES = ImmutableList.of((Object)new RestHandler.Route(RestRequest.Method.GET, "/_plugins/_neural/{nodeId}/stats/"), (Object)new RestHandler.Route(RestRequest.Method.GET, "/_plugins/_neural/{nodeId}/stats/{stat}"), (Object)new RestHandler.Route(RestRequest.Method.GET, "/_plugins/_neural/stats/"), (Object)new RestHandler.Route(RestRequest.Method.GET, "/_plugins/_neural/stats/{stat}"));
    private static final Set<String> RESPONSE_PARAMS = ImmutableSet.of((Object)"nodeId", (Object)"stat", (Object)"include_metadata", (Object)"flat_stat_paths");
    private NeuralSearchSettingsAccessor settingsAccessor;

    public static boolean isValidParamString(String param) {
        return param.matches(PARAM_REGEX) && param.length() < 255;
    }

    public String getName() {
        return NAME;
    }

    public List<RestHandler.Route> routes() {
        return ROUTES;
    }

    protected Set<String> responseParams() {
        return RESPONSE_PARAMS;
    }

    protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) {
        if (!this.settingsAccessor.isStatsEnabled()) {
            return channel -> channel.sendResponse((RestResponse)new BytesRestResponse(RestStatus.FORBIDDEN, "Stats endpoint is disabled"));
        }
        NeuralStatsRequest neuralStatsRequest = this.createNeuralStatsRequest(request);
        return channel -> client.execute((ActionType)NeuralStatsAction.INSTANCE, (ActionRequest)neuralStatsRequest, (ActionListener)new RestActions.NodesResponseRestListener(channel));
    }

    private NeuralStatsRequest createNeuralStatsRequest(RestRequest request) {
        NeuralStatsInput neuralStatsInput = this.createNeuralStatsInputFromRequestParams(request);
        String[] nodeIdsArr = neuralStatsInput.getNodeIds().toArray(new String[0]);
        NeuralStatsRequest neuralStatsRequest = new NeuralStatsRequest(nodeIdsArr, neuralStatsInput);
        neuralStatsRequest.timeout(request.param("timeout"));
        return neuralStatsRequest;
    }

    private NeuralStatsInput createNeuralStatsInputFromRequestParams(RestRequest request) {
        NeuralStatsInput neuralStatsInput = new NeuralStatsInput();
        Optional<String[]> nodeIds = this.splitCommaSeparatedParam(request, "nodeId");
        if (nodeIds.isPresent()) {
            List<String> validFormatNodeIds = Arrays.stream(nodeIds.get()).filter(this::isValidNodeId).toList();
            neuralStatsInput.getNodeIds().addAll(validFormatNodeIds);
        }
        boolean flatten = request.paramAsBoolean(FLATTEN_PARAM, false);
        neuralStatsInput.setFlatten(flatten);
        boolean includeMetadata = request.paramAsBoolean(INCLUDE_METADATA_PARAM, false);
        neuralStatsInput.setIncludeMetadata(includeMetadata);
        Optional<String[]> stats = this.splitCommaSeparatedParam(request, "stat");
        if (!stats.isPresent()) {
            this.addAllStats(neuralStatsInput);
            return neuralStatsInput;
        }
        boolean anyStatAdded = this.processRequestedStats(stats.get(), neuralStatsInput);
        if (!anyStatAdded) {
            this.addAllStats(neuralStatsInput);
        }
        return neuralStatsInput;
    }

    private boolean processRequestedStats(String[] stats, NeuralStatsInput neuralStatsInput) {
        boolean statAdded = false;
        for (String stat : stats) {
            String normalizedStat = stat.toLowerCase(Locale.ROOT);
            if (!RestNeuralStatsAction.isValidParamString(normalizedStat)) {
                log.info("Invalid stat name parameter format: {}", (Object)normalizedStat);
                continue;
            }
            if (EVENT_STAT_NAMES.contains(normalizedStat)) {
                neuralStatsInput.getEventStatNames().add(EventStatName.from(normalizedStat));
                statAdded = true;
            } else if (STATE_STAT_NAMES.contains(normalizedStat)) {
                neuralStatsInput.getInfoStatNames().add(InfoStatName.from(normalizedStat));
                statAdded = true;
            }
            log.info("Non-existent stat name parsed: {}", (Object)normalizedStat);
        }
        return statAdded;
    }

    private void addAllStats(NeuralStatsInput neuralStatsInput) {
        neuralStatsInput.getEventStatNames().addAll(EnumSet.allOf(EventStatName.class));
        neuralStatsInput.getInfoStatNames().addAll(EnumSet.allOf(InfoStatName.class));
    }

    private Optional<String[]> splitCommaSeparatedParam(RestRequest request, String paramName) {
        return Optional.ofNullable(request.param(paramName)).map(s -> s.split(","));
    }

    private boolean isValidNodeId(String nodeId) {
        return RestNeuralStatsAction.isValidParamString(nodeId) && nodeId.length() == 22;
    }

    @Generated
    public RestNeuralStatsAction(NeuralSearchSettingsAccessor settingsAccessor) {
        this.settingsAccessor = settingsAccessor;
    }
}

