/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.forecast.transport;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.ResourceNotFoundException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.GroupedActionListener;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.cluster.routing.Preference;
import org.opensearch.common.inject.Inject;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.core.xcontent.XContentParserUtils;
import org.opensearch.forecast.model.FilterBy;
import org.opensearch.forecast.model.ForecastResult;
import org.opensearch.forecast.model.ForecastResultBucket;
import org.opensearch.forecast.model.ForecastTask;
import org.opensearch.forecast.model.Forecaster;
import org.opensearch.forecast.model.Order;
import org.opensearch.forecast.model.Subaggregation;
import org.opensearch.forecast.transport.BuildInQuery;
import org.opensearch.forecast.transport.GetForecasterAction;
import org.opensearch.forecast.transport.RelationalOperation;
import org.opensearch.forecast.transport.SearchTopForecastResultAction;
import org.opensearch.forecast.transport.SearchTopForecastResultRequest;
import org.opensearch.forecast.transport.SearchTopForecastResultResponse;
import org.opensearch.forecast.transport.handler.ForecastSearchHandler;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.ExistsQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.RangeQueryBuilder;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.search.SearchHit;
import org.opensearch.search.aggregations.Aggregation;
import org.opensearch.search.aggregations.AggregationBuilder;
import org.opensearch.search.aggregations.AggregationBuilders;
import org.opensearch.search.aggregations.Aggregations;
import org.opensearch.search.aggregations.AggregatorFactories;
import org.opensearch.search.aggregations.BucketOrder;
import org.opensearch.search.aggregations.bucket.MultiBucketsAggregation;
import org.opensearch.search.aggregations.bucket.terms.Terms;
import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.opensearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.opensearch.search.aggregations.support.ValuesSourceAggregationBuilder;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.tasks.Task;
import org.opensearch.timeseries.model.Entity;
import org.opensearch.timeseries.transport.GetConfigRequest;
import org.opensearch.timeseries.util.ParseUtils;
import org.opensearch.timeseries.util.QueryUtil;
import org.opensearch.timeseries.util.RestHandlerUtils;
import org.opensearch.transport.TransportService;
import org.opensearch.transport.client.Client;

