redis使用lua操作json redis lua java_限流

  1. 定义一个注解,设置限流属性(时间窗,次数,限流类型,存入redis的前缀)
  2. lua脚本,每访问一次,redis中key的值 incr,,在第一次的时候给key设置过期时间,最后将key的值返回给java判断

问题:

  • 获取request
HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest()
  • 获取包名
method.getDeclaringClass().getName()
  • Collections.singletonList被限定只被分配一个内存空间,也就是只能存放一个元素的内容。
  • 写自己的redisTemplate
/*
    * jdk序列化方案,存的时候会 一堆前缀
    * redis提供的 对象的redisTemplate,序列化,会有很多前缀,,
    * 自己写一个redisTemplate,更改他的序列化方式
    * */
    @Bean
    RedisTemplate<Object,Object> redisTemplate(RedisConnectionFactory redisConnectionFactory){
        RedisTemplate<Object, Object> template = new RedisTemplate<>();
        template.setConnectionFactory(redisConnectionFactory);
        // 设置序列化器   : 使用jackson的序列化器
        Jackson2JsonRedisSerializer<Object> serializer = new Jackson2JsonRedisSerializer<>(Object.class);
        template.setKeySerializer(serializer);
        template.setHashKeySerializer(serializer);
        template.setValueSerializer(serializer);
        template.setHashValueSerializer(serializer);
        return template;
    }
  • redis定义脚本
@Bean
    DefaultRedisScript<Long>  limitScript(){
        DefaultRedisScript<Long> script = new DefaultRedisScript<>();
        script.setResultType(Long.class);
        // 设置脚本位置
        script.setScriptSource(new ResourceScriptSource(new ClassPathResource("lua/limit.lua")));
        return script;
    }
  • 执行redis脚本
Long number = (Long) redisTemplate.execute(redisScript, Collections.singletonList(combineKey), time, count);

代码:
lua脚本

local key = KEYS[1]
local time = tonumber(ARGV[1])
-- 限流次数
local count = tonumber(ARGV[2])

-- 获取redis中的key ,,, 如果接口没有调用过,,有可能没有值,
local current = redis.call('get',key)

-- current超过限流了
if current and tonumber(current)>count then
    return tonumber(current)
end

-- 第一次访问  ,,自增1  。。 并发的时候,另外一个线程可能也执行了自增1
current =  redis.call("incr",key)

-- 没有其他线程
if tonumber(current)==1 then
    redis.call("expire",key,time)
end

-- 不等于1,,其他线程已经设置过期时间
return tonumber(current)


-- 返回current,,在java中判断是限流还是放行

自定义异常:

public class RateLimitException extends Exception {
    public RateLimitException(String message) {
        super(message);
    }
}
@RestControllerAdvice
public class GlobalException {

    @ExceptionHandler(RateLimitException.class)
    public Map<String,Object> rateLimitException(RateLimitException e){
        HashMap<String, Object> map = new HashMap<>();
        map.put("status",500);
        map.put("message",e.getMessage());
        return map;
    }
}

限流类型:

public enum LimitType {
    /**
     * 默认限流策略 , 针对某个接口进行限流,,某个接口在一定时间只能访问多少次
     */
    DEFAULT,
    /**
     * 根据IP限流,,某个IP在一定时间只能访问多少次
     */
    IP
}

限流注解:

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface RateLimiter {
    /**
     * 限流 在redis中的 前缀
     */
    String key() default "rate_limit";


    /**
     * 限流时间窗
     */
    int time() default 60;

    /**
     * 在时间窗内的 限流次数
     */
    int count() default 100;

    /**
     * 限流类型
     */
    LimitType limitType() default LimitType.DEFAULT;

}

redis配置文件

@Configuration
public class RedisConfig {

    /*
    * jdk序列化方案,存的时候会 一堆前缀
    * redis提供的 对象的redisTemplate,序列化,会有很多前缀,,
    * 自己写一个redisTemplate,更改他的序列化方式
    * */
    @Bean
    RedisTemplate<Object,Object> redisTemplate(RedisConnectionFactory redisConnectionFactory){
        RedisTemplate<Object, Object> template = new RedisTemplate<>();
        template.setConnectionFactory(redisConnectionFactory);
        // 设置序列化器   : 使用jackson的序列化器
        Jackson2JsonRedisSerializer<Object> serializer = new Jackson2JsonRedisSerializer<>(Object.class);
        template.setKeySerializer(serializer);
        template.setHashKeySerializer(serializer);
        template.setValueSerializer(serializer);
        template.setHashValueSerializer(serializer);
        return template;
    }

    @Bean
    DefaultRedisScript<Long>  limitScript(){
        DefaultRedisScript<Long> script = new DefaultRedisScript<>();
        script.setResultType(Long.class);
        // 设置脚本位置
        script.setScriptSource(new ResourceScriptSource(new ClassPathResource("lua/limit.lua")));
        return script;
    }
}

ip工具类: 获取访问的ip

/**
 * 获取用户访问ip地址
 */
public class IpUtils {
    public static String getIpAddress(HttpServletRequest request) {
        String ip = request.getHeader("x-forwarded-for");
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("WL-Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("HTTP_CLIENT_IP");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("HTTP_X_FORWARDED_FOR");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getRemoteAddr();
        }


        // 获取到多个ip时取第一个作为客户端真实ip
        if(ip !=null && ip.length() != 0 && ip.contains(",")){
            String[] ipArray = ip.split(",");
            ip = ipArray[0];
        }

        return ip;
    }
    
}

切面

/**
     * 拦截 rateLimiter注解   限流的话,还没进入方法就拦截
     * @param jp
     * @param rateLimiter
     */
    @Before("@annotation(rateLimiter)")
    public void before(JoinPoint jp, RateLimiter rateLimiter) throws RateLimitException {

        int time = rateLimiter.time();
        int count = rateLimiter.count();

        String combineKey = getCombineKey(rateLimiter,jp);

        try {
            // Collections.singletonList被限定只被分配一个内存空间,也就是只能存放一个元素的内容。
            Long number = (Long) redisTemplate.execute(redisScript, Collections.singletonList(combineKey), time, count);
            // number 什么情况下为 null
            if(number == null || number.intValue() > count){
                logger.info("当前接口已达到最大限流次数");
                throw  new RateLimitException("访问频繁,请稍后访问");
            }
            logger.info("一个时间窗内请求次数:{},当前请求次数:{} ,缓存key为:{}",count,number,combineKey);
        } catch (Exception e) {

           throw  e;
        }

    }

    private String getCombineKey(RateLimiter rateLimiter, JoinPoint jp) {
        StringBuilder sb = new StringBuilder(rateLimiter.key());

        LimitType limitType = rateLimiter.limitType();
        // 拼接 ip
        if(limitType == LimitType.IP){
            // 获取ip
            HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
            sb.append(IpUtils.getIpAddress(request)).append("-");
        }

        MethodSignature signature = (MethodSignature) jp.getSignature();
        Method method = signature.getMethod();
        sb.append(method.getDeclaringClass().getName()).append("-").append(method.getName());
        return sb.toString();
    }

测试

@RestController
public class HelloController {
    /**
     * 10s 内,访问 3 次
     * @return
     */
    @RateLimiter(time = 10,count = 3)
    @GetMapping("/hello")
    public String hello(HttpServletRequest request){
        String ip = IpUtils.getIpAddress(request);
        return "hello  "+ip;
    }
}