package org.tuckey.web.filters.urlrewrite;
import org.tuckey.web.filters.urlrewrite.utils.Log;
import org.tuckey.web.filters.urlrewrite.utils.StringUtils;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.lang.reflect.InvocationTargetException;
import java.net.URLDecoder;
import java.util.List;
public class UrlRewriter {
private static Log log = Log.getLog(UrlRewriter.class);
private Conf conf;
public UrlRewriter(Conf conf) {
this.conf = conf;
}
public RewrittenUrl processRequest(final HttpServletRequest hsRequest,
final HttpServletResponse hsResponse)
throws IOException, ServletException, InvocationTargetException {
RuleChain chain = getNewChain(hsRequest, null);
if (chain == null) return null;
chain.process(hsRequest, hsResponse);
return chain.getFinalRewrittenRequest();
}
public boolean processRequest(final HttpServletRequest hsRequest, final HttpServletResponse hsResponse,
FilterChain parentChain)
throws IOException, ServletException {
RuleChain chain = getNewChain(hsRequest, parentChain);
if (chain == null) return false;
chain.doRules(hsRequest, hsResponse);
return chain.isResponseHandled();
}
public String getPathWithinApplication(HttpServletRequest request) {
String requestUri = request.getRequestURI();
if (requestUri == null) requestUri = "";
String decodedRequestUri = decodeRequestString(request, requestUri);
String contextPath = getContextPath(request);
String path;
if (StringUtils.startsWithIgnoreCase(decodedRequestUri, contextPath) && !conf.isUseContext()) {
path = decodedRequestUri.substring(contextPath.length());
} else if (!StringUtils.startsWithIgnoreCase(decodedRequestUri, contextPath) && conf.isUseContext()) {
path = contextPath + decodedRequestUri;
} else {
path = decodedRequestUri;
}
return StringUtils.isBlank(path) ? "/" : path;
}
public String getContextPath(HttpServletRequest request) {
String contextPath = request.getContextPath();
if ("/".equals(contextPath)) {
contextPath = "";
}
return decodeRequestString(request, contextPath);
}
public String decodeRequestString(HttpServletRequest request, String source) {
if (conf.isDecodeUsingEncodingHeader()) {
String enc = request.getCharacterEncoding();
if (enc != null) {
try {
return URLDecoder.decode(source, enc);
} catch (UnsupportedEncodingException ex) {
if (log.isWarnEnabled()) {
log.warn("Could not decode: " + source + " (header encoding: '" + enc + "'); exception: " + ex.getMessage());
}
}
}
}
if (conf.isDecodeUsingCustomCharsetRequired()) {
String enc = conf.getDecodeUsing();
if (enc != null) {
try {
return URLDecoder.decode(source, enc);
} catch (UnsupportedEncodingException ex) {
if (log.isWarnEnabled()) {
log.warn("Could not decode: " + source + " (encoding: '" + enc + "') using default encoding; exception: " + ex.getMessage());
}
}
}
}
return source;
}
private RuleChain getNewChain(final HttpServletRequest hsRequest, FilterChain parentChain) {
String originalUrl = getPathWithinApplication(hsRequest);
if (originalUrl == null) {
log.debug("unable to fetch request uri from request. This shouldn't happen, it may indicate that " +
"the web application server has a bug or that the request was not pased correctly.");
return null;
}
if (log.isDebugEnabled()) {
log.debug("processing request for " + originalUrl);
}
if (originalUrl != null && originalUrl.indexOf("?") == -1 && conf.isUseQueryString()) {
String query = hsRequest.getQueryString();
if (query != null) {
query = query.trim();
if (query.length() > 0) {
originalUrl = originalUrl + "?" + query;
log.debug("query string added");
}
}
}
if (!conf.isOk()) {
log.debug("configuration is not ok. not rewriting request.");
return null;
}
final List rules = conf.getRules();
if (rules.size() == 0) {
log.debug("there are no rules setup. not rewriting request.");
return null;
}
return new RuleChain(this, originalUrl, parentChain);
}
public RewrittenUrl handleInvocationTargetException(final HttpServletRequest hsRequest,
final HttpServletResponse hsResponse, InvocationTargetException e)
throws ServletException, IOException {
Throwable originalThrowable = getOriginalException(e);
if (log.isDebugEnabled()) {
log.debug("attampting to find catch for exception " + originalThrowable.getClass().getName());
}
List catchElems = conf.getCatchElems();
for (int i = 0; i < catchElems.size(); i++) {
CatchElem catchElem = (CatchElem) catchElems.get(i);
if (!catchElem.matches(originalThrowable)) continue;
try {
return catchElem.execute(hsRequest, hsResponse, originalThrowable);
} catch (InvocationTargetException invocationExceptionInner) {
originalThrowable = getOriginalException(invocationExceptionInner);
log.warn("had exception processing catch, trying the rest of the catches with " +
originalThrowable.getClass().getName());
}
}
if (log.isDebugEnabled()) {
log.debug("exception unhandled", e);
}
if (originalThrowable instanceof Error) throw (Error) originalThrowable;
if (originalThrowable instanceof RuntimeException) throw (RuntimeException) originalThrowable;
if (originalThrowable instanceof ServletException) throw (ServletException) originalThrowable;
if (originalThrowable instanceof IOException) throw (IOException) originalThrowable;
throw new ServletException(originalThrowable);
}
private Throwable getOriginalException(InvocationTargetException e) throws ServletException {
Throwable originalThrowable = e.getTargetException();
if (originalThrowable == null) {
originalThrowable = e.getCause();
if (originalThrowable == null) {
throw new ServletException(e);
}
}
if (originalThrowable instanceof ServletException) {
ServletException se = (ServletException) originalThrowable;
for (int i = 0; i < 5 && se.getCause() instanceof ServletException; i++) {
se = (ServletException) se.getCause();
}
if (se.getCause() instanceof InvocationTargetException) {
return getOriginalException((InvocationTargetException) se.getCause());
} else {
throw se;
}
}
return originalThrowable;
}
public Conf getConf() {
return conf;
}
protected RewrittenOutboundUrl processEncodeURL(HttpServletResponse hsResponse, HttpServletRequest hsRequest,
boolean encodeUrlHasBeenRun, String outboundUrl) {
if (log.isDebugEnabled()) {
log.debug("processing outbound url for " + outboundUrl);
}
if (outboundUrl == null) {
return new RewrittenOutboundUrl(null, true);
}
boolean finalEncodeOutboundUrl = true;
String finalToUrl = outboundUrl;
final List outboundRules = conf.getOutboundRules();
try {
for (int i = 0; i < outboundRules.size(); i++) {
final OutboundRule outboundRule = (OutboundRule) outboundRules.get(i);
if (!encodeUrlHasBeenRun && outboundRule.isEncodeFirst()) {
continue;
}
if (encodeUrlHasBeenRun && !outboundRule.isEncodeFirst()) {
continue;
}
final RewrittenOutboundUrl rewrittenUrl = outboundRule.execute(finalToUrl, hsRequest, hsResponse);
if (rewrittenUrl != null) {
if (log.isDebugEnabled()) {
log.debug("\"" + outboundRule.getDisplayName() + "\" matched");
}
finalToUrl = rewrittenUrl.getTarget();
finalEncodeOutboundUrl = rewrittenUrl.isEncode();
if (outboundRule.isLast()) {
log.debug("rule is last");
break;
}
}
}
} catch (InvocationTargetException e) {
try {
handleInvocationTargetException(hsRequest, hsResponse, e);
} catch (ServletException e1) {
log.error(e1);
} catch (IOException e1) {
log.error(e1);
}
}
return new RewrittenOutboundUrl(finalToUrl, finalEncodeOutboundUrl);
}
public void destroy() {
conf.destroy();
}
}