- 定义一个注解,设置限流属性(时间窗,次数,限流类型,存入redis的前缀)
- lua脚本,每访问一次,redis中key的值 incr,,在第一次的时候给key设置过期时间,最后将key的值返回给java判断
问题:
- 获取request
HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest()
- 获取包名
method.getDeclaringClass().getName()
- Collections.singletonList被限定只被分配一个内存空间,也就是只能存放一个元素的内容。
- 写自己的redisTemplate
/*
* jdk序列化方案,存的时候会 一堆前缀
* redis提供的 对象的redisTemplate,序列化,会有很多前缀,,
* 自己写一个redisTemplate,更改他的序列化方式
* */
@Bean
RedisTemplate<Object,Object> redisTemplate(RedisConnectionFactory redisConnectionFactory){
RedisTemplate<Object, Object> template = new RedisTemplate<>();
template.setConnectionFactory(redisConnectionFactory);
// 设置序列化器 : 使用jackson的序列化器
Jackson2JsonRedisSerializer<Object> serializer = new Jackson2JsonRedisSerializer<>(Object.class);
template.setKeySerializer(serializer);
template.setHashKeySerializer(serializer);
template.setValueSerializer(serializer);
template.setHashValueSerializer(serializer);
return template;
}
- redis定义脚本
@Bean
DefaultRedisScript<Long> limitScript(){
DefaultRedisScript<Long> script = new DefaultRedisScript<>();
script.setResultType(Long.class);
// 设置脚本位置
script.setScriptSource(new ResourceScriptSource(new ClassPathResource("lua/limit.lua")));
return script;
}
- 执行redis脚本
Long number = (Long) redisTemplate.execute(redisScript, Collections.singletonList(combineKey), time, count);
代码:
lua脚本
local key = KEYS[1]
local time = tonumber(ARGV[1])
-- 限流次数
local count = tonumber(ARGV[2])
-- 获取redis中的key ,,, 如果接口没有调用过,,有可能没有值,
local current = redis.call('get',key)
-- current超过限流了
if current and tonumber(current)>count then
return tonumber(current)
end
-- 第一次访问 ,,自增1 。。 并发的时候,另外一个线程可能也执行了自增1
current = redis.call("incr",key)
-- 没有其他线程
if tonumber(current)==1 then
redis.call("expire",key,time)
end
-- 不等于1,,其他线程已经设置过期时间
return tonumber(current)
-- 返回current,,在java中判断是限流还是放行
自定义异常:
public class RateLimitException extends Exception {
public RateLimitException(String message) {
super(message);
}
}
@RestControllerAdvice
public class GlobalException {
@ExceptionHandler(RateLimitException.class)
public Map<String,Object> rateLimitException(RateLimitException e){
HashMap<String, Object> map = new HashMap<>();
map.put("status",500);
map.put("message",e.getMessage());
return map;
}
}
限流类型:
public enum LimitType {
/**
* 默认限流策略 , 针对某个接口进行限流,,某个接口在一定时间只能访问多少次
*/
DEFAULT,
/**
* 根据IP限流,,某个IP在一定时间只能访问多少次
*/
IP
}
限流注解:
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface RateLimiter {
/**
* 限流 在redis中的 前缀
*/
String key() default "rate_limit";
/**
* 限流时间窗
*/
int time() default 60;
/**
* 在时间窗内的 限流次数
*/
int count() default 100;
/**
* 限流类型
*/
LimitType limitType() default LimitType.DEFAULT;
}
redis配置文件
@Configuration
public class RedisConfig {
/*
* jdk序列化方案,存的时候会 一堆前缀
* redis提供的 对象的redisTemplate,序列化,会有很多前缀,,
* 自己写一个redisTemplate,更改他的序列化方式
* */
@Bean
RedisTemplate<Object,Object> redisTemplate(RedisConnectionFactory redisConnectionFactory){
RedisTemplate<Object, Object> template = new RedisTemplate<>();
template.setConnectionFactory(redisConnectionFactory);
// 设置序列化器 : 使用jackson的序列化器
Jackson2JsonRedisSerializer<Object> serializer = new Jackson2JsonRedisSerializer<>(Object.class);
template.setKeySerializer(serializer);
template.setHashKeySerializer(serializer);
template.setValueSerializer(serializer);
template.setHashValueSerializer(serializer);
return template;
}
@Bean
DefaultRedisScript<Long> limitScript(){
DefaultRedisScript<Long> script = new DefaultRedisScript<>();
script.setResultType(Long.class);
// 设置脚本位置
script.setScriptSource(new ResourceScriptSource(new ClassPathResource("lua/limit.lua")));
return script;
}
}
ip工具类: 获取访问的ip
/**
* 获取用户访问ip地址
*/
public class IpUtils {
public static String getIpAddress(HttpServletRequest request) {
String ip = request.getHeader("x-forwarded-for");
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("Proxy-Client-IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("WL-Proxy-Client-IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("HTTP_CLIENT_IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("HTTP_X_FORWARDED_FOR");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getRemoteAddr();
}
// 获取到多个ip时取第一个作为客户端真实ip
if(ip !=null && ip.length() != 0 && ip.contains(",")){
String[] ipArray = ip.split(",");
ip = ipArray[0];
}
return ip;
}
}
切面
/**
* 拦截 rateLimiter注解 限流的话,还没进入方法就拦截
* @param jp
* @param rateLimiter
*/
@Before("@annotation(rateLimiter)")
public void before(JoinPoint jp, RateLimiter rateLimiter) throws RateLimitException {
int time = rateLimiter.time();
int count = rateLimiter.count();
String combineKey = getCombineKey(rateLimiter,jp);
try {
// Collections.singletonList被限定只被分配一个内存空间,也就是只能存放一个元素的内容。
Long number = (Long) redisTemplate.execute(redisScript, Collections.singletonList(combineKey), time, count);
// number 什么情况下为 null
if(number == null || number.intValue() > count){
logger.info("当前接口已达到最大限流次数");
throw new RateLimitException("访问频繁,请稍后访问");
}
logger.info("一个时间窗内请求次数:{},当前请求次数:{} ,缓存key为:{}",count,number,combineKey);
} catch (Exception e) {
throw e;
}
}
private String getCombineKey(RateLimiter rateLimiter, JoinPoint jp) {
StringBuilder sb = new StringBuilder(rateLimiter.key());
LimitType limitType = rateLimiter.limitType();
// 拼接 ip
if(limitType == LimitType.IP){
// 获取ip
HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
sb.append(IpUtils.getIpAddress(request)).append("-");
}
MethodSignature signature = (MethodSignature) jp.getSignature();
Method method = signature.getMethod();
sb.append(method.getDeclaringClass().getName()).append("-").append(method.getName());
return sb.toString();
}
测试
@RestController
public class HelloController {
/**
* 10s 内,访问 3 次
* @return
*/
@RateLimiter(time = 10,count = 3)
@GetMapping("/hello")
public String hello(HttpServletRequest request){
String ip = IpUtils.getIpAddress(request);
return "hello "+ip;
}
}