package org.jboss.ejb.plugins.local;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.lang.reflect.Constructor;
import java.rmi.AccessException;
import java.rmi.NoSuchObjectException;
import java.security.Principal;
import java.security.PrivilegedAction;
import java.security.AccessController;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import javax.ejb.AccessLocalException;
import javax.ejb.EJBLocalHome;
import javax.ejb.EJBLocalObject;
import javax.ejb.NoSuchObjectLocalException;
import javax.ejb.TransactionRequiredLocalException;
import javax.ejb.TransactionRolledbackLocalException;
import javax.naming.Context;
import javax.naming.InitialContext;
import javax.transaction.Transaction;
import javax.transaction.TransactionManager;
import javax.transaction.TransactionRequiredException;
import javax.transaction.TransactionRolledbackException;
import org.jboss.ejb.Container;
import org.jboss.ejb.EJBProxyFactoryContainer;
import org.jboss.ejb.LocalProxyFactory;
import org.jboss.invocation.InvocationType;
import org.jboss.invocation.MarshalledInvocation;
import org.jboss.invocation.LocalEJBInvocation;
import org.jboss.logging.Logger;
import org.jboss.metadata.BeanMetaData;
import org.jboss.naming.Util;
import org.jboss.security.SecurityAssociation;
import org.jboss.util.NestedRuntimeException;
import org.jboss.tm.TransactionLocal;
public class BaseLocalProxyFactory implements LocalProxyFactory
{
protected static Logger log = Logger.getLogger(BaseLocalProxyFactory.class);
protected static Map invokerMap = Collections.synchronizedMap(new HashMap());
protected Container container;
protected String localJndiName;
protected TransactionManager transactionManager;
protected EJBLocalHome home;
protected EJBLocalObject statelessObject;
protected Map beanMethodInvokerMap;
protected Map homeMethodInvokerMap;
protected Class localHomeClass;
protected Class localClass;
protected Constructor proxyClassConstructor;
private final TransactionLocal cache = new TransactionLocal()
{
protected Object initialValue()
{
return new HashMap();
}
};
public void setContainer(Container con)
{
this.container = con;
}
public void create() throws Exception
{
BeanMetaData metaData = container.getBeanMetaData();
localJndiName = metaData.getLocalJndiName();
}
public void start()
throws Exception
{
BeanMetaData metaData = container.getBeanMetaData();
EJBProxyFactoryContainer invokerContainer =
(EJBProxyFactoryContainer) container;
localHomeClass = invokerContainer.getLocalHomeClass();
localClass = invokerContainer.getLocalClass();
if(localHomeClass == null || localClass == null)
{
log.debug(metaData.getEjbName()
+
" cannot be Bound, doesn't " +
"have local and local home interfaces");
return;
}
Class[] intfs = {localClass};
Class proxyClass = Proxy.getProxyClass(ClassLoaderAction.UTIL.get(localClass), intfs);
final Class[] constructorParams =
{InvocationHandler.class};
proxyClassConstructor = proxyClass.getConstructor(constructorParams);
Context iniCtx = new InitialContext();
String beanName = metaData.getEjbName();
transactionManager =
(TransactionManager) iniCtx.lookup("java:/TransactionManager");
Method[] methods = localClass.getMethods();
beanMethodInvokerMap = new HashMap();
for(int i = 0; i < methods.length; i++)
{
long hash = MarshalledInvocation.calculateHash(methods[i]);
beanMethodInvokerMap.put(new Long(hash), methods[i]);
}
methods = localHomeClass.getMethods();
homeMethodInvokerMap = new HashMap();
for(int i = 0; i < methods.length; i++)
{
long hash = MarshalledInvocation.calculateHash(methods[i]);
homeMethodInvokerMap.put(new Long(hash), methods[i]);
}
Util.rebind(iniCtx, localJndiName, getEJBLocalHome());
invokerMap.put(localJndiName, this);
log.debug("Bound EJBLocalHome of " + beanName + " to " + localJndiName);
}
public void stop()
{
try
{
if(invokerMap.remove(localJndiName) == this)
{
InitialContext ctx = new InitialContext();
ctx.unbind(localJndiName);
}
}
catch(Exception ignore)
{
}
}
public void destroy()
{
if(beanMethodInvokerMap != null)
{
beanMethodInvokerMap.clear();
}
if(homeMethodInvokerMap != null)
{
homeMethodInvokerMap.clear();
}
MarshalledInvocation.removeHashes(localHomeClass);
MarshalledInvocation.removeHashes(localClass);
container = null;
}
public Constructor getProxyClassConstructor()
{
if(proxyClassConstructor == null)
{
}
return proxyClassConstructor;
}
public EJBLocalHome getEJBLocalHome()
{
if(home == null)
{
EJBProxyFactoryContainer cic = (EJBProxyFactoryContainer) container;
InvocationHandler handler = new LocalHomeProxy(localJndiName, this);
ClassLoader loader = ClassLoaderAction.UTIL.get(cic.getLocalHomeClass());
Class[] interfaces = {cic.getLocalHomeClass()};
home = (EJBLocalHome) Proxy.newProxyInstance(loader,
interfaces,
handler);
}
return home;
}
public EJBLocalObject getStatelessSessionEJBLocalObject()
{
if(statelessObject == null)
{
EJBProxyFactoryContainer cic = (EJBProxyFactoryContainer) container;
InvocationHandler handler =
new StatelessSessionProxy(localJndiName, this);
ClassLoader loader = ClassLoaderAction.UTIL.get(cic.getLocalClass());
Class[] interfaces = {cic.getLocalClass()};
statelessObject = (EJBLocalObject) Proxy.newProxyInstance(loader,
interfaces,
handler);
}
return statelessObject;
}
public EJBLocalObject getStatefulSessionEJBLocalObject(Object id)
{
InvocationHandler handler =
new StatefulSessionProxy(localJndiName, id, this);
try
{
return (EJBLocalObject) proxyClassConstructor.newInstance(new Object[]{handler});
}
catch(Exception ex)
{
throw new NestedRuntimeException(ex);
}
}
public Object getEntityEJBObject(Object id)
{
return getEntityEJBLocalObject(id);
}
public EJBLocalObject getEntityEJBLocalObject(Object id, boolean create)
{
EJBLocalObject result = null;
if(id != null)
{
final Transaction tx = cache.getTransaction();
if(tx == null)
{
result = createEJBLocalObject(id);
}
else
{
Map map = (Map) cache.get(tx);
if(create)
{
result = createEJBLocalObject(id);
map.put(id, result);
}
else
{
result = (EJBLocalObject) map.get(id);
if(result == null)
{
result = createEJBLocalObject(id);
map.put(id, result);
}
}
}
}
return result;
}
public EJBLocalObject getEntityEJBLocalObject(Object id)
{
return getEntityEJBLocalObject(id, false);
}
public Collection getEntityLocalCollection(Collection ids)
{
ArrayList list = new ArrayList(ids.size());
Iterator iter = ids.iterator();
while(iter.hasNext())
{
final Object nextId = iter.next();
list.add(getEntityEJBLocalObject(nextId));
}
return list;
}
public Object invokeHome(Method m, Object[] args) throws Exception
{
ClassLoader oldCl = TCLAction.UTIL.getContextClassLoader();
boolean setCl = !oldCl.equals(container.getClassLoader());
if(setCl)
{
TCLAction.UTIL.setContextClassLoader(container.getClassLoader());
}
SecurityActions sa = SecurityActions.UTIL.getSecurityActions();
try
{
LocalEJBInvocation invocation = new LocalEJBInvocation(null,
m,
args,
getTransaction(),
sa.getPrincipal(),
sa.getCredential());
invocation.setType(InvocationType.LOCALHOME);
return container.invoke(invocation);
}
catch(AccessException ae)
{
throw new AccessLocalException(ae.getMessage(), ae);
}
catch(NoSuchObjectException nsoe)
{
throw new NoSuchObjectLocalException(nsoe.getMessage(), nsoe);
}
catch(TransactionRequiredException tre)
{
throw new TransactionRequiredLocalException(tre.getMessage());
}
catch(TransactionRolledbackException trbe)
{
throw new TransactionRolledbackLocalException(trbe.getMessage(), trbe);
}
finally
{
if(setCl)
{
TCLAction.UTIL.setContextClassLoader(oldCl);
}
}
}
public String getJndiName()
{
return localJndiName;
}
Transaction getTransaction() throws javax.transaction.SystemException
{
if(transactionManager == null)
{
return null;
}
return transactionManager.getTransaction();
}
public Object invoke(Object id, Method m, Object[] args)
throws Exception
{
ClassLoader oldCl = TCLAction.UTIL.getContextClassLoader();
boolean setCl = !oldCl.equals(container.getClassLoader());
if(setCl)
{
TCLAction.UTIL.setContextClassLoader(container.getClassLoader());
}
SecurityActions sa = SecurityActions.UTIL.getSecurityActions();
try
{
LocalEJBInvocation invocation = new LocalEJBInvocation(id,
m,
args,
getTransaction(),
sa.getPrincipal(),
sa.getCredential());
invocation.setType(InvocationType.LOCAL);
return container.invoke(invocation);
}
catch(AccessException ae)
{
throw new AccessLocalException(ae.getMessage(), ae);
}
catch(NoSuchObjectException nsoe)
{
throw new NoSuchObjectLocalException(nsoe.getMessage(), nsoe);
}
catch(TransactionRequiredException tre)
{
throw new TransactionRequiredLocalException(tre.getMessage());
}
catch(TransactionRolledbackException trbe)
{
throw new TransactionRolledbackLocalException(trbe.getMessage(), trbe);
}
finally
{
if(setCl)
{
TCLAction.UTIL.setContextClassLoader(oldCl);
}
}
}
private EJBLocalObject createEJBLocalObject(Object id)
{
InvocationHandler handler = new EntityProxy(localJndiName, id, this);
try
{
return (EJBLocalObject) proxyClassConstructor.newInstance(new Object[]{handler});
}
catch(Exception ex)
{
throw new NestedRuntimeException(ex);
}
}
interface ClassLoaderAction
{
class UTIL
{
static ClassLoaderAction getClassLoaderAction()
{
return System.getSecurityManager() == null ? NON_PRIVILEGED : PRIVILEGED;
}
static ClassLoader get(Class clazz)
{
return getClassLoaderAction().get(clazz);
}
}
ClassLoaderAction PRIVILEGED = new ClassLoaderAction()
{
public ClassLoader get(final Class clazz)
{
return (ClassLoader)AccessController.doPrivileged(
new PrivilegedAction()
{
public Object run()
{
return clazz.getClassLoader();
}
}
);
}
};
ClassLoaderAction NON_PRIVILEGED = new ClassLoaderAction()
{
public ClassLoader get(Class clazz)
{
return clazz.getClassLoader();
}
};
ClassLoader get(Class clazz);
}
interface SecurityActions
{
class UTIL
{
static SecurityActions getSecurityActions()
{
return System.getSecurityManager() == null ? NON_PRIVILEGED : PRIVILEGED;
}
}
SecurityActions NON_PRIVILEGED = new SecurityActions()
{
public Principal getPrincipal()
{
return SecurityAssociation.getPrincipal();
}
public Object getCredential()
{
return SecurityAssociation.getCredential();
}
};
SecurityActions PRIVILEGED = new SecurityActions()
{
private final PrivilegedAction getPrincipalAction = new PrivilegedAction()
{
public Object run()
{
return SecurityAssociation.getPrincipal();
}
};
private final PrivilegedAction getCredentialAction = new PrivilegedAction()
{
public Object run()
{
return SecurityAssociation.getCredential();
}
};
public Principal getPrincipal()
{
return (Principal)AccessController.doPrivileged(getPrincipalAction);
}
public Object getCredential()
{
return AccessController.doPrivileged(getCredentialAction);
}
};
Principal getPrincipal();
Object getCredential();
}
interface TCLAction
{
class UTIL
{
static TCLAction getTCLAction()
{
return System.getSecurityManager() == null ? NON_PRIVILEGED : PRIVILEGED;
}
static ClassLoader getContextClassLoader()
{
return getTCLAction().getContextClassLoader();
}
static ClassLoader getContextClassLoader(Thread thread)
{
return getTCLAction().getContextClassLoader(thread);
}
static void setContextClassLoader(ClassLoader cl)
{
getTCLAction().setContextClassLoader(cl);
}
static void setContextClassLoader(Thread thread, ClassLoader cl)
{
getTCLAction().setContextClassLoader(thread, cl);
}
}
TCLAction NON_PRIVILEGED = new TCLAction()
{
public ClassLoader getContextClassLoader()
{
return Thread.currentThread().getContextClassLoader();
}
public ClassLoader getContextClassLoader(Thread thread)
{
return thread.getContextClassLoader();
}
public void setContextClassLoader(ClassLoader cl)
{
Thread.currentThread().setContextClassLoader(cl);
}
public void setContextClassLoader(Thread thread, ClassLoader cl)
{
thread.setContextClassLoader(cl);
}
};
TCLAction PRIVILEGED = new TCLAction()
{
private final PrivilegedAction getTCLPrivilegedAction = new PrivilegedAction()
{
public Object run()
{
return Thread.currentThread().getContextClassLoader();
}
};
public ClassLoader getContextClassLoader()
{
return (ClassLoader)AccessController.doPrivileged(getTCLPrivilegedAction);
}
public ClassLoader getContextClassLoader(final Thread thread)
{
return (ClassLoader)AccessController.doPrivileged(new PrivilegedAction()
{
public Object run()
{
return thread.getContextClassLoader();
}
});
}
public void setContextClassLoader(final ClassLoader cl)
{
AccessController.doPrivileged(
new PrivilegedAction()
{
public Object run()
{
Thread.currentThread().setContextClassLoader(cl);
return null;
}
}
);
}
public void setContextClassLoader(final Thread thread, final ClassLoader cl)
{
AccessController.doPrivileged(
new PrivilegedAction()
{
public Object run()
{
thread.setContextClassLoader(cl);
return null;
}
}
);
}
};
ClassLoader getContextClassLoader();
ClassLoader getContextClassLoader(Thread thread);
void setContextClassLoader(ClassLoader cl);
void setContextClassLoader(Thread thread, ClassLoader cl);
}
}