Redis支持LUA脚本的主要优势
LUA脚本的融合将使Redis数据库产生更多的使用场景,迸发更多新的优势:
- 高效性:减少网络开销及时延,多次redis服务器网络请求的操作,使用LUA脚本可以用一个请求完成
- 数据可靠性:Redis会将整个脚本作为一个整体执行,中间不会被其他命令插入。
- 复用性:LUA脚本执行后会永久存储在Redis服务器端,其他客户端可以直接复用
- 可嵌入性:可嵌入JAVA,C#等多种编程语言,支持不同操作系统跨平台交互
- 简单强大:小巧轻便,资源占用率低,支持过程化和对象化的编程语言
自己也是第一次在工作中使用lua这种语言,记录一下
创建Lua文件req_ratelimit.lua
local key = KEYS[1] --限流KEY
local limitCount = tonumber(ARGV[1]) --限流大小
local limitTime = tonumber(ARGV[2]) --限流时间
local current = redis.call('get', key);
if current then
if current + 1 > limitCount then --如果超出限流大小
return 0
else
redis.call("INCRBY", key,"1")
return current + 1
end
else
redis.call("set", key,"1")
redis.call("expire", key,limitTime)
return 1
end
自定义注解RateLimiter
package com.shinedata.ann;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
@Target({ElementType.TYPE, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface RateLimiter {
/**
* 限流唯一标识
* @return
*/
String key() default "rate.limit:";
/**
* 限流时间
* @return
*/
int time() default 1;
/**
* 限流次数
* @return
*/
int count() default 100;
/**
*是否限制IP,默认 否
* @return
*/
boolean restrictionsIp() default false;
}
定义切面RateLimiterAspect
package com.shinedata.aop;
import com.shinedata.ann.RateLimiter;
import com.shinedata.config.redis.RedisUtils;
import com.shinedata.exception.RateLimiterException;
import org.apache.commons.lang3.StringUtils;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.reflect.MethodSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.ClassPathResource;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.scripting.support.ResourceScriptSource;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import javax.annotation.PostConstruct;
import javax.servlet.http.HttpServletRequest;
import java.io.Serializable;
import java.lang.reflect.Method;
import java.util.Collections;
import java.util.List;
/**
* @ClassName RateLimiterAspect
* @Author yupanpan
* @Date 2020/5/6 13:46
*/
@Aspect
@Component
public class RateLimiterAspect {
private final Logger logger = LoggerFactory.getLogger(this.getClass());
private static ThreadLocal<String> ipThreadLocal=new ThreadLocal();
private DefaultRedisScript<Number> redisScript;
@PostConstruct
public void init(){
redisScript = new DefaultRedisScript<Number>();
redisScript.setResultType(Number.class);
redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("redis/req_ratelimit.lua")));
}
@Around("@annotation(com.shinedata.ann.RateLimiter)")
public Object interceptor(ProceedingJoinPoint joinPoint) throws Throwable {
try {
MethodSignature signature = (MethodSignature) joinPoint.getSignature();
Method method = signature.getMethod();
Class<?> targetClass = method.getDeclaringClass();
RateLimiter rateLimit = method.getAnnotation(RateLimiter.class);
if (rateLimit != null) {
HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
boolean restrictionsIp = rateLimit.restrictionsIp();
if(restrictionsIp){
ipThreadLocal.set(getIpAddr(request));
}
StringBuffer stringBuffer = new StringBuffer();
stringBuffer.append(rateLimit.key());
if(StringUtils.isNotBlank(ipThreadLocal.get())){
stringBuffer.append(ipThreadLocal.get()).append("-");
}
stringBuffer.append("-").append(targetClass.getName()).append("- ").append(method.getName());
List<String> keys = Collections.singletonList(stringBuffer.toString());
Number number = RedisUtils.execute(redisScript, keys, rateLimit.count(), rateLimit.time());
if (number != null && number.intValue() != 0 && number.intValue() <= rateLimit.count()) {
logger.info("限流时间段内访问第:{} 次", number.toString());
return joinPoint.proceed();
}else {
logger.error("已经到设置限流次数,当前次数:{}",number.toString());
throw new RateLimiterException("服务器繁忙,请稍后再试");
}
} else {
return joinPoint.proceed();
}
}finally {
ipThreadLocal.remove();
}
}
public static String getIpAddr(HttpServletRequest request) {
String ipAddress = null;
try {
ipAddress = request.getHeader("x-forwarded-for");
if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
ipAddress = request.getHeader("Proxy-Client-IP");
}
if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
ipAddress = request.getHeader("WL-Proxy-Client-IP");
}
if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
ipAddress = request.getRemoteAddr();
}
// 对于通过多个代理的情况,第一个IP为客户端真实IP,多个IP按照','分割
if (ipAddress != null && ipAddress.length() > 15) {
// "***.***.***.***".length()= 15
if (ipAddress.indexOf(",") > 0) {
ipAddress = ipAddress.substring(0, ipAddress.indexOf(","));
}
}
} catch (Exception e) {
ipAddress = "";
}
return ipAddress;
}
}
Spring data redis提供了DefaultRedisScript来使用lua和redis进行交互,具体的详情网上很多文章,这里使用ThreadLocal是因为IP存在可变的,保证自己的线程的IP不会被其他线程所修改,切记要最后清理ThreadLocal,防止内存泄漏
RedisUtils工具类(方法太多,只展示execute方法)
package com.shinedata.config.redis;
import org.checkerframework.checker.units.qual.K;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import javax.annotation.PostConstruct;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
/**
* @ClassName RedisUtils
* @Author yupanpan
* @Date 2019/11/20 13:38
*/
@Component
public class RedisUtils {
@Autowired
@Qualifier("redisTemplate")
private RedisTemplate<String, Object> redisTemplate;
private static RedisUtils redisUtils;
@PostConstruct
public void init() {
redisUtils = this;
redisUtils.redisTemplate = this.redisTemplate;
}
public static Number execute(DefaultRedisScript<Number> script, List keys, Object... args) {
return redisUtils.redisTemplate.execute(script, keys,args);
}
}
自己配置的RedisTemplate
package com.shinedata.config.redis;
import org.apache.log4j.LogManager;
import org.apache.log4j.Logger;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.connection.jedis.JedisConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.serializer.GenericJackson2JsonRedisSerializer;
import org.springframework.data.redis.serializer.StringRedisSerializer;
import redis.clients.jedis.JedisPoolConfig;
/**
* @ClassName RedisConfig
* @Author yupanpan
* @Date 2019/11/20 13:26
*/
@Configuration
public class RedisConfig extends RedisProperties{
protected Logger log = LogManager.getLogger(RedisConfig.class);
/**
* JedisPoolConfig 连接池
* @return
*/
@Bean("jedisPoolConfig")
public JedisPoolConfig jedisPoolConfig() {
JedisPoolConfig jedisPoolConfig = new JedisPoolConfig();
// 最大空闲数
jedisPoolConfig.setMaxIdle(500);
jedisPoolConfig.setMinIdle(100);
// 连接池的最大数据库连接数
jedisPoolConfig.setMaxTotal(6000);
// 最大建立连接等待时间
jedisPoolConfig.setMaxWaitMillis(5000);
// 逐出连接的最小空闲时间 默认1800000毫秒(30分钟)
jedisPoolConfig.setMinEvictableIdleTimeMillis(100);
// 每次逐出检查时 逐出的最大数目 如果为负数就是 : 1/abs(n), 默认3
// jedisPoolConfig.setNumTestsPerEvictionRun(numTestsPerEvictionRun);
// 逐出扫描的时间间隔(毫秒) 如果为负数,则不运行逐出线程, 默认-1
jedisPoolConfig.setTimeBetweenEvictionRunsMillis(600);
// 是否在从池中取出连接前进行检验,如果检验失败,则从池中去除连接并尝试取出另一个
jedisPoolConfig.setTestOnBorrow(true);
// 在空闲时检查有效性, 默认false
jedisPoolConfig.setTestWhileIdle(false);
return jedisPoolConfig;
}
/**
* JedisConnectionFactory
* @param jedisPoolConfig
*/
@Bean("jedisConnectionFactory")
public JedisConnectionFactory jedisConnectionFactory(@Qualifier("jedisPoolConfig")JedisPoolConfig jedisPoolConfig) {
JedisConnectionFactory JedisConnectionFactory = new JedisConnectionFactory(jedisPoolConfig);
// 连接池
JedisConnectionFactory.setPoolConfig(jedisPoolConfig);
// IP地址
JedisConnectionFactory.setHostName(redisHost);
// 端口号
JedisConnectionFactory.setPort(redisPort);
// 如果Redis设置有密码
JedisConnectionFactory.setPassword(redisPassword);
// 客户端超时时间单位是毫秒
JedisConnectionFactory.setTimeout(10000);
return JedisConnectionFactory;
}
/**
* 实例化 RedisTemplate 对象代替原有的RedisTemplate<String, String>
* @return
*/
@Bean("redisTemplate")
public RedisTemplate<String, Object> functionDomainRedisTemplate(@Qualifier("jedisConnectionFactory") RedisConnectionFactory redisConnectionFactory) {
RedisTemplate<String, Object> redisTemplate = new RedisTemplate<>();
initDomainRedisTemplate(redisTemplate, redisConnectionFactory);
return redisTemplate;
}
/**
* 设置数据存入 redis 的序列化方式
* @param redisTemplate
* @param factory
*/
private void initDomainRedisTemplate(RedisTemplate<String, Object> redisTemplate, RedisConnectionFactory factory) {
// 如果不配置Serializer,那么存储的时候缺省使用String,比如如果用User类型存储,那么会提示错误User can't cast
// to String!
redisTemplate.setKeySerializer(new StringRedisSerializer());
redisTemplate.setHashKeySerializer(new StringRedisSerializer());
redisTemplate.setHashValueSerializer(new GenericJackson2JsonRedisSerializer());
redisTemplate.setValueSerializer(new GenericJackson2JsonRedisSerializer());
// 开启事务/true必须手动释放连接,false会自动释放连接 如果调用方有用@Transactional做事务控制,可以开启事务,Spring会处理连接问题
redisTemplate.setEnableTransactionSupport(false);
redisTemplate.setConnectionFactory(factory);
}
}
全局Controller异常处理GlobalExceptionHandler
package com.shinedata.exception;
import com.fasterxml.jackson.databind.JsonMappingException;
import com.shinedata.util.ResultData;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpStatus;
import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.bind.annotation.ResponseStatus;
import org.springframework.web.bind.annotation.RestControllerAdvice;
@RestControllerAdvice
public class GlobalExceptionHandler {
private Logger logger = LoggerFactory.getLogger(GlobalExceptionHandler.class);
@ExceptionHandler(value = RateLimiterException.class)
@ResponseStatus(HttpStatus.OK)
public ResultData runtimeExceptionHandler(RateLimiterException e) {
logger.error("系统错误:", e);
return ResultData.getResultError(StringUtils.isNotBlank(e.getMessage()) ? e.getMessage() : "处理失败");
}
@ExceptionHandler(value = Exception.class)
@ResponseStatus(HttpStatus.OK)
public ResultData runtimeExceptionHandler(RuntimeException e) {
Throwable cause = e.getCause();
logger.error("系统错误:", e);
logger.error(e.getMessage());
if (cause instanceof JsonMappingException) {
return ResultData.getResultError("参数错误");
}
return ResultData.getResultError(StringUtils.isNotBlank(e.getMessage()) ? e.getMessage() : "处理失败");
}
}
使用就很简单了,一个注解搞定
补充:优化了lua为
local key = KEYS[1]
local limitCount = tonumber(ARGV[1])
local limitTime = tonumber(ARGV[2])
local current = redis.call('get', key);
if current then
redis.call("INCRBY", key,"1")
return current + 1
else
redis.call("set", key,"1")
redis.call("expire", key,limitTime)
return 1
end