一、概述

限流的方式可基于客户端的访问 IP 进行限流,也可以对访问接口的频率进行限流。主要的实现思路是通过在 Redis 中设置 key 的过期时间以及该接口对应的访问次数,若大于访问次数,则不允许再访问接口,可对页面进行重定向或给出提示信息,本文则简单地实现核心的逻辑。

  项目结构如下图所示:

限流中间件 spring boot 接口限流springboot_ide

二、实现过程

2.1 项目依赖
<properties>
        <java.version>1.8</java.version>
        <hutool.version>5.3.4</hutool.version>
        <fastjson.version>1.2.60</fastjson.version>
        <commons-pool2.version>2.5.0</commons-pool2.version>
        <commons.codec>1.11</commons.codec>
    </properties>

    <dependencies>
        <!-- spring web 依赖 -->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>

        <!--aop 依赖-->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-aop</artifactId>
        </dependency>

        <!--Apache commons 工具包依赖-->
        <dependency>
            <groupId>org.apache.commons</groupId>
            <artifactId>commons-lang3</artifactId>
        </dependency>
        <dependency>
            <groupId>commons-codec</groupId>
            <artifactId>commons-codec</artifactId>
            <version>${commons.codec}</version>
        </dependency>

        <!-- lombok 依赖 -->
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
            <optional>true</optional>
        </dependency>

        <!--hutool 工具包-->
        <dependency>
            <groupId>cn.hutool</groupId>
            <artifactId>hutool-all</artifactId>
            <version>${hutool.version}</version>
        </dependency>

        <!--spring boot 集成 redis 所需 common-pool2-->
        <dependency>
            <groupId>org.apache.commons</groupId>
            <artifactId>commons-pool2</artifactId>
            <version>${commons-pool2.version}</version>
        </dependency>

        <!-- fastjson -->
        <dependency>
            <groupId>com.alibaba</groupId>
            <artifactId>fastjson</artifactId>
            <version>${fastjson.version}</version>
        </dependency>

        <!-- redis 依赖 -->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-data-redis</artifactId>
        </dependency>

        <!--guaua-->
        <dependency>
            <groupId>com.google.guava</groupId>
            <artifactId>guava</artifactId>
            <version>29.0-jre</version>
        </dependency>
    </dependencies>
2.2 限流枚举
public enum LimitType {
    CUSTOMER,
    IP;
}
2.3 限流注解
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface Limit {

    // 资源名称,用于描述接口实现的功能
    String name() default "";

    // 资源的 key
    String key() default "";

    // key 的前缀
    String prefix() default "";

    // 时间,单位是 秒
    int period() default 30;

    // 在 period 的时间内,限制的访问次数
    int count() default 5;

    // 限流类型
    LimitType limitType() default LimitType.CUSTOMER;
}
2.4 通用的请求响应实体
public class ResultDto extends HashMap<String, Object> {

    private static final long serialVersionUID = -4228527911405974961L;

    public ResultDto() {
    }

    public static ResultDto error(String msg) {
        return error(500, msg);
    }

    public static ResultDto error(int code, String msg) {
        ResultDto json = new ResultDto();
        json.put("code", code);
        json.put("msg", msg);
        return json;
    }

    public static ResultDto success(int code, Object value) {
        ResultDto json = new ResultDto();
        json.put("code", code);
        json.put("data", value);
        return json;
    }

    @Override
    public ResultDto put(String key, Object value) {
        super.put(key, value);
        return this;
    }
}
2.5 Redis 配置文件
@Slf4j
@Configuration
@EnableCaching
@ConditionalOnClass(RedisOperations.class)
@EnableConfigurationProperties(RedisProperties.class)
public class RedisConfig extends CachingConfigurerSupport {

