/**
 * JBoss, the OpenSource J2EE webOS
 *
 * Distributable under LGPL license.
 * See terms of license at gnu.org.
 */
package org.jboss.webservice.handler;

// $Id: HandlerChainBaseImpl.java,v 1.8.2.2 2005/04/22 11:28:31 tdiesler Exp $

import org.jboss.axis.AxisFault;
import org.jboss.axis.Constants;
import org.jboss.axis.message.SOAPElementAxisImpl;
import org.jboss.logging.Logger;

import javax.xml.namespace.QName;
import javax.xml.rpc.JAXRPCException;
import javax.xml.rpc.handler.Handler;
import javax.xml.rpc.handler.HandlerChain;
import javax.xml.rpc.handler.HandlerInfo;
import javax.xml.rpc.handler.MessageContext;
import javax.xml.rpc.handler.soap.SOAPMessageContext;
import javax.xml.soap.SOAPEnvelope;
import javax.xml.soap.SOAPException;
import javax.xml.soap.SOAPHeader;
import javax.xml.soap.SOAPHeaderElement;
import javax.xml.soap.SOAPPart;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Set;

/**
 * Represents a list of handlers. All elements in the
 * HandlerChain are of the type javax.xml.rpc.handler.Handler.
 * <p/>
 * Abstracts the policy and mechanism for the invocation of the registered handlers.
 *
 * @author Thomas.Diesler@jboss.org
 * @since 06-May-2004
 */
public abstract class HandlerChainBaseImpl implements HandlerChain
{
   private static Logger log = Logger.getLogger(HandlerChainBaseImpl.class);

   public static final int STATE_DOES_NOT_EXIST = 0;
   public static final int STATE_CREATED = 1;
   public static final int STATE_READY = 2;
   public static final int STATE_DESTROYED = 3;

   // The List<Entry> objects
   protected ArrayList handlers = new ArrayList();
   // The roles associated with the handler chain
   protected HashSet roles = new HashSet();
   // The index of the first handler that returned false during processing
   protected int falseIndex = -1;
   // The state of this handler chain
   protected int state;

   /**
    * Constructs a handler chain with the given handlers infos
    */
   public HandlerChainBaseImpl(List infos, Set roles)
   {
      log.debug("Create a handler chain for roles: " + roles);
      addHandlersToChain(infos, roles);
   }

   /**
    * Initialize the a handler chain with the given handlers infos
    *
    * @throws javax.xml.rpc.JAXRPCException If any error during initialization
    */
   private void addHandlersToChain(List infos, Set roleSet)
   {
      try
      {
         if (infos != null)
         {
            for (int i = 0; i < infos.size(); i++)
            {
               HandlerInfo info = (HandlerInfo)infos.get(i);
               HandlerWrapper handler = new HandlerWrapper((Handler)info.getHandlerClass().newInstance());
               handlers.add(new Entry(handler, info));
            }
         }
         if (roleSet != null)
         {
            roles.addAll(roleSet);
         }
      }
      catch (Exception e)
      {
         throw new JAXRPCException("Cannot initialize handler chain", e);
      }

      // set state to created
      state = STATE_CREATED;
   }

   /**
    * Get the state of this handler chain
    */
   public int getState()
   {
      return state;
   }

   /**
    * Initializes the configuration for a HandlerChain.
    *
    * @param config Configuration for the initialization of this handler chain
    * @throws javax.xml.rpc.JAXRPCException If any error during initialization
    */
   public void init(Map config)
   {
      log.debug("init: [config=" + config + "]");
      for (int i = 0; i < handlers.size(); i++)
      {
         Entry entry = (Entry)handlers.get(i);
         entry.handler.init(entry.info);
      }

      // set state to ready
      state = STATE_READY;
   }

   /**
    * Indicates the end of lifecycle for a HandlerChain.
    *
    * @throws javax.xml.rpc.JAXRPCException If any error during destroy
    */
   public void destroy()
   {
      log.debug("destroy");
      for (int i = 0; i < handlers.size(); i++)
      {
         Entry entry = (Entry)handlers.get(i);
         entry.handler.destroy();
      }
      handlers.clear();

      // set state to destroyed
      state = STATE_DESTROYED;
   }

