/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.search.collector;

import com.google.common.annotations.VisibleForTesting;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.FieldValueHitQueue;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.LeafFieldComparator;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TopFieldDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.util.PriorityQueue;
import org.opensearch.common.Nullable;
import org.opensearch.neuralsearch.query.HybridQueryScorer;
import org.opensearch.neuralsearch.search.HitsThresholdChecker;
import org.opensearch.neuralsearch.search.collector.HybridSearchCollector;
import org.opensearch.neuralsearch.search.lucene.MultiLeafFieldComparator;

public abstract class HybridTopFieldDocSortCollector
implements HybridSearchCollector {
    @Generated
    private static final Logger log = LogManager.getLogger(HybridTopFieldDocSortCollector.class);
    private final int numHits;
    private final HitsThresholdChecker hitsThresholdChecker;
    private final Sort sort;
    @Nullable
    private FieldDoc after;
    private FieldComparator<?> firstComparator;
    @VisibleForTesting
    private FieldValueHitQueue.Entry[] fieldValueLeafTrackers;
    private int totalHits;
    protected int docBase;
    protected LeafFieldComparator[] comparators;
    private TotalHits.Relation totalHitsRelation = TotalHits.Relation.EQUAL_TO;
    protected int reverseMul;
    protected FieldValueHitQueue<FieldValueHitQueue.Entry>[] compoundScores;
    protected boolean[] queueFull;
    protected float maxScore = 0.0f;
    protected int[] collectedHits;
    private boolean needsInitialization = true;
    private Boolean searchSortPartOfIndexSort = null;
    private static final TopFieldDocs EMPTY_TOP_FIELD_DOCS = new TopFieldDocs(new TotalHits(0L, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0], new SortField[0]);

    HybridTopFieldDocSortCollector(int numHits, HitsThresholdChecker hitsThresholdChecker, Sort sort, FieldDoc after) {
        this.numHits = numHits;
        this.hitsThresholdChecker = hitsThresholdChecker;
        this.sort = sort;
        this.after = after;
    }

    public List<TopFieldDocs> topDocs() {
        if (this.compoundScores == null) {
            return new ArrayList<TopFieldDocs>();
        }
        ArrayList<TopFieldDocs> topFieldDocs = new ArrayList<TopFieldDocs>();
        for (int subQueryNumber = 0; subQueryNumber < this.compoundScores.length; ++subQueryNumber) {
            topFieldDocs.add(this.topDocsPerQuery(0, Math.min(this.collectedHits[subQueryNumber], this.compoundScores[subQueryNumber].size()), (PriorityQueue<FieldValueHitQueue.Entry>)this.compoundScores[subQueryNumber], this.collectedHits[subQueryNumber], this.sort.getSort()));
        }
        return topFieldDocs;
    }

    public ScoreMode scoreMode() {
        return this.hitsThresholdChecker.scoreMode();
    }

    private TopFieldDocs topDocsPerQuery(int start, int howMany, PriorityQueue<FieldValueHitQueue.Entry> pq, int totalHits, SortField[] sortFields) {
        if (howMany < 0) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "Number of hits requested must be greater than 0 but value was %d", howMany));
        }
        if (start < 0) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "Expected value of starting position is between 0 and %d, got %d", howMany, start));
        }
        if (start >= howMany || howMany == 0) {
            return EMPTY_TOP_FIELD_DOCS;
        }
        int size = howMany - start;
        ScoreDoc[] results = new ScoreDoc[size];
        this.populateResults(results, size, pq);
        return new TopFieldDocs(new TotalHits((long)totalHits, this.totalHitsRelation), results, sortFields);
    }

    private void populateResults(ScoreDoc[] results, int howMany, PriorityQueue<FieldValueHitQueue.Entry> pq) {
        FieldValueHitQueue queue = (FieldValueHitQueue)pq;
        for (int i = howMany - 1; i >= 0 && pq.size() > 0; --i) {
            if (i >= results.length) continue;
            FieldValueHitQueue.Entry entry = (FieldValueHitQueue.Entry)queue.pop();
            int n = queue.getComparators().length;
            Object[] fields = new Object[n];
            for (int j = 0; j < n; ++j) {
                fields[j] = queue.getComparators()[j].value(entry.slot);
            }
            results[i] = new FieldDoc(entry.doc, entry.score, fields);
        }
    }

    private void add(int slot, int doc, FieldValueHitQueue<FieldValueHitQueue.Entry> compoundScore, int subQueryNumber, float score) {
        FieldValueHitQueue.Entry bottomEntry = new FieldValueHitQueue.Entry(slot, this.docBase + doc);
        bottomEntry.score = score;
        this.fieldValueLeafTrackers[subQueryNumber] = (FieldValueHitQueue.Entry)compoundScore.add((Object)bottomEntry);
        assert (slot < this.numHits);
        boolean isQueueFull = false;
        if (slot == this.numHits - 1) {
            isQueueFull = true;
        }
        this.queueFull[subQueryNumber] = isQueueFull;
    }

    private void updateBottom(int doc, FieldValueHitQueue<FieldValueHitQueue.Entry> compoundScore, int subQueryIndex) {
        this.fieldValueLeafTrackers[subQueryIndex].doc = this.docBase + doc;
        this.fieldValueLeafTrackers[subQueryIndex] = (FieldValueHitQueue.Entry)compoundScore.updateTop();
    }

    private boolean canEarlyTerminate(Sort searchSort, Sort indexSort) {
        return this.canEarlyTerminateOnDocId(searchSort) || this.canEarlyTerminateOnPrefix(searchSort, indexSort);
    }

    private boolean canEarlyTerminateOnDocId(Sort searchSort) {
        SortField[] fields1 = searchSort.getSort();
        return SortField.FIELD_DOC.equals((Object)fields1[0]);
    }

    private boolean canEarlyTerminateOnPrefix(Sort searchSort, Sort indexSort) {
        if (indexSort != null) {
            SortField[] indexSortField;
            SortField[] searchSortField = searchSort.getSort();
            if (searchSortField.length > (indexSortField = indexSort.getSort()).length) {
                return false;
            }
            for (int i = 0; i < searchSortField.length; ++i) {
                if (searchSortField[i].equals((Object)indexSortField[i])) continue;
                return false;
            }
            return true;
        }
        return false;
    }

    @Generated
    FieldValueHitQueue.Entry[] getFieldValueLeafTrackers() {
        return this.fieldValueLeafTrackers;
    }

    @Override
    @Generated
    public int getTotalHits() {
        return this.totalHits;
    }

    @Generated
    public TotalHits.Relation getTotalHitsRelation() {
        return this.totalHitsRelation;
    }

    @Generated
    public void setTotalHitsRelation(TotalHits.Relation totalHitsRelation) {
        this.totalHitsRelation = totalHitsRelation;
    }

    @Override
    @Generated
    public float getMaxScore() {
        return this.maxScore;
    }

    protected abstract class HybridTopDocSortLeafCollector
    implements LeafCollector {
        protected HybridQueryScorer compoundQueryScorer;
        private boolean collectedAllCompetitiveHits = false;
        private boolean initializeLeafComparatorsPerSegmentOnce = true;

        public void setScorer(Scorable scorer) throws IOException {
            if (scorer instanceof HybridQueryScorer) {
                log.debug("passed scorer is of type HybridQueryScorer, saving it for collecting documents and scores");
                this.compoundQueryScorer = (HybridQueryScorer)scorer;
            } else {
                this.compoundQueryScorer = this.getHybridQueryScorer(scorer);
                if (Objects.isNull((Object)this.compoundQueryScorer)) {
                    log.error(String.format(Locale.ROOT, "cannot find scorer of type HybridQueryScorer in a hierarchy of scorer %s", scorer));
                }
            }
        }

        private HybridQueryScorer getHybridQueryScorer(Scorable scorer) throws IOException {
            if (scorer == null) {
                return null;
            }
            if (scorer instanceof HybridQueryScorer) {
                return (HybridQueryScorer)scorer;
            }
            for (Scorable.ChildScorable childScorable : scorer.getChildren()) {
                HybridQueryScorer hybridQueryScorer = this.getHybridQueryScorer(childScorable.child());
                if (!Objects.nonNull((Object)hybridQueryScorer)) continue;
                log.debug(String.format(Locale.ROOT, "found hybrid query scorer, it's child of scorer %s", childScorable.child().getClass().getSimpleName()));
                return hybridQueryScorer;
            }
            return null;
        }

        protected void incrementTotalHitCount() throws IOException {
            ++HybridTopFieldDocSortCollector.this.totalHits;
            HybridTopFieldDocSortCollector.this.hitsThresholdChecker.incrementHitCount();
            if (!HybridTopFieldDocSortCollector.this.scoreMode().isExhaustive() && HybridTopFieldDocSortCollector.this.getTotalHitsRelation() == TotalHits.Relation.EQUAL_TO && HybridTopFieldDocSortCollector.this.hitsThresholdChecker.isThresholdReached()) {
                for (LeafFieldComparator comparator : HybridTopFieldDocSortCollector.this.comparators) {
                    comparator.setHitsThresholdReached();
                }
                HybridTopFieldDocSortCollector.this.setTotalHitsRelation(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO);
            }
        }

        protected void collectHit(int doc, int hitsCollected, int subQueryNumber, float score) throws IOException {
            int slot = hitsCollected - 1;
            if (HybridTopFieldDocSortCollector.this.numHits > 0) {
                HybridTopFieldDocSortCollector.this.comparators[subQueryNumber].copy(slot, doc);
                HybridTopFieldDocSortCollector.this.add(slot, doc, HybridTopFieldDocSortCollector.this.compoundScores[subQueryNumber], subQueryNumber, score);
                if (HybridTopFieldDocSortCollector.this.queueFull[subQueryNumber]) {
                    HybridTopFieldDocSortCollector.this.comparators[subQueryNumber].setBottom(HybridTopFieldDocSortCollector.this.fieldValueLeafTrackers[subQueryNumber].slot);
                }
            } else {
                HybridTopFieldDocSortCollector.this.queueFull[subQueryNumber] = true;
            }
        }

        protected void collectCompetitiveHit(int doc, int subQueryNumber) throws IOException {
            if (HybridTopFieldDocSortCollector.this.numHits > 0) {
                HybridTopFieldDocSortCollector.this.comparators[subQueryNumber].copy(HybridTopFieldDocSortCollector.this.fieldValueLeafTrackers[subQueryNumber].slot, doc);
                HybridTopFieldDocSortCollector.this.updateBottom(doc, HybridTopFieldDocSortCollector.this.compoundScores[subQueryNumber], subQueryNumber);
                HybridTopFieldDocSortCollector.this.comparators[subQueryNumber].setBottom(HybridTopFieldDocSortCollector.this.fieldValueLeafTrackers[subQueryNumber].slot);
            }
        }

        protected boolean thresholdCheck(int doc, int subQueryNumber) throws IOException {
            if (this.collectedAllCompetitiveHits || HybridTopFieldDocSortCollector.this.reverseMul * HybridTopFieldDocSortCollector.this.comparators[subQueryNumber].compareBottom(doc) <= 0) {
                if (HybridTopFieldDocSortCollector.this.searchSortPartOfIndexSort.booleanValue()) {
                    if (HybridTopFieldDocSortCollector.this.hitsThresholdChecker.isThresholdReached()) {
                        HybridTopFieldDocSortCollector.this.setTotalHitsRelation(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO);
                        log.info("Terminating collection as hits threshold is reached");
                        throw new CollectionTerminatedException();
                    }
                    this.collectedAllCompetitiveHits = true;
                }
                return true;
            }
            return false;
        }

        protected void initializePriorityQueuesWithComparators(LeafReaderContext context, int numberOfSubQueries) throws IOException {
            int i;
            if (HybridTopFieldDocSortCollector.this.needsInitialization) {
                HybridTopFieldDocSortCollector.this.compoundScores = new FieldValueHitQueue[numberOfSubQueries];
                HybridTopFieldDocSortCollector.this.comparators = new LeafFieldComparator[numberOfSubQueries];
                HybridTopFieldDocSortCollector.this.queueFull = new boolean[numberOfSubQueries];
                HybridTopFieldDocSortCollector.this.collectedHits = new int[numberOfSubQueries];
                for (i = 0; i < numberOfSubQueries; ++i) {
                    this.initializeLeafFieldComparators(context, i);
                }
                HybridTopFieldDocSortCollector.this.fieldValueLeafTrackers = new FieldValueHitQueue.Entry[numberOfSubQueries];
                HybridTopFieldDocSortCollector.this.needsInitialization = false;
            }
            if (this.initializeLeafComparatorsPerSegmentOnce) {
                for (i = 0; i < numberOfSubQueries; ++i) {
                    this.initializeComparators(context, i);
                }
                this.initializeLeafComparatorsPerSegmentOnce = false;
            }
        }

        private void initializeLeafFieldComparators(LeafReaderContext context, int subQueryNumber) throws IOException {
            HybridTopFieldDocSortCollector.this.compoundScores[subQueryNumber] = FieldValueHitQueue.create((SortField[])HybridTopFieldDocSortCollector.this.sort.getSort(), (int)HybridTopFieldDocSortCollector.this.numHits);
            HybridTopFieldDocSortCollector.this.firstComparator = HybridTopFieldDocSortCollector.this.compoundScores[subQueryNumber].getComparators()[0];
            if (HybridTopFieldDocSortCollector.this.compoundScores[subQueryNumber].getComparators().length == 1) {
                HybridTopFieldDocSortCollector.this.firstComparator.setSingleSort();
            }
            if (HybridTopFieldDocSortCollector.this.after != null) {
                this.setAfterFieldValueInFieldCompartor(subQueryNumber);
            }
        }

        private void initializeComparators(LeafReaderContext context, int subQueryNumber) throws IOException {
            if (HybridTopFieldDocSortCollector.this.searchSortPartOfIndexSort == null) {
                Sort indexSort = context.reader().getMetaData().sort();
                HybridTopFieldDocSortCollector.this.searchSortPartOfIndexSort = HybridTopFieldDocSortCollector.this.canEarlyTerminate(HybridTopFieldDocSortCollector.this.sort, indexSort);
                if (HybridTopFieldDocSortCollector.this.searchSortPartOfIndexSort.booleanValue()) {
                    HybridTopFieldDocSortCollector.this.firstComparator.disableSkipping();
                }
            }
            LeafFieldComparator[] leafFieldComparators = HybridTopFieldDocSortCollector.this.compoundScores[subQueryNumber].getComparators(context);
            int[] reverseMuls = HybridTopFieldDocSortCollector.this.compoundScores[subQueryNumber].getReverseMul();
            if (leafFieldComparators.length == 1) {
                HybridTopFieldDocSortCollector.this.reverseMul = reverseMuls[0];
                HybridTopFieldDocSortCollector.this.comparators[subQueryNumber] = leafFieldComparators[0];
            } else {
                HybridTopFieldDocSortCollector.this.reverseMul = 1;
                HybridTopFieldDocSortCollector.this.comparators[subQueryNumber] = new MultiLeafFieldComparator(leafFieldComparators, reverseMuls);
            }
            HybridTopFieldDocSortCollector.this.comparators[subQueryNumber].setScorer((Scorable)this.compoundQueryScorer);
        }

        private void setAfterFieldValueInFieldCompartor(int subQueryNumber) {
            FieldComparator[] fieldComparators = HybridTopFieldDocSortCollector.this.compoundScores[subQueryNumber].getComparators();
            for (int k = 0; k < fieldComparators.length; ++k) {
                FieldComparator fieldComparator = fieldComparators[k];
                fieldComparator.setTopValue(HybridTopFieldDocSortCollector.this.after.fields[k]);
            }
        }
    }
}

