/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.processor;

import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.get.GetAction;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.get.MultiGetAction;
import org.opensearch.action.get.MultiGetResponse;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.env.Environment;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.ingest.IngestDocumentWrapper;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.InferenceProcessor;
import org.opensearch.neuralsearch.processor.TextInferenceRequest;
import org.opensearch.neuralsearch.processor.optimization.InferenceFilter;
import org.opensearch.neuralsearch.processor.optimization.TextEmbeddingInferenceFilter;
import org.opensearch.neuralsearch.util.TokenWeightUtil;
import org.opensearch.neuralsearch.util.prune.PruneType;
import org.opensearch.neuralsearch.util.prune.PruneUtils;
import org.opensearch.transport.client.OpenSearchClient;

public final class SparseEncodingProcessor
extends InferenceProcessor {
    @Generated
    private static final Logger log = LogManager.getLogger(SparseEncodingProcessor.class);
    public static final String TYPE = "sparse_encoding";
    public static final String LIST_TYPE_NESTED_MAP_KEY = "sparse_encoding";
    private final OpenSearchClient openSearchClient;
    private final boolean skipExisting;
    private final TextEmbeddingInferenceFilter textEmbeddingInferenceFilter;
    private final PruneType pruneType;
    private final float pruneRatio;

    public SparseEncodingProcessor(String tag, String description, int batchSize, String modelId, Map<String, Object> fieldMap, boolean skipExisting, TextEmbeddingInferenceFilter textEmbeddingInferenceFilter, PruneType pruneType, float pruneRatio, OpenSearchClient openSearchClient, MLCommonsClientAccessor clientAccessor, Environment environment, ClusterService clusterService) {
        super(tag, description, batchSize, "sparse_encoding", "sparse_encoding", modelId, fieldMap, clientAccessor, environment, clusterService);
        this.pruneType = pruneType;
        this.pruneRatio = pruneRatio;
        this.skipExisting = skipExisting;
        this.textEmbeddingInferenceFilter = textEmbeddingInferenceFilter;
        this.openSearchClient = openSearchClient;
    }

    @Override
    public void doExecute(IngestDocument ingestDocument, Map<String, Object> processMap, List<String> inferenceList, BiConsumer<IngestDocument, Exception> handler) {
        if (!this.skipExisting) {
            this.generateAndSetMapInference(ingestDocument, processMap, inferenceList, this.pruneType, this.pruneRatio, handler);
            return;
        }
        Object index = ingestDocument.getSourceAndMetadata().get("_index");
        Object id = ingestDocument.getSourceAndMetadata().get("_id");
        if (Objects.isNull(index) || Objects.isNull(id)) {
            this.generateAndSetMapInference(ingestDocument, processMap, inferenceList, this.pruneType, this.pruneRatio, handler);
            return;
        }
        this.openSearchClient.execute((ActionType)GetAction.INSTANCE, (ActionRequest)new GetRequest(index.toString(), id.toString()), ActionListener.wrap(response -> {
            Map existingDocument = response.getSourceAsMap();
            if (existingDocument == null || existingDocument.isEmpty()) {
                this.generateAndSetMapInference(ingestDocument, processMap, inferenceList, this.pruneType, this.pruneRatio, handler);
                return;
            }
            Map<String, Object> filteredProcessMap = this.textEmbeddingInferenceFilter.filterAndCopyExistingEmbeddings(existingDocument, ingestDocument.getSourceAndMetadata(), processMap);
            List<String> filteredInferenceList = this.createInferenceList(filteredProcessMap).stream().filter(Objects::nonNull).collect(Collectors.toList());
            if (filteredInferenceList.isEmpty()) {
                handler.accept(ingestDocument, null);
            } else {
                this.generateAndSetMapInference(ingestDocument, filteredProcessMap, filteredInferenceList, this.pruneType, this.pruneRatio, handler);
            }
        }, e -> handler.accept((IngestDocument)null, (Exception)e)));
    }

    @Override
    public void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException) {
        this.mlCommonsClientAccessor.inferenceSentencesWithMapResult((TextInferenceRequest)((TextInferenceRequest.TextInferenceRequestBuilder)((TextInferenceRequest.TextInferenceRequestBuilder)TextInferenceRequest.builder().modelId(this.modelId)).inputTexts(inferenceList)).build(), ActionListener.wrap(resultMaps -> {
            List<Map> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps).stream().map(vector -> PruneUtils.pruneSparseVector(this.pruneType, this.pruneRatio, vector)).toList();
            handler.accept(sparseVectors);
        }, onException));
    }

    @Override
    public void subBatchExecute(List<IngestDocumentWrapper> ingestDocumentWrappers, Consumer<List<IngestDocumentWrapper>> handler) {
        try {
            if (CollectionUtils.isEmpty(ingestDocumentWrappers)) {
                handler.accept(ingestDocumentWrappers);
                return;
            }
            List<InferenceProcessor.DataForInference> dataForInferences = this.getDataForInference(ingestDocumentWrappers);
            List<String> inferenceList = this.constructInferenceTexts(dataForInferences);
            if (inferenceList.isEmpty()) {
                handler.accept(ingestDocumentWrappers);
                return;
            }
            if (!this.skipExisting) {
                this.doSubBatchExecute(ingestDocumentWrappers, inferenceList, dataForInferences, handler);
                return;
            }
            this.openSearchClient.execute((ActionType)MultiGetAction.INSTANCE, (ActionRequest)this.buildMultiGetRequest(dataForInferences), ActionListener.wrap(response -> this.reuseOrGenerateEmbedding((MultiGetResponse)response, ingestDocumentWrappers, inferenceList, dataForInferences, handler, (InferenceFilter)this.textEmbeddingInferenceFilter), e -> this.updateWithExceptions(this.getIngestDocumentWrappers(dataForInferences), handler, (Exception)e)));
        }
        catch (Exception e2) {
            this.updateWithExceptions(ingestDocumentWrappers, handler, e2);
        }
    }

    @Generated
    public PruneType getPruneType() {
        return this.pruneType;
    }

    @Generated
    public float getPruneRatio() {
        return this.pruneRatio;
    }
}