   /**
    * Gets SOAP actor roles registered for this HandlerChain at this SOAP node. The returned array includes the
    * special SOAP actor next.
    *
    * @return SOAP Actor roles as URIs
    */
   public String[] getRoles()
   {
      String[] arr = new String[roles.size()];
      roles.toArray(arr);
      return arr;
   }

   /**
    * Sets SOAP Actor roles for this HandlerChain. This specifies the set of roles in which this HandlerChain is to act
    * for the SOAP message processing at this SOAP node. These roles assumed by a HandlerChain must be invariant during
    * the processing of an individual SOAP message through the HandlerChain.
    * <p/>
    * A HandlerChain always acts in the role of the special SOAP actor next. Refer to the SOAP specification for the
    * URI name for this special SOAP actor. There is no need to set this special role using this method.
    *
    * @param soapActorNames URIs for SOAP actor name
    */
   public void setRoles(String[] soapActorNames)
   {
      List newRoles = Arrays.asList(soapActorNames);
      log.debug("setRoles: " + newRoles);

      roles.clear();
      roles.addAll(newRoles);
   }

   /**
    * Initiates the request processing for this handler chain.
    *
    * @param msgContext MessageContext parameter provides access to the request SOAP message.
    * @return Returns true if all handlers in chain have been processed. Returns false if a handler in the chain returned false from its handleRequest method.
    * @throws javax.xml.rpc.JAXRPCException if any processing error happens
    */
   public boolean handleRequest(MessageContext msgContext)
   {
      boolean doNext = true;

      log.debug("Enter: doHandleRequest");

      // Replace handlers that did not survive the previous call
      replaceDirtyHandlers();

      int handlerIndex = 0;
      Handler currHandler = null;
      try
      {
         for (; doNext && handlerIndex < handlers.size(); handlerIndex++)
         {
            String lastMessageTrace = null;
            if (log.isTraceEnabled())
            {
               org.jboss.axis.MessageContext msgCtx = (org.jboss.axis.MessageContext)msgContext;
               SOAPPart soapPart = msgCtx.getRequestMessage().getSOAPPart();
               lastMessageTrace = traceSOAPPart(soapPart, lastMessageTrace);
            }

            currHandler = ((Entry)handlers.get(handlerIndex)).getHandler();
            log.debug("Handle request: " + currHandler);
            doNext = currHandler.handleRequest(msgContext);

            if (log.isTraceEnabled())
            {
               org.jboss.axis.MessageContext msgCtx = (org.jboss.axis.MessageContext)msgContext;
               SOAPPart soapPart = msgCtx.getRequestMessage().getSOAPPart();
               lastMessageTrace = traceSOAPPart(soapPart, lastMessageTrace);
            }
         }
      }
      catch (RuntimeException e)
      {
         log.error("RuntimeException in request handler", e);
         doNext = false;
         throw e;
      }
      finally
      {
         // we start at this index in the response chain
         if (doNext == false)
            falseIndex = handlerIndex;

         log.debug("Exit: doHandleRequest with status: " + doNext);
      }

      return doNext;
   }

