/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.query;

import com.google.common.annotations.VisibleForTesting;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Supplier;
import lombok.Generated;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.Query;
import org.opensearch.common.SetOnce;
import org.opensearch.core.ParseField;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.ParsingException;
import org.opensearch.core.common.io.stream.NamedWriteable;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.query.AbstractQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.knn.index.query.parser.MethodParametersParser;
import org.opensearch.knn.index.query.parser.RescoreParser;
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.neuralsearch.common.MinClusterVersionUtil;
import org.opensearch.neuralsearch.common.VectorUtil;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.query.ModelInferenceQueryBuilder;

public class NeuralQueryBuilder
extends AbstractQueryBuilder<NeuralQueryBuilder>
implements ModelInferenceQueryBuilder {
    @Generated
    private static final Logger log = LogManager.getLogger(NeuralQueryBuilder.class);
    public static final String NAME = "neural";
    @VisibleForTesting
    static final ParseField QUERY_TEXT_FIELD = new ParseField("query_text", new String[0]);
    public static final ParseField QUERY_IMAGE_FIELD = new ParseField("query_image", new String[0]);
    public static final ParseField MODEL_ID_FIELD = new ParseField("model_id", new String[0]);
    @VisibleForTesting
    static final ParseField K_FIELD = new ParseField("k", new String[0]);
    private static final int DEFAULT_K = 10;
    private static MLCommonsClientAccessor ML_CLIENT;
    private String fieldName;
    private String queryText;
    private String queryImage;
    private String modelId;
    private Integer k = null;
    private Float maxDistance = null;
    private Float minScore = null;
    private Boolean expandNested;
    @VisibleForTesting
    private Supplier<float[]> vectorSupplier;
    private QueryBuilder filter;
    private Map<String, ?> methodParameters;
    private RescoreContext rescoreContext;

    public static void initialize(MLCommonsClientAccessor mlClient) {
        ML_CLIENT = mlClient;
    }

    public static Builder builder() {
        return new Builder();
    }

    public NeuralQueryBuilder(StreamInput in) throws IOException {
        super(in);
        this.fieldName = in.readString();
        if (MinClusterVersionUtil.isClusterOnOrAfterMinReqVersion(QUERY_IMAGE_FIELD.getPreferredName())) {
            this.queryText = in.readOptionalString();
            this.queryImage = in.readOptionalString();
        } else {
            this.queryText = in.readString();
        }
        this.modelId = MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport() ? in.readOptionalString() : in.readString();
        this.k = MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForRadialSearch() ? in.readOptionalInt() : Integer.valueOf(in.readVInt());
        this.filter = (QueryBuilder)in.readOptionalNamedWriteable(QueryBuilder.class);
        if (MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForRadialSearch()) {
            this.maxDistance = in.readOptionalFloat();
            this.minScore = in.readOptionalFloat();
        }
        if (MinClusterVersionUtil.isClusterOnOrAfterMinReqVersion(KNNQueryBuilder.EXPAND_NESTED_FIELD.getPreferredName())) {
            this.expandNested = in.readOptionalBoolean();
        }
        if (MinClusterVersionUtil.isClusterOnOrAfterMinReqVersion(KNNQueryBuilder.METHOD_PARAMS_FIELD.getPreferredName())) {
            this.methodParameters = MethodParametersParser.streamInput((StreamInput)in, MinClusterVersionUtil::isClusterOnOrAfterMinReqVersion);
        }
        this.rescoreContext = RescoreParser.streamInput((StreamInput)in);
    }

    protected void doWriteTo(StreamOutput out) throws IOException {
        out.writeString(this.fieldName);
        if (MinClusterVersionUtil.isClusterOnOrAfterMinReqVersion(QUERY_IMAGE_FIELD.getPreferredName())) {
            out.writeOptionalString(this.queryText);
            out.writeOptionalString(this.queryImage);
        } else {
            out.writeString(this.queryText);
        }
        if (MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) {
            out.writeOptionalString(this.modelId);
        } else {
            out.writeString(this.modelId);
        }
        if (MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForRadialSearch()) {
            out.writeOptionalInt(this.k);
        } else {
            out.writeVInt(this.k.intValue());
        }
        out.writeOptionalNamedWriteable((NamedWriteable)this.filter);
        if (MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForRadialSearch()) {
            out.writeOptionalFloat(this.maxDistance);
            out.writeOptionalFloat(this.minScore);
        }
        if (MinClusterVersionUtil.isClusterOnOrAfterMinReqVersion(KNNQueryBuilder.EXPAND_NESTED_FIELD.getPreferredName())) {
            out.writeOptionalBoolean(this.expandNested);
        }
        if (MinClusterVersionUtil.isClusterOnOrAfterMinReqVersion(KNNQueryBuilder.METHOD_PARAMS_FIELD.getPreferredName())) {
            MethodParametersParser.streamOutput((StreamOutput)out, this.methodParameters, MinClusterVersionUtil::isClusterOnOrAfterMinReqVersion);
        }
        RescoreParser.streamOutput((StreamOutput)out, (RescoreContext)this.rescoreContext);
    }

    protected void doXContent(XContentBuilder xContentBuilder, ToXContent.Params params) throws IOException {
        xContentBuilder.startObject(NAME);
        xContentBuilder.startObject(this.fieldName);
        if (Objects.nonNull(this.queryText)) {
            xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), this.queryText);
        }
        if (Objects.nonNull(this.queryImage)) {
            xContentBuilder.field(QUERY_IMAGE_FIELD.getPreferredName(), this.queryImage);
        }
        if (Objects.nonNull(this.modelId)) {
            xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), this.modelId);
        }
        if (Objects.nonNull(this.k)) {
            xContentBuilder.field(K_FIELD.getPreferredName(), this.k);
        }
        if (Objects.nonNull(this.filter)) {
            xContentBuilder.field(KNNQueryBuilder.FILTER_FIELD.getPreferredName(), (ToXContent)this.filter);
        }
        if (Objects.nonNull(this.maxDistance)) {
            xContentBuilder.field(KNNQueryBuilder.MAX_DISTANCE_FIELD.getPreferredName(), this.maxDistance);
        }
        if (Objects.nonNull(this.minScore)) {
            xContentBuilder.field(KNNQueryBuilder.MIN_SCORE_FIELD.getPreferredName(), this.minScore);
        }
        if (Objects.nonNull(this.expandNested)) {
            xContentBuilder.field(KNNQueryBuilder.EXPAND_NESTED_FIELD.getPreferredName(), this.expandNested);
        }
        if (Objects.nonNull(this.methodParameters)) {
            MethodParametersParser.doXContent((XContentBuilder)xContentBuilder, this.methodParameters);
        }
        if (Objects.nonNull(this.rescoreContext)) {
            RescoreParser.doXContent((XContentBuilder)xContentBuilder, (RescoreContext)this.rescoreContext);
        }
        this.printBoostAndQueryName(xContentBuilder);
        xContentBuilder.endObject();
        xContentBuilder.endObject();
    }

    public static NeuralQueryBuilder fromXContent(XContentParser parser) throws IOException {
        boolean queryTypeIsProvided;
        NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder();
        if (parser.currentToken() != XContentParser.Token.START_OBJECT) {
            throw new ParsingException(parser.getTokenLocation(), "Token must be START_OBJECT", new Object[0]);
        }
        parser.nextToken();
        neuralQueryBuilder.fieldName(parser.currentName());
        parser.nextToken();
        NeuralQueryBuilder.parseQueryParams(parser, neuralQueryBuilder);
        if (parser.nextToken() != XContentParser.Token.END_OBJECT) {
            throw new ParsingException(parser.getTokenLocation(), "[neural] query doesn't support multiple fields, found [" + neuralQueryBuilder.fieldName() + "] and [" + parser.currentName() + "]", new Object[0]);
        }
        NeuralQueryBuilder.validateQueryParameters(neuralQueryBuilder.fieldName(), neuralQueryBuilder.queryText(), neuralQueryBuilder.queryImage());
        if (!MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) {
            NeuralQueryBuilder.requireValue((Object)neuralQueryBuilder.modelId(), (String)"Model ID must be provided for neural query");
        }
        if (!(queryTypeIsProvided = NeuralQueryBuilder.validateKNNQueryType(neuralQueryBuilder.k(), neuralQueryBuilder.maxDistance(), neuralQueryBuilder.minScore()))) {
            neuralQueryBuilder.k(10);
        }
        return neuralQueryBuilder;
    }

    private static void parseQueryParams(XContentParser parser, NeuralQueryBuilder neuralQueryBuilder) throws IOException {
        XContentParser.Token token;
        String currentFieldName = "";
        while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
            if (token == XContentParser.Token.FIELD_NAME) {
                currentFieldName = parser.currentName();
                continue;
            }
            if (token.isValue()) {
                if (QUERY_TEXT_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    neuralQueryBuilder.queryText(parser.text());
                    continue;
                }
                if (QUERY_IMAGE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    neuralQueryBuilder.queryImage(parser.text());
                    continue;
                }
                if (MODEL_ID_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    neuralQueryBuilder.modelId(parser.text());
                    continue;
                }
                if (K_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    neuralQueryBuilder.k((Integer)NumberFieldMapper.NumberType.INTEGER.parse(parser.objectBytes(), false));
                    continue;
                }
                if (NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    neuralQueryBuilder.queryName(parser.text());
                    continue;
                }
                if (BOOST_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    neuralQueryBuilder.boost(parser.floatValue());
                    continue;
                }
                if (KNNQueryBuilder.MAX_DISTANCE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    neuralQueryBuilder.maxDistance(Float.valueOf(parser.floatValue()));
                    continue;
                }
                if (KNNQueryBuilder.MIN_SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    neuralQueryBuilder.minScore(Float.valueOf(parser.floatValue()));
                    continue;
                }
                if (KNNQueryBuilder.EXPAND_NESTED_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    neuralQueryBuilder.expandNested(parser.booleanValue());
                    continue;
                }
                throw new ParsingException(parser.getTokenLocation(), "[neural] query does not support [" + currentFieldName + "]", new Object[0]);
            }
            if (token == XContentParser.Token.START_OBJECT) {
                if (KNNQueryBuilder.FILTER_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    neuralQueryBuilder.filter(NeuralQueryBuilder.parseInnerQueryBuilder((XContentParser)parser));
                    continue;
                }
                if (KNNQueryBuilder.METHOD_PARAMS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    neuralQueryBuilder.methodParameters(MethodParametersParser.fromXContent((XContentParser)parser));
                    continue;
                }
                if (!KNNQueryBuilder.RESCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) continue;
                neuralQueryBuilder.rescoreContext(RescoreParser.fromXContent((XContentParser)parser));
                continue;
            }
            throw new ParsingException(parser.getTokenLocation(), "[neural] unknown token [" + token + "] after [" + currentFieldName + "]", new Object[0]);
        }
    }

    protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
        if (this.vectorSupplier() != null) {
            if (this.vectorSupplier().get() == null) {
                return this;
            }
            return KNNQueryBuilder.builder().fieldName(this.fieldName()).vector(this.vectorSupplier.get()).filter(this.filter()).maxDistance(this.maxDistance).minScore(this.minScore).expandNested(this.expandNested).k(this.k).methodParameters(this.methodParameters).rescoreContext(this.rescoreContext).build();
        }
        SetOnce vectorSetOnce = new SetOnce();
        HashMap<String, String> inferenceInput = new HashMap<String, String>();
        if (StringUtils.isNotBlank((String)this.queryText())) {
            inferenceInput.put("inputText", this.queryText());
        }
        if (StringUtils.isNotBlank((String)this.queryImage())) {
            inferenceInput.put("inputImage", this.queryImage());
        }
        queryRewriteContext.registerAsyncAction((client, actionListener) -> ML_CLIENT.inferenceSentences(this.modelId(), inferenceInput, (ActionListener<List<Float>>)ActionListener.wrap(floatList -> {
            vectorSetOnce.set((Object)VectorUtil.vectorAsListToArray(floatList));
            actionListener.onResponse(null);
        }, arg_0 -> ((ActionListener)actionListener).onFailure(arg_0))));
        return new NeuralQueryBuilder(this.fieldName(), this.queryText(), this.queryImage(), this.modelId(), this.k(), this.maxDistance(), this.minScore(), this.expandNested(), () -> ((SetOnce)vectorSetOnce).get(), this.filter(), this.methodParameters(), this.rescoreContext());
    }

    protected Query doToQuery(QueryShardContext queryShardContext) {
        throw new UnsupportedOperationException("Query cannot be created by NeuralQueryBuilder directly");
    }

    protected boolean doEquals(NeuralQueryBuilder obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || this.getClass() != obj.getClass()) {
            return false;
        }
        EqualsBuilder equalsBuilder = new EqualsBuilder();
        equalsBuilder.append((Object)this.fieldName, (Object)obj.fieldName);
        equalsBuilder.append((Object)this.queryText, (Object)obj.queryText);
        equalsBuilder.append((Object)this.queryImage, (Object)obj.queryImage);
        equalsBuilder.append((Object)this.modelId, (Object)obj.modelId);
        equalsBuilder.append((Object)this.k, (Object)obj.k);
        equalsBuilder.append((Object)this.maxDistance, (Object)obj.maxDistance);
        equalsBuilder.append((Object)this.minScore, (Object)obj.minScore);
        equalsBuilder.append((Object)this.expandNested, (Object)obj.expandNested);
        equalsBuilder.append(this.getVector(this.vectorSupplier), this.getVector(obj.vectorSupplier));
        equalsBuilder.append((Object)this.filter, (Object)obj.filter);
        equalsBuilder.append(this.methodParameters, obj.methodParameters);
        equalsBuilder.append((Object)this.rescoreContext, (Object)obj.rescoreContext);
        return equalsBuilder.isEquals();
    }

    protected int doHashCode() {
        return Objects.hash(this.fieldName, this.queryText, this.queryImage, this.modelId, this.k, this.maxDistance, this.minScore, this.expandNested, Arrays.hashCode(this.getVector(this.vectorSupplier)), this.filter, this.methodParameters, this.rescoreContext);
    }

    private float[] getVector(Supplier<float[]> vectorSupplier) {
        return Objects.isNull(vectorSupplier) ? null : vectorSupplier.get();
    }

    public String getWriteableName() {
        return NAME;
    }

    private static void validateQueryParameters(String fieldName, String queryText, String queryImage) {
        if (StringUtils.isBlank((String)queryText) && StringUtils.isBlank((String)queryImage)) {
            throw new IllegalArgumentException("Either query text or image text must be provided for neural query");
        }
        NeuralQueryBuilder.requireValue((Object)fieldName, (String)"Field name must be provided for neural query");
    }

    private static boolean validateKNNQueryType(Integer k, Float maxDistance, Float minScore) {
        int queryCount = 0;
        if (k != null) {
            ++queryCount;
        }
        if (maxDistance != null) {
            ++queryCount;
        }
        if (minScore != null) {
            ++queryCount;
        }
        if (queryCount > 1) {
            throw new IllegalArgumentException("Only one of k, max_distance, or min_score can be provided");
        }
        return queryCount == 1;
    }

    @Override
    @Generated
    public String fieldName() {
        return this.fieldName;
    }

    @Generated
    public String queryText() {
        return this.queryText;
    }

    @Generated
    public String queryImage() {
        return this.queryImage;
    }

    @Override
    @Generated
    public String modelId() {
        return this.modelId;
    }

    @Generated
    public Integer k() {
        return this.k;
    }

    @Generated
    public Float maxDistance() {
        return this.maxDistance;
    }

    @Generated
    public Float minScore() {
        return this.minScore;
    }

    @Generated
    public Boolean expandNested() {
        return this.expandNested;
    }

    @Generated
    public QueryBuilder filter() {
        return this.filter;
    }

    @Generated
    public Map<String, ?> methodParameters() {
        return this.methodParameters;
    }

    @Generated
    public RescoreContext rescoreContext() {
        return this.rescoreContext;
    }

    @Generated
    public NeuralQueryBuilder fieldName(String fieldName) {
        this.fieldName = fieldName;
        return this;
    }

    @Generated
    public NeuralQueryBuilder queryText(String queryText) {
        this.queryText = queryText;
        return this;
    }

    @Generated
    public NeuralQueryBuilder queryImage(String queryImage) {
        this.queryImage = queryImage;
        return this;
    }

    @Override
    @Generated
    public NeuralQueryBuilder modelId(String modelId) {
        this.modelId = modelId;
        return this;
    }

    @Generated
    public NeuralQueryBuilder k(Integer k) {
        this.k = k;
        return this;
    }

    @Generated
    public NeuralQueryBuilder maxDistance(Float maxDistance) {
        this.maxDistance = maxDistance;
        return this;
    }

    @Generated
    public NeuralQueryBuilder minScore(Float minScore) {
        this.minScore = minScore;
        return this;
    }

    @Generated
    public NeuralQueryBuilder expandNested(Boolean expandNested) {
        this.expandNested = expandNested;
        return this;
    }

    @Generated
    public NeuralQueryBuilder filter(QueryBuilder filter) {
        this.filter = filter;
        return this;
    }

    @Generated
    public NeuralQueryBuilder methodParameters(Map<String, ?> methodParameters) {
        this.methodParameters = methodParameters;
        return this;
    }

    @Generated
    public NeuralQueryBuilder rescoreContext(RescoreContext rescoreContext) {
        this.rescoreContext = rescoreContext;
        return this;
    }

    @Generated
    private NeuralQueryBuilder() {
    }

    @Generated
    private NeuralQueryBuilder(String fieldName, String queryText, String queryImage, String modelId, Integer k, Float maxDistance, Float minScore, Boolean expandNested, Supplier<float[]> vectorSupplier, QueryBuilder filter, Map<String, ?> methodParameters, RescoreContext rescoreContext) {
        this.fieldName = fieldName;
        this.queryText = queryText;
        this.queryImage = queryImage;
        this.modelId = modelId;
        this.k = k;
        this.maxDistance = maxDistance;
        this.minScore = minScore;
        this.expandNested = expandNested;
        this.vectorSupplier = vectorSupplier;
        this.filter = filter;
        this.methodParameters = methodParameters;
        this.rescoreContext = rescoreContext;
    }

    @Generated
    Supplier<float[]> vectorSupplier() {
        return this.vectorSupplier;
    }

    @Generated
    NeuralQueryBuilder vectorSupplier(Supplier<float[]> vectorSupplier) {
        this.vectorSupplier = vectorSupplier;
        return this;
    }

    public static class Builder {
        private String fieldName;
        private String queryText;
        private String queryImage;
        private String modelId;
        private Integer k = null;
        private Float maxDistance = null;
        private Float minScore = null;
        private Boolean expandNested;
        private Supplier<float[]> vectorSupplier;
        private QueryBuilder filter;
        private Map<String, ?> methodParameters;
        private RescoreContext rescoreContext;
        private String queryName;
        private float boost = 1.0f;

        public Builder fieldName(String fieldName) {
            this.fieldName = fieldName;
            return this;
        }

        public Builder queryText(String queryText) {
            this.queryText = queryText;
            return this;
        }

        public Builder queryImage(String queryImage) {
            this.queryImage = queryImage;
            return this;
        }

        public Builder modelId(String modelId) {
            this.modelId = modelId;
            return this;
        }

        public Builder k(Integer k) {
            this.k = k;
            return this;
        }

        public Builder maxDistance(Float maxDistance) {
            this.maxDistance = maxDistance;
            return this;
        }

        public Builder minScore(Float minScore) {
            this.minScore = minScore;
            return this;
        }

        public Builder expandNested(Boolean expandNested) {
            this.expandNested = expandNested;
            return this;
        }

        public Builder vectorSupplier(Supplier<float[]> vectorSupplier) {
            this.vectorSupplier = vectorSupplier;
            return this;
        }

        public Builder filter(QueryBuilder filter) {
            this.filter = filter;
            return this;
        }

        public Builder methodParameters(Map<String, ?> methodParameters) {
            this.methodParameters = methodParameters;
            return this;
        }

        public Builder queryName(String queryName) {
            this.queryName = queryName;
            return this;
        }

        public Builder boost(float boost) {
            this.boost = boost;
            return this;
        }

        public Builder rescoreContext(RescoreContext rescoreContext) {
            this.rescoreContext = rescoreContext;
            return this;
        }

        public NeuralQueryBuilder build() {
            NeuralQueryBuilder.validateQueryParameters(this.fieldName, this.queryText, this.queryImage);
            boolean queryTypeIsProvided = NeuralQueryBuilder.validateKNNQueryType(this.k, this.maxDistance, this.minScore);
            if (!queryTypeIsProvided) {
                this.k = 10;
            }
            return (NeuralQueryBuilder)((NeuralQueryBuilder)new NeuralQueryBuilder(this.fieldName, this.queryText, this.queryImage, this.modelId, this.k, this.maxDistance, this.minScore, this.expandNested, this.vectorSupplier, this.filter, this.methodParameters, this.rescoreContext).boost(this.boost)).queryName(this.queryName);
        }
    }
}

