SpringBoot如何防止会话重放攻击呢

  • 一、什么是会话重放攻击漏洞
  • 二、如何避免会话重放攻击
  • 1、环境和框架
  • 2、创建SignAuthServletRequestWrapper
  • 3、创建过滤器XssFilter
  • 4、创建AuthFilterConfig配置过滤器
  • 5、postman调用脚本


一、什么是会话重放攻击漏洞

客户端对服务端的网络请求如果被攻击者拦截并且抓取,攻击者可以用抓取到的参数不断对接口发起重复请求,造成大量重复操作且可能产生重复数据。

二、如何避免会话重放攻击

在SpringBoot框架里,可以通过过滤器校验请求头里的Sign(签名)和Timestamp(时间戳)以及ReqId(前端生成的唯一请求id),这样代码改动范围比较小且对业务模块的侵入性比较低。

主要思路:
前端在提交时将请求参数拼接成字符串(key1=val1&key2=val2…),在字符串尾部添加Secret(签名)、ReqId(前端生成的唯一请求id)、Timestamp(时间戳)做md5计算获得签名Sign,把Sign、ReqId、Timestamp添加到请求头,后端在过滤器里判断请求id是否唯一、请求时间是否过期、签名是否正确。

1、环境和框架

  • 开发语言:Java
  • 框架: SpringBoot
  • 依赖组件: hutool-all:5.0.3

2、创建SignAuthServletRequestWrapper

@Slf4j
public class SignAuthServletRequestWrapper extends HttpServletRequestWrapper {

    private HttpServletResponse response;
    private HttpServletRequest request;
    // 是否必须验证签名
    private Boolean isRequireSign;
    // 签名验证是否已经处理过
    private Boolean isHandle = false;

    /**
     * 秘钥
     */
    private static final String ENCRYPT_SECRET = "350cb4303db847788eee5b125f7fdbaa";

    private static final String METHOD_GET = "GET";

    private static final String METHOD_POST = "POST";

    public SignAuthServletRequestWrapper(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse, Boolean isRequireSign) {
        super(httpServletRequest);
        response = httpServletResponse;
        request = httpServletRequest;
        this.isRequireSign = true;
    }

    public SignAuthServletRequestWrapper(HttpServletRequest request) {
        super(request);
    }

    @Override
    public String getHeader(String name) {
        String value = super.getHeader(name);
        if (StringUtils.isEmpty(value)) {
            return value;
        }
        return HtmlUtil.filter(value);
    }

    @Override
    public String getParameter(String name) {
        String value = super.getParameter(name);
        if (StringUtils.isEmpty(value)) {
            return value;
        }
        return HtmlUtil.filter(value);
    }

    /**
     * 如果reqId不存在就写入缓存,已存在则抛出异常
     * @param reqId 请求id
     * @return
     */
    private void addReqIdIfNotExists(String reqId)
    {
        // TODO 伪代码如下
        /*
            String redisKey = "unique-req-id-" +reqId;
            if(RedisCache.hasKey(redisKey))
            {
                throw new Exception("重复请求reqId:" + reqId)
            }

            // 添加到缓存中,缓存过期时间和签名过期时间一致
            RedisCache.add(redisKey,timeoutMs,TIMEUNIT.MS)
        * */
    }

