大家都知道Java的servlet分get和post请求方式,在servlet或者在集成了springMVC、Struts2的框架的情况下获取请求的参数。那么有时候我们需要在拦截其中获取ServletRequest的参数就不那么容易了。因为在ServletRequst中,如果是get请求我们可以通过request.getParameter("")来获取get的参数或者是form提交的post参数,但是如果是ajax提交的post请求的application/json请求,那么在get的时候就无法获取到值了,有人会想我通过request的请求流来解析json文本,这样做是可以的,但 是有个问题就是如果在拦截其中调用了ServletRequest的getInputStream方法,那么在后面的servlet中或者你集成的框架中的control层就无法调用getInputStream方法来解析获取参数了。 

  有了上面的疑问,我们就有了分析,解决办法的途径。通过对HttpServletRequest的分析结合资料,最后得出结论就是改写ServletRequst的getInputStream方法便可以解决问题。我们可以分析一下HttpServletRequest的中的stream只能被read一次,那么我们可以在filter中调用getInputSteam获取json字符串,然后通过获取的json文本去生成新的stream来给ServletRequest,后面的control就可以继续获取stream(我们自己用json文本生成)。有了这个思路我们就来看看代码。
一.改写ServletRequest

PostServletRequest.java

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;

import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;

import org.apache.commons.lang3.StringUtils;

public class PostServletRequest extends HttpServletRequestWrapper {
    
    private String body=null;

    /**
     * Constructs a request object wrapping the given request.
     * @param request
     * @throws IllegalArgumentException if the request is null
     */
    public PostServletRequest(HttpServletRequest request,String body) {
        super(request);
        this.body=body;
    }


    @Override
    public ServletInputStream getInputStream() throws IOException {
        ServletInputStream inputStream = null;
        if(StringUtils.isNotEmpty(body)){
            inputStream =  new PostServletInputStream(body);
        }
        return inputStream;
    }


    @Override
    public BufferedReader getReader() throws IOException {
        String enc = getCharacterEncoding();
        if(enc == null) enc = "UTF-8";
        return new BufferedReader(new InputStreamReader(getInputStream(), enc));
    }
}

 

二.ServletInputStream的改写

PostServletInputStream.java

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;

import javax.servlet.ServletInputStream;

public class PostServletInputStream  extends ServletInputStream {
    
    private InputStream inputStream;
    private String body ;//解析json之后的文本

    public PostServletInputStream(String body) throws IOException {
        this.body=body;
        inputStream = null;
    }


    private InputStream acquireInputStream() throws IOException {
        if(inputStream == null) {
            inputStream = new ByteArrayInputStream(body.getBytes());//通过解析之后传入的文本生成inputStream以便后面control调用
        }

        return inputStream;
    }


    public void close() throws IOException {
        try {
            if(inputStream != null) {
                inputStream.close();
            }
        }
        catch(IOException e) {
            throw e;
        }
        finally {
            inputStream = null;
        }
    }


    public int read() throws IOException {
        return acquireInputStream().read();
    }


    public boolean markSupported() {
        return false;
    }


    public synchronized void mark(int i) {
        throw new UnsupportedOperationException("mark not supported");
    }


    public synchronized void reset() throws IOException {
        throw new IOException(new UnsupportedOperationException("reset not supported"));
    }
}

 

三.在filter中的调用

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * 过滤器
 */
public class UrlFilter extends AbstractWebFilter {

    private final static Logger LOGGER = LoggerFactory.getLogger(UrlFilter.class);
    
    private final static String MERID_WHITE_LIST = "merABC:test/login.do,test/getList.do;merDEF:test/hello.do,test/greet.do";
    private static Map<String, List<String>> merIdWhiteListMap = new HashMap<String, List<String>>();