   /**
    * Initiates the response processing for this handler chain.
    * <p/>
    * In this implementation, the response handler chain starts processing from the same Handler
    * instance (that returned false) and goes backward in the execution sequence.
    *
    * @return Returns true if all handlers in chain have been processed.
    *         Returns false if a handler in the chain returned false from its handleResponse method.
    * @throws javax.xml.rpc.JAXRPCException if any processing error happens
    */
   public boolean handleResponse(MessageContext msgContext)
   {
      boolean doNext = true;

      log.debug("Enter: handleResponse");

      int handlerIndex = handlers.size() - 1;
      if (falseIndex != -1)
         handlerIndex = falseIndex;

      Handler currHandler = null;
      try
      {
         for (; doNext && handlerIndex >= 0; handlerIndex--)
         {
            String lastMessageTrace = null;
            if (log.isTraceEnabled())
            {
               org.jboss.axis.MessageContext msgCtx = (org.jboss.axis.MessageContext)msgContext;
               SOAPPart soapPart = msgCtx.getResponseMessage().getSOAPPart();
               lastMessageTrace = traceSOAPPart(soapPart, lastMessageTrace);
            }

            currHandler = ((Entry)handlers.get(handlerIndex)).getHandler();
            log.debug("Handle response: " + currHandler);
            doNext = currHandler.handleResponse(msgContext);

            if (log.isTraceEnabled())
            {
               org.jboss.axis.MessageContext msgCtx = (org.jboss.axis.MessageContext)msgContext;
               SOAPPart soapPart = msgCtx.getResponseMessage().getSOAPPart();
               lastMessageTrace = traceSOAPPart(soapPart, lastMessageTrace);
            }
         }
      }
      catch (RuntimeException rte)
      {
         log.error("RuntimeException in response handler", rte);
         doNext = false;
         throw rte;
      }
      finally
      {
         // we start at this index in the fault chain
         if (doNext == false)
            falseIndex = handlerIndex;

         log.debug("Exit: handleResponse with status: " + doNext);
      }

      return doNext;
   }

   /**
    * Initiates the SOAP fault processing for this handler chain.
    * <p/>
    * In this implementation, the fault handler chain starts processing from the same Handler
    * instance (that returned false) and goes backward in the execution sequence.
    *
    * @return Returns true if all handlers in chain have been processed.
    *         Returns false if a handler in the chain returned false from its handleFault method.
    * @throws javax.xml.rpc.JAXRPCException if any processing error happens
    */
   public boolean handleFault(MessageContext msgContext)
   {
      boolean doNext = true;

      int handlerIndex = handlers.size() - 1;
      if (falseIndex != -1)
         handlerIndex = falseIndex;

      Handler currHandler = null;
      for (; doNext && handlerIndex >= 0; handlerIndex--)
      {
         currHandler = ((Entry)handlers.get(handlerIndex)).getHandler();
         log.debug("Handle fault: " + currHandler);
         doNext = currHandler.handleFault(msgContext);
      }

      return doNext;
   }

   /**
    * Trace the SOAPPart, do nothing if the String representation is equal to the last one.
    */
   protected String traceSOAPPart(SOAPPart soapPart, String lastMessageTrace)
   {
      try
      {
         SOAPEnvelope env = soapPart.getEnvelope();
         String envAsString = ((SOAPElementAxisImpl)env).getAsStringFromInternal();
         if (envAsString.equals(lastMessageTrace) == false)
         {
            log.trace(envAsString);
            lastMessageTrace = envAsString;
         }
         return lastMessageTrace;
      }
      catch (SOAPException e)
      {
         log.error("Cannot get SOAPEnvelope", e);
         return null;
      }
   }

   /**
    * Replace handlers that did not survive the previous call
    */
   protected void replaceDirtyHandlers()
   {
      for (int i = 0; i < handlers.size(); i++)
      {
         Entry entry = (Entry)handlers.get(i);
         if (entry.handler.getState() == HandlerWrapper.DOES_NOT_EXIST)
         {
            log.debug("Replacing dirty handler: " + entry.handler);
            try
            {
               HandlerWrapper handler = new HandlerWrapper((Handler)entry.info.getHandlerClass().newInstance());
               entry.handler = handler;
               handler.init(entry.info);
            }
            catch (Exception e)
            {
               log.error("Cannot create handler instance for: " + entry.info);
            }
         }
      }
   }

   /**
    * Get the handler at the requested position
    */
   protected Handler getHandlerAt(int pos)
   {
      if (pos < 0 || handlers.size() <= pos)
         throw new IllegalArgumentException("No handler at position: " + pos);

      Entry entry = (Entry)handlers.get(pos);
      return entry.handler;
   }