    @SneakyThrows
    @Override
    public String[] getParameterValues(String name) {

        String method = super.getMethod();
        log.info("getParameterValues");
        log.info(method);
        String contentType = request.getContentType();
        SortedMap<String, String> paramsMap = null;
        paramsMap = (SortedMap<String, String>) HttpServletUtil.getParameter(request);
        // 请求方式为GET或者为POST的application/x-www-form-urlencoded
        if( isRequireSign && !isHandle &&  (METHOD_GET.equals(method) || contentType.contains(CONTENT_TYPE_APPLICATION_X_WWW_FORM_URLENCODED))) {

            paramsMap = (SortedMap<String, String>) HttpServletUtil.getParameter(request);

            if (paramsMap == null) {
                paramsMap = new TreeMap<>();
            }

            // 获取签名和时间戳
            String sign = super.getHeader("Sign");
            String timestampStr = super.getHeader("Timestamp");
            // 前端生成的唯一的请求id
            String reqId = super.getHeader("reqId");

            if(StringUtils.isBlank(reqId)) {
                throw new Exception("reqId不能为空");
            }

            // 如果reqId不存在就写入缓存,已存在则抛出异常
            addReqIdIfNotExists(reqId);

            if (StringUtils.isBlank(sign)) {
                sign = paramsMap.get("Sign");
                paramsMap.remove("Sign");
            }

            if (StringUtils.isBlank(timestampStr)) {
                timestampStr = paramsMap.get("Timestamp");
                paramsMap.remove("Timestamp");
            }

            if (StringUtils.isEmpty(sign)) {
                throw new Exception("签名不能为空");
            }
            if (StringUtils.isEmpty(timestampStr)) {
                throw new Exception("时间戳不能为空");
            }
            // 重放时间限制
            long timestamp = Long.parseLong(timestampStr);
            if (System.currentTimeMillis() - timestamp >= timeoutMs) {
                throw new Exception("签名已过期");
            }

            Boolean verifyRes = verifySign(paramsMap, sign, timestampStr,reqId);
            if (!verifyRes && isRequireSign) {
                throw new Exception("签名校验失败");
            }
            isHandle = true;
        }

        String[] values = super.getParameterValues(name);
        if (values == null) {
            return null;
        }
        for (int i = 0; i < values.length; i++) {
            if (StringUtils.isEmpty(values[i])) {
                values[i] = values[i];
            } else {
                values[i] = HtmlUtil.filter(values[i]);
            }
        }
        return values;
    }

    private static final String CONTENT_TYPE_MULTIPART_FORM_DATA = "multipart/form-data";
    private static final String CONTENT_TYPE_APPLICATION_X_WWW_FORM_URLENCODED = "application/x-www-form-urlencoded";
//    private static final String CONTENT_FORM_DATA = "multipart/form-data";

    // 过期时间(60秒)
    private static final Integer timeoutMs =60*1000;

    @SneakyThrows
    @Override
    public ServletInputStream getInputStream() throws IOException {

        String method = super.getMethod();
        log.info("getInputStream");
        log.info(method);

        String contentType = this.getRequest().getContentType();
        log.info(" content type :" +this.getRequest().getContentType());

        // 判断是否为文件上传,如果为文件上传不进行判断
        if (this.getRequest().getContentType().contains(CONTENT_TYPE_MULTIPART_FORM_DATA)) {
            return super.getInputStream();
        } else {
            return getInputStreamWithFilter();
        }
    }

    /**
     * 判断是否是对象或数组
     * @param str
     * @return
     */
    public static Boolean isObjectOrArray(String str) {
        if (str == null) {
            return false;
        }
        return (str.startsWith("[") && str.endsWith("]")) || (str.startsWith("{") && str.endsWith("}"));
    }

    public static void removeObjectOrArrayItem(Map<String,String> map) {
        Iterator iter = map.entrySet().iterator();
        String curKey;
        while (iter.hasNext()) {
            Map.Entry entry = (Map.Entry) iter.next();
            String key = (String) entry.getKey();
            if (key != null) {
                // curKey = key.toString();
                String value = (String) entry.getValue();
                if (isObjectOrArray(value)) {
                    iter.remove();
                }
            }
        }
    }

    /**
     * 判断是否是数组字符串
     * @param str
     * @return
     */
    private static Boolean isArrayStr(String str) {
        if (str == null) {
            return false;
        }

        // 如果字符串以"["开合和"]"闭合则为数组字符串
        return str.startsWith("[") && str.endsWith("]");
    }

    /**
     * 是否为对象字符串
     * @param str
     * @return
     */
    private static Boolean isObjectStr(String str)
    {
        if (str == null) {
            return false;
        }

        // 如果字符串以"{"开合和"}"闭合则为为对象字符串
        return str.startsWith("{") && str.endsWith("}");
    }

    public ServletInputStream getInputStreamWithFilter() throws Exception {
        // 从InputStream中读取请求json字符串
        InputStream in = super.getInputStream();
        StringBuffer body = new StringBuffer();
        InputStreamReader reader = new InputStreamReader(in, Charset.forName("UTF-8"));
        BufferedReader buffer = new BufferedReader(reader);
        String line = buffer.readLine();
        while (line != null) {
            body.append(line);
            line = buffer.readLine();
        }
        buffer.close();
        reader.close();
        in.close();

        String json = body.toString();
        // 是否为数组
        Boolean isArray = isArrayStr(json);
        // 如果必须验证签名 并且 该请求尚未处理过
        if(isRequireSign && !isHandle ) {
            if (isArray) {
                // 数组处理方法
                arrayVerifySignHandler(json);
            } else {
                // 对象处理方法
                objectVerifySignHandler(json);
            }
        }

        // 原来的inputstream读取过后再读取会报错,因此重新将json转成inputstream
        final ByteArrayInputStream bain = new ByteArrayInputStream(json.getBytes());
        return new ServletInputStream() {
            @Override
            public int read() throws IOException {
                return bain.read();
            }
            @Override
            public boolean isFinished() {
                return false;
            }

            @Override
            public boolean isReady() {
                return false;
            }

            @Override
            public void setReadListener(ReadListener listener) {
            }
        };
    }

