通过aop来限制接口的请求
1.定义注解
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Limit {
/**
* 限流key
*/
String key() default "rate_limit:";
/**
* 限流时间,单位秒
*/
int time() default 60;
/**
* 限流次数
*/
int count() default 100;
/**
* 限流类型
*/
LimitType limitType() default LimitType.DEFAULT;
}
2.配置redis
@Configuration
public class RedisConfig {
@Bean
public RedisTemplate<Object, Object> redisTemplate(RedisConnectionFactory connectionFactory) {
RedisTemplate<Object, Object> redisTemplate = new RedisTemplate<>();
redisTemplate.setConnectionFactory(connectionFactory);
// 使用Jackson2JsonRedisSerialize 替换默认序列化(默认采用的是JDK序列化)
Jackson2JsonRedisSerializer<Object> jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer<>(Object.class);
ObjectMapper om = new ObjectMapper();
om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
om.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL);
jackson2JsonRedisSerializer.setObjectMapper(om);
redisTemplate.setKeySerializer(jackson2JsonRedisSerializer);
redisTemplate.setValueSerializer(jackson2JsonRedisSerializer);
redisTemplate.setHashKeySerializer(jackson2JsonRedisSerializer);
redisTemplate.setHashValueSerializer(jackson2JsonRedisSerializer);
return redisTemplate;
}
@Bean
public DefaultRedisScript<Long> limitScript() {
DefaultRedisScript<Long> redisScript = new DefaultRedisScript<>();
redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("lua/redis_limit.lua")));
redisScript.setResultType(Long.class);
return redisScript;
}
@Bean
public DefaultRedisScript<Boolean> limitScript1() {
DefaultRedisScript<Boolean> redisScript = new DefaultRedisScript<>();
redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("lua/redis_limit1.lua")));
redisScript.setResultType(Boolean.class);
return redisScript;
}
@Bean
public DefaultRedisScript<Long> limitScript2() {
DefaultRedisScript<Long> redisScript = new DefaultRedisScript<>();
redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("lua/redis_limit2.lua")));
redisScript.setResultType(Long.class);
return redisScript;
}
}
3.定义枚举类
public enum LimitType {
DEFAULT,IP
}
4.aop切面
/**
* 限流处理
*
*/
@Aspect
@Component
public class LimitAspect {
private static final Logger log = LoggerFactory.getLogger(LimitAspect.class);
@Resource
private RedisTemplate<Object, Object> redisTemplate;
@Resource
private RedisScript<Long> limitScript;
@Resource
private RedisScript<Boolean> limitScript1;
@Resource
private RedisScript<Long> limitScript2;
//简单限流
// @Before("@annotation(rateLimiter)")
// public void doBefore(JoinPoint point, Limit rateLimiter) throws Throwable {
// String key = rateLimiter.key();
// int time = rateLimiter.time();
// int count = rateLimiter.count();
//
// String combineKey = getCombineKey(rateLimiter, point);
// List<Object> keys = Collections.singletonList(combineKey);
// try {
// Long number = redisTemplate.execute(limitScript, keys, count, time);
// if (number==null || number.intValue() > count) {
// throw new ServiceException("访问过于频繁,请稍候再试");
// }
// log.info("限制请求'{}',当前请求'{}',缓存key'{}'", count, number.intValue(), key);
// } catch (ServiceException e) {
// throw e;
// } catch (Exception e) {
// throw new RuntimeException("服务器限流异常,请稍候再试");
// }
// }
//动态时间窗口限流
// @Before("@annotation(rateLimiter)")
// public void doBefore(JoinPoint point, Limit rateLimiter) throws Throwable {
// String key = rateLimiter.key();
// int time = rateLimiter.time();
// int count = rateLimiter.count();
//
// String combineKey = getCombineKey(rateLimiter, point);
// List<Object> keys = Collections.singletonList(combineKey);
// long now = System.currentTimeMillis();
// final long ms = TimeUnit.MILLISECONDS.convert(time, TimeUnit.SECONDS);
// try {
// Boolean number = redisTemplate.execute(limitScript1, keys, count, now,ms);
// if (!number) {
// throw new ServiceException("访问过于频繁,请稍候再试");
// }
// log.info("限制请求'{}',缓存key'{}'", count, key);
// } catch (ServiceException e) {
// throw e;
// } catch (Exception e) {
// throw new RuntimeException("服务器限流异常,请稍候再试");
// }
// }
//令牌桶限流
@Before("@annotation(rateLimiter)")
public void doBefore(JoinPoint point, Limit rateLimiter) throws Throwable {
String key = rateLimiter.key();
int time = rateLimiter.time();
int count = rateLimiter.count();
String combineKey = getCombineKey(rateLimiter, point);
List<Object> keys = Collections.singletonList(combineKey);
long now = System.currentTimeMillis();
final long ms = TimeUnit.MILLISECONDS.convert(time, TimeUnit.SECONDS);
try {
Long number = redisTemplate.execute(limitScript2, keys, time,count);
if (number<0) {
throw new ServiceException("访问过于频繁,请稍候再试");
}
log.info("限制请求'{}',当前令牌数'{}',缓存key'{}'", count, number,key);
} catch (ServiceException e) {
throw e;
} catch (Exception e) {
throw new RuntimeException("服务器限流异常,请稍候再试");
}
}
public String getCombineKey(Limit rateLimiter, JoinPoint point) {
StringBuffer stringBuffer = new StringBuffer(rateLimiter.key());
if (rateLimiter.limitType() == LimitType.IP) {
stringBuffer.append(IpUtils.getIpAddr(((ServletRequestAttributes) RequestContextHolder.currentRequestAttributes()).getRequest())).append("-");
}
MethodSignature signature = (MethodSignature) point.getSignature();
Method method = signature.getMethod();
Class<?> targetClass = method.getDeclaringClass();
stringBuffer.append(targetClass.getName()).append("-").append(method.getName());
return stringBuffer.toString();
}
}
5.limit.lua
---
--- Created by ct.
--- DateTime: 2022/6/29 17:06
---
---@Override
--- public <T> T execute(RedisScript<T> script, List<K> keys, Object... args) {
--- return scriptExecutor.execute(script, keys, args);
--- }
local key = KEYS[1] ---代表需要操作的key
local count = tonumber(ARGV[1]) ---第一个参数 次数
local time = tonumber(ARGV[2]) ---第二个参数 间隔时间
local current = redis.call('get', key) ---获取key 的值
if current and tonumber(current) > count then ---当这个key存在且小于规定的次数
return tonumber(current) ---返回当前key的值
end
current = redis.call('incr', key) ---把key可以加1 如果不存在key就把key赋值为1
if tonumber(current) == 1 then ---如果key是1
redis.call('expire', key, time) ---设置key的过期时间
end
return tonumber(current) ---返回当前的key的值
limit1.lua
---
---
--- Created by Ct.
--- DateTime: 2022/7/1 9:02
---
local key = KEYS[1] ---代表需要操作的key
local count = tonumber(ARGV[1]) ---第一个参数 次数
local time = tonumber(ARGV[2]) ---第二个参数 当前时间
local period = tonumber(ARGV[3]) ---第三个参数 间隔时间
redis.call('ZREMRANGEBYSCORE', key, 0, tonumber(time) - tonumber(period)); ---移除时间窗外的值
if (redis.call('ZCARD', key) >= tonumber(count)) then ---
return nil;
end ;
redis.call('ZADD', key, time, time);
redis.call('pexpire',key, period);
return true;
limit2.lua
---
--- Generated by EmmyLua(https://github.com/EmmyLua)
--- Created by Administrator.
--- DateTime: 2022/7/1 17:21
---
---[[
---根据key(参数) 查询 对应的 value(令牌数)
--- 如果为null 说明该key 是第一次进入
--- {
--- 初始化 令牌桶(参数)数量;记录初始化时间 ->返回 剩余令牌数
--- }
---
--- 如果不为null
--- {
--- 判断 value 是否大于1
--- {
--- 大于1 ->value - 1 -> 返回 剩余令牌数
--- 小于1 -> 判断 补充令牌时间间隔是否足够
--- {
--- 足够 -> 补充令牌;更新补充令牌时间-> 返回 剩余令牌数
--- 不足够 -> 返回 -1 (说明超过限流访问次数)
--- }
--- }
--- }]]
---
---脚本里使用了redis.call(time)这样的命令,获取当前服务器时间,传递到从节点执行的时候这个时间肯定会不一样,那不就造成了主从节点数据不一致的情况么
---抛出异常 Write commands not allowed after non deterministic commands
---redis.replicate_commands(); 在脚本第一行执行这个函数,Redis会将修改数据的命令收集起来,然后用MULTI/EXEC包裹起来,
---这种方式称为script effects replication,这个类似于mysql中的基于行的复制模式,将非纯函数的值计算出来,用来持久化和主从复制。
---
redis.replicate_commands();
-- 参数中传递的key
local key = KEYS[1]
-- 令牌桶填充 最小时间间隔
local update_len = tonumber(ARGV[1])
-- 记录 当前key上次更新令牌桶的时间的 key
local key_time = 'ratetokenprefix'..key
-- 获取当前时间(这里的curr_time_arr 中第一个是 秒数,第二个是 秒数后毫秒数),由于我是按秒计算的,这里只要curr_time_arr[1](注意:redis数组下标是从1开始的)
--如果需要获得毫秒数 则为 tonumber(arr[1]*1000 + arr[2])
local curr_time_arr = redis.call('TIME')
-- 当前时间秒数
local nowTime = tonumber(curr_time_arr[1])
-- 从redis中获取当前key 对应的上次更新令牌桶的key 对应的value
local curr_key_time = tonumber(redis.call('get',key_time) or 0)
-- 获取当前key对应令牌桶中的令牌数
local token_count = tonumber(redis.call('get',KEYS[1]) or -1)
-- 当前令牌桶的容量
local token_size = tonumber(ARGV[2])
-- 令牌桶数量小于0 说明令牌桶没有初始化
if token_count < 0 then
redis.call('set',key_time,nowTime)
redis.call('set',key,token_size -1)
return token_size -1
else
if token_count > 0 then --当前令牌桶中令牌数够用
redis.call('set',key,token_count - 1)
return token_count -1 --返回剩余令牌数
else --当前令牌桶中令牌数已清空
if curr_key_time + update_len < nowTime then --判断一下,当前时间秒数 与上次更新时间秒数 的间隔,是否大于规定时间间隔数 (update_len)
redis.call('set',key_time,nowTime)
redis.call('set',key,token_size -1)
return token_size - 1
else
return -1
end
end
end