package org.jboss.webservice.handler;
import org.jboss.axis.AxisFault;
import org.jboss.axis.Constants;
import org.jboss.axis.message.SOAPElementAxisImpl;
import org.jboss.logging.Logger;
import javax.xml.namespace.QName;
import javax.xml.rpc.JAXRPCException;
import javax.xml.rpc.handler.Handler;
import javax.xml.rpc.handler.HandlerChain;
import javax.xml.rpc.handler.HandlerInfo;
import javax.xml.rpc.handler.MessageContext;
import javax.xml.rpc.handler.soap.SOAPMessageContext;
import javax.xml.soap.SOAPEnvelope;
import javax.xml.soap.SOAPException;
import javax.xml.soap.SOAPHeader;
import javax.xml.soap.SOAPHeaderElement;
import javax.xml.soap.SOAPPart;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Set;
public abstract class HandlerChainBaseImpl implements HandlerChain
{
private static Logger log = Logger.getLogger(HandlerChainBaseImpl.class);
public static final int STATE_DOES_NOT_EXIST = 0;
public static final int STATE_CREATED = 1;
public static final int STATE_READY = 2;
public static final int STATE_DESTROYED = 3;
protected ArrayList handlers = new ArrayList();
protected HashSet roles = new HashSet();
protected int falseIndex = -1;
protected int state;
public HandlerChainBaseImpl(List infos, Set roles)
{
log.debug("Create a handler chain for roles: " + roles);
addHandlersToChain(infos, roles);
}
private void addHandlersToChain(List infos, Set roleSet)
{
try
{
if (infos != null)
{
for (int i = 0; i < infos.size(); i++)
{
HandlerInfo info = (HandlerInfo)infos.get(i);
HandlerWrapper handler = new HandlerWrapper((Handler)info.getHandlerClass().newInstance());
handlers.add(new Entry(handler, info));
}
}
if (roleSet != null)
{
roles.addAll(roleSet);
}
}
catch (Exception e)
{
throw new JAXRPCException("Cannot initialize handler chain", e);
}
state = STATE_CREATED;
}
public int getState()
{
return state;
}
public void init(Map config)
{
log.debug("init: [config=" + config + "]");
for (int i = 0; i < handlers.size(); i++)
{
Entry entry = (Entry)handlers.get(i);
entry.handler.init(entry.info);
}
state = STATE_READY;
}
public void destroy()
{
log.debug("destroy");
for (int i = 0; i < handlers.size(); i++)
{
Entry entry = (Entry)handlers.get(i);
entry.handler.destroy();
}
handlers.clear();
state = STATE_DESTROYED;
}
public String[] getRoles()
{
String[] arr = new String[roles.size()];
roles.toArray(arr);
return arr;
}
public void setRoles(String[] soapActorNames)
{
List newRoles = Arrays.asList(soapActorNames);
log.debug("setRoles: " + newRoles);
roles.clear();
roles.addAll(newRoles);
}
public boolean handleRequest(MessageContext msgContext)
{
boolean doNext = true;
log.debug("Enter: doHandleRequest");
replaceDirtyHandlers();
int handlerIndex = 0;
Handler currHandler = null;
try
{
for (; doNext && handlerIndex < handlers.size(); handlerIndex++)
{
String lastMessageTrace = null;
if (log.isTraceEnabled())
{
org.jboss.axis.MessageContext msgCtx = (org.jboss.axis.MessageContext)msgContext;
SOAPPart soapPart = msgCtx.getRequestMessage().getSOAPPart();
lastMessageTrace = traceSOAPPart(soapPart, lastMessageTrace);
}
currHandler = ((Entry)handlers.get(handlerIndex)).getHandler();
log.debug("Handle request: " + currHandler);
doNext = currHandler.handleRequest(msgContext);
if (log.isTraceEnabled())
{
org.jboss.axis.MessageContext msgCtx = (org.jboss.axis.MessageContext)msgContext;
SOAPPart soapPart = msgCtx.getRequestMessage().getSOAPPart();
lastMessageTrace = traceSOAPPart(soapPart, lastMessageTrace);
}
}
}
catch (RuntimeException e)
{
log.error("RuntimeException in request handler", e);
doNext = false;
throw e;
}
finally
{
if (doNext == false)
falseIndex = handlerIndex;
log.debug("Exit: doHandleRequest with status: " + doNext);
}
return doNext;
}
public boolean handleResponse(MessageContext msgContext)
{
boolean doNext = true;
log.debug("Enter: handleResponse");
int handlerIndex = handlers.size() - 1;
if (falseIndex != -1)
handlerIndex = falseIndex;
Handler currHandler = null;
try
{
for (; doNext && handlerIndex >= 0; handlerIndex--)
{
String lastMessageTrace = null;
if (log.isTraceEnabled())
{
org.jboss.axis.MessageContext msgCtx = (org.jboss.axis.MessageContext)msgContext;
SOAPPart soapPart = msgCtx.getResponseMessage().getSOAPPart();
lastMessageTrace = traceSOAPPart(soapPart, lastMessageTrace);
}
currHandler = ((Entry)handlers.get(handlerIndex)).getHandler();
log.debug("Handle response: " + currHandler);
doNext = currHandler.handleResponse(msgContext);
if (log.isTraceEnabled())
{
org.jboss.axis.MessageContext msgCtx = (org.jboss.axis.MessageContext)msgContext;
SOAPPart soapPart = msgCtx.getResponseMessage().getSOAPPart();
lastMessageTrace = traceSOAPPart(soapPart, lastMessageTrace);
}
}
}
catch (RuntimeException rte)
{
log.error("RuntimeException in response handler", rte);
doNext = false;
throw rte;
}
finally
{
if (doNext == false)
falseIndex = handlerIndex;
log.debug("Exit: handleResponse with status: " + doNext);
}
return doNext;
}
public boolean handleFault(MessageContext msgContext)
{
boolean doNext = true;
int handlerIndex = handlers.size() - 1;
if (falseIndex != -1)
handlerIndex = falseIndex;
Handler currHandler = null;
for (; doNext && handlerIndex >= 0; handlerIndex--)
{
currHandler = ((Entry)handlers.get(handlerIndex)).getHandler();
log.debug("Handle fault: " + currHandler);
doNext = currHandler.handleFault(msgContext);
}
return doNext;
}
protected String traceSOAPPart(SOAPPart soapPart, String lastMessageTrace)
{
try
{
SOAPEnvelope env = soapPart.getEnvelope();
String envAsString = ((SOAPElementAxisImpl)env).getAsStringFromInternal();
if (envAsString.equals(lastMessageTrace) == false)
{
log.trace(envAsString);
lastMessageTrace = envAsString;
}
return lastMessageTrace;
}
catch (SOAPException e)
{
log.error("Cannot get SOAPEnvelope", e);
return null;
}
}
protected void replaceDirtyHandlers()
{
for (int i = 0; i < handlers.size(); i++)
{
Entry entry = (Entry)handlers.get(i);
if (entry.handler.getState() == HandlerWrapper.DOES_NOT_EXIST)
{
log.debug("Replacing dirty handler: " + entry.handler);
try
{
HandlerWrapper handler = new HandlerWrapper((Handler)entry.info.getHandlerClass().newInstance());
entry.handler = handler;
handler.init(entry.info);
}
catch (Exception e)
{
log.error("Cannot create handler instance for: " + entry.info);
}
}
}
}
protected Handler getHandlerAt(int pos)
{
if (pos < 0 || handlers.size() <= pos)
throw new IllegalArgumentException("No handler at position: " + pos);
Entry entry = (Entry)handlers.get(pos);
return entry.handler;
}
protected void checkMustUnderstand(MessageContext msgContext)
{
String errorMsg = null;
try
{
SOAPMessageContext msgCtx = (SOAPMessageContext)msgContext;
SOAPPart soapPart = msgCtx.getMessage().getSOAPPart();
SOAPHeader soapHeader = soapPart.getEnvelope().getHeader();
if (soapHeader != null)
{
Iterator it = soapHeader.examineAllHeaderElements();
while (errorMsg == null && it.hasNext())
{
SOAPHeaderElement headerElement = (SOAPHeaderElement)it.next();
if (headerElement.getMustUnderstand() == true)
{
QName headerName = new QName(headerElement.getNamespaceURI(), headerElement.getLocalName());
String actor = headerElement.getActor();
if (actor == null || Constants.URI_SOAP11_NEXT_ACTOR.equals(actor))
{
errorMsg = "Unprocessed mustUnderstand header " + headerName;
break;
}
if (actor != null && roles.contains(actor))
{
Iterator itHandlers = handlers.iterator();
while (itHandlers.hasNext())
{
Entry entry = (Entry)itHandlers.next();
Handler handler = entry.getHandler();
List headers = Arrays.asList(handler.getHeaders());
if (headers.contains(headerName))
{
errorMsg = "Unprocessed mustUnderstand header " + headerName;
break;
}
}
}
}
}
}
}
catch (SOAPException e)
{
log.error("Cannot check mustUnderstand for headers", e);
}
if (errorMsg != null)
{
AxisFault fault = new AxisFault(errorMsg);
fault.setFaultCode(Constants.FAULT_MUSTUNDERSTAND);
throw new JAXRPCException(fault);
}
}
private class Entry
{
private HandlerWrapper handler;
private HandlerInfo info;
public Entry(HandlerWrapper handler, HandlerInfo info)
{
this.handler = handler;
this.info = info;
}
public Handler getHandler()
{
return handler;
}
public HandlerInfo getInfo()
{
return info;
}
}
public boolean remove(Object o)
{
return handlers.remove(o);
}
public boolean containsAll(Collection c)
{
return handlers.containsAll(c);
}
public boolean removeAll(Collection c)
{
return handlers.removeAll(c);
}
public boolean retainAll(Collection c)
{
return handlers.retainAll(c);
}
public int hashCode()
{
return handlers.hashCode();
}
public boolean equals(Object o)
{
return handlers.equals(o);
}
public Iterator iterator()
{
return handlers.iterator();
}
public List subList(int fromIndex, int toIndex)
{
return handlers.subList(fromIndex, toIndex);
}
public ListIterator listIterator()
{
return handlers.listIterator();
}
public ListIterator listIterator(int index)
{
return handlers.listIterator(index);
}
public int size()
{
return handlers.size();
}
public void clear()
{
handlers.clear();
}
public boolean isEmpty()
{
return handlers.isEmpty();
}
public Object[] toArray()
{
return handlers.toArray();
}
public Object get(int index)
{
return handlers.get(index);
}
public Object remove(int index)
{
return handlers.remove(index);
}
public void add(int index, Object element)
{
handlers.add(index, element);
}
public int indexOf(Object elem)
{
return handlers.indexOf(elem);
}
public int lastIndexOf(Object elem)
{
return handlers.lastIndexOf(elem);
}
public boolean add(Object o)
{
return handlers.add(o);
}
public boolean contains(Object elem)
{
return handlers.contains(elem);
}
public boolean addAll(int index, Collection c)
{
return handlers.addAll(index, c);
}
public boolean addAll(Collection c)
{
return handlers.addAll(c);
}
public Object set(int index, Object element)
{
return handlers.set(index, element);
}
public Object[] toArray(Object[] a)
{
return handlers.toArray(a);
}
}