关于 JWT Token 自动续期的解决

关于 JWT Token 自动续期的解决方案,最近在做token,因为session无法适应小程序,所以改为token,找的了一篇很好的文章,参考实现了之后想分享给大家,里面添加了部分自己的代码。文章最后是我参考的文献。

在前后端分离的开发模式下,前端用户登录成功后后端服务会给用户颁发一个 jwt token。前端(如 vue)在接收到 jwt token 后会将 token 存储到 LocalStorage 中。

后续每次请求都会将此 token 放在请求头中传递到后端服务,后端服务会有一个过滤器对 token 进行拦截校验,校验 token 是否过期,如果 token 过期则会让前端跳转到登录页面重新登录。

因为 jwt token 中一般会包含用户的基础信息,为了保证 token 的安全性,一般会将 token 的过期时间设置的比较短。

但是这样又会导致前端用户需要频繁登录(token 过期),甚至有的表单比较复杂,前端用户在填写表单时需要思考较长时间,等真正提交表单时后端校验发现 token 过期失效了不得不跳转到登录页面。

如果真发生了这种情况前端用户肯定是要骂人的,用户体验非常不友好。本篇内容就是在前端用户无感知的情况下实现 token 的自动续期,避免频繁登录、表单填写内容丢失情况的发生。

实现原理

jwt token 自动续期的实现原理如下:

登录成功后将用户生成的 jwt token 作为 key、value 存储到 cache 缓存里面 (这时候 key、value 值一样),将缓存有效期设置为 token 有效时间的 2 倍。
当该用户再次请求时,通过后端的一个 jwt Filter 校验前端 token 是否是有效 token,如果 token 无效表明是非法请求,直接抛出异常即可;
根据规则取出 cache token,判断 cache token 是否存在,此时主要分以下几种情况:
cache token 不存在
这种情况表明该用户账户空闲超时,返回用户信息已失效,请重新登录。
cache token 存在,则需要使用 jwt 工具类验证该 cache token 是否过期超时,不过期无需处理。
过期则表示该用户一直在操作只是 token 失效了,后端程序会给 token 对应的 key 映射的 value 值重新生成 jwt token 并覆盖 value 值,该缓存生命周期重新计算。
实现逻辑的核心原理:
前端请求 Header 中设置的 token 保持不变,校验有效性以缓存中的 token 为准。

代码实现(伪码)

0、pom.xml文件

<!-- jwt -->
        <dependency>
            <groupId>com.auth0</groupId>
            <artifactId>java-jwt</artifactId>
            <version>3.8.2</version>
        </dependency>

1、登录拦截+成功后给用户签发 token,并设置 token 的有效期

/**
     * 添加拦截器
     * @param registry
     */
    @Override
    public void addInterceptors(InterceptorRegistry registry) {
        String[] excludes=new String[]{"/","/doc/**","/css/**","/js/**","/fonts/**","/HUI/**","/img/**","/ztree/**","/register/**"
                ,"/login/**","/files/**","/test/**"};
        registry.addInterceptor(getLoginHandlerInterceptor()).addPathPatterns("/**")
                .excludePathPatterns(excludes);

    }
//生成签名 用于客户端验证
        String token = JwtUtil.sign(users.getContactNumber(),users.getUserPassword());
        // 设置到redis  5分钟后redis的token过期
        token = tokenUtil.setToken(TokenUtil.REDIS_EXPIRE_TIME,token);
        // 响应带token
        response.setHeader("token",token);

2、将 token 存入 redis,并设定过期时间,将 redis 的过期时间设置成 token 过期时间的两倍

String tokenKey = "sys:user:token" + token;
        if (redisUtil.hasKey(tokenKey)){
            redisUtil.del(tokenKey);
        }
        boolean result = redisUtil.set(tokenKey, token,time);

3、过滤器校验 token,校验 token 有效性

public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        //从header中获取token
        String token = request.getHeader("token");
        UserTable user = tokenUtil.getUserByToken(request);
        if(user == null){
            // token的验证方法
            rejectForJson(ErrCodeEnum.TOKEN_NOT_EXIST, response, request);
            return false;
        }
        //校验token是否失效,自动续期
        if(!tokenUtil.refreshToken(token,user)){
            rejectForJson(ErrCodeEnum.TOKEN_NOT_EXPIRE, response, request);
            return false;
        }
        String me= request.getMethod();
         if("OPTIONS".equals(me)){
             return true;
         }
        return true;
    }

