package cn.com.duiba.spring.boot.starter.dsp.model.service.impl;

import cn.com.duiba.nezha.alg.model.tf.LocalTFModel;
import cn.com.duiba.spring.boot.starter.dsp.model.service.AlgoTFModelFactory;
import cn.com.duiba.spring.boot.starter.dsp.model.service.AlgoTFModelProxy;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.MapUtils;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;

@Slf4j
@Component
public class AlgoTFModelFactoryImpl implements AlgoTFModelFactory {

    private final AtomicInteger atomicInteger = new AtomicInteger(0);

    // 使用了享元模式，防止大量同一对象的创建，消耗大量内存空间
    private final Map<String, AlgoTFModelProxy> proxyMap = new ConcurrentHashMap<>();

    @Override
    public LocalTFModel getTFModel(String tfKey) {

        if (proxyMap.containsKey(tfKey)) {
            AlgoTFModelProxy algoTFModelProxy = proxyMap.get(tfKey);
            return algoTFModelProxy.chooseTFModel();
        }

        if (atomicInteger.compareAndSet(0, 1)) {
            AlgoTFModelProxy proxy;
            try {
                proxy = new AlgoTFModelProxyImpl(tfKey);
                proxyMap.putIfAbsent(tfKey, proxy);
            } catch (Exception e) {
                log.warn("AlgoTFModelProxy init error", e);
                atomicInteger.set(0);
                return null;
            }
            atomicInteger.set(0);
            return proxy.chooseTFModel();
        }
        return null;
    }

    @Scheduled(fixedDelay = 2 * 60 * 1000)
    void updateTFModelTask() {

        if (MapUtils.isEmpty(proxyMap)) {
            return;
        }

        closeTFModels();
        if (isExistLoadingTFModels()) {
            return;
        }
        updateTFModels();
    }

    /**
     * 关闭TF模型：关闭超过2分钟没有被访问的tf模型
     */
    private void closeTFModels() {

        for (Map.Entry<String, AlgoTFModelProxy> entry : proxyMap.entrySet()) {
            AlgoTFModelProxy algoTFModelProxy = entry.getValue();
            if (Objects.isNull(algoTFModelProxy)) {
                continue;
            }
            algoTFModelProxy.closeTFModel();
            if (algoTFModelProxy.needFlush()) {
                proxyMap.remove(entry.getKey());
            }
        }
    }

    private boolean isExistLoadingTFModels() {
        for (AlgoTFModelProxy algoTFModelProxy : proxyMap.values()) {
            if (Objects.isNull(algoTFModelProxy)) {
                continue;
            }

            if (algoTFModelProxy.hasTwoRunningModel()){
                return true;
            }
        }
        return false;
    }

    /**
     * 更新tf模型
     */
    private void updateTFModels() {
        List<AlgoTFModelProxy> algoTFModelProxies = new ArrayList<>(proxyMap.values());
        algoTFModelProxies.sort(Comparator.comparingLong(x -> ((AlgoTFModelProxyImpl) x).getTfModelUpdateTime()));
        for (AlgoTFModelProxy algoTFModelProxy : algoTFModelProxies) {
            if (Objects.isNull(algoTFModelProxy)) {
                continue;
            }

            if (algoTFModelProxy.updateTFModel()) {
                return;
            }
        }


    }

}
