/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.index.codec.KNN990Codec;

import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Supplier;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.opensearch.common.StopWatch;
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.KNN990Codec.KNN990QuantizationStateWriter;
import org.opensearch.knn.index.codec.KNN990Codec.NativeEngineFieldVectorsWriter;
import org.opensearch.knn.index.codec.nativeindex.NativeIndexBuildStrategyFactory;
import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter;
import org.opensearch.knn.index.quantizationservice.QuantizationService;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory;
import org.opensearch.knn.plugin.stats.KNNGraphValue;
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

public class NativeEngines990KnnVectorsWriter
extends KnnVectorsWriter {
    @Generated
    private static final Logger log = LogManager.getLogger(NativeEngines990KnnVectorsWriter.class);
    private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(NativeEngines990KnnVectorsWriter.class);
    private final SegmentWriteState segmentWriteState;
    private final FlatVectorsWriter flatVectorsWriter;
    private KNN990QuantizationStateWriter quantizationStateWriter;
    private final List<NativeEngineFieldVectorsWriter<?>> fields = new ArrayList();
    private boolean finished;
    private final Integer approximateThreshold;
    private final NativeIndexBuildStrategyFactory nativeIndexBuildStrategyFactory;

    public NativeEngines990KnnVectorsWriter(SegmentWriteState segmentWriteState, FlatVectorsWriter flatVectorsWriter, Integer approximateThreshold, NativeIndexBuildStrategyFactory nativeIndexBuildStrategyFactory) {
        this.segmentWriteState = segmentWriteState;
        this.flatVectorsWriter = flatVectorsWriter;
        this.approximateThreshold = approximateThreshold;
        this.nativeIndexBuildStrategyFactory = nativeIndexBuildStrategyFactory;
    }

    public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
        NativeEngineFieldVectorsWriter<?> newField = NativeEngineFieldVectorsWriter.create(fieldInfo, this.flatVectorsWriter.addField(fieldInfo), this.segmentWriteState.infoStream);
        this.fields.add(newField);
        return newField;
    }

    public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
        this.flatVectorsWriter.flush(maxDoc, sortMap);
        for (NativeEngineFieldVectorsWriter<?> field : this.fields) {
            FieldInfo fieldInfo = field.getFieldInfo();
            VectorDataType vectorDataType = FieldInfoExtractor.extractVectorDataType(fieldInfo);
            int totalLiveDocs = field.getVectors().size();
            if (totalLiveDocs == 0) {
                log.debug("[Flush] No live docs for field {}", (Object)fieldInfo.getName());
                continue;
            }
            Supplier<KNNVectorValues<?>> knnVectorValuesSupplier = KNNVectorValuesFactory.getVectorValuesSupplier(vectorDataType, field.getFlatFieldVectorsWriter().getDocsWithFieldSet(), field.getVectors());
            QuantizationState quantizationState = this.train(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs);
            if (quantizationState == null && this.shouldSkipBuildingVectorDataStructure(totalLiveDocs)) {
                log.debug("Skip building vector data structure for field: {}, as liveDoc: {} is less than the threshold {} during flush", (Object)fieldInfo.name, (Object)totalLiveDocs, (Object)this.approximateThreshold);
                continue;
            }
            NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, this.segmentWriteState, quantizationState, this.nativeIndexBuildStrategyFactory);
            StopWatch stopWatch = new StopWatch().start();
            writer.flushIndex(knnVectorValuesSupplier, totalLiveDocs);
            long time_in_millis = stopWatch.stop().totalTime().millis();
            KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.incrementBy(time_in_millis);
            log.debug("Flush took {} ms for vector field [{}]", (Object)time_in_millis, (Object)fieldInfo.getName());
        }
    }

    public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
        this.flatVectorsWriter.mergeOneField(fieldInfo, mergeState);
        VectorDataType vectorDataType = FieldInfoExtractor.extractVectorDataType(fieldInfo);
        Supplier<KNNVectorValues<?>> knnVectorValuesSupplier = KNNVectorValuesFactory.getKNNVectorValuesSupplierForMerge(vectorDataType, fieldInfo, mergeState);
        int totalLiveDocs = this.getLiveDocs(knnVectorValuesSupplier.get());
        if (totalLiveDocs == 0) {
            log.debug("[Merge] No live docs for field {}", (Object)fieldInfo.getName());
            return;
        }
        QuantizationState quantizationState = this.train(fieldInfo, knnVectorValuesSupplier, totalLiveDocs);
        if (quantizationState == null && this.shouldSkipBuildingVectorDataStructure(totalLiveDocs)) {
            log.debug("Skip building vector data structure for field: {}, as liveDoc: {} is less than the threshold {} during merge", (Object)fieldInfo.name, (Object)totalLiveDocs, (Object)this.approximateThreshold);
            return;
        }
        NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, this.segmentWriteState, quantizationState, this.nativeIndexBuildStrategyFactory);
        StopWatch stopWatch = new StopWatch().start();
        writer.mergeIndex(knnVectorValuesSupplier, totalLiveDocs);
        long time_in_millis = stopWatch.stop().totalTime().millis();
        KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.incrementBy(time_in_millis);
        log.debug("Merge took {} ms for vector field [{}]", (Object)time_in_millis, (Object)fieldInfo.getName());
    }

    public void finish() throws IOException {
        if (this.finished) {
            throw new IllegalStateException("NativeEnginesKNNVectorsWriter is already finished");
        }
        this.finished = true;
        if (this.quantizationStateWriter != null) {
            this.quantizationStateWriter.writeFooter();
        }
        this.flatVectorsWriter.finish();
    }

    public void close() throws IOException {
        if (this.quantizationStateWriter != null) {
            this.quantizationStateWriter.closeOutput();
        }
        IOUtils.close((Closeable[])new Closeable[]{this.flatVectorsWriter});
    }

    public long ramBytesUsed() {
        return SHALLOW_SIZE + this.flatVectorsWriter.ramBytesUsed() + this.fields.stream().mapToLong(NativeEngineFieldVectorsWriter::ramBytesUsed).sum();
    }

    private QuantizationState train(FieldInfo fieldInfo, Supplier<KNNVectorValues<?>> knnVectorValuesSupplier, int totalLiveDocs) throws IOException {
        QuantizationService quantizationService = QuantizationService.getInstance();
        QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo);
        QuantizationState quantizationState = null;
        if (quantizationParams != null && totalLiveDocs > 0) {
            this.initQuantizationStateWriterIfNecessary();
            KNNVectorValues<?> knnVectorValues = knnVectorValuesSupplier.get();
            quantizationState = quantizationService.train(quantizationParams, knnVectorValues, totalLiveDocs);
            this.quantizationStateWriter.writeState(fieldInfo.getFieldNumber(), quantizationState);
        }
        return quantizationState;
    }

    private int getLiveDocs(KNNVectorValues<?> vectorValues) throws IOException {
        int liveDocs = 0;
        while (vectorValues.nextDoc() != Integer.MAX_VALUE) {
            ++liveDocs;
        }
        return liveDocs;
    }

    private void initQuantizationStateWriterIfNecessary() throws IOException {
        if (this.quantizationStateWriter == null) {
            this.quantizationStateWriter = new KNN990QuantizationStateWriter(this.segmentWriteState);
            this.quantizationStateWriter.writeHeader(this.segmentWriteState);
        }
    }

    private boolean shouldSkipBuildingVectorDataStructure(long docCount) {
        if (this.approximateThreshold < 0) {
            return true;
        }
        return docCount < (long)this.approximateThreshold.intValue();
    }
}