    /**
     * 对象验证签名处理方法
     * @param json
     * @throws Exception
     */
    public void objectVerifySignHandler(String json) throws Exception {
        // 获取签名
        String sign = super.getHeader("Sign");
        // 时间戳
        String timestampStr = super.getHeader("Timestamp");
        // 前端生成的唯一的请求id
        String reqId = super.getHeader("ReqId");

        SortedMap<String, String> paramsMap =   JSON.parseObject(json, new TypeReference<SortedMap<String, String>>() {
        });

        if (paramsMap == null) {
            paramsMap = new TreeMap<>();
        }


        if (StringUtils.isBlank(reqId)) {
            reqId = paramsMap.get("ReqId");
            paramsMap.remove("ReqId");
            /*
                public class SignProperties
                {
                    private String reqId;
                    private String sign;
                    private String timestamp;
                }
            * */

            /*
                getReqIdFromParamsIfIsBlank(paramsMap,signProperties)
                {
                    if (StringUtils.isBlank(signProperties.getReqId())) {
                        signProperties.setReqId(paramsMap.get(PROPERTY_OF_REQ_ID));
                        // PROPERTY_OF_REQ_ID
                        // PROPERTY_OF_SIGN
                        // PROPERTY_OF_TIMESTAMP
                        paramsMap.remove(PROPERTY_OF_REQ_ID);
                    }
                }
            * */
        }

        if (StringUtils.isBlank(sign)) {
            sign = paramsMap.get("Sign");
            paramsMap.remove("Sign");
        }

        if (StringUtils.isBlank(timestampStr)) {
            timestampStr = paramsMap.get("Timestamp");
            paramsMap.remove("Timestamp");
        }

        if(StringUtils.isEmpty(reqId)) {
            throw new Exception("reqId不能为空");
        }

        if (StringUtils.isEmpty(sign)) {
            throw new Exception("签名不能为空");
        }
        if (StringUtils.isEmpty(timestampStr)) {
            throw new Exception("时间戳不能为空");
        }

        // 重放时间限制
        long timestamp = Long.parseLong(timestampStr);
        if (System.currentTimeMillis() - timestamp >= timeoutMs) {
            throw new Exception("签名已过期");
        }

        // 如果reqId不存在就写入缓存,已存在则抛出异常
        addReqIdIfNotExists(reqId);

        Boolean verifyRes = verifySign(paramsMap, sign, timestampStr,reqId);
        if (!verifyRes && isRequireSign) {
            throw new Exception("签名校验失败");
        }
        isHandle = true;
    }

//    private class SignProperties
//    {
//        private String reqId;
//        private String sign;
//        private String timestamp;
//    }

    /**
     * 数组验证签名处理方法
     * @param json
     * @throws Exception
     */
    public void arrayVerifySignHandler(String json) throws Exception {
        JSONArray objects = JSON.parseArray(json);
        log.error("objects:{}",JSON.toJSONString(objects));
        // 获取签名
        String sign = super.getHeader("Sign");
        // 时间戳
        String timestampStr = super.getHeader("Timestamp");
        // 前端生成的唯一的请求id
        String reqId = super.getHeader("ReqId");

        if (StringUtils.isEmpty(sign)) {
            throw new Exception("签名不能为空");
        }
        if (StringUtils.isEmpty(timestampStr)) {
            throw new Exception("时间戳不能为空");
        }

        // 重放时间限制
        long timestamp = Long.parseLong(timestampStr);
        if (System.currentTimeMillis() - timestamp >= timeoutMs) {
            throw new Exception("签名已过期");
        }

        // 如果reqId不存在就写入缓存,已存在则抛出异常
        addReqIdIfNotExists(reqId);

        StringBuilder stringBuilder = new StringBuilder();
        // 循环遍历数组拼接字符串
        for(Integer i = 0 ;i< objects.size();i++)
        {
            Object item = objects.get(i);
            stringBuilder.append(String.format("%d=%s&", i, JSON.toJSONString(item)));
        }

        // 填充时间戳和秘钥
        stringBuilder.append("timestamp="+timestampStr+"&");
        stringBuilder.append("reqId="+reqId+"&");
        stringBuilder.append("secret=" + ENCRYPT_SECRET);
        String paramsStr = stringBuilder.toString();
        String md5 = DigestUtil.md5Hex(paramsStr).toLowerCase();
        Boolean verifyRes = md5.equals(sign);

        if (!verifyRes && isRequireSign) {
            throw new Exception("签名校验失败");
        }
        isHandle = true;
    }