4、实现 token 的自动续期

public boolean refreshToken(String token, UserTable userTable) {
        // 是否是redis里的token
        String tokenKey = "sys:user:token" + token;
        String cacheToken = String.valueOf(redisUtil.get(tokenKey));
        if (StringUtils.isNotEmpty(cacheToken)) {
            // 校验token有效性,注意需要校验的是缓存中的token
            if (!JwtUtil.verify(cacheToken, userTable.getContactNumber(), userTable.getUserPassword())) {
                String newToken = JwtUtil.sign(userTable.getContactNumber(), userTable.getUserPassword());
                // 设置超时时间
                redisUtil.set(tokenKey, newToken,REDIS_EXPIRE_TIME*2);
            }
            return true;
        }
        return false;
    }

5、几个工具类

一、JwtUtil.java

package com.sinan.snwit.utils;

import com.auth0.jwt.JWT;
import com.auth0.jwt.JWTVerifier;
import com.auth0.jwt.algorithms.Algorithm;
import com.auth0.jwt.exceptions.JWTDecodeException;
import com.auth0.jwt.interfaces.DecodedJWT;
import org.apache.commons.lang.StringUtils;

import javax.servlet.http.HttpServletRequest;
import java.util.Date;

public class JwtUtil {
    //设置签名的过期时间 60s
    public static final long EXPIRE_TIME = 60*1000;

    /**
     * 校验token是否正确
     * @param token  密钥
     * @param secret 用户的密码
     * @return 是否正确
     */
    public static boolean verify(String token, String contactNumber, String secret) {
        try {
            // 根据密码生成JWT效验器
            Algorithm algorithm = Algorithm.HMAC256(secret);
            JWTVerifier verifier = JWT.require(algorithm).withClaim("contactNumber", contactNumber).build();
            // 效验TOKEN
            DecodedJWT jwt = verifier.verify(token);
            return true;
        } catch (Exception exception) {
            return false;
        }
    }

    /**
     * 获得token中的信息无需secret解密也能获得
     * @return token中包含的用户名
     */
    public static String getContactNumber(String token) {
        try {
            DecodedJWT jwt = JWT.decode(token);
            return jwt.getClaim("contactNumber").asString();
        } catch (JWTDecodeException e) {
            return null;
        }
    }

    /**
     * 生成签名
     */
    public static String sign(String contactNumber, String secret) {
        Date date = new Date(System.currentTimeMillis() + EXPIRE_TIME);
        //使用HS256生成token,密钥则是用户的密码
        Algorithm algorithm = Algorithm.HMAC256(secret);
        // 附带contactNumber信息
        return JWT.create().withClaim("contactNumber", contactNumber).withExpiresAt(date).sign(algorithm);
    }

    /**
     * 根据request中的token获取用户账号
     * @param request
     * @return
     */
    public static String getContactNumberByToken(HttpServletRequest request) {
        String accessToken = request.getHeader("X-Access-Token");
        String contactNumber = getContactNumber(accessToken);
        if (StringUtils.isEmpty(contactNumber)) {
            throw new RuntimeException("无法获取有效用户!");
        }
        return contactNumber;
    }
}

二、TokenUtil.java

package com.sinan.snwit.utils;

import com.auth0.jwt.JWT;
import com.auth0.jwt.JWTVerifier;
import com.auth0.jwt.algorithms.Algorithm;
import com.auth0.jwt.interfaces.DecodedJWT;
import com.sinan.snwit.entity.UserTable;
import com.sinan.snwit.service.UserTableService;
import org.apache.commons.lang.StringUtils;
import org.springframework.stereotype.Component;

import javax.annotation.Resource;
import javax.servlet.http.HttpServletRequest;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;

@Component
public class TokenUtil {

    @Resource
    private RedisUtil redisUtil;

    @Resource
    private UserTableService userTableService;

    // 5分钟后redis的token过期
    public static final long REDIS_EXPIRE_TIME = 5*60;

    /**
     * 设置token
     * @param time
     * @param token
     * @return
     */
    public String setToken(Long time,String token){
        // token+user存入redis
        String tokenKey = "sys:user:token" + token;
        if (redisUtil.hasKey(tokenKey)){
            redisUtil.del(tokenKey);
        }
        boolean result = redisUtil.set(tokenKey, token,time);
        if (result){
            return token;
        }
        return null;
    }

