/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.memoryoptsearch.faiss;

import java.io.IOException;
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.IOSupplier;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.opensearch.knn.memoryoptsearch.VectorSearcher;
import org.opensearch.knn.memoryoptsearch.faiss.FaissHNSW;
import org.opensearch.knn.memoryoptsearch.faiss.FaissHnswGraph;
import org.opensearch.knn.memoryoptsearch.faiss.FaissIdMapIndex;
import org.opensearch.knn.memoryoptsearch.faiss.FaissIndex;

public class FaissMemoryOptimizedSearcher
implements VectorSearcher {
    private static final FlatVectorsScorer VECTOR_SCORER = FlatVectorScorerUtil.getLucene99FlatVectorsScorer();
    private final IndexInput indexInput;
    private final FaissIndex faissIndex;
    private final FaissHNSW hnsw;

    public FaissMemoryOptimizedSearcher(IndexInput indexInput) throws IOException {
        this.indexInput = indexInput;
        this.faissIndex = FaissIndex.load(indexInput);
        this.hnsw = FaissMemoryOptimizedSearcher.extractFaissHnsw(this.faissIndex);
    }

    private static FaissHNSW extractFaissHnsw(FaissIndex faissIndex) {
        if (faissIndex instanceof FaissIdMapIndex) {
            FaissIdMapIndex idMapIndex = (FaissIdMapIndex)faissIndex;
            return idMapIndex.getNestedIndex().getHnsw();
        }
        throw new IllegalArgumentException("Faiss index [" + faissIndex.getIndexType() + "] does not have HNSW as an index.");
    }

    @Override
    public void search(float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
        this.search(VectorEncoding.FLOAT32, (IOSupplier<RandomVectorScorer>)((IOSupplier)() -> VECTOR_SCORER.getRandomVectorScorer(this.faissIndex.getVectorSimilarityFunction().getVectorSimilarityFunction(), (KnnVectorValues)this.faissIndex.getFloatValues(this.indexInput), target)), knnCollector, acceptDocs);
    }

    @Override
    public void search(byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
        this.search(VectorEncoding.BYTE, (IOSupplier<RandomVectorScorer>)((IOSupplier)() -> VECTOR_SCORER.getRandomVectorScorer(this.faissIndex.getVectorSimilarityFunction().getVectorSimilarityFunction(), (KnnVectorValues)this.faissIndex.getByteValues(this.indexInput), target)), knnCollector, acceptDocs);
    }

    @Override
    public void close() throws IOException {
        this.indexInput.close();
    }

    private void search(VectorEncoding vectorEncoding, IOSupplier<RandomVectorScorer> scorerSupplier, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
        if (this.faissIndex.getTotalNumberOfVectors() == 0 || knnCollector.k() == 0) {
            return;
        }
        if (this.faissIndex.getVectorEncoding() != vectorEncoding) {
            throw new IllegalArgumentException("Search for vector encoding [" + String.valueOf(vectorEncoding) + "] is not supported in an index vector whose encoding is [" + String.valueOf(this.faissIndex.getVectorEncoding()) + "]");
        }
        RandomVectorScorer scorer = (RandomVectorScorer)scorerSupplier.get();
        OrdinalTranslatedKnnCollector collector = new OrdinalTranslatedKnnCollector(knnCollector, arg_0 -> ((RandomVectorScorer)scorer).ordToDoc(arg_0));
        Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs);
        if (knnCollector.k() < scorer.maxOrd()) {
            HnswGraphSearcher.search((RandomVectorScorer)scorer, (KnnCollector)collector, (HnswGraph)new FaissHnswGraph(this.hnsw, this.indexInput), (Bits)acceptedOrds);
        } else {
            for (int i = 0; i < scorer.maxOrd(); ++i) {
                if (acceptedOrds != null && !acceptedOrds.get(i)) continue;
                if (knnCollector.earlyTerminated()) break;
                knnCollector.incVisitedCount(1);
                knnCollector.collect(scorer.ordToDoc(i), scorer.score(i));
            }
        }
    }
}