    /**
     * 验证签名
     * @param map
     * @param sign
     * @param timestampStr
     * @param reqId
     * @return
     */
    public static Boolean verifySign(SortedMap<String, String> map,String sign,String timestampStr,String reqId) {

        StringBuilder stringBuilder = new StringBuilder();
        for (Map.Entry<String, String> entry : map.entrySet()) {
            stringBuilder.append(String.format("%s=%s&", entry.getKey(), entry.getValue()));
        }
        stringBuilder.append("timestamp="+timestampStr+"&");
        stringBuilder.append("reqId="+reqId+"&");
        stringBuilder.append("secret=" + ENCRYPT_SECRET);
        String paramsStr = stringBuilder.toString();
        String md5 = DigestUtil.md5Hex(paramsStr).toLowerCase();
        return md5.equals(sign);
    }
}

3、创建过滤器XssFilter

/**
 * @Author: zjhang
 * @Description: 拦截会话重放攻击
 */
@WebFilter(urlPatterns = "/*", filterName = "signAuthFilter")
public class SignAuthFilter implements Filter {

    private FilterConfig filterConfig = null;
 
    // 是否验证签名
    private Boolean isRequireSign = true;
    
    // 被允许放行的路径,支持路由参数(antPathMatcher)
    private static final List<String> ALLOWED_PATHS = Arrays.asList(
            "/api/v1/file/upload",
            "/api/v1/file/vos/{ids}"//  例如 "/api/v1/file/vos/1" 或者 "/api/v1/file/vos/2" 
    );


    /**
    * 默认构造函数
    */
    public SignAuthFilter()
    {
        this.isRequireSign = true;
    }

    /**
    * 构造函数
    * @param isRequireSign  是否验证签名
    */
    public SignAuthFilter(Boolean isRequireSign)
    {
        this.isRequireSign = isRequireSign;
    }
  
    private static AntPathMatcher antPathMatcher =new AntPathMatcher();

    @Override
    public void init(FilterConfig filterConfig) throws ServletException {
        this.filterConfig = filterConfig;
    }

    // 是否是被放行的请求
    private static Boolean isAllowedPath(String path) {
        for (String allowedPath : ALLOWED_PATHS) {
            // antPathMatcher匹配
            if (antPathMatcher.match(allowedPath, path)) {
                return true;
            }
        }
        return false;
    }

    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
        HttpServletRequest request = (HttpServletRequest) servletRequest;
        String path = request.getRequestURI().substring(request.getContextPath().length()).replaceAll("[/]+$", "");

        if (isAllowedPath(path)) {
            // 正常的请求
            filterChain.doFilter(servletRequest, servletResponse);
        } else {
            // 签名校验过滤器
            filterChain.doFilter(new SignAuthServletRequestWrapper((HttpServletRequest) servletRequest, (HttpServletResponse) servletResponse, isRequireSign), servletResponse);
        } 
    }

    @Override
    public void destroy() {
        this.filterConfig = null;
    }
}

4、创建AuthFilterConfig配置过滤器

@Configuration
public class XSSFilterAndAuthFilterConfig {

    /**
     * 是否验证签名(映射配置: config.isRequireSign).
     */
    @Value("${config.isRequireSign:false}")
    private Boolean isRequireSign;
 
    /**
     * 注册sign过滤器
     * @return
     */
    @Bean
    public FilterRegistrationBean authSignFilterRegistration() {
        FilterRegistrationBean filterRegistrationBean = new FilterRegistrationBean(new SignAuthFilter(isRequireSign));
        filterRegistrationBean.addUrlPatterns("/*");
        filterRegistrationBean.setName("SignAuthFilter");
        filterRegistrationBean.setOrder(98);
        return filterRegistrationBean;
    }
}

5、postman调用脚本

function addHeader(key,value){
    pm.request.headers.remove(key)
    pm.request.headers.add({
        key: key,
        value: value
    })
}

