/***************************************
 *                                     *
 *  JBoss: The OpenSource J2EE WebOS   *
 *                                     *
 *  Distributable under LGPL license.  *
 *  See terms of license at gnu.org.   *
 *                                     *
 ***************************************/

package org.jboss.tm.iiop;

import javax.naming.Context;
import javax.naming.InitialContext;
import javax.naming.NamingException;
import javax.transaction.TransactionManager;
import org.jboss.iiop.CorbaORB;
import org.jboss.tm.TxUtils;
import org.jboss.util.NestedRuntimeException;
import org.omg.CORBA.Any;
import org.omg.CORBA.LocalObject;
import org.omg.CORBA_2_3.ORB;
import org.omg.CosTransactions.PropagationContext;
import org.omg.CosTransactions.PropagationContextHelper;
import org.omg.CosTransactions.TransIdentity;
import org.omg.CosTransactions.otid_t;
import org.omg.IOP.Codec;
import org.omg.IOP.ServiceContext;
import org.omg.IOP.TransactionService;
import org.omg.IOP.CodecPackage.InvalidTypeForEncoding;
import org.omg.PortableInterceptor.ClientRequestInfo;
import org.omg.PortableInterceptor.ClientRequestInterceptor;

/**
 * This implementation of 
 * <code>org.omg.PortableInterceptor.ClientRequestInterceptor</code>
 * inserts the transactional context into outgoing requests
 * from JBoss's transaction manager.
 *
 * @author  <a href="mailto:adrian@jboss.com">Adrian Brock</a>
 * @version $Revision: 1.2.4.1 $
 */
public class TxServerClientInterceptor extends LocalObject implements ClientRequestInterceptor 
{
   /** @since 4.0.1 */
   static final long serialVersionUID = 4716203472714459196L;

   // Static fields -------------------------------------------------

   private static final int txContextId = TransactionService.value;
   private static Codec codec;
   private static TransactionManager tm;
   private static PropagationContext emptyPC;

   // Static methods ------------------------------------------------

   static void init(Codec codec)
   {
      TxServerClientInterceptor.codec = codec;
   }
   
   static TransactionManager getTransactionManager()
   {
      if (tm == null)
      {
         try
         {
            Context ctx = new InitialContext();
            tm = (TransactionManager)ctx.lookup("java:/TransactionManager");
         }
         catch (NamingException e)
         {
            throw new NestedRuntimeException("java:/TransactionManager lookup failed", e);
         }
      }
      return tm;
   }

   static PropagationContext getEmptyPropagationContext()
   {
      if (emptyPC == null)
      {
         // According to the spec, this should all be ignored
         // But we get NPEs if it doesn't contain some content
         emptyPC = new PropagationContext();
         emptyPC.parents = new TransIdentity[0];
         emptyPC.current = new TransIdentity();
         emptyPC.current.otid = new otid_t();
         emptyPC.current.otid.formatID = 666;
         emptyPC.current.otid.bqual_length = 1;
         emptyPC.current.otid.tid = new byte[] { (byte) 1 };
         emptyPC.implementation_specific_data = ORB.init().create_any();
         emptyPC.implementation_specific_data.insert_boolean(false);
      }
      return emptyPC;
   }
   
   // Constructor ---------------------------------------------------

   public TxServerClientInterceptor() 
   {
      // do nothing
   }

   // org.omg.PortableInterceptor.Interceptor operations ------------

   public String name()
   {
      return "TxServerClientInterceptor";
   }

   public void destroy()
   {
      // do nothing
   }    

   // ClientRequestInterceptor operations ---------------------------

   public void send_request(ClientRequestInfo ri)
   {
      try
      {
         Any any = getTransactionPropagationContextAny();
         if (any != null)
         {
            ServiceContext sc = new ServiceContext(txContextId, codec.encode_value(any));
            ri.add_request_service_context(sc, true /*replace existing context*/);
         }
      }
      catch (InvalidTypeForEncoding e)
      {
         throw new NestedRuntimeException(e);
      }
   }

   public void send_poll(ClientRequestInfo ri) 
   {
      // do nothing
   }
   
   public void receive_reply(ClientRequestInfo ri) 
   {
      // do nothing
   }
   
   public void receive_exception(ClientRequestInfo ri) 
   {
      // do nothing
   }
   
   public void receive_other(ClientRequestInfo ri) 
   {
      // do nothing
   }
   
   protected Any getTransactionPropagationContextAny()
   {
      try
      {
         TransactionManager tm = getTransactionManager();
         if (TxUtils.isCompleted(tm))
            return null;
         PropagationContext pc = getEmptyPropagationContext();
         Any any = CorbaORB.getInstance().create_any();
         PropagationContextHelper.insert(any, pc);
         return any;
      }
      catch (Exception e)
      {
         throw new NestedRuntimeException("Error getting tpc", e);
      }
   }
}