    /*
     * 设置 redis 数据默认过期时间,默认 2 小时,即设置 @cacheable 序列化方式
     */
    @Bean
    public RedisCacheConfiguration redisCacheConfiguration() {

        FastJsonRedisSerializer<Object> fastJsonRedisSerializer = new FastJsonRedisSerializer<>(Object.class);

        // 获取 Redis 的默认配置
        RedisCacheConfiguration configuration = RedisCacheConfiguration.defaultCacheConfig();

        configuration = configuration
                .serializeValuesWith(RedisSerializationContext.SerializationPair.fromSerializer(fastJsonRedisSerializer)) // 设置 value 值的序列化器
                .entryTtl(Duration.ofHours(6)); // 设置默认的过期时间

        return configuration;
    }

    /*
     * 操作 redis 的模板
     */
    @SuppressWarnings("all")
    @Bean(name = "redisTemplate")
    @ConditionalOnMissingBean(name = "redisTemplate")
    public RedisTemplate<Object, Object> redisTemplate(RedisConnectionFactory redisConnectionFactory) {
        RedisTemplate<Object, Object> template = new RedisTemplate<>();

        // value 值的序列化器
        FastJsonRedisSerializer<Object> fastJsonRedisSerializer = new FastJsonRedisSerializer<>(Object.class);

        // value值的序列化采用 fastJsonRedisSerializer
        template.setValueSerializer(fastJsonRedisSerializer);
        template.setHashValueSerializer(fastJsonRedisSerializer);

        // 全局开启AutoType,这里方便开发,使用全局的方式
        ParserConfig.getGlobalInstance().setAutoTypeSupport(true);

        // 建议使用这种方式,小范围指定白名单
        // ParserConfig.getGlobalInstance().addAccept("club.wadreamer.domain");

        // key 的序列化器采用 StringRedisSerializer
        template.setKeySerializer(new StringRedisSerializer());
        template.setHashKeySerializer(new StringRedisSerializer());
        template.setConnectionFactory(redisConnectionFactory);

        return template;
    }

    /*
     * 自定义缓存 key 生成策略,默认将使用该策略,即在不指定缓存的 key 时,使用该 keyGenerator 生成 key
     */
    @Bean
    @Override
    public KeyGenerator keyGenerator() {
        return (target, method, params) -> {
            Map<String, Object> container = new HashMap<>(3);
            Class<?> targetClassClass = target.getClass();
            // 类地址
            container.put("class", targetClassClass.toGenericString());
            // 方法名称
            container.put("methodName", method.getName());
            // 包名称
            container.put("package", targetClassClass.getPackage());
            // 参数列表
            for (int i = 0; i < params.length; i++) {
                container.put(String.valueOf(i), params[i]);
            }
            // 转为JSON字符串
            String jsonString = JSON.toJSONString(container);
            // 做SHA256 Hash计算,得到一个SHA256摘要作为Key
            return DigestUtils.sha256Hex(jsonString);
        };
    }

    /*
     * 异常处理,当Redis发生异常时,打印日志,但是程序正常走
     */
    @Bean
    @Override
    public CacheErrorHandler errorHandler() {
        log.info("初始化 -> [{}]", "Redis CacheErrorHandler");
        return new CacheErrorHandler() {
            @Override
            public void handleCacheGetError(RuntimeException e, Cache cache, Object key) {
                log.error("Redis occur handleCacheGetError:key -> [{}]", key, e);
            }

            @Override
            public void handleCachePutError(RuntimeException e, Cache cache, Object key, Object value) {
                log.error("Redis occur handleCachePutError:key -> [{}];value -> [{}]", key, value, e);
            }

            @Override
            public void handleCacheEvictError(RuntimeException e, Cache cache, Object key) {
                log.error("Redis occur handleCacheEvictError:key -> [{}]", key, e);
            }

            @Override
            public void handleCacheClearError(RuntimeException e, Cache cache) {
                log.error("Redis occur handleCacheClearError:", e);
            }
        };
    }

}

/*
 * value 值的序列化器
 */
class FastJsonRedisSerializer<T> implements RedisSerializer<T> {

    private final Class<T> clazz;

    FastJsonRedisSerializer(Class<T> clazz) {
        super();
        this.clazz = clazz;
    }