    /**
     * 根据token获取UserTable
     * @return
     */
    public UserTable getUserByToken(HttpServletRequest request) {
        String token = request.getHeader("token");
        //解析token获取用户名
        if (StringUtils.isNotEmpty(token)){
            String contactNumber = JwtUtil.getContactNumber(token);
            //根据用户名获取用户实体,在实际开发中从redis取
            UserTable user = userTableService.selectTel(contactNumber);
            return user;
        }
        return null;
    }

    /**
     * 刷新token
     * @param token
     * @param userTable
     * @return
     */
    public boolean refreshToken(String token, UserTable userTable) {
        // 是否是redis里的token
        String tokenKey = "sys:user:token" + token;
        String cacheToken = String.valueOf(redisUtil.get(tokenKey));
        if (StringUtils.isNotEmpty(cacheToken)) {
            // 校验token有效性,注意需要校验的是缓存中的token
            if (!JwtUtil.verify(cacheToken, userTable.getContactNumber(), userTable.getUserPassword())) {
                String newToken = JwtUtil.sign(userTable.getContactNumber(), userTable.getUserPassword());
                // 设置超时时间
                redisUtil.set(tokenKey, newToken,REDIS_EXPIRE_TIME*2);
            }
            return true;
        }
        return false;
    }

}

三、RedisUtil.java

package com.sinan.snwit.utils;


import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;

import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;

@Component
public class RedisUtil {
    @Autowired
    RedisTemplate<String, Object> redisTemplate;

