一、概述
限流的方式可基于客户端的访问 IP 进行限流,也可以对访问接口的频率进行限流。主要的实现思路是通过在 Redis 中设置 key 的过期时间以及该接口对应的访问次数,若大于访问次数,则不允许再访问接口,可对页面进行重定向或给出提示信息,本文则简单地实现核心的逻辑。
项目结构如下图所示:
二、实现过程
2.1 项目依赖
<properties>
<java.version>1.8</java.version>
<hutool.version>5.3.4</hutool.version>
<fastjson.version>1.2.60</fastjson.version>
<commons-pool2.version>2.5.0</commons-pool2.version>
<commons.codec>1.11</commons.codec>
</properties>
<dependencies>
<!-- spring web 依赖 -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<!--aop 依赖-->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
</dependency>
<!--Apache commons 工具包依赖-->
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
</dependency>
<dependency>
<groupId>commons-codec</groupId>
<artifactId>commons-codec</artifactId>
<version>${commons.codec}</version>
</dependency>
<!-- lombok 依赖 -->
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<optional>true</optional>
</dependency>
<!--hutool 工具包-->
<dependency>
<groupId>cn.hutool</groupId>
<artifactId>hutool-all</artifactId>
<version>${hutool.version}</version>
</dependency>
<!--spring boot 集成 redis 所需 common-pool2-->
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-pool2</artifactId>
<version>${commons-pool2.version}</version>
</dependency>
<!-- fastjson -->
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>fastjson</artifactId>
<version>${fastjson.version}</version>
</dependency>
<!-- redis 依赖 -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
<!--guaua-->
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>29.0-jre</version>
</dependency>
</dependencies>
2.2 限流枚举
public enum LimitType {
CUSTOMER,
IP;
}
2.3 限流注解
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface Limit {
// 资源名称,用于描述接口实现的功能
String name() default "";
// 资源的 key
String key() default "";
// key 的前缀
String prefix() default "";
// 时间,单位是 秒
int period() default 30;
// 在 period 的时间内,限制的访问次数
int count() default 5;
// 限流类型
LimitType limitType() default LimitType.CUSTOMER;
}
2.4 通用的请求响应实体
public class ResultDto extends HashMap<String, Object> {
private static final long serialVersionUID = -4228527911405974961L;
public ResultDto() {
}
public static ResultDto error(String msg) {
return error(500, msg);
}
public static ResultDto error(int code, String msg) {
ResultDto json = new ResultDto();
json.put("code", code);
json.put("msg", msg);
return json;
}
public static ResultDto success(int code, Object value) {
ResultDto json = new ResultDto();
json.put("code", code);
json.put("data", value);
return json;
}
@Override
public ResultDto put(String key, Object value) {
super.put(key, value);
return this;
}
}
2.5 Redis 配置文件
@Slf4j
@Configuration
@EnableCaching
@ConditionalOnClass(RedisOperations.class)
@EnableConfigurationProperties(RedisProperties.class)
public class RedisConfig extends CachingConfigurerSupport {
/*
* 设置 redis 数据默认过期时间,默认 2 小时,即设置 @cacheable 序列化方式
*/
@Bean
public RedisCacheConfiguration redisCacheConfiguration() {
FastJsonRedisSerializer<Object> fastJsonRedisSerializer = new FastJsonRedisSerializer<>(Object.class);
// 获取 Redis 的默认配置
RedisCacheConfiguration configuration = RedisCacheConfiguration.defaultCacheConfig();
configuration = configuration
.serializeValuesWith(RedisSerializationContext.SerializationPair.fromSerializer(fastJsonRedisSerializer)) // 设置 value 值的序列化器
.entryTtl(Duration.ofHours(6)); // 设置默认的过期时间
return configuration;
}
/*
* 操作 redis 的模板
*/
@SuppressWarnings("all")
@Bean(name = "redisTemplate")
@ConditionalOnMissingBean(name = "redisTemplate")
public RedisTemplate<Object, Object> redisTemplate(RedisConnectionFactory redisConnectionFactory) {
RedisTemplate<Object, Object> template = new RedisTemplate<>();
// value 值的序列化器
FastJsonRedisSerializer<Object> fastJsonRedisSerializer = new FastJsonRedisSerializer<>(Object.class);
// value值的序列化采用 fastJsonRedisSerializer
template.setValueSerializer(fastJsonRedisSerializer);
template.setHashValueSerializer(fastJsonRedisSerializer);
// 全局开启AutoType,这里方便开发,使用全局的方式
ParserConfig.getGlobalInstance().setAutoTypeSupport(true);
// 建议使用这种方式,小范围指定白名单
// ParserConfig.getGlobalInstance().addAccept("club.wadreamer.domain");
// key 的序列化器采用 StringRedisSerializer
template.setKeySerializer(new StringRedisSerializer());
template.setHashKeySerializer(new StringRedisSerializer());
template.setConnectionFactory(redisConnectionFactory);
return template;
}
/*
* 自定义缓存 key 生成策略,默认将使用该策略,即在不指定缓存的 key 时,使用该 keyGenerator 生成 key
*/
@Bean
@Override
public KeyGenerator keyGenerator() {
return (target, method, params) -> {
Map<String, Object> container = new HashMap<>(3);
Class<?> targetClassClass = target.getClass();
// 类地址
container.put("class", targetClassClass.toGenericString());
// 方法名称
container.put("methodName", method.getName());
// 包名称
container.put("package", targetClassClass.getPackage());
// 参数列表
for (int i = 0; i < params.length; i++) {
container.put(String.valueOf(i), params[i]);
}
// 转为JSON字符串
String jsonString = JSON.toJSONString(container);
// 做SHA256 Hash计算,得到一个SHA256摘要作为Key
return DigestUtils.sha256Hex(jsonString);
};
}
/*
* 异常处理,当Redis发生异常时,打印日志,但是程序正常走
*/
@Bean
@Override
public CacheErrorHandler errorHandler() {
log.info("初始化 -> [{}]", "Redis CacheErrorHandler");
return new CacheErrorHandler() {
@Override
public void handleCacheGetError(RuntimeException e, Cache cache, Object key) {
log.error("Redis occur handleCacheGetError:key -> [{}]", key, e);
}
@Override
public void handleCachePutError(RuntimeException e, Cache cache, Object key, Object value) {
log.error("Redis occur handleCachePutError:key -> [{}];value -> [{}]", key, value, e);
}
@Override
public void handleCacheEvictError(RuntimeException e, Cache cache, Object key) {
log.error("Redis occur handleCacheEvictError:key -> [{}]", key, e);
}
@Override
public void handleCacheClearError(RuntimeException e, Cache cache) {
log.error("Redis occur handleCacheClearError:", e);
}
};
}
}
/*
* value 值的序列化器
*/
class FastJsonRedisSerializer<T> implements RedisSerializer<T> {
private final Class<T> clazz;
FastJsonRedisSerializer(Class<T> clazz) {
super();
this.clazz = clazz;
}
@Override
public byte[] serialize(T t) {
if (t == null) {
return new byte[0];
}
// SerializerFeature.WriteClassName: JSON标准是不支持自省的,即不知道 json 字符串代表的数据类型,
// 所以通过 SerializerFeature.WriteClassName, json 中会带有 @type 属性指定该 json 字符串所指代的类型
return JSON.toJSONString(t, SerializerFeature.WriteClassName).getBytes(StandardCharsets.UTF_8);
}
@Override
public T deserialize(byte[] bytes) {
if (bytes == null || bytes.length <= 0) {
return null;
}
String str = new String(bytes, StandardCharsets.UTF_8);
return JSON.parseObject(str, clazz);
}
}
/*
* key 的序列化器
*/
class StringRedisSerializer implements RedisSerializer<Object> {
private final Charset charset;
StringRedisSerializer() {
this(StandardCharsets.UTF_8);
}
private StringRedisSerializer(Charset charset) {
Assert.notNull(charset, "Charset must not be null!");
this.charset = charset;
}
@Override
public String deserialize(byte[] bytes) {
return (bytes == null ? null : new String(bytes, charset));
}
@Override
public byte[] serialize(Object object) {
String string = JSON.toJSONString(object);
if (StringUtils.isBlank(string)) {
return null;
}
string = string.replace("\"", "");
return string.getBytes(charset);
}
}
2.6 IPUtils 工具类
public class IPUtils {
private static Logger logger = LoggerFactory.getLogger(IPUtils.class);
private static final String IP_UTILS_FLAG = ",";
private static final String UNKNOWN = "unknown";
private static final String LOCALHOST_IP = "0:0:0:0:0:0:0:1";
private static final String LOCALHOST_IP1 = "127.0.0.1";
/**
* 获取IP地址
* 使用 Nginx 等反向代理软件, 则不能通过 request.getRemoteAddr() 获取IP地址
* 如果使用了多级反向代理的话,X-Forwarded-For 的值并不止一个,而是一串 IP 地址,X-Forwarded-For 中第一个非 unknown 的有效IP字符串,则为真实IP地址
*/
public static String getIpAddr(HttpServletRequest request) {
String ip = null;
try {
//以下两个获取在 k8s 中,将真实的客户端IP,放到了 x-Original-Forwarded-For。而将 WAF 的回源地址放到了 x-Forwarded-For 了。
ip = request.getHeader("X-Original-Forwarded-For");
if (StringUtils.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader("X-Forwarded-For");
}
//获取 nginx 等代理的 ip
if (StringUtils.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader("x-forwarded-for");
}
if (StringUtils.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader("Proxy-Client-IP");
}
if (StringUtils.isEmpty(ip) || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader("WL-Proxy-Client-IP");
}
if (StringUtils.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader("HTTP_CLIENT_IP");
}
if (StringUtils.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader("HTTP_X_FORWARDED_FOR");
}
//兼容k8s集群获取ip
if (StringUtils.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getRemoteAddr();
if (LOCALHOST_IP1.equalsIgnoreCase(ip) || LOCALHOST_IP.equalsIgnoreCase(ip)) {
//根据网卡取本机配置的IP
InetAddress iNet = null;
try {
iNet = InetAddress.getLocalHost();
} catch (UnknownHostException e) {
logger.error("getClientIp error: {}", e);
}
ip = iNet.getHostAddress();
}
}
} catch (Exception e) {
logger.error("IPUtils ERROR ", e);
}
//使用代理,则获取第一个IP地址
if (!StringUtils.isEmpty(ip) && ip.indexOf(IP_UTILS_FLAG) > 0) {
ip = ip.substring(0, ip.indexOf(IP_UTILS_FLAG));
}
return ip;
}
}
2.7 限流核心逻辑
@Aspect
@Component
public class LimitAspect {
private final RedisTemplate<Object, Object> redisTemplate;
public LimitAspect(RedisTemplate<Object, Object> redisTemplate) {
this.redisTemplate = redisTemplate;
}
@Pointcut("@annotation(club.wadreamer.limit.annotation.Limit)")
public void pointcut() {
}
@Around("pointcut()")
public Object around(ProceedingJoinPoint joinPoint) throws Throwable {
// 获取 HTTP 请求
HttpServletRequest httpServletRequest = ((ServletRequestAttributes) Objects.requireNonNull(RequestContextHolder.getRequestAttributes())).getRequest();
// 获取应用了切面的方法签名
MethodSignature methodSignature = (MethodSignature) joinPoint.getSignature();
// 获取方法名
Method method = methodSignature.getMethod();
// 获取标志在方法上的注解
Limit limit = method.getAnnotation(Limit.class);
LimitType limitType = limit.limitType();
String key = limit.key();
if (StringUtils.isEmpty(key)) {
key = "_" + methodSignature.getName();
}
if (limitType == LimitType.IP) {
key += "_" + IPUtils.getIpAddr(httpServletRequest);
}
// 生成应用在 Redis 中的 key
ImmutableList<Object> keys = ImmutableList.of(StringUtils.join(limit.prefix(), "_", key,
httpServletRequest.getRequestURI().replaceAll("/", "_")));
// 获取 Lua 脚本
String luaScript = buildLuaScript();
// 生成执行 Lua 且返回值为 Number 的 redis 脚本
DefaultRedisScript<Number> redisScript = new DefaultRedisScript<>(luaScript, Number.class);
// 执行 Lua 脚本
Number count = redisTemplate.execute(redisScript, keys, limit.count(), limit.period());
if (Objects.nonNull(count) && count.intValue() <= limit.count()) {
String.format("第{}次访问key为 {},描述为 [{}] 的接口", count, keys, limit.name());
return joinPoint.proceed();
} else {
return ResultDto.error("接口访问次数受限");
}
}
/*
* KEYS 表示 keys,ARGV 表示 [limit.count(), limit.period()]
* Lua 的数组默认是从 1 开始的
*/
private String buildLuaScript() {
return "local c" +
"\nc = redis.call('get', KEYS[1])" +
"\nif c and tonumber(c) > tonumber(ARGV[1]) then" +
"\nreturn c" +
"\nend" +
"\nc = redis.call('incr', KEYS[1])" +
"\nif tonumber(c) == 1 then" +
"\nredis.call('expire', KEYS[1], ARGV[2])" +
"\nend" +
"\nreturn c";
}
}
2.8 application.yml
server:
port: 8080
spring:
redis:
#数据库索引
database: ${REDIS_DB:1}
host: ${REDIS_HOST:192.168.125.140}
port: ${REDIS_PORT:6379}
password: ${REDIS_PWD:}
#连接超时时间
timeout: 5000
三、测试
接下来设计一个测试接口,并标识限流注解。
@RestController
@RequestMapping("/api/limit")
public class LimitController {
private static final AtomicInteger ATOMIC_INTEGER = new AtomicInteger();
/**
* 测试限流注解,下面配置说明该接口基于 IP 进行限流, 60 秒内最多只能访问 10 次,保存到 redis 的键名为 limit_test_IP地址_接口,
*/
@GetMapping
@Limit(key = "test", period = 60, count = 10, name = "testLimit", prefix = "limit", limitType = LimitType.IP)
public Object test() {
return ResultDto.success(200, ATOMIC_INTEGER.incrementAndGet());
}
}
使用 postman 进行测试,结果如下:
使用 redis desktop manager 中的数据如下: