需求:
每个IP在指定的时间内可以请求某一个接口多少次,如果请求次数超过指定数,就返回拒绝信息
没做IP防刷之前,请求多了之后服务蹦了
做防刷之后
当然,还有限流
直接上代码
接口注解代码:
import java.lang.annotation.*;
/**
* 限流注解
*
* @author lixx
* @version 1.0
* @date 2020-07-14 15:58
*/
@Target({ElementType.METHOD, ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface GuavaRateLimiter {
//每秒产生令牌的速率
double permitsPerSecond() default 30;
// 在指定的时间内没有获取到令牌的话,就走降级服务,单位毫秒
long timeout() default 500;
}
注解的逻辑实现代码:
import com.familylinkedu.common.annotation.GuavaRateLimiter;
import com.familylinkedu.common.config.redis.RedisService;
import com.familylinkedu.common.enums.Constant;
import com.familylinkedu.common.exception.CustomRuntimeException;
import com.familylinkedu.common.utils.IpUtils;
import com.google.common.util.concurrent.RateLimiter;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.io.ClassPathResource;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.scripting.support.ResourceScriptSource;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import javax.servlet.http.HttpServletRequest;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
/**
* 限流注解
* GuavaRateLimiter.class 类的实现
*
* @author lixx
* @version 1.0
* @date 2020-07-14 16:22
*/
@Aspect
@Component
public class GuavaRateLimiterAspect {
private ConcurrentHashMap<String, RateLimiter> rateLimiterMap = new ConcurrentHashMap<>();
private RateLimiter rateLimiter;
@Autowired
private RedisService redisService;
/**
* 定义切入点
*/
@Pointcut("@annotation(guavaRateLimiter)")
public void rateLimiterAspect(GuavaRateLimiter guavaRateLimiter) {
}
/**
* 环绕通知
*
* @param proceedingJoinPoint
* @return
*/
@Around(value = "rateLimiterAspect(guavaRateLimiter)", argNames = "proceedingJoinPoint,guavaRateLimiter")
public Object doBefore(ProceedingJoinPoint proceedingJoinPoint, GuavaRateLimiter guavaRateLimiter) throws Throwable {
Boolean ipLimit = ipLimit();
if (!ipLimit) {
throw new CustomRuntimeException("检测出您的IP异常访问此服务,已被系统限流!温馨提示:为营造良好的网络环境,请规范网络行为!");
}
// 使用rateLimiter生成令牌
double perSecond = guavaRateLimiter.permitsPerSecond();
long timeout = guavaRateLimiter.timeout();
String uri = getRequest().getRequestURI();
// 一个url共用一个桶(不是一个请求共用一个桶)
if (rateLimiterMap.containsKey(uri)) {
//如果存在
rateLimiter = rateLimiterMap.get(uri);
} else {
// 如果不存在
rateLimiter = RateLimiter.create(perSecond);
rateLimiterMap.put(uri, rateLimiter);
}
// 获取令牌桶中的令牌,如果没有获取到走降级服务
boolean tryAcquire = rateLimiter.tryAcquire(timeout, TimeUnit.MILLISECONDS);
if (!tryAcquire) {
// 服务降级
throw new CustomRuntimeException("系统火爆,请稍候重试!");
}
return proceedingJoinPoint.proceed();
}
/**
* 获取请求的路径
*
* @return
*/
private HttpServletRequest getRequest() {
ServletRequestAttributes requestAttributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
return requestAttributes.getRequest();
}
/**
* ip限流,
*
* @return true表示正常,false表示被限流
*/
private Boolean ipLimit() {
HttpServletRequest request = getRequest();
//key是ip+接口地址,这样就满足博文上面的需求了
String ipKey = Constant.DefaultRedisKeyEnum.IP_LIMIT + ":" + IpUtils.getIpAddr(request) + ":" + request.getRequestURI();
DefaultRedisScript<Boolean> lockScript = new DefaultRedisScript<>();
lockScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("ratelimiter.lua")));
lockScript.setResultType(Boolean.class);
List<Object> keyList = new ArrayList<>(2);
keyList.add(ipKey);
return redisService.execute(lockScript, keyList);
}
}
代码中用到的获取ip工具类:
import javax.servlet.http.HttpServletRequest;
import java.net.InetAddress;
import java.net.UnknownHostException;
public class IpUtils {
public static String getIpAddr(HttpServletRequest request) {
String ipAddress;
try {
ipAddress = request.getHeader("x-forwarded-for");
if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
ipAddress = request.getHeader("Proxy-Client-IP");
}
if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
ipAddress = request.getHeader("WL-Proxy-Client-IP");
}
if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
ipAddress = request.getRemoteAddr();
if ("127.0.0.1".equals(ipAddress)) {
// 根据网卡取本机配置的IP
InetAddress inet = null;
try {
inet = InetAddress.getLocalHost();
} catch (UnknownHostException e) {
e.printStackTrace();
}
assert inet != null;
ipAddress = inet.getHostAddress();
}
}
// 对于通过多个代理的情况,第一个IP为客户端真实IP,多个IP按照','分割
// "***.***.***.***".length()
if (ipAddress != null && ipAddress.length() > 15) {
// = 15
if (ipAddress.indexOf(",") > 0) {
ipAddress = ipAddress.substring(0, ipAddress.indexOf(","));
}
}
} catch (Exception e) {
ipAddress = "";
}
return ipAddress;
}
}
代码中用到的redis:
import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.PropertyAccessor;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.data.redis.serializer.Jackson2JsonRedisSerializer;
import org.springframework.data.redis.serializer.StringRedisSerializer;
/**
* redis配置
*
* @author lixx
* @version 1.0
* @date 2019-08-07 10:58
*/
@Configuration
public class RedisConfig {
@Bean
@ConditionalOnMissingBean(StringRedisTemplate.class)
public StringRedisTemplate stringRedisTemplate(RedisConnectionFactory redisConnectionFactory) {
StringRedisTemplate template = new StringRedisTemplate();
template.setConnectionFactory(redisConnectionFactory);
return template;
}
@Bean
@ConditionalOnMissingBean(name = "redisObjectTemplate")
public RedisTemplate<Object, Object> redisObjectTemplate(RedisConnectionFactory redisConnectionFactory) {
RedisTemplate<Object, Object> redisObjectTemplate = new RedisTemplate<>();
redisObjectTemplate.setConnectionFactory(redisConnectionFactory);
Jackson2JsonRedisSerializer jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer(Object.class);
ObjectMapper om = new ObjectMapper();
om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
om.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL);
jackson2JsonRedisSerializer.setObjectMapper(om);
StringRedisSerializer stringRedisSerializer = new StringRedisSerializer();
/*将key value 进行stringRedisSerializer序列化*/
redisObjectTemplate.setKeySerializer(stringRedisSerializer);
redisObjectTemplate.setValueSerializer(stringRedisSerializer);
/*将HashKey HashValue 进行序列化*/
redisObjectTemplate.setHashKeySerializer(stringRedisSerializer);
redisObjectTemplate.setHashValueSerializer(jackson2JsonRedisSerializer);
redisObjectTemplate.afterPropertiesSet();
return redisObjectTemplate;
}
}
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.stereotype.Service;
/**
* redis api封装类
*
* @author lixx
* @version 1.0
* @date 2019-08-07 11:07
*/
@Service
public class RedisService {
@Autowired
private RedisTemplate<Object, Object> redisObjectTemplate;
// ===============================lua script=================================
/**
* 太长了,删减了其他的redisTemplate操作,只留下本文需要用到的lua模块
* 执行 lua脚本
*
* @param script
* @param objects
* @param <T>
* @return
*/
public <T> T execute(RedisScript<T> script, List<Object> objects) {
return redisObjectTemplate.execute(script, objects);
}
}
lua脚本 ratelimiter.lua(直接放在resources目录下):
-- ip地址限流,expire 秒不能超过 limit次
local expire = 10 --key过期时间,单位秒
local limit = 10 --在指定的过期时间内可以访问的请求数
local ipKey = KEYS[1]
-- 判断 IP是否存在
local exists = redis.call('EXISTS', ipKey)
--return exists
-- IPkey存在的话值+1
if exists == 1
then
redis.call('incr', ipKey)
else
-- IP不存在的话创建并且值+1
redis.call('incr', ipKey)
redis.call('expire', ipKey, expire)
end
-- 获取IPkey的值,如果大于指定的并发量就返回false,不大于返回true
local concurrency_count = redis.call('get', ipKey)
if tonumber(concurrency_count) > limit
then
return false
else
return true
end
使用:
ps:注解来自军哥的代码,加了ip防刷,这里防刷应该放在网关中(每一个请求都要防止ip刷请求,但每一个接口不需要都限流),这里使用的lua命令的方式,大家也可以使用布隆过滤器,推荐放在过滤器里面,还需要把ConcurrentHashMap换成LoadingCache,代码如下,不过在内部类中拿不到注解的每秒生成令牌的速率。大佬们看出需要改正的地方可以留言在下面/色
@PostConstruct
public void initRateLimit() {
CacheBuilder.newBuilder()
.maximumSize(1000)
.expireAfterWrite(1, TimeUnit.DAYS)
.build(new CacheLoader<String, RateLimiter>() {
@Override
public RateLimiter load(String ipKey) {
return RateLimiter.create(50);
}
});
}