package org.jboss.ejb.plugins;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.rmi.RemoteException;
import java.rmi.ServerException;
import java.rmi.NoSuchObjectException;
import java.lang.reflect.Method;
import javax.transaction.TransactionManager;
import javax.transaction.TransactionRolledbackException;
import javax.transaction.SystemException;
import javax.transaction.Transaction;
import javax.transaction.Synchronization;
import javax.transaction.RollbackException;
import javax.ejb.EJBException;
import javax.ejb.NoSuchEntityException;
import javax.ejb.NoSuchObjectLocalException;
import javax.ejb.TransactionRolledbackLocalException;
import javax.ejb.TimedObject;
import javax.ejb.Timer;
import org.jboss.invocation.Invocation;
import org.jboss.invocation.InvocationType;
abstract class AbstractTxInterceptor
extends AbstractInterceptor
{
protected static final Method ejbTimeout;
static
{
try
{
ejbTimeout = TimedObject.class.getMethod("ejbTimeout", new Class[]{Timer.class});
}
catch (Exception e)
{
throw new ExceptionInInitializerError(e);
}
}
protected TransactionManager tm;
public void create() throws Exception
{
super.create();
tm = getContainer().getTransactionManager();
}
protected Object invokeNext(Invocation invocation, boolean inheritedTx)
throws Exception
{
InvocationType type = invocation.getType();
try
{
if (type == InvocationType.REMOTE || type == InvocationType.LOCAL || type == InvocationType.SERVICE_ENDPOINT)
{
if (ejbTimeout.equals(invocation.getMethod()))
registerTimer(invocation);
return getNext().invoke(invocation);
}
else
{
return getNext().invokeHome(invocation);
}
}
catch (Throwable e)
{
if (e instanceof Exception &&
!(e instanceof RuntimeException || e instanceof RemoteException))
{
throw (Exception) e;
}
Transaction tx = invocation.getTransaction();
if (tx != null)
{
try
{
tx.setRollbackOnly();
}
catch (SystemException ex)
{
log.error("SystemException while setting transaction " +
"for rollback only", ex);
}
catch (IllegalStateException ex)
{
log.error("IllegalStateException while setting transaction " +
"for rollback only", ex);
}
}
boolean isLocal =
type == InvocationType.LOCAL ||
type == InvocationType.LOCALHOME;
if (!inheritedTx)
{
if (e instanceof Exception)
{
throw (Exception) e;
}
if (e instanceof Error)
{
throw (Error) e;
}
if (isLocal)
{
String msg = formatException("Unexpected Throwable", e);
throw new EJBException(msg);
}
else
{
ServerException ex = new ServerException("Unexpected Throwable");
ex.detail = e;
throw ex;
}
}
Throwable cause;
if (e instanceof NoSuchEntityException)
{
NoSuchEntityException nsee = (NoSuchEntityException) e;
if (isLocal)
{
cause = new NoSuchObjectLocalException(nsee.getMessage(),
nsee.getCausedByException());
}
else
{
cause = new NoSuchObjectException(nsee.getMessage());
((NoSuchObjectException) cause).detail =
nsee.getCausedByException();
}
}
else
{
if (isLocal)
{
if (e instanceof Exception)
{
cause = e;
}
else if (e instanceof Error)
{
String msg = formatException("Unexpected Error", e);
cause = new EJBException(msg);
}
else
{
String msg = formatException("Unexpected Throwable", e);
cause = new EJBException(msg);
}
}
else
{
cause = e;
}
}
if (isLocal)
{
if (cause instanceof TransactionRolledbackLocalException)
{
throw (TransactionRolledbackLocalException) cause;
}
else
{
throw new TransactionRolledbackLocalException(cause.getMessage(),
(Exception) cause);
}
}
else
{
if (cause instanceof TransactionRolledbackException)
{
throw (TransactionRolledbackException) cause;
}
else
{
TransactionRolledbackException ex =
new TransactionRolledbackException(cause.getMessage());
ex.detail = cause;
throw ex;
}
}
}
}
private void registerTimer(Invocation invocation)
throws RollbackException, SystemException
{
Timer timer = (Timer) invocation.getArguments()[0];
Transaction transaction = invocation.getTransaction();
if (transaction != null && timer instanceof Synchronization)
transaction.registerSynchronization((Synchronization) timer);
}
private String formatException(String msg, Throwable t)
{
StringWriter sw = new StringWriter();
PrintWriter pw = new PrintWriter(sw);
if (msg != null)
{
pw.println(msg);
}
t.printStackTrace(pw);
return sw.toString();
}
}