    public void setRedisTemplate(RedisTemplate<String, Object> redisTemplate) {
        this.redisTemplate = redisTemplate;
    }
    //=============================common============================
    /**
     * 指定缓存失效时间
     * @param key 键
     * @param time 时间(秒)
     * @return
     */
//	public boolean expire(String key,long time){
    public boolean setTime(String key,long time){
        try {
            //	if(time>0){
            redisTemplate.expire(key, time, TimeUnit.SECONDS);
            //	}
            return true;
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    /**
     * 根据key 获取过期时间
     * @param key 键 不能为null
     * @return 时间(秒) 返回	-1:永久有效 ; -2:已失效或不存在
     */
    //public long getExpire(String key){
    public long getTime(String key){
        return redisTemplate.getExpire(key,TimeUnit.SECONDS);
    }

    /**
     * 判断key是否存在
     * @param key 键
     * @return true 存在 false不存在
     */
    public boolean hasKey(String key){
        try {
            return redisTemplate.hasKey(key);
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    /**
     * 删除缓存
     * @param key 可以传一个值 或多个
     */
    @SuppressWarnings("unchecked")
    public void del(String ... key){
        if(key!=null&&key.length>0){
            if(key.length==1){
                redisTemplate.delete(key[0]);
            }else{
                redisTemplate.delete(CollectionUtils.arrayToList(key));
            }
        }
    }

    //============================String=============================
    /**
     * 普通缓存获取
     * @param key 键
     * @return 值
     */
    public Object get(String key){
        return key==null?null:redisTemplate.opsForValue().get(key);
    }

    /**
     * 普通缓存放入
     * @param key 键
     * @param value 值
     * @return true成功 false失败
     */
    public boolean set(String key,Object value) {
        try {
            redisTemplate.opsForValue().set(key, value);
            return true;
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }

    }

    /**
     * 普通缓存放入并设置时间
     * @param key 键
     * @param value 值
     * @param time 时间(秒) time要大于0 如果time小于等于0 将设置无限期
     * @return true成功 false 失败
     */
    public boolean set(String key,Object value,long time){
        try {
            if(time>0){
                redisTemplate.opsForValue().set(key, value, time, TimeUnit.SECONDS);
            }else{
                set(key, value);
            }
            return true;
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    /**
     * 增加<BR/>
     * 对key的value为数值的value进行增加运算
     * @param key 键
     * @param num 要增加几(大于0增加,小于0减少)
     * @return
     */
    public long increment(String key, long num){
        return redisTemplate.opsForValue().increment(key, num);
    }

    //================================Map=================================
    /**
     * HashGet
     * @param key 键 不能为null
     * @param hashKey 项 不能为null
     * @return 值
     */
    public Object hGetHashValue(String key,String hashKey){
        return redisTemplate.opsForHash().get(key, hashKey);
    }

    /**
     * 获取hashKey对应的所有键值
     * @param key 键
     * @return 对应的多个键值
     */
    public Map<Object,Object> hGetHashMap(String key){
        return redisTemplate.opsForHash().entries(key);
    }

    /**
     * HashSet
     * @param key 键
     * @param map 对应多个键值
     * @return true 成功 false 失败
     */
    public boolean hSetHashMap(String key, Map<String,Object> map){
        try {
            redisTemplate.opsForHash().putAll(key, map);
            return true;
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    /**
     * HashSet 并设置时间
     * @param key 键
     * @param map 对应多个键值
     * @param time 时间(秒)
     * @return true成功 false失败
     */
    public boolean hSetHashMap(String key, Map<String,Object> map, long time){
        try {
            redisTemplate.opsForHash().putAll(key, map);
            if(time>0){
                //expire(key, time);
                setTime(key, time);
            }
            return true;
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    /**
     * 向一张hash表中放入数据,如果不存在将创建
     * @param key 键
     * @param hashKey 项
     * @param value 值
     * @return true 成功 false失败
     */
    public boolean hSetHashValue(String key,String hashKey,Object value) {
        try {
            redisTemplate.opsForHash().put(key, hashKey, value);
            return true;
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    /**
     * 向一张hash表中放入数据,如果不存在将创建
     * @param key 键
     * @param hashKey 项
     * @param value 值
     * @param time 时间(秒)  注意:如果已存在的hash表有时间,这里将会替换原有的时间
     * @return true 成功 false失败
     */
    public boolean hSetHashValue(String key,String hashKey,Object value,long time) {
        try {
            redisTemplate.opsForHash().put(key, hashKey, value);
            if(time>0){
//				expire(key, time);
                setTime(key, time);
            }
            return true;
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    /**
     * 删除hash表中的值
     * @param key 键 不能为null
     * @param hashKey 项 可以使多个 不能为null
     */
    public long hDeleteHashKey(String key, Object... hashKey){
        return redisTemplate.opsForHash().delete(key,hashKey);
    }

    /**
     * 判断hash表中是否有该项的值
     * @param key 键 不能为null
     * @param hashKey 项 不能为null
     * @return true 存在 false不存在
     */
    public boolean hHasHashKey(String key, String hashKey){
        return redisTemplate.opsForHash().hasKey(key, hashKey);
    }

    /**
     * Hash值增加num----value需为数值类型	<BR/>
     * 如果不存在,就会创建一个,并把 num 作为value返回
     * @param key 			键
     * @param hashKey 	HashMap的键
     * @param num 		要增加几(大于0增加,小于0减少)
     * @return
     */
    public double hPlusHashValue(String key, String hashKey,double num){
        return redisTemplate.opsForHash().increment(key, hashKey, num);
    }

    /**
     * Hash值减少num----value需为数值类型	<BR/>
     * 如果不存在,就会创建一个,并把 num 作为value返回
     * @param key			键
     * @param hashKey 	HashMap的键
     * @param num			要减少几(大于0减少,小于0增加)
     * @return
     */
    public double hMinusHashValue(String key, String hashKey,double num){
        return redisTemplate.opsForHash().increment(key, hashKey,-num);
    }

    //============================set=============================
    /**
     * 根据key获取Set中的所有值
     * @param key 键
     * @return
     */
    public Set<Object> sGetSet(String key){
        try {
            return redisTemplate.opsForSet().members(key);
        } catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }

    /**
     * 查询该value是否在键为key的Set中
     * @param key 键
     * @param value 值
     * @return true 存在 false不存在
     */
    public boolean sHasValueOnSet(String key,Object value){
        try {
            return redisTemplate.opsForSet().isMember(key, value);
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    /**
     * 将数据放入Set缓存
     * @param key 键
     * @param values 值 可以是多个
     * @return 成功个数
     */
    public long sSetSet(String key, Object...values) {
        try {
            return redisTemplate.opsForSet().add(key, values);
        } catch (Exception e) {
            e.printStackTrace();
            return 0;
        }
    }

    /**
     * 将Set数据放入缓存
     * @param key 键
     * @param time 时间(秒)
     * @param values 值 可以是多个
     * @return 成功个数
     */
    public long sSetSets(String key,long time,Object...values) {
        try {
            Long count = redisTemplate.opsForSet().add(key, values);
            //	if(time>0) expire(key, time);
            if(time>0){
                setTime(key, time);
            }
            return count;
        } catch (Exception e) {
            e.printStackTrace();
            return 0;
        }
    }

    /**
     * 获取Set缓存的长度
     * @param key 键
     * @return
     */
    public long sGetSetSize(String key){
        try {
            return redisTemplate.opsForSet().size(key);
        } catch (Exception e) {
            e.printStackTrace();
            return 0;
        }
    }

    /**
     * 从Set中移除值为value的值
     * @param key 键
     * @param values 值 可以是多个
     * @return 移除的个数
     */
    public long sRemoveSet(String key, Object ...values) {
        try {
            Long count = redisTemplate.opsForSet().remove(key, values);
            return count;
        } catch (Exception e) {
            e.printStackTrace();
            return 0;
        }
    }
    //===============================list=================================

    /**
     * 获取list缓存的内容
     * @param key 键
     * @param start 开始
     * @param end 结束  0 到 -1代表所有值
     * @return
     */
    public List<Object> lGetList(String key, long start, long end){
        try {
            return redisTemplate.opsForList().range(key, start, end);
        } catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }

    /**
     * 获取list缓存的长度
     * @param key 键
     * @return
     */
    public long lGetListSize(String key){
        try {
            return redisTemplate.opsForList().size(key);
        } catch (Exception e) {
            e.printStackTrace();
            return 0;
        }
    }

    /**
     * 通过索引 获取list中的值
     * @param key 键
     * @param index 索引  index>=0时, 0 表头,1 第二个元素,依次类推;index<0时,-1,表尾,-2倒数第二个元素,依次类推
     * @return
     */
    public Object lGetListByIndex(String key,long index){
        try {
            return redisTemplate.opsForList().index(key, index);
        } catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }

    /**
     * 将list放入缓存
     * @param key 键
     * @param value 值(Object)
     * @return
     */
    public boolean lSetlist(String key, Object value) {
        try {
            redisTemplate.opsForList().rightPush(key, value);
            return true;
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    /**
     * 将list放入缓存
     * @param key 键
     * @param value 值(List)
     * @return
     */
    @SuppressWarnings("rawtypes")
    public boolean lSetListAll(String key, List value) {
        try {
            redisTemplate.opsForList().rightPushAll(key, value);
            return true;
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    /**
     * 将list放入缓存
     * @param key 键
     * @param value 值(Object)
     * @param time 时间(秒)
     * @return
     */
    public boolean lSetlist(String key, Object value, long time) {
        try {
            redisTemplate.opsForList().rightPush(key, value);
            //        if (time > 0) expire(key, time);
            if (time > 0) {
                setTime(key, time);
            }
            return true;
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    /**
     * 将list放入缓存
     * @param key 键
     * @param value 值(List)
     * @param time 时间(秒)
     * @return
     */
    @SuppressWarnings("rawtypes")
    public boolean lSetListAll(String key,  List value, long time) {
        try {
            redisTemplate.opsForList().rightPushAll(key, value);
            //		if (time > 0) expire(key, time);
            if (time > 0){
                setTime(key, time);
            }
            return true;
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    /**
     * 根据索引修改list中的某条数据
     * @param key 键
     * @param index 索引
     * @param value 值
     * @return
     */
    public boolean lUpdateListByIndex(String key, long index,Object value) {
        try {
            redisTemplate.opsForList().set(key, index, value);
            return true;
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    /**
     * 移除N个值为value
     * @param key 键
     * @param count 移除多少个
     * @param value 值
     * @return 移除的个数
     */
    public long lRemoveList(String key,long count,Object value) {
        try {
            Long remove = redisTemplate.opsForList().remove(key, count, value);
            return remove;
        } catch (Exception e) {
            e.printStackTrace();
            return 0;
        }
    }

}

小结

jwt token 实现逻辑的核心原理是 前端请求 Header 中设置的 token 保持不变,校验有效性以缓存中的 token 为准,千万不要直接校验 Header 中的 token。实现原理部分大家好好体会一下,思路比实现更重要