package org.jboss.ejb.plugins;
import org.jboss.invocation.Invocation;
import org.jboss.invocation.InvocationType;
import org.jboss.metadata.BeanMetaData;
import org.jboss.metadata.MetaData;
import org.jboss.metadata.XmlLoadable;
import org.jboss.tm.JBossTransactionRolledbackException;
import org.jboss.tm.JBossTransactionRolledbackLocalException;
import org.jboss.tm.TransactionTimeoutConfiguration;
import org.jboss.util.NestedException;
import org.jboss.util.deadlock.ApplicationDeadlockException;
import org.w3c.dom.Element;
import javax.ejb.EJBException;
import javax.ejb.TransactionRequiredLocalException;
import javax.transaction.HeuristicMixedException;
import javax.transaction.HeuristicRollbackException;
import javax.transaction.RollbackException;
import javax.transaction.Status;
import javax.transaction.SystemException;
import javax.transaction.Transaction;
import javax.transaction.TransactionRequiredException;
import javax.transaction.TransactionRolledbackException;
import java.lang.reflect.Method;
import java.rmi.RemoteException;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import java.util.Iterator;
import java.util.ArrayList;
public class TxInterceptorCMT extends AbstractTxInterceptor implements XmlLoadable
{
   
   public static int MAX_RETRIES = 5;
   public static Random random = new Random();
   
   
   private boolean exceptionRollback = true;
   
   private TxRetryExceptionHandler[] retryHandlers = null;
   
   
   public static ApplicationDeadlockException isADE(Throwable t)
   {
      while (t!=null)
      {
         if (t instanceof ApplicationDeadlockException)
         {
            return (ApplicationDeadlockException)t;
         }
         else if (t instanceof RemoteException)
         {
            t = ((RemoteException)t).detail;
         }
         else if (t instanceof EJBException)
         {
            t = ((EJBException)t).getCausedByException();
         }
         else
         {
            return null;
         }
      }
      return null;
   }
   
   
   
   
   public void importXml(Element ielement)
   {
      try
      {
         Element element = MetaData.getOptionalChild(ielement, "retry-handlers");
         if (element == null) return;
         ArrayList list = new ArrayList();
         Iterator handlers = MetaData.getChildrenByTagName(element, "handler");
         while (handlers.hasNext())
         {
            Element handler = (Element)handlers.next();
            String className = MetaData.getElementContent(handler).trim();
            Class clazz = SecurityActions.getContextClassLoader().loadClass(className);
            list.add(clazz.newInstance());
         }
         retryHandlers = (TxRetryExceptionHandler[])list.toArray(new TxRetryExceptionHandler[list.size()]);
      }
      catch (Exception ex)
      {
         log.warn("Unable to importXml for the TxInterceptorCMT", ex);
      }
   }
   
   public void create() throws Exception
   {
      super.create();
      BeanMetaData bmd = getContainer().getBeanMetaData();
      exceptionRollback = bmd.getExceptionRollback();
      if (exceptionRollback == false)
         exceptionRollback = bmd.getApplicationMetaData().getExceptionRollback();
   }
   public Object invokeHome(Invocation invocation) throws Exception
   {
      Transaction oldTransaction = invocation.getTransaction();
      for (int i = 0; i < MAX_RETRIES; i++)
      {
         try
         {
            return runWithTransactions(invocation);
         }
         catch (Exception ex)
         {
            checkRetryable(i, ex, oldTransaction);
         }
      }
      throw new RuntimeException("Unreachable");
   }
   
   public Object invoke(Invocation invocation) throws Exception
   {
      Transaction oldTransaction = invocation.getTransaction();
      for (int i = 0; i < MAX_RETRIES; i++)
      {
         try
         {
            return runWithTransactions(invocation);
         }
         catch (Exception ex)
         {
            checkRetryable(i, ex, oldTransaction);
         }
      }
      throw new RuntimeException("Unreachable");
   }
   private void checkRetryable(int i, Exception ex, Transaction oldTransaction) throws Exception
   {
                  if (i + 1 >= MAX_RETRIES || oldTransaction != null) throw ex;
            ApplicationDeadlockException deadlock = isADE(ex);
      if (deadlock != null)
      {
         if (!deadlock.retryable()) throw deadlock;
         log.debug(deadlock.getMessage() + " retrying tx " + (i + 1));
      }
      else if (retryHandlers != null)
      {
         boolean retryable = false;
         for (int j = 0; j < retryHandlers.length; j++)
         {
            retryable = retryHandlers[j].retry(ex);
            if (retryable) break;
         }
         if (!retryable) throw ex;
         log.debug(ex.getMessage() + " retrying tx " + (i + 1));
      }
      else
      {
         throw ex;
      }
      Thread.sleep(random.nextInt(1 + i), random.nextInt(1000));
   }
   
