package org.jboss.ejb.plugins.inflow;
import java.lang.reflect.Method;
import javax.resource.ResourceException;
import javax.transaction.Status;
import javax.transaction.Transaction;
import javax.transaction.TransactionManager;
import javax.transaction.xa.XAResource;
import org.jboss.ejb.MessageDrivenContainer;
import org.jboss.invocation.Invocation;
import org.jboss.logging.Logger;
import org.jboss.proxy.Interceptor;
import EDU.oswego.cs.dl.util.concurrent.SynchronizedBoolean;
public class MessageEndpointInterceptor extends Interceptor
{
private static final Logger log = Logger.getLogger(MessageEndpointInterceptor.class);
public static final String MESSAGE_ENDPOINT_FACTORY = "MessageEndpoint.Factory";
public static final String MESSAGE_ENDPOINT_XARESOURCE = "MessageEndpoint.XAResource";
private boolean trace = log.isTraceEnabled();
private String cachedProxyString = null;
protected SynchronizedBoolean released = new SynchronizedBoolean(false);
protected boolean delivered = false;
protected Thread inUseThread = null;
protected ClassLoader oldClassLoader = null;
protected Transaction transaction = null;
protected Transaction suspended = null;
private JBossMessageEndpointFactory endpointFactory;
public MessageEndpointInterceptor()
{
}
public Object invoke(Invocation mi) throws Throwable
{
if (released.get())
throw new IllegalStateException("This message endpoint + " + getProxyString(mi) + " has been released");
Thread currentThread = Thread.currentThread();
if (inUseThread != null && inUseThread.equals(currentThread) == false)
throw new IllegalStateException("This message endpoint + " + getProxyString(mi) + " is already in use by another thread " + inUseThread);
inUseThread = currentThread;
String method = mi.getMethod().getName();
if (trace)
log.trace("MessageEndpoint " + getProxyString(mi) + " in use by " + method + " " + inUseThread);
if (method.equals("release"))
{
release(mi);
return null;
}
else if (method.equals("beforeDelivery"))
{
before(mi);
return null;
}
else if (method.equals("afterDelivery"))
{
after(mi);
return null;
}
else
return delivery(mi);
}
protected void release(Invocation mi) throws Throwable
{
released.set(true);
if (trace)
log.trace("MessageEndpoint " + getProxyString(mi) + " released");
if (oldClassLoader != null)
{
try
{
finish("release", mi, false);
}
catch (Throwable t)
{
log.warn("Error in release ", t);
}
}
}
protected void before(Invocation mi) throws Throwable
{
if (oldClassLoader != null)
throw new IllegalStateException("Missing afterDelivery from the previous beforeDelivery for message endpoint " + getProxyString(mi));
if (trace)
log.trace("MessageEndpoint " + getProxyString(mi) + " released");
MessageDrivenContainer container = getContainer(mi);
oldClassLoader = GetTCLAction.getContextClassLoader(inUseThread);
SetTCLAction.setContextClassLoader(inUseThread, container.getClassLoader());
if (trace)
log.trace("MessageEndpoint " + getProxyString(mi) + " set context classloader to " + container.getClassLoader());
try
{
startTransaction("beforeDelivery", mi, container);
}
catch (Throwable t)
{
resetContextClassLoader(mi);
throw new ResourceException(t);
}
}
protected void after(Invocation mi) throws Throwable
{
if (oldClassLoader == null)
throw new IllegalStateException("afterDelivery without a previous beforeDelivery for message endpoint " + getProxyString(mi));
try
{
finish("afterDelivery", mi, true);
}
catch (Throwable t)
{
throw new ResourceException(t);
}
}
protected Object delivery(Invocation mi) throws Throwable
{
if (delivered)
throw new IllegalStateException("Multiple message delivery between before and after delivery is not allowed for message endpoint " + getProxyString(mi));
if (trace)
log.trace("MessageEndpoint " + getProxyString(mi) + " delivering");
if (oldClassLoader != null)
delivered = true;
MessageDrivenContainer container = getContainer(mi);
boolean commit = true;
try
{
if (oldClassLoader == null)
startTransaction("delivery", mi, container);
return getNext().invoke(mi);
}
catch (Throwable t)
{
if (trace)
log.trace("MessageEndpoint " + getProxyString(mi) + " delivery error", t);
if (t instanceof Error || t instanceof RuntimeException)
{
if (transaction != null)
transaction.setRollbackOnly();
commit = false;
}
throw t;
}
finally
{
if (oldClassLoader == null)
{
try
{
endTransaction(mi, commit);
}
finally
{
releaseThreadLock(mi);
}
}
}
}
protected void finish(String context, Invocation mi, boolean commit) throws Throwable
{
try
{
endTransaction(mi, commit);
}
finally
{
delivered = false;
resetContextClassLoader(mi);
releaseThreadLock(mi);
}
}
protected void startTransaction(String context, Invocation mi, MessageDrivenContainer container) throws Throwable
{
XAResource resource = (XAResource) mi.getInvocationContext().getValue(MESSAGE_ENDPOINT_XARESOURCE);
Method method = null;
if ("delivery".equals(context))
method = mi.getMethod();
else
method = (Method) mi.getArguments()[0];
boolean isTransacted = getMessageEndpointFactory(mi).isDeliveryTransacted(method);
if (trace)
log.trace("MessageEndpoint " + getProxyString(mi) + " " + context + " method=" + method + " xaResource=" + resource + " transacted=" + isTransacted);
TransactionManager tm = container.getTransactionManager();
suspended = tm.suspend();
if (trace)
log.trace("MessageEndpoint " + getProxyString(mi) + " " + context + " currentTx=" + suspended);
if (isTransacted)
{
if (suspended == null)
{
tm.begin();
transaction = tm.getTransaction();
if (trace)
log.trace("MessageEndpoint " + getProxyString(mi) + " started transaction=" + transaction);
if (resource != null)
{
transaction.enlistResource(resource);
if (trace)
log.trace("MessageEndpoint " + getProxyString(mi) + " enlisted=" + resource);
}
}
else
{
try
{
tm.resume(suspended);
}
finally
{
suspended = null;
if (trace)
log.trace("MessageEndpoint " + getProxyString(mi) + " transaction=" + suspended + " already active, IGNORED=" + resource);
}
}
}
}
protected void endTransaction(Invocation mi, boolean commit) throws Throwable
{
TransactionManager tm = null;
Transaction currentTx = null;
try
{
if (transaction != null)
{
tm = getContainer(mi).getTransactionManager();
currentTx = tm.getTransaction();
if (currentTx != null && currentTx.equals(transaction) == false)
{
log.warn("Current transaction " + currentTx + " is not the expected transaction.");
tm.suspend();
tm.resume(transaction);
}
else
{
currentTx = null;
}
if (commit == false || transaction.getStatus() == Status.STATUS_MARKED_ROLLBACK)
{
if (trace)
log.trace("MessageEndpoint " + getProxyString(mi) + " rollback");
tm.rollback();
}
else
{
if (trace)
log.trace("MessageEndpoint " + getProxyString(mi) + " commit");
tm.commit();
}
}
if (suspended != null)
{
try
{
tm = getContainer(mi).getTransactionManager();
tm.resume(suspended);
}
finally
{
suspended = null;
}
}
}
finally
{
if (currentTx != null)
{
try
{
tm.resume(currentTx);
}
catch (Throwable t)
{
log.warn("MessageEndpoint " + getProxyString(mi) + " failed to resume old transaction " + currentTx);
}
}
}
}
protected void resetContextClassLoader(Invocation mi)
{
if (trace)
log.trace("MessageEndpoint " + getProxyString(mi) + " reset classloader " + oldClassLoader);
SetTCLAction.setContextClassLoader(inUseThread, oldClassLoader);
oldClassLoader = null;
}
protected void releaseThreadLock(Invocation mi)
{
if (trace)
log.trace("MessageEndpoint " + getProxyString(mi) + " no longer in use by " + inUseThread);
inUseThread = null;
}
protected String getProxyString(Invocation mi)
{
if (cachedProxyString == null)
cachedProxyString = mi.getInvocationContext().getCacheId().toString();
return cachedProxyString;
}
protected JBossMessageEndpointFactory getMessageEndpointFactory(Invocation mi)
{
if (endpointFactory == null)
endpointFactory = (JBossMessageEndpointFactory) mi.getInvocationContext().getValue(MESSAGE_ENDPOINT_FACTORY);
return endpointFactory;
}
protected MessageDrivenContainer getContainer(Invocation mi)
{
return getMessageEndpointFactory(mi).getContainer();
}
}