function signHandler(params)
{
    let secret = "350cb4303db847788eee5b125f7fdbaa"
    let timestamp = Date.now()
    let reqId = guid()
  
    let queryStr = '' 

    if(params instanceof Array)
    { 
        queryStr = buildQueryStrByArray(params)
    }
    else if(params instanceof Object)
    { 
        queryStr = buildQueryStrByObj(params)
    }
    else{
        throw new Exception('非对象或数组类型')
    }

    // queryStr结尾拼接时间戳和私钥
    queryStr += "timestamp=" + timestamp + "&"
    queryStr += "reqId=" + reqId  + "&"
    queryStr += "secret=" + secret 
    let sign = CryptoJS.MD5(queryStr).toString().toLowerCase();
    addHeader('Timestamp',timestamp)
    addHeader('reqId',reqId)
    addHeader('Sign',sign)
}

// 生成guid
function guid() {
    return 'xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx'.replace(/[xy]/g, function (c) {
        let r = Math.random() * 16 | 0,
            v = c == 'x' ? r : (r & 0x3 | 0x8);
        return v.toString(16);
    });
}

const access_token_name = 'Access-Token'
const csrf_token_name = 'X-XSRF-TOKEN'

var access_token = postman.getEnvironmentVariable("access_token");
var xsrf_token = postman.getEnvironmentVariable("xsrf-token");

// access token
addHeader(access_token_name,access_token)
// csrf token
addHeader(csrf_token_name,xsrf_token)

function getQueryParams()
{
    let query = pm.request.url.query;
    let params = {}
    query.each(item=>{
        console.log(item)
        if(!item.disabled){
            params[item.key] = item.value
        }
    })
    return params;
}

function getRawParams()
{
    // console.log(pm.request.body.raw)
    let params = JSON.parse(pm.request.body.raw)
    return params;
}

function getFormParams()
{
    let params = {}
    let urlencoded = pm.request.body.urlencoded;
    urlencoded.each(item=>{
        if(!item.disabled)
        {
            params[item.key] = item.value
        }
    })
    return params;
}

function objKeySort(obj) {
    var sortedKeys = Object.keys(obj).sort();
    // console.log(sortedKeys)
    //创建一个新的对象,用于存放排好序的键值对
    var newObj = {};
    //遍历sortedKeys数组
    for(var i = 0; i < sortedKeys.length; i++) {
        //向新创建的对象中按照排好的顺序依次增加键值对
        newObj[sortedKeys[i]] = obj[sortedKeys[i]];
    }
    //返回排好序的新对象
    // console.log(newObj)
    return newObj;
}

function buildQueryStrByObj(params)
{
    // 获取参数对象所有key
    let paramKeys = []
    for (let key in params) {
        paramKeys.push(key)
    }

    // 按字母a-z排序
    paramKeys.sort()
    let queryStr = ""
    for (let i in paramKeys) {
        let key = paramKeys[i]
        let param = params[key]
        if (param instanceof Object || param instanceof Array) {
            queryStr += key + "=" + JSON.stringify(param) + "&"
        } else {
            queryStr += key + "=" + param + "&"
        }
    }
    return queryStr
}

function buildQueryStrByArray(paramsArr)
{
    let queryStr = ""

    for(let i= 0; i< paramsArr.length;i++)
    {
        let key = i
        let param = objKeySort(paramsArr[i])
        if (param instanceof Object || param instanceof Array) {
            queryStr += key + "=" + JSON.stringify(param) + "&"
        } else {
            queryStr += key + "=" + param + "&"
        }
    }

    return queryStr
}

// console.log(pm.request.method)
let method = pm.request.method;
let mode = pm.request.body.mode;
console.log(pm.request)
// 暂时只处理GET和POST,POST里的文件上传暂时不处理
if(method == 'GET' || method == 'POST')
{
    try
    {
        let params = null

        if(method == 'GET')
        {
            params = getQueryParams();
        }
        else if(method == 'POST')
        {
            if(mode == 'raw')
            {
                params = getRawParams();
            }
            else if(mode == 'urlencoded')
            {
                params = getFormParams()
            }
        }

        if(params != null && params != undefined)
        {
            // 签名处理
            signHandler(params)
        }
    }
    catch(e)
    {
        console.log(e)
    }
}

启发思路来源于网络,对 签名验证方法里存在的问题进行修改和补充(添加reqId、兼容对象和数组,提供了postman的调用demo)