一、SpringBoot针对富文本和非富文本添加xss过滤(如果富文本字段是唯一,这里的唯一是不跟非富文本字段同名,实际写一个HttpServletRequestWrapper就行)

1.xss过滤器

package com.doctortech.tmc.filter;

import com.doctortech.tmc.support.xss.XssHttpServletRequestWrapper;
import com.doctortech.tmc.support.xss.XssRichTextHttpServletRequestWrapper;
import org.springframework.stereotype.Component;

import javax.servlet.*;
import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;

/**
 * @author zxb
 * @version 1.0
 * @date 2021/08/10 11:43
 * @description xss过滤器
 */
@WebFilter
@Component
public class XssFilter implements Filter {

    @Override
    public void init(FilterConfig filterConfig) throws ServletException {

    }

    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
        //获取请求数据
        HttpServletRequest req = (HttpServletRequest) servletRequest;
        //获取请求的url路径
        String path = ((HttpServletRequest) servletRequest).getServletPath();
        //声明要被忽略请求的数组
        String[] exclusionsUrls = {"/fileUpload/upload","/fileUpload/upload/img"};
        //声明带有富文本的接口数组
        String[] richTextUrls = {"/admin/adminArticle/add", "/admin/adminArticle/update", "/admin/adminPolicy/add"};
        //第一种xss过滤
        XssRichTextHttpServletRequestWrapper xssAndSqlHttpServletRequestWrapper = new XssRichTextHttpServletRequestWrapper(req);
        //遍历忽略的请求数组,若该接口url为忽略的就调用原本的过滤器,不走xss过滤
        for (String str : exclusionsUrls) {
            if (path.contains(str)) {
                filterChain.doFilter(servletRequest, servletResponse);
                return;
            }
        }
        //若为带有富文本的接口,走富文本xss过滤
        for (String rtu : richTextUrls) {
            if (path.contains(rtu)) {
                filterChain.doFilter(xssAndSqlHttpServletRequestWrapper, servletResponse);
                return;
            }
        }
        //将请求放入XSS请求包装器中,返回过滤后的值
        XssHttpServletRequestWrapper xssRequestWrapper = new XssHttpServletRequestWrapper(req);
        filterChain.doFilter(xssRequestWrapper, servletResponse);
    }

    @Override
    public void destroy() {

    }
}

2.针对富文本接口进行过滤

package com.doctortech.tmc.support.xss;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.commons.lang3.StringUtils;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;
import java.nio.charset.Charset;
import java.util.*;

/**
 * @author zxb
 * @version 1.0
 * @date 2021/08/10 10:58
 * @description XSS过滤(针对含富文本变量的接口进行过滤)
 */
public class XssRichTextHttpServletRequestWrapper extends HttpServletRequestWrapper {
    /**
     * 声明sql注入的关键词key
     */
    private static String sqlKey = "and|exec|insert|select|delete|update|count|*|%|chr|mid|master|truncate|char|declare|;|or|-|+";
    /**
     * 声明富文本字段,多个请用”|“分隔开
     */
    private static String richTextKey = "content";
    private static Set<String> notAllowedKeyWords = new HashSet<>(0);
    private static Set<String> richTextKeySet = new HashSet<>(0);

    /**
     * 初始化sql注入关键词
     */
    static {
        String[] keyStr = sqlKey.split("\\|");
        //将key添加到Set集合中
        for (String str : keyStr) {
            notAllowedKeyWords.add(str);
        }
    }

    /**
     * 初始化富文本字段
     */
    static {
        String[] keyStr = richTextKey.split("\\|");
        //将key添加到Set集合中
        for (String str : keyStr) {
            richTextKeySet.add(str);
        }
    }

    /**
     * 构造函数
     * @param request
     */
    public XssRichTextHttpServletRequestWrapper(HttpServletRequest request) {
        super(request);
    }

    /**
     * 重写getParameter
     * @param name
     * @return
     */
    @Override
    public String getParameter(String name) {
        String value = super.getParameter(name);
        if (!StringUtils.isEmpty(value)) {
            value = cleanXSS(value);
            value = cleanSqlKeyWords(value);
        }
        return value;
    }

