/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.training;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.lang.ArrayUtils;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.memory.NativeMemoryAllocation;
import org.opensearch.knn.jni.JNICommons;
import org.opensearch.knn.quantization.factory.QuantizerFactory;
import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
import org.opensearch.knn.quantization.models.requests.TrainingRequest;
import org.opensearch.knn.quantization.quantizer.Quantizer;
import org.opensearch.knn.training.TrainingDataConsumer;
import org.opensearch.search.SearchHit;

public class FloatTrainingDataConsumer
extends TrainingDataConsumer {
    private final QuantizationConfig quantizationConfig;

    public FloatTrainingDataConsumer(NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation) {
        super(trainingDataAllocation);
        this.quantizationConfig = trainingDataAllocation.getQuantizationConfig();
    }

    @Override
    public void accept(List<?> floats) {
        if (this.isValidFloatsAndQuantizationConfig(floats)) {
            try {
                List<byte[]> byteVectors = this.quantizeVectors(floats);
                long memoryAddress = this.trainingDataAllocation.getMemoryAddress();
                memoryAddress = JNICommons.storeBinaryVectorData(memoryAddress, (byte[][])byteVectors.toArray((T[])new byte[0][0]), byteVectors.size());
                this.trainingDataAllocation.setMemoryAddress(memoryAddress);
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
        } else {
            this.trainingDataAllocation.setMemoryAddress(JNICommons.storeVectorData(this.trainingDataAllocation.getMemoryAddress(), (float[][])floats.stream().map(v -> ArrayUtils.toPrimitive((Float[])((Float[])v))).toArray(x$0 -> new float[x$0][]), floats.size()));
        }
    }

    @Override
    public void processTrainingVectors(SearchResponse searchResponse, int vectorsToAdd, String fieldName) {
        SearchHit[] hits = searchResponse.getHits().getHits();
        ArrayList<Float[]> vectors = new ArrayList<Float[]>();
        String[] fieldPath = fieldName.split("\\.");
        for (int vector = 0; vector < vectorsToAdd; ++vector) {
            Object fieldValue = this.extractFieldValue(hits[vector], fieldPath);
            if (!(fieldValue instanceof List)) continue;
            List fieldList = (List)fieldValue;
            vectors.add((Float[])fieldList.stream().map(Number::floatValue).toArray(Float[]::new));
        }
        this.setTotalVectorsCountAdded(this.getTotalVectorsCountAdded() + vectors.size());
        this.accept(vectors);
    }

    private List<byte[]> quantizeVectors(final List<?> vectors) throws IOException {
        ArrayList<byte[]> bytes = new ArrayList<byte[]>();
        ScalarQuantizationParams quantizationParams = new ScalarQuantizationParams(this.quantizationConfig.getQuantizationType());
        Quantizer<float[], byte[]> quantizer = QuantizerFactory.getQuantizer(quantizationParams);
        TrainingRequest<float[]> trainingRequest = new TrainingRequest<float[]>(this, vectors.size()){

            @Override
            public float[] getVectorAtThePosition(int position) {
                return ArrayUtils.toPrimitive((Float[])((Float[])vectors.get(position)));
            }
        };
        QuantizationState quantizationState = quantizer.train(trainingRequest);
        BinaryQuantizationOutput binaryQuantizationOutput = new BinaryQuantizationOutput(this.quantizationConfig.getQuantizationType().getId());
        for (int i = 0; i < vectors.size(); ++i) {
            quantizer.quantize(ArrayUtils.toPrimitive((Float[])((Float[])vectors.get(i))), quantizationState, binaryQuantizationOutput);
            bytes.add(binaryQuantizationOutput.getQuantizedVectorCopy());
        }
        return bytes;
    }

    private boolean isValidFloatsAndQuantizationConfig(List<?> floats) {
        return floats != null && !floats.isEmpty() && this.quantizationConfig != null && this.quantizationConfig != QuantizationConfig.EMPTY;
    }
}