public class SearchTopForecastResultTransportAction
extends HandledTransportAction<SearchTopForecastResultRequest, SearchTopForecastResultResponse> {
    private static final Logger logger = LogManager.getLogger(SearchTopForecastResultTransportAction.class);
    private ForecastSearchHandler searchHandler;
    private static final String defaultIndex = "opensearch-forecast-results*";
    private static final int DEFAULT_SIZE = 5;
    private static final int MAX_SIZE = 50;
    protected static final String AGG_NAME_TERM = "term_agg";
    private final Client client;
    private NamedXContentRegistry xContent;

    @Inject
    public SearchTopForecastResultTransportAction(TransportService transportService, ActionFilters actionFilters, ForecastSearchHandler searchHandler, Client client, NamedXContentRegistry xContent) {
        super(SearchTopForecastResultAction.NAME, transportService, actionFilters, SearchTopForecastResultRequest::new);
        this.searchHandler = searchHandler;
        this.client = client;
        this.xContent = xContent;
    }

    protected void doExecute(Task task, SearchTopForecastResultRequest request, ActionListener<SearchTopForecastResultResponse> listener) {
        GetConfigRequest getForecasterRequest = new GetConfigRequest(request.getForecasterId(), -3L, false, true, "", "", false, null);
        this.client.execute((ActionType)GetForecasterAction.INSTANCE, (ActionRequest)getForecasterRequest, ActionListener.wrap(getForecasterResponse -> {
            if (getForecasterResponse.getForecaster() == null) {
                throw new IllegalArgumentException(String.format(Locale.ROOT, "No forecaster found with ID %s", request.getForecasterId()));
            }
            Forecaster forecaster = getForecasterResponse.getForecaster();
            List<String> categoryFields = forecaster.getCategoryFields();
            if (categoryFields == null || categoryFields.isEmpty()) {
                throw new IllegalArgumentException(String.format(Locale.ROOT, "No category fields found for forecaster ID %s", request.getForecasterId()));
            }
            if (request.getSplitBy() == null || request.getSplitBy().isEmpty()) {
                request.setSplitBy(categoryFields);
            } else {
                for (String categoryField : request.getSplitBy()) {
                    if (categoryFields.contains(categoryField)) continue;
                    throw new IllegalArgumentException(String.format(Locale.ROOT, "Category field %s doesn't exist for forecaster ID %s", categoryField, request.getForecasterId()));
                }
            }
            if (request.isRunOnce() && Strings.isNullOrEmpty((String)request.getTaskId())) {
                ForecastTask runOnceTask = getForecasterResponse.getRunOnceTask();
                if (runOnceTask == null) {
                    throw new ResourceNotFoundException(String.format(Locale.ROOT, "No latest run once tasks found for forecaster ID %s", request.getForecasterId()), new Object[0]);
                }
                request.setTaskId(runOnceTask.getTaskId());
            }
            if (request.getSize() == null) {
                request.setSize(5);
            } else {
                if (request.getSize() > 50) {
                    throw new IllegalArgumentException("Size cannot exceed 50");
                }
                if (request.getSize() <= 0) {
                    throw new IllegalArgumentException("Size must be a positive integer");
                }
            }
            SearchRequest searchRequest = this.generateQuery(request, forecaster);
            if (!Strings.isNullOrEmpty((String)forecaster.getCustomResultIndexPattern())) {
                searchRequest.indices(new String[]{forecaster.getCustomResultIndexPattern()});
            }
            this.searchHandler.search(searchRequest, this.onSearchResponse(request, categoryFields, forecaster, listener));
        }, exception -> {
            logger.error("Failed to get top forecast results", (Throwable)exception);
            listener.onFailure(exception);
        }));
    }

    private ActionListener<SearchResponse> onSearchResponse(SearchTopForecastResultRequest request, List<String> categoryFields, Forecaster forecaster, ActionListener<SearchTopForecastResultResponse> listener) {
        return ActionListener.wrap(response -> {
            Aggregations aggs = response.getAggregations();
            if (aggs == null) {
                listener.onResponse((Object)new SearchTopForecastResultResponse(new ArrayList<ForecastResultBucket>()));
                return;
            }
            Aggregation aggResults = aggs.get(AGG_NAME_TERM);
            if (aggResults == null) {
                listener.onResponse((Object)new SearchTopForecastResultResponse(new ArrayList<ForecastResultBucket>()));
                return;
            }
            List buckets = ((MultiBucketsAggregation)aggResults).getBuckets();
            if (buckets == null || buckets.size() == 0) {
                listener.onFailure((Exception)new ResourceNotFoundException("No forecast value found. forecast_from timestamp or other parameters might be incorrect.", new Object[0]));
                return;
            }
            GroupedActionListener groupListeneer = new GroupedActionListener(ActionListener.wrap(r -> {
                List sortedList = r.stream().sorted((a, b) -> Integer.compare(a.getBucketIndex(), b.getBucketIndex())).collect(Collectors.toList());
                listener.onResponse((Object)new SearchTopForecastResultResponse(new ArrayList<ForecastResultBucket>(sortedList)));
            }, exception -> {
                logger.warn("Failed to find valid aggregation result", (Throwable)exception);
                listener.onFailure((Exception)new OpenSearchStatusException("Failed to find valid aggregation result", RestStatus.INTERNAL_SERVER_ERROR, new Object[0]));
            }), buckets.size());
            for (int i = 0; i < buckets.size(); ++i) {
                MultiBucketsAggregation.Bucket bucket = (MultiBucketsAggregation.Bucket)buckets.get(i);
                this.createForecastResultBucket(bucket, i, request, categoryFields, forecaster, (ActionListener<ForecastResultBucket>)groupListeneer);
            }
        }, e -> listener.onFailure(e));
    }

    public void createForecastResultBucket(MultiBucketsAggregation.Bucket bucket, int bucketIndex, SearchTopForecastResultRequest request, List<String> categoryFields, Forecaster forecaster, ActionListener<ForecastResultBucket> listener) {
        HashMap<String, Double> aggregationsMap = new HashMap<String, Double>();
        for (Aggregation aggregation : bucket.getAggregations()) {
            if (!(aggregation instanceof NumericMetricsAggregation.SingleValue)) {
                listener.onFailure((Exception)new IllegalArgumentException(String.format(Locale.ROOT, "A single value aggregation is required; received [{}]", aggregation)));
            }
            NumericMetricsAggregation.SingleValue singleValueAggregation = (NumericMetricsAggregation.SingleValue)aggregation;
            aggregationsMap.put(aggregation.getName(), singleValueAggregation.value());
        }
        if (bucket instanceof Terms.Bucket) {
            this.convertToCategoricalFieldValuePair((String)bucket.getKey(), bucketIndex, (int)bucket.getDocCount(), aggregationsMap, request, categoryFields, forecaster, listener);
        } else {
            listener.onFailure((Exception)new IllegalArgumentException(String.format(Locale.ROOT, "We only use terms aggregation in top, but got %s", bucket)));
        }
    }

    private void convertToCategoricalFieldValuePair(String keyInSearchResponse, int bucketIndex, int docCount, Map<String, Double> aggregations, SearchTopForecastResultRequest request, List<String> categoryFields, Forecaster forecaster, ActionListener<ForecastResultBucket> listener) {
        List<String> splitBy = request.getSplitBy();
        HashMap<String, Object> keys = new HashMap<String, Object>();
        if (splitBy == null || splitBy.size() == categoryFields.size()) {
            this.findMatchingCategoricalFieldValuePair(keyInSearchResponse, docCount, aggregations, bucketIndex, forecaster, listener);
        } else {
            keys.put(splitBy.get(0), keyInSearchResponse);
            listener.onResponse((Object)new ForecastResultBucket(keys, docCount, aggregations, bucketIndex));
        }
    }

    private void findMatchingCategoricalFieldValuePair(String entityId, int docCount, Map<String, Double> aggregations, int bucketIndex, Forecaster forecaster, ActionListener<ForecastResultBucket> listener) {
        TermQueryBuilder entityIdFilter = QueryBuilders.termQuery((String)"entity_id", (String)entityId);
        BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery().filter((QueryBuilder)entityIdFilter);
        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query((QueryBuilder)internalFilterQuery).size(1);
        String resultIndex = Strings.isNullOrEmpty((String)forecaster.getCustomResultIndexOrAlias()) ? defaultIndex : forecaster.getCustomResultIndexPattern();
        SearchRequest searchRequest = new SearchRequest().indices(new String[]{resultIndex}).source(searchSourceBuilder).preference(Preference.LOCAL.toString());
        String failure = String.format(Locale.ROOT, "Cannot find a result matching entity id %s", entityId);
        ActionListener searchResponseListener = ActionListener.wrap(searchResponse -> {
            try {
                SearchHit[] hits = searchResponse.getHits().getHits();
                if (hits.length == 0) {
                    listener.onFailure((Exception)new IllegalArgumentException(failure));
                    return;
                }
                SearchHit searchHit = hits[0];
                try (XContentParser parser = RestHandlerUtils.createXContentParserFromRegistry(this.xContent, searchHit.getSourceRef());){
                    XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_OBJECT, (XContentParser.Token)parser.nextToken(), (XContentParser)parser);
                    Optional<Entity> entity = ForecastResult.parse(parser).getEntity();
                    if (entity.isEmpty()) {
                        listener.onFailure((Exception)new IllegalArgumentException(failure));
                        return;
                    }
                    listener.onResponse((Object)new ForecastResultBucket(this.convertMap(entity.get().getAttributes()), docCount, aggregations, bucketIndex));
                }
                catch (Exception e) {
                    listener.onFailure((Exception)new IllegalArgumentException(failure, e));
                }
            }
            catch (Exception e) {
                listener.onFailure((Exception)new IllegalArgumentException(failure, e));
            }
        }, e -> listener.onFailure((Exception)new IllegalArgumentException(failure, (Throwable)e)));
        this.searchHandler.search(searchRequest, (ActionListener<SearchResponse>)searchResponseListener);
    }

    private Map<String, Object> convertMap(Map<String, String> stringMap) {
        HashMap<String, Object> objectMap = new HashMap<String, Object>();
        for (Map.Entry<String, String> entry : stringMap.entrySet()) {
            objectMap.put(entry.getKey(), entry.getValue());
        }
        return objectMap;
    }

    private SearchRequest generateQuery(SearchTopForecastResultRequest request, Forecaster forecaster) {
        BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
        RangeQueryBuilder rangeQuery = this.generateDateFilter(request, forecaster);
        boolQueryBuilder = boolQueryBuilder.filter((QueryBuilder)rangeQuery);
        boolQueryBuilder.filter((QueryBuilder)new ExistsQueryBuilder("forecast_value"));
        FilterBy filterBy = request.getFilterBy();
        switch (filterBy) {
            case CUSTOM_QUERY: {
                if (request.getFilterQuery() == null) break;
                boolQueryBuilder = boolQueryBuilder.filter(request.getFilterQuery());
                break;
            }
            case BUILD_IN_QUERY: {
                QueryBuilder buildInSubFilter = this.generateBuildInSubFilter(request, forecaster);
                if (buildInSubFilter == null) break;
                boolQueryBuilder = boolQueryBuilder.filter(buildInSubFilter);
                break;
            }
            default: {
                throw new IllegalArgumentException(String.format(Locale.ROOT, "Unexpected filter by %s", new Object[]{request.getFilterBy()}));
            }
        }
        boolQueryBuilder = this.generateTaskIdFilter(request, boolQueryBuilder);
        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query((QueryBuilder)boolQueryBuilder).trackTotalHits(false).size(0);
        TermsAggregationBuilder termsAgg = this.generateTermsAggregation(request, forecaster);
        if (termsAgg != null) {
            searchSourceBuilder = searchSourceBuilder.aggregation((AggregationBuilder)termsAgg);
        }
        return new SearchRequest().indices(new String[]{defaultIndex}).source(searchSourceBuilder);
    }

    private QueryBuilder generateBuildInSubFilter(SearchTopForecastResultRequest request, Forecaster forecaster) {
        BuildInQuery buildInQuery = request.getBuildInQuery();
        switch (buildInQuery) {
            case MIN_CONFIDENCE_INTERVAL_WIDTH: 
            case MAX_CONFIDENCE_INTERVAL_WIDTH: {
                return QueryBuilders.termQuery((String)"horizon_index", (Object)forecaster.getHorizon());
            }
            case DISTANCE_TO_THRESHOLD_VALUE: {
                RangeQueryBuilder res = QueryBuilders.rangeQuery((String)"forecast_value");
                Float threshold = request.getThreshold();
                switch (request.getRelationToThreshold()) {
                    case GREATER_THAN: {
                        res = res.gt((Object)threshold);
                        break;
                    }
                    case GREATER_THAN_OR_EQUAL_TO: {
                        res = res.gte((Object)threshold);
                        break;
                    }
                    case LESS_THAN: {
                        res = res.lt((Object)threshold);
                        break;
                    }
                    case LESS_THAN_OR_EQUAL_TO: {
                        res = res.lte((Object)threshold);
                    }
                }
                return res;
            }
        }
        return null;
    }

    private RangeQueryBuilder generateDateFilter(SearchTopForecastResultRequest request, Forecaster forecaster) {
        long startInclusive = request.getForecastFrom().toEpochMilli();
        long endExclusive = startInclusive + forecaster.getIntervalInMilliseconds();
        return QueryBuilders.rangeQuery((String)"data_end_time").gte((Object)startInclusive).lt((Object)endExclusive);
    }

    private BoolQueryBuilder generateTaskIdFilter(SearchTopForecastResultRequest request, BoolQueryBuilder query) {
        if (!Strings.isNullOrEmpty((String)request.getTaskId())) {
            query.filter((QueryBuilder)QueryBuilders.termQuery((String)"task_id", (String)request.getTaskId()));
        } else {
            TermQueryBuilder forecasterIdFilter = QueryBuilders.termQuery((String)"forecaster_id", (String)request.getForecasterId());
            ExistsQueryBuilder taskIdExistsFilter = QueryBuilders.existsQuery((String)"task_id");
            query.filter((QueryBuilder)forecasterIdFilter).mustNot((QueryBuilder)taskIdExistsFilter);
        }
        return query;
    }

    private TermsAggregationBuilder generateTermsAggregation(SearchTopForecastResultRequest request, Forecaster forecaster) {
        TermsAggregationBuilder termsAgg = AggregationBuilders.terms((String)AGG_NAME_TERM).size(request.getSize().intValue());
        if (request.getSplitBy().size() == forecaster.getCategoryFields().size()) {
            termsAgg = (TermsAggregationBuilder)termsAgg.field("entity_id");
        } else if (request.getSplitBy().size() == 1) {
            termsAgg = (TermsAggregationBuilder)termsAgg.script(QueryUtil.getScriptForCategoryField(request.getSplitBy().get(0)));
        }
        ArrayList<BucketOrder> orders = new ArrayList<BucketOrder>();
        FilterBy filterBy = request.getFilterBy();
        switch (filterBy) {
            case BUILD_IN_QUERY: {
                Pair<AggregationBuilder, BucketOrder> aggregationOrderPair = this.generateBuildInSubAggregation(request);
                termsAgg.subAggregation((AggregationBuilder)aggregationOrderPair.getLeft());
                orders.add((BucketOrder)aggregationOrderPair.getRight());
                break;
            }
            case CUSTOM_QUERY: {
                for (Subaggregation subaggregation : request.getSubaggregations()) {
                    try {
                        AggregatorFactories.Builder internalAgg = ParseUtils.parseAggregators(subaggregation.getAggregation().toString(), this.xContent, null);
                        AggregationBuilder aggregation = (AggregationBuilder)internalAgg.getAggregatorFactories().iterator().next();
                        termsAgg.subAggregation(aggregation);
                        orders.add(BucketOrder.aggregation((String)aggregation.getName(), (subaggregation.getOrder() == Order.ASC ? 1 : 0) != 0));
                    }
                    catch (IOException e) {
                        throw new IllegalArgumentException(String.format(Locale.ROOT, "Unexpected IOException when parsing %s", subaggregation), e);
                    }
                }
                break;
            }
            default: {
                throw new IllegalArgumentException(String.format(Locale.ROOT, "Unexpected filter by %s", new Object[]{filterBy}));
            }
        }
        if (orders.isEmpty()) {
            throw new IllegalArgumentException("Cannot have empty order list");
        }
        termsAgg.order(orders);
        return termsAgg;
    }

    private Pair<AggregationBuilder, BucketOrder> generateBuildInSubAggregation(SearchTopForecastResultRequest request) {
        String aggregationName = null;
        ValuesSourceAggregationBuilder aggregation = null;
        BucketOrder order = null;
        BuildInQuery buildInQuery = request.getBuildInQuery();
        switch (buildInQuery) {
            case MIN_CONFIDENCE_INTERVAL_WIDTH: {
                aggregationName = BuildInQuery.MIN_CONFIDENCE_INTERVAL_WIDTH.name();
                aggregation = AggregationBuilders.min((String)aggregationName).field("confidence_interval_width");
                order = BucketOrder.aggregation((String)aggregationName, (boolean)true);
                return Pair.of((Object)aggregation, (Object)order);
            }
            case MAX_CONFIDENCE_INTERVAL_WIDTH: {
                aggregationName = BuildInQuery.MAX_CONFIDENCE_INTERVAL_WIDTH.name();
                aggregation = AggregationBuilders.max((String)aggregationName).field("confidence_interval_width");
                order = BucketOrder.aggregation((String)aggregationName, (boolean)false);
                return Pair.of((Object)aggregation, (Object)order);
            }
            case MIN_VALUE_WITHIN_THE_HORIZON: {
                aggregationName = BuildInQuery.MIN_VALUE_WITHIN_THE_HORIZON.name();
                aggregation = AggregationBuilders.min((String)aggregationName).field("forecast_value");
                order = BucketOrder.aggregation((String)aggregationName, (boolean)true);
                return Pair.of((Object)aggregation, (Object)order);
            }
            case MAX_VALUE_WITHIN_THE_HORIZON: {
                aggregationName = BuildInQuery.MAX_VALUE_WITHIN_THE_HORIZON.name();
                aggregation = AggregationBuilders.max((String)aggregationName).field("forecast_value");
                order = BucketOrder.aggregation((String)aggregationName, (boolean)false);
                return Pair.of((Object)aggregation, (Object)order);
            }
            case DISTANCE_TO_THRESHOLD_VALUE: {
                RelationalOperation relationToThreshold = request.getRelationToThreshold();
                switch (relationToThreshold) {
                    case GREATER_THAN: 
                    case GREATER_THAN_OR_EQUAL_TO: {
                        aggregationName = BuildInQuery.DISTANCE_TO_THRESHOLD_VALUE.name();
                        aggregation = AggregationBuilders.max((String)aggregationName).field("forecast_value");
                        order = BucketOrder.aggregation((String)aggregationName, (boolean)false);
                        return Pair.of((Object)aggregation, (Object)order);
                    }
                    case LESS_THAN: 
                    case LESS_THAN_OR_EQUAL_TO: {
                        aggregationName = BuildInQuery.DISTANCE_TO_THRESHOLD_VALUE.name();
                        aggregation = AggregationBuilders.min((String)aggregationName).field("forecast_value");
                        order = BucketOrder.aggregation((String)aggregationName, (boolean)true);
                        return Pair.of((Object)aggregation, (Object)order);
                    }
                }
                throw new IllegalArgumentException(String.format(Locale.ROOT, "Unexpected relation to threshold %s", new Object[]{relationToThreshold}));
            }
        }
        throw new IllegalArgumentException(String.format(Locale.ROOT, "Unexpected build in query type %s", new Object[]{buildInQuery}));
    }
}