    @Override
    public String[] getParameterValues(String name) {
        String[] parameterValues = super.getParameterValues(name);
        if (parameterValues == null) {
            return null;
        }
        for (int i = 0; i < parameterValues.length; i++) {
            String value = parameterValues[i];
            parameterValues[i] = cleanXSS(value);
            parameterValues[i] = cleanSqlKeyWords(parameterValues[i]);
        }
        return parameterValues;
    }

    @Override
    public String getHeader(String name) {
        //过滤xss攻击
        String value = cleanXSS(super.getHeader(name));
        if (value == null){
            return null;
        }
        //过滤sql注入
        return cleanSqlKeyWords(value);
    }

    @Override
    public String getQueryString() {
        return cleanXSS(super.getQueryString());
    }

    /**
     * 过滤json数据中xss攻击
     * @return
     * @throws IOException
     */
    @Override
    public ServletInputStream getInputStream() throws IOException {
        //调用方法将流数据return为String
        String bodyStr = getRequestBody(super.getInputStream());
        //如果bodyStr为"",则返回0
        if ("".equals(bodyStr)) {
            return new ServletInputStream() {
                @Override
                public int read() throws IOException {
                    return 0;
                }

                @Override
                public boolean isFinished() {
                    return false;
                }

                @Override
                public boolean isReady() {
                    return false;
                }

                @Override
                public void setReadListener(ReadListener readListener) {

                }
            };
        }

        final ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(bodyStr.getBytes());
        return new ServletInputStream() {
            @Override
            public int read() throws IOException {
                return byteArrayInputStream.read();
            }

            @Override
            public boolean isFinished() {
                return false;
            }

            @Override
            public boolean isReady() {
                return false;
            }

            @Override
            public void setReadListener(ReadListener readListener) {

            }

        };
    }