   /**
    * Check if there are mustUnderstand headers that were not processed
    */
   protected void checkMustUnderstand(MessageContext msgContext)
   {
      String errorMsg = null;

      try
      {
         SOAPMessageContext msgCtx = (SOAPMessageContext)msgContext;
         SOAPPart soapPart = msgCtx.getMessage().getSOAPPart();
         SOAPHeader soapHeader = soapPart.getEnvelope().getHeader();
         if (soapHeader != null)
         {
            Iterator it = soapHeader.examineAllHeaderElements();
            while (errorMsg == null && it.hasNext())
            {
               SOAPHeaderElement headerElement = (SOAPHeaderElement)it.next();
               if (headerElement.getMustUnderstand() == true)
               {
                  QName headerName = new QName(headerElement.getNamespaceURI(), headerElement.getLocalName());

                  String actor = headerElement.getActor();
                  if (actor == null || Constants.URI_SOAP11_NEXT_ACTOR.equals(actor))
                  {
                     errorMsg = "Unprocessed mustUnderstand header " + headerName;
                     break;
                  }

                  if (actor != null && roles.contains(actor))
                  {
                     Iterator itHandlers = handlers.iterator();
                     while (itHandlers.hasNext())
                     {
                        Entry entry = (Entry)itHandlers.next();
                        Handler handler = entry.getHandler();

                        // Check if this handler should have processed this header element
                        List headers = Arrays.asList(handler.getHeaders());
                        if (headers.contains(headerName))
                        {
                           errorMsg = "Unprocessed mustUnderstand header " + headerName;
                           break;
                        }
                     }
                  }
               }
            }
         }
      }
      catch (SOAPException e)
      {
         log.error("Cannot check mustUnderstand for headers", e);
      }

      if (errorMsg != null)
      {
         AxisFault fault = new AxisFault(errorMsg);
         fault.setFaultCode(Constants.FAULT_MUSTUNDERSTAND);
         throw new JAXRPCException(fault);
      }
   }

   /**
    * An entry in the handler list
    */
   private class Entry
   {
      private HandlerWrapper handler;
      private HandlerInfo info;

      public Entry(HandlerWrapper handler, HandlerInfo info)
      {
         this.handler = handler;
         this.info = info;
      }

      public Handler getHandler()
      {
         return handler;
      }

      public HandlerInfo getInfo()
      {
         return info;
      }
   }

   // java.util.List interface ****************************************************************************************

   public boolean remove(Object o)
   {
      return handlers.remove(o);
   }

   public boolean containsAll(Collection c)
   {
      return handlers.containsAll(c);
   }

   public boolean removeAll(Collection c)
   {
      return handlers.removeAll(c);
   }

   public boolean retainAll(Collection c)
   {
      return handlers.retainAll(c);
   }

   public int hashCode()
   {
      return handlers.hashCode();
   }

   public boolean equals(Object o)
   {
      return handlers.equals(o);
   }

   public Iterator iterator()
   {
      return handlers.iterator();
   }

   public List subList(int fromIndex, int toIndex)
   {
      return handlers.subList(fromIndex, toIndex);
   }

   public ListIterator listIterator()
   {
      return handlers.listIterator();
   }

   public ListIterator listIterator(int index)
   {
      return handlers.listIterator(index);
   }

   public int size()
   {
      return handlers.size();
   }

   public void clear()
   {
      handlers.clear();
   }

   public boolean isEmpty()
   {
      return handlers.isEmpty();
   }

   public Object[] toArray()
   {
      return handlers.toArray();
   }

   public Object get(int index)
   {
      return handlers.get(index);
   }

   public Object remove(int index)
   {
      return handlers.remove(index);
   }

   public void add(int index, Object element)
   {
      handlers.add(index, element);
   }

   public int indexOf(Object elem)
   {
      return handlers.indexOf(elem);
   }

   public int lastIndexOf(Object elem)
   {
      return handlers.lastIndexOf(elem);
   }

   public boolean add(Object o)
   {
      return handlers.add(o);
   }

   public boolean contains(Object elem)
   {
      return handlers.contains(elem);
   }

   public boolean addAll(int index, Collection c)
   {
      return handlers.addAll(index, c);
   }

   public boolean addAll(Collection c)
   {
      return handlers.addAll(c);
   }

   public Object set(int index, Object element)
   {
      return handlers.set(index, element);
   }

   public Object[] toArray(Object[] a)
   {
      return handlers.toArray(a);
   }
}