   private void printMethod(Method m, byte type)
   {
      String txName;
      switch(type)
      {
         case MetaData.TX_MANDATORY:
            txName = "TX_MANDATORY";
            break;
         case MetaData.TX_NEVER:
            txName = "TX_NEVER";
            break;
         case MetaData.TX_NOT_SUPPORTED:
            txName = "TX_NOT_SUPPORTED";
            break;
         case MetaData.TX_REQUIRED:
            txName = "TX_REQUIRED";
            break;
         case MetaData.TX_REQUIRES_NEW:
            txName = "TX_REQUIRES_NEW";
            break;
         case MetaData.TX_SUPPORTS:
            txName = "TX_SUPPORTS";
            break;
         default:
            txName = "TX_UNKNOWN";
      }
      String methodName;
      if(m != null)
         methodName = m.getName();
      else
         methodName ="<no method>";
      if (log.isTraceEnabled())
      {
         if (m != null && (type == MetaData.TX_REQUIRED || type == MetaData.TX_REQUIRES_NEW))
            log.trace(txName + " for " + methodName + " timeout=" + container.getBeanMetaData().getTransactionTimeout(methodName));
         else
            log.trace(txName + " for " + methodName);
      }
   }
    
   private Object runWithTransactions(Invocation invocation) throws Exception
   {
            Transaction oldTransaction = invocation.getTransaction();
            Transaction newTransaction = null;
      boolean trace = log.isTraceEnabled();
      if( trace )
         log.trace("Current transaction in MI is " + oldTransaction);
      InvocationType type = invocation.getType();
      byte transType = container.getBeanMetaData().getTransactionMethod(invocation.getMethod(), type);
      if ( trace )
         printMethod(invocation.getMethod(), transType);
                              Transaction threadTx = tm.suspend();
      if( trace )
         log.trace("Thread came in with tx " + threadTx);
      try
      {
         switch (transType)
         {
            case MetaData.TX_NOT_SUPPORTED:
            {
                              try
               {
                  invocation.setTransaction(null);
                  return invokeNext(invocation, false);
               }
               finally
               {
                  invocation.setTransaction(oldTransaction);
               }
            }
            case MetaData.TX_REQUIRED:
            {
               int oldTimeout = 0;
               Transaction theTransaction = oldTransaction;
               if (oldTransaction == null)
               {                                     oldTimeout = startTransaction(invocation);
                                    newTransaction = tm.getTransaction();
                  if( trace )
                     log.trace("Starting new tx " + newTransaction);
                                    invocation.setTransaction(newTransaction);
                  theTransaction = newTransaction;
               }
               else
               {
                                                      tm.resume(oldTransaction);
               }
                              try
               {
                  Object result = invokeNext(invocation, oldTransaction != null);
                  checkTransactionStatus(theTransaction, type);
                  return result;
               }
               finally
               {
                  if( trace )
                     log.trace("TxInterceptorCMT: In finally");
                                    if (newTransaction != null)
                     endTransaction(invocation, newTransaction, oldTransaction, oldTimeout);
                  else
                     tm.suspend();
               }
            }
            case MetaData.TX_SUPPORTS:
            {
                                                            if (oldTransaction != null)
               {
                  tm.resume(oldTransaction);
               }
               try
               {
                  Object result = invokeNext(invocation, oldTransaction != null);
                  if (oldTransaction != null)
                     checkTransactionStatus(oldTransaction, type);
                  return result;
               }
               finally
               {
                  tm.suspend();
               }
                                          }
            case MetaData.TX_REQUIRES_NEW:
            {
                              int oldTimeout = startTransaction(invocation);
                              newTransaction = tm.getTransaction();
                              invocation.setTransaction(newTransaction);
                              try
               {
                  Object result = invokeNext(invocation, false);
                  checkTransactionStatus(newTransaction, type);
                  return result;
               }
               finally
               {
                                    endTransaction(invocation, newTransaction, oldTransaction, oldTimeout);
               }
            }
            case MetaData.TX_MANDATORY:
            {
               if (oldTransaction == null)
               {
                  if (type == InvocationType.LOCAL ||
                        type == InvocationType.LOCALHOME)
                  {
                     throw new TransactionRequiredLocalException(
                           "Transaction Required");
                  }
                  else
                  {
                     throw new TransactionRequiredException(
                           "Transaction Required");
                  }
               }
                              tm.resume(oldTransaction);
               try
               {
                  Object result = invokeNext(invocation, true);
                  checkTransactionStatus(oldTransaction, type);
                  return result;
               }
               finally
               {
                  tm.suspend();
               }
            }
            case MetaData.TX_NEVER:
            {
               if (oldTransaction != null)
               {
                  throw new EJBException("Transaction not allowed");
               }
               return invokeNext(invocation, false);
            }
            default:
                log.error("Unknown TX attribute "+transType+" for method"+invocation.getMethod());
         }
      }
      finally
      {
                  if (threadTx != null)
            tm.resume(threadTx);
      }
      return null;
   }
   private int startTransaction(final Invocation invocation) throws Exception
   {
            int oldTimeout = -1;
      if (tm instanceof TransactionTimeoutConfiguration)
      {
         oldTimeout = ((TransactionTimeoutConfiguration) tm).getTransactionTimeout();
         int newTimeout = container.getBeanMetaData().getTransactionTimeout(invocation.getMethod());
         tm.setTransactionTimeout(newTimeout);
      }
      tm.begin();
      return oldTimeout;
   }
   private void endTransaction(final Invocation invocation, final Transaction tx, final Transaction oldTx, final int oldTimeout) 
      throws TransactionRolledbackException, SystemException
   {
            Transaction current = tm.getTransaction();
      if ((tx == null && current != null) || tx.equals(current) == false)
         throw new IllegalStateException("Wrong transaction association: expected " + tx + " was " + current);
      try
      {
                  if (tx.getStatus() == Status.STATUS_MARKED_ROLLBACK)
         {
            tx.rollback();
         }
         else
         {
                                                            tx.commit();
         }
      }
      catch (RollbackException e)
      {
         throwJBossException(e, invocation.getType());
      }
      catch (HeuristicMixedException e)
      {
         throwJBossException(e, invocation.getType());
      }
      catch (HeuristicRollbackException e)
      {
         throwJBossException(e, invocation.getType());
      }
      catch (SystemException e)
      {
         throwJBossException(e, invocation.getType());
      }
      finally
      {
                  invocation.setTransaction(oldTx);
                                                                                 tm.suspend();
                  if (oldTimeout != -1)
            tm.setTransactionTimeout(oldTimeout);
      }
   }
   
   
   protected void throwJBossException(Exception e, InvocationType type)
      throws TransactionRolledbackException
   {
                        if (e instanceof NestedException)
         {
            NestedException rollback = (NestedException) e;
            if(rollback.getCause() instanceof Exception)
            {
               e = (Exception) rollback.getCause();
            }
         }
         if (type == InvocationType.LOCAL
             || type == InvocationType.LOCALHOME)
         {
            throw new JBossTransactionRolledbackLocalException(e);
         }
         else
         {
            throw new JBossTransactionRolledbackException(e);
         }
   }
   
   protected void checkTransactionStatus(Transaction tx, InvocationType type)
      throws TransactionRolledbackException
   {
      if (exceptionRollback)
      {
         if (log.isTraceEnabled())
            log.trace("No exception from ejb, checking transaction status: " + tx);
         int status = Status.STATUS_UNKNOWN;
         try
         {
            status = tx.getStatus();
         }
         catch (Throwable t)
         {
            log.debug("Ignored error trying to retrieve transaction status", t);
         }
         if (status != Status.STATUS_ACTIVE)
         {
            Exception e = new Exception("Transaction cannot be committed (probably transaction timeout): " + tx);
            throwJBossException(e, type);
         }
      }
   }
   
   
      public void sample(Object s)
   {
         }
   public Map retrieveStatistic()
   {
      return null;
   }
   public void resetStatistic()
   {
   }
}