    /**
     * 获取json数据
     * @param stream
     * @return
     */
    private String getRequestBody(InputStream stream) {
        String line = "";
        StringBuilder body = new StringBuilder();
        int counter = 0;
        // 读取POST提交的数据内容
        BufferedReader reader = new BufferedReader(new InputStreamReader(stream, Charset.forName("UTF-8")));
        try {
            while ((line = reader.readLine()) != null) {
                //拼接读取到的数据
                body.append(line);
                counter++;
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
        if (body == null) {
            return "";
        }
        //最后返回数据
        String data = transJsonNode(body.toString());
        return data;
    }

    /**
     * xss过滤
     * @param valueP
     * @return
     */
    private String cleanXSS(String valueP) {
        if (StringUtils.isBlank(valueP)) {
            return "";
        }
        String value = valueP.replaceAll("<[\\s]*?script[^>]*?>[\\s\\S]*?<[\\s]*?\\/[\\s]*?script[\\s]*?>", "");
        value = value.replaceAll("<[\\s]*?javascript[^>]*?>[\\s\\S]*?<[\\s]*?\\/[\\s]*?javascript[\\s]*?>", "");
        value = value.replaceAll("<", "<").replaceAll(">", ">");
        value = value.replaceAll("eval\\((.*)\\)", "");
        value = value.replaceAll("alert", "");
        value = cleanSqlKeyWords(value);
        return value;
    }

    /**
     * 过滤富文本中的xss
     * @param valueP
     * @return
     */
    private String cleanRichTextXSS(String valueP) {
        if (StringUtils.isBlank(valueP)) {
            return "";
        }
        String value = valueP.replaceAll("eval\\((.*)\\)", "");
        value = value.replaceAll("<[\\s]*?script[^>]*?>[\\s\\S]*?<[\\s]*?\\/[\\s]*?script[\\s]*?>", "");
        value = value.replaceAll("<[\\s]*?javascript[^>]*?>[\\s\\S]*?<[\\s]*?\\/[\\s]*?javascript[\\s]*?>", "");
        value = value.replaceAll("alert", "");
        value = cleanSqlKeyWords(value);
        return value;
    }

    /**
     * 过滤sql关键字,避免sql注入
     * @param value
     * @return
     */
    private String cleanSqlKeyWords(String value) {
        String paramValue = value;
        //暂时不过滤sql关键字
//        for (String keyword : notAllowedKeyWords) {
//            if (paramValue.length() > keyword.length() + 4
//                    && (paramValue.contains(" "+keyword)||paramValue.contains(keyword+" ")||paramValue.contains(" "+keyword+" "))) {
//                paramValue = StringUtils.replace(paramValue, keyword, replacedString);
//            }
//        }
        return paramValue;
    }

    /**
     * 将json字符串数据转成json树,再深度遍历去除xss
     * @param jsonStr json字符串
     * @return
     */
    private String transJsonNode(String jsonStr) {
        String str = "";
        try {
            ObjectMapper objectMapper = new ObjectMapper();
            JsonNode jsonNode = objectMapper.readTree(jsonStr);
            str = objectMapper.writeValueAsString(cleanJsonNodeXSS(jsonNode));
        } catch (JsonProcessingException e) {
            e.printStackTrace();
        }
        return str;
    }

    /**
     * 对json树深度遍历去除xss
     * @param jsonNode json树
     * @return
     */
    private Object cleanJsonNodeXSS(JsonNode jsonNode) {
        Iterator<Map.Entry<String, JsonNode>> fields = jsonNode.fields();
        if (!fields.hasNext()) {
            String value = jsonNode.asText();
            return cleanXSS(value);
        }
        Map<String, Object> map = new HashMap<>();
        while(fields.hasNext()) {
            Map.Entry<String, JsonNode> next = fields.next();
            if (next.getValue().isTextual()) {
                String value = next.getValue().asText();
                String key = next.getKey();
                //如果key=富文本字段名,进行特殊过滤
                String str = "";
                if (richTextKeySet.contains(key)) {
                    str = cleanRichTextXSS(value);
                } else {
                    str = cleanXSS(value);
                }
                map.put(next.getKey(),str);
            }else if (next.getValue().isObject()){
                map.put(next.getKey(),cleanJsonNodeXSS(next.getValue()));
            }else if(next.getValue().isArray()) {
                List<Object> elementList = new ArrayList<>();
                Iterator<JsonNode> elements = next.getValue().elements();
                while (elements.hasNext()) {
                    JsonNode childrenNext = elements.next();
                    Object nodeMap = cleanJsonNodeXSS(childrenNext);
                    elementList.add(nodeMap);
                }
                map.put(next.getKey(),elementList);
            }else {
                map.put(next.getKey(),next.getValue());
            }
        }
        return map;
    }
}

2.针对非富文本接口过滤

package com.doctortech.tmc.support.xss;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;
import java.nio.charset.Charset;
import java.util.*;

/**
 * @author zxb
 * @version 1.0
 * @date 2021/08/10 11:29
 * @description xss过滤(针对不含富文本变量的接口进行过滤)
 */
public class XssHttpServletRequestWrapper extends HttpServletRequestWrapper  {
    private static String key = "and|exec|insert|select|delete|update|count|*|%|chr|mid|master|truncate|char|declare|;|or|-|+";
    private static Set<String> notAllowedKeyWords = new HashSet<String>(0);

    /**
     * 初始化sql注入关键词
     */
    static {
        String[] keyStr = key.split("\\|");
        for (String str : keyStr) {
            notAllowedKeyWords.add(str);
        }
    }

    /**
     * 构造函数
     * @param servletRequest
     * @throws IOException
     */
    public XssHttpServletRequestWrapper(HttpServletRequest servletRequest) throws IOException {
        super(servletRequest);
    }

    /**
     * 重写getInputStream方法,过滤json数据
     * @return
     * @throws IOException
     */
    @Override
    public ServletInputStream getInputStream() throws IOException {
        String bodyStr = getRequestBody(super.getInputStream());
        if ("".equals(bodyStr)) {
            return new ServletInputStream() {
                @Override
                public int read() throws IOException {
                    return 0;
                }

                @Override
                public boolean isFinished() {
                    return false;
                }

                @Override
                public boolean isReady() {
                    return false;
                }

                @Override
                public void setReadListener(ReadListener readListener) {

                }
            };
        }
        final ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(bodyStr.getBytes());
        return new ServletInputStream() {
            @Override
            public int read() throws IOException {
                return byteArrayInputStream.read();
            }

            @Override
            public boolean isFinished() {
                return false;
            }

            @Override
            public boolean isReady() {
                return false;
            }

            @Override
            public void setReadListener(ReadListener readListener) {

            }

        };
    }

    /**
     * 获取json数据
     * @param stream
     * @return
     */
    private String getRequestBody(InputStream stream) {
        String line = "";
        StringBuilder body = new StringBuilder();
        // 读取POST提交的数据内容
        BufferedReader reader = new BufferedReader(new InputStreamReader(stream, Charset.forName("UTF-8")));
        try {
            while ((line = reader.readLine()) != null) {
                //拼接读取到的数据
                body.append(line);
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
        if (body == null) {
            return "";
        }
        //最后返回数据
        String data = transJsonNode(body.toString());
        return data;
    }

    /**
     * 将容易引起xss漏洞的半角字符替换成全角字符
     * @param s
     * @param type
     * @return
     */
    private static String xssEncode(String s, int type) {
        if (s == null || s.isEmpty()) {
            return s;
        }
        StringBuilder sb = new StringBuilder(s.length() + 16);
        for (int i = 0; i < s.length(); i++) {
            char c = s.charAt(i);
            if (type == 0) {
                switch (c) {
                    case '\'':
                        // 全角单引号
                        sb.append('‘');
                        break;
                    case '\"':
                        // 全角双引号
                        sb.append('“');
                        break;
                    case '>':
                        // 全角大于号
                        sb.append('>');
                        break;
                    case '<':
                        // 全角小于号
                        sb.append('<');
                        break;
                    case '&':
                        // 全角&符号
                        sb.append('&');
                        break;
                    case '\\':
                        // 全角斜线
                        sb.append('\');
                        break;
                    case '#':
                        // 全角井号
                        sb.append('#');
                        break;
                    // < 字符的 URL 编码形式表示的 ASCII 字符(十六进制格式) 是: %3c
                    case '%':
                        processUrlEncoder(sb, s, i);
                        break;
                    default:
                        sb.append(c);
                        break;
                }
            } else {
                switch (c) {
                    case '>':
                        // 全角大于号
                        sb.append('>');
                        break;
                    case '<':
                        // 全角小于号
                        sb.append('<');
                        break;
                    case '&':
                        // 全角&符号
                        sb.append('&');
                        break;
                    case '#':
                        // 全角井号
                        sb.append('#');
                        break;
                    // < 字符的 URL 编码形式表示的 ASCII 字符(十六进制格式) 是: %3c
                    case '%':
                        processUrlEncoder(sb, s, i);
                        break;
                    default:
                        sb.append(c);
                        break;
                }
            }

        }
        return sb.toString();
    }

    /**
     * 针对特殊字符的编码进行处理
     * @param sb
     * @param s
     * @param index
     */
    public static void processUrlEncoder(StringBuilder sb, String s, int index) {
        if (s.length() >= index + 2) {
            // %3c, %3C
            if (s.charAt(index + 1) == '3' && (s.charAt(index + 2) == 'c' || s.charAt(index + 2) == 'C')) {
                sb.append('<');
                return;
            }
            // %3c (0x3c=60)
            if (s.charAt(index + 1) == '6' && s.charAt(index + 2) == '0') {
                sb.append('<');
                return;
            }
            // %3e, %3E
            if (s.charAt(index + 1) == '3' && (s.charAt(index + 2) == 'e' || s.charAt(index + 2) == 'E')) {
                sb.append('>');
                return;
            }
            // %3e (0x3e=62)
            if (s.charAt(index + 1) == '6' && s.charAt(index + 2) == '2') {
                sb.append('>');
                return;
            }
        }
        sb.append(s.charAt(index));
    }

    /**
     * 重写getParameter,对参数进行xss过滤
     * @param parameter
     * @return
     */
    @Override
    public String getParameter(String parameter) {
        String value = super.getParameter(parameter);
        if (value == null) {
            return null;
        }
        return cleanXSS(value);
    }

    /**
     * 重写getParameterValues,对参数进行xss过滤
     * @param parameter
     * @return
     */
    @Override
    public String[] getParameterValues(String parameter) {
        String[] values = super.getParameterValues(parameter);
        if (values == null) {
            return null;
        }
        int count = values.length;
        String[] encodedValues = new String[count];
        for (int i = 0; i < count; i++) {
            encodedValues[i] = cleanXSS(values[i]);
        }
        return encodedValues;
    }

    /**
     * 重写getParameterMap,对参数进行xss过滤
     * @return
     */
    @Override
    public Map<String, String[]> getParameterMap(){
        Map<String, String[]> values = super.getParameterMap();
        if (values == null) {
            return null;
        }
        Map<String, String[]> result = new HashMap<>();
        for(String key:values.keySet()){
            String encodedKey = cleanXSS(key);
            int count = values.get(key).length;
            String[] encodedValues = new String[count];
            for (int i = 0; i < count; i++){
                encodedValues[i] = cleanXSS(values.get(key)[i]);
            }
            result.put(encodedKey,encodedValues);
        }
        return result;
    }

    /**
     * 重写getHeader,对参数进行xss过滤
     * @param name
     * @return
     */
    @Override
    public String getHeader(String name) {
        String value = super.getHeader(name);
        if (value == null) {
            return null;
        }
        return cleanXSS(value);
    }

    /**
     * xss过滤
     * @param valueP 内容
     * @return
     */
    private String cleanXSS(String valueP) {
        String value = valueP.replaceAll("<[\\s]*?script[^>]*?>[\\s\\S]*?<[\\s]*?\\/[\\s]*?script[\\s]*?>", "");
        value = value.replaceAll("<[\\s]*?javascript[^>]*?>[\\s\\S]*?<[\\s]*?\\/[\\s]*?javascript[\\s]*?>", "");
        value = value.replaceAll("<", "<").replaceAll(">", ">");
        value = value.replaceAll("<", "& lt;").replaceAll(">", "& gt;");
        value = value.replaceAll("eval\\((.*)\\)", "");
        value = value.replaceAll("alert", "");
        value = cleanSqlKeyWords(value);
        return value;
    }

    /**
     * sql注入过滤
     * @param value 内容
     * @return
     */
    private String cleanSqlKeyWords(String value) {
        String paramValue = value;
        //暂时
//        for (String keyword : notAllowedKeyWords) {
//            if (paramValue.length() > keyword.length() + 3
//                    && (paramValue.contains(" "+keyword)||paramValue.contains(keyword+" ")||paramValue.contains(" "+keyword+" "))) {
//                paramValue = StringUtils.replace(paramValue, keyword, replacedString);
//                System.out.println(this.currentUrl + "已被过滤,因为参数中包含不允许sql的关键词(" + keyword
//                        + ")"+";参数:"+value+";过滤后的参数:"+paramValue);
//            }
//        }
        return paramValue;
    }

    /**
     * 将json字符串数据转成json树,再深度遍历去除xss
     * @param jsonStr json字符串
     * @return
     */
    private String transJsonNode(String jsonStr) {
        String str = "";
        try {
            ObjectMapper objectMapper = new ObjectMapper();
            JsonNode jsonNode = objectMapper.readTree(jsonStr);
            str = objectMapper.writeValueAsString(cleanJsonNodeXSS(jsonNode));
        } catch (JsonProcessingException e) {
            e.printStackTrace();
        }
        return str;
    }

    /**
     * 对json树深度遍历去除xss
     * @param jsonNode json树
     * @return
     */
    private Object cleanJsonNodeXSS(JsonNode jsonNode) {
        Iterator<Map.Entry<String, JsonNode>> fields = jsonNode.fields();
        if (!fields.hasNext()) {
            String value = jsonNode.asText();
            return cleanXSS(value);
        }
        Map<String, Object> map = new HashMap<>();
        while(fields.hasNext()) {
            Map.Entry<String, JsonNode> next = fields.next();
            if (next.getValue().isTextual()) {
                String value = next.getValue().asText();
                String str = cleanXSS(value);
                map.put(next.getKey(),str);
            }else if (next.getValue().isObject()){
                map.put(next.getKey(),cleanJsonNodeXSS(next.getValue()));
            }else if(next.getValue().isArray()) {
                List<Object> elementList = new ArrayList<>();
                Iterator<JsonNode> elements = next.getValue().elements();
                while (elements.hasNext()) {
                    JsonNode childrenNext = elements.next();
                    Object nodeMap = cleanJsonNodeXSS(childrenNext);
                    elementList.add(nodeMap);
                }
                map.put(next.getKey(),elementList);
            }
            else {
                map.put(next.getKey(),next.getValue());
            }

        }
        return map;
    }
}