大家都知道Java的servlet分get和post请求方式,在servlet或者在集成了springMVC、Struts2的框架的情况下获取请求的参数。那么有时候我们需要在拦截其中获取ServletRequest的参数就不那么容易了。因为在ServletRequst中,如果是get请求我们可以通过request.getParameter("")来获取get的参数或者是form提交的post参数,但是如果是ajax提交的post请求的application/json请求,那么在get的时候就无法获取到值了,有人会想我通过request的请求流来解析json文本,这样做是可以的,但 是有个问题就是如果在拦截其中调用了ServletRequest的getInputStream方法,那么在后面的servlet中或者你集成的框架中的control层就无法调用getInputStream方法来解析获取参数了。
有了上面的疑问,我们就有了分析,解决办法的途径。通过对HttpServletRequest的分析结合资料,最后得出结论就是改写ServletRequst的getInputStream方法便可以解决问题。我们可以分析一下HttpServletRequest的中的stream只能被read一次,那么我们可以在filter中调用getInputSteam获取json字符串,然后通过获取的json文本去生成新的stream来给ServletRequest,后面的control就可以继续获取stream(我们自己用json文本生成)。有了这个思路我们就来看看代码。
一.改写ServletRequest
PostServletRequest.java
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import org.apache.commons.lang3.StringUtils;
public class PostServletRequest extends HttpServletRequestWrapper {
private String body=null;
/**
* Constructs a request object wrapping the given request.
* @param request
* @throws IllegalArgumentException if the request is null
*/
public PostServletRequest(HttpServletRequest request,String body) {
super(request);
this.body=body;
}
@Override
public ServletInputStream getInputStream() throws IOException {
ServletInputStream inputStream = null;
if(StringUtils.isNotEmpty(body)){
inputStream = new PostServletInputStream(body);
}
return inputStream;
}
@Override
public BufferedReader getReader() throws IOException {
String enc = getCharacterEncoding();
if(enc == null) enc = "UTF-8";
return new BufferedReader(new InputStreamReader(getInputStream(), enc));
}
}
二.ServletInputStream的改写
PostServletInputStream.java
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import javax.servlet.ServletInputStream;
public class PostServletInputStream extends ServletInputStream {
private InputStream inputStream;
private String body ;//解析json之后的文本
public PostServletInputStream(String body) throws IOException {
this.body=body;
inputStream = null;
}
private InputStream acquireInputStream() throws IOException {
if(inputStream == null) {
inputStream = new ByteArrayInputStream(body.getBytes());//通过解析之后传入的文本生成inputStream以便后面control调用
}
return inputStream;
}
public void close() throws IOException {
try {
if(inputStream != null) {
inputStream.close();
}
}
catch(IOException e) {
throw e;
}
finally {
inputStream = null;
}
}
public int read() throws IOException {
return acquireInputStream().read();
}
public boolean markSupported() {
return false;
}
public synchronized void mark(int i) {
throw new UnsupportedOperationException("mark not supported");
}
public synchronized void reset() throws IOException {
throw new IOException(new UnsupportedOperationException("reset not supported"));
}
}
三.在filter中的调用
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* 过滤器
*/
public class UrlFilter extends AbstractWebFilter {
private final static Logger LOGGER = LoggerFactory.getLogger(UrlFilter.class);
private final static String MERID_WHITE_LIST = "merABC:test/login.do,test/getList.do;merDEF:test/hello.do,test/greet.do";
private static Map<String, List<String>> merIdWhiteListMap = new HashMap<String, List<String>>();
@Override
public void init(FilterConfig filterConfig) throws ServletException {
String[] merIdWhiteList = MERID_WHITE_LIST.split(";");
if(merIdWhiteList != null && merIdWhiteList.length > 0) {
int merIdSize = merIdWhiteList.length;
for(int i=0;i<merIdSize;i++) {
String merIdUrls = merIdWhiteList[i];
String[] merIdUrl = merIdUrls.split(":");
if(merIdUrl != null && merIdUrl.length == 2) {
String merId = merIdUrl[0];
String urls = merIdUrl[1];
String[] urlList = urls.split(",");
if(urlList != null && urlList.length > 0) {
List<String> lists = Arrays.asList(urlList);
merIdWhiteListMap.put(merId, lists);
}
}
}
}
LOGGER.info("merIdWhiteListMap:{}", JsonUtil.toJsonStr(merIdWhiteListMap));
}
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
HttpServletRequest req = (HttpServletRequest) request;
HttpServletResponse rsp = (HttpServletResponse) response;
String url = getRequestPath(req);
if(url.endsWith(".do")) {
//解析post的json参数,进一步根据请求入参校验
String body = getBody((HttpServletRequest)request);
if(StringUtils.isNotEmpty(body)) {
String headMerId = (String) ParamsReflectUtil.getFieldValueRecursive(body, "headMerId");
if(StringUtils.isNotEmpty(headMerId)) {
List<String> urlList = merIdWhiteListMap.get(headMerId);
LOGGER.info("urlList:{}, curretn url:{}", urlList, url);
if(urlList.contains(url)) {
//使用解析数据重新生成ServletRequest,供doChain调用
request = getRequest(request,body);
chain.doFilter(request, response);
}else {
LOGGER.info("非法url请求: {},请求入参:{}", url, body);
forbiddenJson(rsp);
}
}else {
chain.doFilter(request, response);
}
}else {
chain.doFilter(request, response);
}
}else {
LOGGER.info("非法url请求: {}", url);
forbiddenJson(rsp);
}
}
@Override
public void destroy() {
}
/**
* 返回ajax信息
*/
private void forbiddenJson(HttpServletResponse httpResponse) throws IOException {
Map<String,Object> param = new HashMap<String,Object>();
param.put("error", "403");
httpResponse.setStatus(403);
httpResponse.setCharacterEncoding("utf-8");
httpResponse.setContentType("application/json");
httpResponse.getWriter().print(JsonUtil.toJsonStr(param));
}
private String getBody(HttpServletRequest request) throws IOException {
String body = null;
StringBuilder stringBuilder = new StringBuilder();
BufferedReader bufferedReader = null;
try {
InputStream inputStream = request.getInputStream();
if (inputStream != null) {
bufferedReader = new BufferedReader(new InputStreamReader(inputStream));
char[] charBuffer = new char[128];
int bytesRead = -1;
while ((bytesRead = bufferedReader.read(charBuffer)) > 0) {
stringBuilder.append(charBuffer, 0, bytesRead);
}
} else {
stringBuilder.append("");
}
} catch (IOException ex) {
throw ex;
} finally {
if (null != bufferedReader) {
bufferedReader.close();
}
}
body = stringBuilder.toString();
return body;
}
/**
* 将post解析过后的request进行封装改写
* @param request
* @param body
* @return
*/
private ServletRequest getRequest(ServletRequest request, String body) {
String enctype = request.getContentType();
if (StringUtils.isNotEmpty(enctype) && enctype.contains("application/json")) {
return new PostServletRequest((HttpServletRequest) request, body);
}
return request;
}
}
AbstractWebFilter.java
import java.io.IOException;
import javax.servlet.Filter;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;public abstract class AbstractWebFilter implements Filter {
private static Logger log = LoggerFactory.getLogger(AbstractWebFilter.class);
/**
* 获取客户端ip
*/
protected String getRemortIP(HttpServletRequest request) {
if (request.getHeader("x-forwarded-for") == null) {
return request.getRemoteAddr();
}
return request.getHeader("x-forwarded-for");
}
/**
* 获取请求的url不含上下文的路径(并且路径开头不带"/")
*/
protected String getRequestPath(HttpServletRequest request) {
String requestPath = request.getServletPath();
if (requestPath != null && requestPath.startsWith("/")) {
return requestPath.substring(1);
}
return requestPath;
}
/**
* 重定向到登录页面
*/
private void forbiddenRedirect(HttpServletResponse httpResponse) throws IOException {
String logoutUrl = PropertiesUtils.getString("logoutUrl");
httpResponse.sendRedirect(logoutUrl);
}
}
ParamsReflectUtil.java
import java.util.Iterator;
import java.util.Map;
import java.util.TreeMap;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
/**
* 属性(对象)值反射获取工具类
*/
public class ParamsReflectUtil {
private final static Logger logger = LoggerFactory.getLogger(ParamsReflectUtil.class);
public static Object getFieldValueRecursive(String jsonStr, String field) {
JSONObject jsonObject = JSON.parseObject(jsonStr);
Object fieldValue = null;
for (Iterator iter = jsonObject.keySet().iterator(); iter.hasNext();) {
String name = (String) iter.next();
Object value = jsonObject.get(name);
if ((value != null) && (name != null)) {
if (value instanceof JSONObject) {
fieldValue = getFieldValueRecursive(JsonUtil.toJsonStr(value), field);
}else if(value instanceof String){
if(name.equals(field)) {
fieldValue = jsonObject.get(name);
break;
}
}
}
}
return fieldValue;
}
}