package cn.com.duiba.spring.boot.starter.dsp.rateLimiter;

import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.stereotype.Component;
import javax.annotation.Resource;
import java.lang.reflect.Method;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;

@Aspect
@Component
public class LimitAspect {

    private final static Map<Long, Integer> redisNodeIndex = new ConcurrentHashMap<>();

    private final static Logger logger = LoggerFactory.getLogger(LimitAspect.class);

    private final static Map<String, AtomicLong> map = new ConcurrentHashMap<>();

    private final static Integer REDIS_NODE_NUM = 4;

    static {
        map.put("rate.limit:com.duiba.tuia.adx.web.service.algo.impl.AdxAlgoServiceImpl-hello-limit", new AtomicLong());

        redisNodeIndex.put(0L, 14);
        redisNodeIndex.put(1L, 18);
        redisNodeIndex.put(2L, 296);
        redisNodeIndex.put(3L, 346);
        redisNodeIndex.put(4L, 29);
        redisNodeIndex.put(5L, 21);
        redisNodeIndex.put(6L, 32);
        redisNodeIndex.put(7L, 122);
        redisNodeIndex.put(8L, 69);
        redisNodeIndex.put(9L, 65);
        redisNodeIndex.put(10L, 76);
        redisNodeIndex.put(11L, 197);
        redisNodeIndex.put(12L, 50);
        redisNodeIndex.put(13L, 58);
        redisNodeIndex.put(14L, 223);
        redisNodeIndex.put(15L, 47);
    }

    @Resource(name = "redis03StringRedisTemplate")
    private StringRedisTemplate stringRedisTemplate;

    @Autowired
    private DefaultRedisScript<Long> redisLuaScript;

    @Autowired
    private RateLimitProperties rateLimitProperties;

    @Pointcut(value = "@annotation(cn.com.duiba.spring.boot.starter.dsp.rateLimiter.RateLimit)")
    public void rateLimitPointcut() {
        // 点击注解切入.
    }

    @Around(value = "rateLimitPointcut()")
    public Object interceptor(ProceedingJoinPoint joinPoint) throws Throwable {
        if (!rateLimitProperties.isAdxRateLimitSwitch()) {
            return joinPoint.proceed();
        }

        MethodSignature signature = (MethodSignature) joinPoint.getSignature();
        Method method = signature.getMethod();
        Class<?> targetClass = method.getDeclaringClass();
        RateLimit rateLimit = method.getAnnotation(RateLimit.class);

        if (rateLimit == null) {
            return joinPoint.proceed();
        }

        StringBuilder stringBuilder = new StringBuilder();
        stringBuilder.append("rate.limit:").
                append(targetClass.getName()).append("-")
                .append(method.getName()).append("-")
                .append(rateLimit.key());

        String commonRedisKey = stringBuilder.toString();

        long index = map.get(commonRedisKey).getAndIncrement() % REDIS_NODE_NUM;

        String redisKey = commonRedisKey + "{" + redisNodeIndex.get(index) + "}";
        logger.info("限流啦, redis key{}", redisKey);

        List<String> keys = Collections.singletonList(redisKey);

        int totalLimitCount = rateLimit.count();
        int limitCount = totalLimitCount % REDIS_NODE_NUM > index ? totalLimitCount / REDIS_NODE_NUM + 1 : totalLimitCount / REDIS_NODE_NUM;

        Long number = stringRedisTemplate.execute(redisLuaScript, keys, String.valueOf(limitCount), String.valueOf(rateLimit.time()));

        if (number != null && number != 0 && number <= limitCount) {
            logger.info("限流时间段内访问第：{} 次", number);
            return joinPoint.proceed();
        }

        //由于本文没有配置公共异常类，如果配置可替换
        throw new RuntimeException("已经到设置限流次数");
    }

}
