/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.index.memory;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheStats;
import com.google.common.cache.RemovalCause;
import com.google.common.cache.RemovalNotification;
import java.io.Closeable;
import java.util.Deque;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.ReentrantLock;
import lombok.Generated;
import org.apache.commons.lang.Validate;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.knn.common.exception.OutOfNativeMemoryException;
import org.opensearch.knn.common.featureflags.KNNFeatureFlags;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.memory.NativeMemoryAllocation;
import org.opensearch.knn.index.memory.NativeMemoryCacheManagerDto;
import org.opensearch.knn.index.memory.NativeMemoryEntryContext;
import org.opensearch.knn.plugin.stats.StatNames;
import org.opensearch.threadpool.Scheduler;
import org.opensearch.threadpool.ThreadPool;

public class NativeMemoryCacheManager
implements Closeable {
    public static String GRAPH_COUNT = "graph_count";
    private static final Logger logger = LogManager.getLogger(NativeMemoryCacheManager.class);
    private static NativeMemoryCacheManager INSTANCE;
    private static ThreadPool threadPool;
    private Cache<String, NativeMemoryAllocation> cache;
    private Deque<String> accessRecencyQueue;
    private final ConcurrentHashMap<String, ReentrantLock> indexLocks = new ConcurrentHashMap();
    private final ExecutorService executor = Executors.newSingleThreadExecutor();
    private AtomicBoolean cacheCapacityReached = new AtomicBoolean(false);
    private long maxWeight = Long.MAX_VALUE;
    private Scheduler.Cancellable maintenanceTask;

    NativeMemoryCacheManager() {
        this.initialize();
    }

    public static synchronized NativeMemoryCacheManager getInstance() {
        if (INSTANCE == null) {
            INSTANCE = new NativeMemoryCacheManager();
        }
        return INSTANCE;
    }

    private void initialize() {
        this.initialize(NativeMemoryCacheManagerDto.builder().isWeightLimited((Boolean)KNNSettings.state().getSettingValue("knn.memory.circuit_breaker.enabled")).maxWeight(KNNSettings.getClusterCbLimit().getKb()).isExpirationLimited((Boolean)KNNSettings.state().getSettingValue("knn.cache.item.expiry.enabled")).expiryTimeInMin(((TimeValue)KNNSettings.state().getSettingValue("knn.cache.item.expiry.minutes")).getMinutes()).build());
    }

    private void initialize(NativeMemoryCacheManagerDto nativeMemoryCacheDTO) {
        CacheBuilder cacheBuilder = CacheBuilder.newBuilder().recordStats().concurrencyLevel(1).removalListener(this::onRemoval);
        if (nativeMemoryCacheDTO.isWeightLimited()) {
            this.maxWeight = nativeMemoryCacheDTO.getMaxWeight();
            cacheBuilder.maximumWeight(this.maxWeight).weigher((k, v) -> v.getSizeInKB());
        }
        if (nativeMemoryCacheDTO.isExpirationLimited()) {
            cacheBuilder.expireAfterAccess(nativeMemoryCacheDTO.getExpiryTimeInMin(), TimeUnit.MINUTES);
        }
        this.cacheCapacityReached = new AtomicBoolean(false);
        this.accessRecencyQueue = new ConcurrentLinkedDeque<String>();
        this.cache = cacheBuilder.build();
        if (threadPool != null) {
            this.startMaintenance(this.cache);
        } else {
            logger.warn("ThreadPool is null during NativeMemoryCacheManager initialization. Maintenance will not start.");
        }
    }

    public synchronized void rebuildCache() {
        this.rebuildCache(NativeMemoryCacheManagerDto.builder().isWeightLimited((Boolean)KNNSettings.state().getSettingValue("knn.memory.circuit_breaker.enabled")).maxWeight(KNNSettings.state().getCircuitBreakerLimit().getKb()).isExpirationLimited((Boolean)KNNSettings.state().getSettingValue("knn.cache.item.expiry.enabled")).expiryTimeInMin(((TimeValue)KNNSettings.state().getSettingValue("knn.cache.item.expiry.minutes")).getMinutes()).build());
    }

    public synchronized void rebuildCache(NativeMemoryCacheManagerDto nativeMemoryCacheDTO) {
        logger.info("KNN Cache rebuilding.");
        this.executor.execute(() -> {
            this.cache.invalidateAll();
            this.initialize(nativeMemoryCacheDTO);
        });
    }

    @Override
    public void close() {
        this.executor.shutdown();
        if (this.maintenanceTask != null) {
            this.maintenanceTask.cancel();
        }
    }

    public long getCacheSizeInKilobytes() {
        return this.cache.asMap().values().stream().mapToLong(NativeMemoryAllocation::getSizeInKB).sum();
    }

    public Float getCacheSizeAsPercentage() {
        return this.getSizeAsPercentage(this.getCacheSizeInKilobytes());
    }

    public long getIndicesSizeInKilobytes() {
        return this.cache.asMap().values().stream().filter(nativeMemoryAllocation -> nativeMemoryAllocation instanceof NativeMemoryAllocation.IndexAllocation).mapToLong(NativeMemoryAllocation::getSizeInKB).sum();
    }

    public Float getIndicesSizeAsPercentage() {
        return this.getSizeAsPercentage(this.getIndicesSizeInKilobytes());
    }

    public Long getIndexSizeInKilobytes(String indexName) {
        Validate.notNull((Object)indexName, (String)"Index name cannot be null");
        return this.cache.asMap().values().stream().filter(nativeMemoryAllocation -> nativeMemoryAllocation instanceof NativeMemoryAllocation.IndexAllocation).filter(indexAllocation -> indexName.equals(((NativeMemoryAllocation.IndexAllocation)indexAllocation).getOpenSearchIndexName())).mapToLong(NativeMemoryAllocation::getSizeInKB).sum();
    }

    public Float getIndexSizeAsPercentage(String indexName) {
        Validate.notNull((Object)indexName, (String)"Index name cannot be null");
        return this.getSizeAsPercentage(this.getIndexSizeInKilobytes(indexName));
    }

    public long getTrainingSizeInKilobytes() {
        return this.cache.asMap().values().stream().filter(nativeMemoryAllocation -> nativeMemoryAllocation instanceof NativeMemoryAllocation.TrainingDataAllocation || nativeMemoryAllocation instanceof NativeMemoryAllocation.AnonymousAllocation).mapToLong(NativeMemoryAllocation::getSizeInKB).sum();
    }

    public Float getTrainingSizeAsPercentage() {
        return this.getSizeAsPercentage(this.getTrainingSizeInKilobytes());
    }

    public long getMaxCacheSizeInKilobytes() {
        return this.maxWeight;
    }

    public int getIndexGraphCount(String indexName) {
        Validate.notNull((Object)indexName, (String)"Index name cannot be null");
        return Long.valueOf(this.cache.asMap().values().stream().filter(nativeMemoryAllocation -> nativeMemoryAllocation instanceof NativeMemoryAllocation.IndexAllocation).filter(indexAllocation -> indexName.equals(((NativeMemoryAllocation.IndexAllocation)indexAllocation).getOpenSearchIndexName())).count()).intValue();
    }

    public CacheStats getCacheStats() {
        return this.cache.stats();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void open(String key, NativeMemoryEntryContext nativeMemoryEntryContext) {
        ReentrantLock indexFileLock = this.indexLocks.computeIfAbsent(key, k -> new ReentrantLock());
        try {
            indexFileLock.lock();
            nativeMemoryEntryContext.open();
        }
        finally {
            indexFileLock.unlock();
            if (!indexFileLock.hasQueuedThreads()) {
                this.indexLocks.remove(key, indexFileLock);
            }
        }
    }

    private NativeMemoryAllocation getFromCacheAndUpdateRecency(String key) {
        NativeMemoryAllocation result = (NativeMemoryAllocation)this.cache.getIfPresent((Object)key);
        if (result != null) {
            this.updateAccessRecency(key);
        }
        return result;
    }

    private void updateAccessRecency(String key) {
        this.accessRecencyQueue.remove(key);
        this.accessRecencyQueue.addLast(key);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public NativeMemoryAllocation get(NativeMemoryEntryContext<?> nativeMemoryEntryContext, boolean isAbleToTriggerEviction) throws ExecutionException {
        if (!isAbleToTriggerEviction && this.maxWeight - this.getCacheSizeInKilobytes() - (long)nativeMemoryEntryContext.calculateSizeInKB().intValue() <= 0L && !this.cache.asMap().containsKey(nativeMemoryEntryContext.getKey())) {
            throw new OutOfNativeMemoryException("Entry cannot be loaded into cache because it would not fit. Entry size: " + nativeMemoryEntryContext.calculateSizeInKB() + " KB Current Cache Size: " + this.getCacheSizeInKilobytes() + " KB Max Cache Size: " + this.maxWeight);
        }
        if (KNNFeatureFlags.isForceEvictCacheEnabled()) {
            String key = nativeMemoryEntryContext.getKey();
            NativeMemoryAllocation result = this.getFromCacheAndUpdateRecency(key);
            if (result != null) {
                return result;
            }
            this.open(key, nativeMemoryEntryContext);
            NativeMemoryCacheManager nativeMemoryCacheManager = this;
            synchronized (nativeMemoryCacheManager) {
                result = this.getFromCacheAndUpdateRecency(key);
                if (result != null) {
                    return result;
                }
                if (this.getCacheSizeInKilobytes() + (long)nativeMemoryEntryContext.calculateSizeInKB().intValue() >= this.maxWeight) {
                    Iterator<String> lruIterator = this.accessRecencyQueue.iterator();
                    while (lruIterator.hasNext() && this.getCacheSizeInKilobytes() + (long)nativeMemoryEntryContext.calculateSizeInKB().intValue() >= this.maxWeight) {
                        String keyToRemove = lruIterator.next();
                        NativeMemoryAllocation allocationToRemove = (NativeMemoryAllocation)this.cache.getIfPresent((Object)keyToRemove);
                        if (allocationToRemove != null) {
                            allocationToRemove.close();
                            this.cache.invalidate((Object)keyToRemove);
                        }
                        lruIterator.remove();
                    }
                }
                result = (NativeMemoryAllocation)this.cache.get((Object)key, nativeMemoryEntryContext::load);
                this.accessRecencyQueue.addLast(key);
                return result;
            }
        }
        try (NativeMemoryEntryContext<?> nativeMemoryEntryContext2 = nativeMemoryEntryContext;){
            String key = nativeMemoryEntryContext.getKey();
            this.open(key, nativeMemoryEntryContext);
            NativeMemoryAllocation nativeMemoryAllocation = (NativeMemoryAllocation)this.cache.get((Object)key, nativeMemoryEntryContext::load);
            return nativeMemoryAllocation;
        }
    }

    public Optional<NativeMemoryAllocation> getIndexMemoryAllocation(String indexName) {
        Validate.notNull((Object)indexName, (String)"Index name cannot be null");
        return this.cache.asMap().values().stream().filter(nativeMemoryAllocation -> nativeMemoryAllocation instanceof NativeMemoryAllocation.IndexAllocation).filter(indexAllocation -> indexName.equals(((NativeMemoryAllocation.IndexAllocation)indexAllocation).getOpenSearchIndexName())).findFirst();
    }

    public void invalidate(String key) {
        this.cache.invalidate((Object)key);
    }

    public void invalidateAll() {
        this.cache.invalidateAll();
    }

    public Boolean isCacheCapacityReached() {
        return this.cacheCapacityReached.get();
    }

    public void setCacheCapacityReached(Boolean value) {
        this.cacheCapacityReached.set(value);
    }

    public Map<String, Map<String, Object>> getIndicesCacheStats() {
        HashMap<String, Map<String, Object>> statValues = new HashMap<String, Map<String, Object>>();
        for (Map.Entry entry : this.cache.asMap().entrySet()) {
            if (!(entry.getValue() instanceof NativeMemoryAllocation.IndexAllocation)) continue;
            NativeMemoryAllocation.IndexAllocation indexAllocation = (NativeMemoryAllocation.IndexAllocation)entry.getValue();
            String indexName = indexAllocation.getOpenSearchIndexName();
            Map indexMap = statValues.computeIfAbsent(indexName, name -> new HashMap());
            indexMap.computeIfAbsent(GRAPH_COUNT, key -> this.getIndexGraphCount(indexName));
            indexMap.computeIfAbsent(StatNames.GRAPH_MEMORY_USAGE.getName(), key -> this.getIndexSizeInKilobytes(indexName));
            indexMap.computeIfAbsent(StatNames.GRAPH_MEMORY_USAGE_PERCENTAGE.getName(), key -> this.getIndexSizeAsPercentage(indexName));
        }
        return statValues;
    }

    private void onRemoval(RemovalNotification<String, NativeMemoryAllocation> removalNotification) {
        NativeMemoryAllocation nativeMemoryAllocation = (NativeMemoryAllocation)removalNotification.getValue();
        nativeMemoryAllocation.close();
        if (RemovalCause.SIZE == removalNotification.getCause()) {
            KNNSettings.state().updateCircuitBreakerSettings(true);
            this.setCacheCapacityReached(true);
        }
        logger.debug("[KNN] Cache evicted. Key {}, Reason: {}", removalNotification.getKey(), (Object)removalNotification.getCause());
    }

    private Float getSizeAsPercentage(long size) {
        long cbLimit = KNNSettings.state().getCircuitBreakerLimit().getKb();
        if (cbLimit == 0L) {
            return Float.valueOf(0.0f);
        }
        return Float.valueOf((float)(100L * size) / (float)cbLimit);
    }

    private void startMaintenance(Cache<String, NativeMemoryAllocation> cacheInstance) {
        if (this.maintenanceTask != null) {
            this.maintenanceTask.cancel();
        }
        Runnable cleanUp = () -> {
            try {
                cacheInstance.cleanUp();
            }
            catch (Exception e) {
                logger.error("Error cleaning up cache", (Throwable)e);
            }
        };
        TimeValue interval = (TimeValue)KNNSettings.state().getSettingValue("knn.cache.item.expiry.minutes");
        this.maintenanceTask = threadPool.scheduleWithFixedDelay(cleanUp, interval, "management");
    }

    @Generated
    public static void setThreadPool(ThreadPool threadPool) {
        NativeMemoryCacheManager.threadPool = threadPool;
    }

    @Generated
    public Scheduler.Cancellable getMaintenanceTask() {
        return this.maintenanceTask;
    }
}