    @Override
    public byte[] serialize(T t) {
        if (t == null) {
            return new byte[0];
        }

        // SerializerFeature.WriteClassName: JSON标准是不支持自省的,即不知道 json 字符串代表的数据类型,
        // 所以通过 SerializerFeature.WriteClassName, json 中会带有 @type 属性指定该 json 字符串所指代的类型
        return JSON.toJSONString(t, SerializerFeature.WriteClassName).getBytes(StandardCharsets.UTF_8);
    }

    @Override
    public T deserialize(byte[] bytes) {
        if (bytes == null || bytes.length <= 0) {
            return null;
        }
        String str = new String(bytes, StandardCharsets.UTF_8);
        return JSON.parseObject(str, clazz);
    }
}

/*
 * key 的序列化器
 */
class StringRedisSerializer implements RedisSerializer<Object> {

    private final Charset charset;

    StringRedisSerializer() {
        this(StandardCharsets.UTF_8);
    }

    private StringRedisSerializer(Charset charset) {
        Assert.notNull(charset, "Charset must not be null!");
        this.charset = charset;
    }

    @Override
    public String deserialize(byte[] bytes) {
        return (bytes == null ? null : new String(bytes, charset));
    }

    @Override
    public byte[] serialize(Object object) {
        String string = JSON.toJSONString(object);
        if (StringUtils.isBlank(string)) {
            return null;
        }
        string = string.replace("\"", "");
        return string.getBytes(charset);
    }
}
2.6 IPUtils 工具类
public class IPUtils {
    private static Logger logger = LoggerFactory.getLogger(IPUtils.class);
    private static final String IP_UTILS_FLAG = ",";
    private static final String UNKNOWN = "unknown";
    private static final String LOCALHOST_IP = "0:0:0:0:0:0:0:1";
    private static final String LOCALHOST_IP1 = "127.0.0.1";

    /**
     * 获取IP地址
     * 使用 Nginx 等反向代理软件, 则不能通过 request.getRemoteAddr() 获取IP地址
     * 如果使用了多级反向代理的话,X-Forwarded-For 的值并不止一个,而是一串 IP 地址,X-Forwarded-For 中第一个非 unknown 的有效IP字符串,则为真实IP地址
     */
    public static String getIpAddr(HttpServletRequest request) {
        String ip = null;
        try {
            //以下两个获取在 k8s 中,将真实的客户端IP,放到了 x-Original-Forwarded-For。而将 WAF 的回源地址放到了 x-Forwarded-For 了。
            ip = request.getHeader("X-Original-Forwarded-For");
            if (StringUtils.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
                ip = request.getHeader("X-Forwarded-For");
            }

            //获取 nginx 等代理的 ip
            if (StringUtils.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
                ip = request.getHeader("x-forwarded-for");
            }
            if (StringUtils.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
                ip = request.getHeader("Proxy-Client-IP");
            }
            if (StringUtils.isEmpty(ip) || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
                ip = request.getHeader("WL-Proxy-Client-IP");
            }
            if (StringUtils.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
                ip = request.getHeader("HTTP_CLIENT_IP");
            }
            if (StringUtils.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
                ip = request.getHeader("HTTP_X_FORWARDED_FOR");
            }

            //兼容k8s集群获取ip
            if (StringUtils.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
                ip = request.getRemoteAddr();
                if (LOCALHOST_IP1.equalsIgnoreCase(ip) || LOCALHOST_IP.equalsIgnoreCase(ip)) {
                    //根据网卡取本机配置的IP
                    InetAddress iNet = null;
                    try {
                        iNet = InetAddress.getLocalHost();
                    } catch (UnknownHostException e) {
                        logger.error("getClientIp error: {}", e);
                    }
                    ip = iNet.getHostAddress();
                }
            }
        } catch (Exception e) {
            logger.error("IPUtils ERROR ", e);
        }

        //使用代理,则获取第一个IP地址
        if (!StringUtils.isEmpty(ip) && ip.indexOf(IP_UTILS_FLAG) > 0) {
            ip = ip.substring(0, ip.indexOf(IP_UTILS_FLAG));
        }
        return ip;
    }
}
2.7 限流核心逻辑
@Aspect
@Component
public class LimitAspect {