    @Override
    public void init(FilterConfig filterConfig) throws ServletException {
        
        String[] merIdWhiteList = MERID_WHITE_LIST.split(";");
        if(merIdWhiteList != null && merIdWhiteList.length > 0) {
            int merIdSize = merIdWhiteList.length;
            for(int i=0;i<merIdSize;i++) {
                String merIdUrls = merIdWhiteList[i];
                String[] merIdUrl = merIdUrls.split(":");
                if(merIdUrl != null && merIdUrl.length == 2) {
                    String merId = merIdUrl[0];
                    String urls = merIdUrl[1];
                    String[] urlList = urls.split(",");
                    if(urlList != null && urlList.length > 0) {
                        List<String> lists = Arrays.asList(urlList);
                        merIdWhiteListMap.put(merId, lists);
                    }
                }
            }
        }
        
        LOGGER.info("merIdWhiteListMap:{}", JsonUtil.toJsonStr(merIdWhiteListMap));
    }

    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
        HttpServletRequest req = (HttpServletRequest) request;
        HttpServletResponse rsp = (HttpServletResponse) response;
        String url = getRequestPath(req);
        if(url.endsWith(".do")) {
            //解析post的json参数,进一步根据请求入参校验
            String body = getBody((HttpServletRequest)request);
            if(StringUtils.isNotEmpty(body)) {
                String headMerId = (String) ParamsReflectUtil.getFieldValueRecursive(body, "headMerId");
                if(StringUtils.isNotEmpty(headMerId)) {
                    List<String> urlList = merIdWhiteListMap.get(headMerId);
                    LOGGER.info("urlList:{}, curretn url:{}", urlList, url);
                    if(urlList.contains(url)) {
                        
                        //使用解析数据重新生成ServletRequest,供doChain调用
                        request = getRequest(request,body);
                        
                        chain.doFilter(request, response);
                    }else {
                        LOGGER.info("非法url请求: {},请求入参:{}", url, body);
                        forbiddenJson(rsp);
                    }
                }else {
                    chain.doFilter(request, response);
                }
            }else {
                chain.doFilter(request, response);
            }
        }else {
            LOGGER.info("非法url请求: {}", url);
            forbiddenJson(rsp);
        }
    }

    @Override
    public void destroy() {

    }

    /**
     * 返回ajax信息
     */
    private void forbiddenJson(HttpServletResponse httpResponse) throws IOException {
        Map<String,Object> param = new HashMap<String,Object>();
        param.put("error", "403");
        httpResponse.setStatus(403);
        httpResponse.setCharacterEncoding("utf-8");
        httpResponse.setContentType("application/json");
        httpResponse.getWriter().print(JsonUtil.toJsonStr(param));
    }

    private String getBody(HttpServletRequest request) throws IOException {
        String body = null;
        StringBuilder stringBuilder = new StringBuilder();
        BufferedReader bufferedReader = null;
        try {
            InputStream inputStream = request.getInputStream();
            if (inputStream != null) {
                bufferedReader = new BufferedReader(new InputStreamReader(inputStream));
                char[] charBuffer = new char[128];
                int bytesRead = -1;
                while ((bytesRead = bufferedReader.read(charBuffer)) > 0) {
                    stringBuilder.append(charBuffer, 0, bytesRead);
                }
            } else {
                stringBuilder.append("");
            }
        } catch (IOException ex) {
            throw ex;
        } finally {
            if (null != bufferedReader) {
                bufferedReader.close();
            }
        }
        body = stringBuilder.toString();
        return body;
    }

    /**
     * 将post解析过后的request进行封装改写
     * @param request
     * @param body
     * @return
     */
    private ServletRequest getRequest(ServletRequest request, String body) {
        String enctype = request.getContentType();
        if (StringUtils.isNotEmpty(enctype) && enctype.contains("application/json")) {
            return new PostServletRequest((HttpServletRequest) request, body);
        }
        return request;
    }
}

AbstractWebFilter.java

import java.io.IOException;

import javax.servlet.Filter;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;public abstract class AbstractWebFilter implements Filter {
    
    private static Logger log = LoggerFactory.getLogger(AbstractWebFilter.class);

    /**
     * 获取客户端ip
     */
    protected String getRemortIP(HttpServletRequest request) {
        if (request.getHeader("x-forwarded-for") == null) {
            return request.getRemoteAddr();
        }
        return request.getHeader("x-forwarded-for");
    }

    /**
     * 获取请求的url不含上下文的路径(并且路径开头不带"/")
     */
    protected String getRequestPath(HttpServletRequest request) {
        String requestPath = request.getServletPath();
        if (requestPath != null && requestPath.startsWith("/")) {
            return requestPath.substring(1);
        }
        return requestPath;
    }
    
    /**
     * 重定向到登录页面
     */
    private void forbiddenRedirect(HttpServletResponse httpResponse) throws IOException {
        String logoutUrl = PropertiesUtils.getString("logoutUrl");
        httpResponse.sendRedirect(logoutUrl);
    }
}

ParamsReflectUtil.java

import java.util.Iterator;
import java.util.Map;
import java.util.TreeMap;

import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;

/**
 * 属性(对象)值反射获取工具类
 */
public class ParamsReflectUtil {
    
    private final static Logger logger = LoggerFactory.getLogger(ParamsReflectUtil.class);
    
    public static Object getFieldValueRecursive(String jsonStr, String field) {
        
        JSONObject jsonObject = JSON.parseObject(jsonStr);
        Object fieldValue = null;
        for (Iterator iter = jsonObject.keySet().iterator(); iter.hasNext();) {
            String name = (String) iter.next();
            Object value = jsonObject.get(name);
            if ((value != null) && (name != null)) {
                if (value instanceof JSONObject) {
                    fieldValue = getFieldValueRecursive(JsonUtil.toJsonStr(value), field);
                }else if(value instanceof String){
                    if(name.equals(field)) {
                        fieldValue = jsonObject.get(name);
                        break;
                    }
                }
            }
        }
        return fieldValue;
    }
}