/***************************************
 *                                     *
 *  JBoss: The OpenSource J2EE WebOS   *
 *                                     *
 *  Distributable under LGPL license.  *
 *  See terms of license at gnu.org.   *
 *                                     *
 ***************************************/

package org.jboss.tm.iiop;

import org.omg.CORBA.Any;
import org.omg.CORBA.LocalObject;
import org.omg.CORBA.TCKind;
import org.omg.CosTransactions.PropagationContext;
import org.omg.CosTransactions.PropagationContextHelper;
import org.omg.IOP.Codec;
import org.omg.IOP.CodecPackage.InvalidTypeForEncoding;
import org.omg.IOP.ServiceContext;
import org.omg.PortableInterceptor.ClientRequestInfo;
import org.omg.PortableInterceptor.ClientRequestInterceptor;
import org.omg.PortableInterceptor.InvalidSlot;

import org.jboss.iiop.CorbaORB;
import org.jboss.logging.Logger;

/**
 * This implementation of 
 * <code>org.omg.PortableInterceptor.ClientRequestInterceptor</code>
 * inserts the transactional context into outgoing requests.
 *
 * @author  <a href="mailto:reverbel@ime.usp.br">Francisco Reverbel</a>
 * @version $Revision: 1.2.4.3 $
 */
public class TxClientInterceptor 
      extends LocalObject
      implements ClientRequestInterceptor 
{
   /** @since 4.0.1 */
   static final long serialVersionUID = -8021933521763745659L;

   // Static fields -------------------------------------------------

   private static final int txContextId = org.omg.IOP.TransactionService.value;
   private static int slotId;
   private static Codec codec;
   private static org.omg.PortableInterceptor.Current piCurrent;
   private static Any emptyAny = null;
   private static final Logger log = 
      Logger.getLogger(TxClientInterceptor.class);
   private static final boolean traceEnabled = log.isTraceEnabled();

   // Static methods ------------------------------------------------

   static void init(int slotId, Codec codec, 
                    org.omg.PortableInterceptor.Current piCurrent)
   {
      TxClientInterceptor.slotId = slotId;
      TxClientInterceptor.codec = codec;
      TxClientInterceptor.piCurrent = piCurrent;
   }

   /**
    * Sets the transaction propagation context to be sent out with the IIOP
    * requests generated by the current thread.
    */
   public static void setOutgoingPropagationContext(PropagationContext pc) 
   {
      Any any = CorbaORB.getInstance().create_any();
      PropagationContextHelper.insert(any, pc);
      try
      {
         piCurrent.set_slot(slotId, any);
      }
      catch (InvalidSlot e) 
      {
         throw new RuntimeException("Exception setting propagation context: " 
                                    + e);
      }
   }

   /**
    * Unsets the transaction propagation context associated with the current
    * thread.
    */
   public static void unsetOutgoingPropagationContext() 
   {
      try 
      {
         piCurrent.set_slot(slotId, getEmptyAny());
      } 
      catch (InvalidSlot e) 
      {
         throw new RuntimeException("Exception unsetting propagation context: "
                                    + e);
      }
   }

   /**
    * Auxiliary method that returns an empty Any.
    */
   private static Any getEmptyAny()
   {
      if (emptyAny == null)
         emptyAny = CorbaORB.getInstance().create_any();
      return emptyAny;
   }         
   
   // Constructor ---------------------------------------------------

   public TxClientInterceptor() 
   {
      // do nothing
   }

   // org.omg.PortableInterceptor.Interceptor operations ------------

   public String name()
   {
      return "TxClientInterceptor";
   }

   public void destroy()
   {
      // do nothing
   }    

   // ClientRequestInterceptor operations ---------------------------

   public void send_request(ClientRequestInfo ri)
   {
      if (traceEnabled)
         log.trace("send_request: " + ri.operation());
      try
      {
         Any any = ri.get_slot(slotId);
         if (any.type().kind().value() != TCKind._tk_null)
         {
            ServiceContext sc = new ServiceContext(txContextId, 
                                                   codec.encode_value(any));
            ri.add_request_service_context(sc,
                                           true /*replace existing context*/);
         }
      }
      catch (InvalidSlot e)
      {
         throw new RuntimeException("Exception getting slot in " +
                                    "TxServerInterceptor: " + e);
      }
      catch (InvalidTypeForEncoding e)
      {
         throw new RuntimeException(e);
      }
   }

   public void send_poll(ClientRequestInfo ri) 
   {
      // do nothing
   }
   
   public void receive_reply(ClientRequestInfo ri) 
   {
      // do nothing
   }
   
   public void receive_exception(ClientRequestInfo ri) 
   {
      // do nothing
   }
   
   public void receive_other(ClientRequestInfo ri) 
   {
      // do nothing
   }
   
}