    private final RedisTemplate<Object, Object> redisTemplate;

    public LimitAspect(RedisTemplate<Object, Object> redisTemplate) {
        this.redisTemplate = redisTemplate;
    }

    @Pointcut("@annotation(club.wadreamer.limit.annotation.Limit)")
    public void pointcut() {
    }

    @Around("pointcut()")
    public Object around(ProceedingJoinPoint joinPoint) throws Throwable {

        // 获取 HTTP 请求
        HttpServletRequest httpServletRequest = ((ServletRequestAttributes) Objects.requireNonNull(RequestContextHolder.getRequestAttributes())).getRequest();

        // 获取应用了切面的方法签名
        MethodSignature methodSignature = (MethodSignature) joinPoint.getSignature();

        // 获取方法名
        Method method = methodSignature.getMethod();

        // 获取标志在方法上的注解
        Limit limit = method.getAnnotation(Limit.class);

        LimitType limitType = limit.limitType();

        String key = limit.key();

        if (StringUtils.isEmpty(key)) {
            key = "_" + methodSignature.getName();
        }

        if (limitType == LimitType.IP) {
            key += "_" + IPUtils.getIpAddr(httpServletRequest);
        }

        // 生成应用在 Redis 中的 key
        ImmutableList<Object> keys = ImmutableList.of(StringUtils.join(limit.prefix(), "_", key,
                httpServletRequest.getRequestURI().replaceAll("/", "_")));

        // 获取 Lua 脚本
        String luaScript = buildLuaScript();

        // 生成执行 Lua 且返回值为 Number 的 redis 脚本
        DefaultRedisScript<Number> redisScript = new DefaultRedisScript<>(luaScript, Number.class);

        // 执行 Lua 脚本
        Number count = redisTemplate.execute(redisScript, keys, limit.count(), limit.period());

        if (Objects.nonNull(count) && count.intValue() <= limit.count()) {
            String.format("第{}次访问key为 {},描述为 [{}] 的接口", count, keys, limit.name());
            return joinPoint.proceed();
        } else {
            return ResultDto.error("接口访问次数受限");
        }
    }

    /*
     * KEYS 表示 keys,ARGV 表示 [limit.count(), limit.period()]
     * Lua 的数组默认是从 1 开始的
     */
    private String buildLuaScript() {
        return "local c" +
                "\nc = redis.call('get', KEYS[1])" +
                "\nif c and tonumber(c) > tonumber(ARGV[1]) then" +
                "\nreturn c" +
                "\nend" +
                "\nc = redis.call('incr', KEYS[1])" +
                "\nif tonumber(c) == 1 then" +
                "\nredis.call('expire', KEYS[1], ARGV[2])" +
                "\nend" +
                "\nreturn c";
    }
}
2.8 application.yml
server:
  port: 8080

spring:
  redis:
    #数据库索引
    database: ${REDIS_DB:1}
    host: ${REDIS_HOST:192.168.125.140}
    port: ${REDIS_PORT:6379}
    password: ${REDIS_PWD:}
    #连接超时时间
    timeout: 5000

三、测试

  接下来设计一个测试接口,并标识限流注解。

@RestController
@RequestMapping("/api/limit")
public class LimitController {

    private static final AtomicInteger ATOMIC_INTEGER = new AtomicInteger();

    /**
     * 测试限流注解,下面配置说明该接口基于 IP 进行限流, 60 秒内最多只能访问 10 次,保存到 redis 的键名为 limit_test_IP地址_接口,
     */
    @GetMapping
    @Limit(key = "test", period = 60, count = 10, name = "testLimit", prefix = "limit", limitType = LimitType.IP)
    public Object test() {
        return ResultDto.success(200, ATOMIC_INTEGER.incrementAndGet());
    }
}

  使用 postman 进行测试,结果如下:

限流中间件 spring boot 接口限流springboot_json_02


  使用 redis desktop manager 中的数据如下:

限流中间件 spring boot 接口限流springboot_限流中间件 spring boot_03