package org.jboss.resource.adapter.jdbc.remote;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.UndeclaredThrowableException;
import java.sql.CallableStatement;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.ParameterMetaData;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.Map;
import javax.management.MBeanServer;
import javax.management.ObjectName;
import javax.naming.BinaryRefAddr;
import javax.naming.InitialContext;
import javax.naming.NamingException;
import javax.naming.Reference;
import javax.naming.StringRefAddr;
import javax.resource.Referenceable;
import javax.sql.DataSource;
import org.jboss.deployment.DeploymentException;
import org.jboss.invocation.Invocation;
import org.jboss.invocation.Invoker;
import org.jboss.invocation.InvokerInterceptor;
import org.jboss.invocation.MarshalledInvocation;
import org.jboss.logging.Logger;
import org.jboss.naming.NonSerializableFactory;
import org.jboss.naming.Util;
import org.jboss.proxy.ClientMethodInterceptor;
import org.jboss.proxy.GenericProxyFactory;
import org.jboss.resource.connectionmanager.ConnectionFactoryBindingService;
import org.jboss.system.Registry;
import org.jboss.system.ServiceMBeanSupport;
public class WrapperDataSourceService extends ConnectionFactoryBindingService
implements WrapperDataSourceServiceMBean
{
private static Logger log = Logger.getLogger(WrapperDataSourceService.class);
private ObjectName jmxInvokerName;
private Invoker delegateInvoker;
private Object theProxy;
private HashMap marshalledInvocationMapping = new HashMap();
private HashMap connectionMap = new HashMap();
private HashMap statementMap = new HashMap();
private HashMap resultSetMap = new HashMap();
private HashMap databaseMetaDataMap = new HashMap();
private boolean trace = log.isTraceEnabled();
protected void startService() throws Exception
{
determineBindName();
createConnectionFactory();
createProxy();
calculateMethodHases();
bindConnectionFactory();
}
protected void stopService() throws Exception
{
unbindConnectionFactory();
destroyProxy();
}
protected void bindConnectionFactory() throws Exception
{
InitialContext ctx = new InitialContext();
try
{
log.debug("Binding object '" + cf + "' into JNDI at '" + bindName + "'");
NonSerializableFactory.rebind(bindName, cf);
Referenceable referenceable = (Referenceable) cf;
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos);
oos.writeObject(theProxy);
oos.close();
byte[] proxyBytes = baos.toByteArray();
BinaryRefAddr dsAddr = new BinaryRefAddr("ProxyData", proxyBytes);
String factory = DataSourceFactory.class.getName();
Reference dsRef = new Reference("javax.sql.DataSource", dsAddr, factory, null);
referenceable.setReference(dsRef);
baos.reset();
ObjectOutputStream oos2 = new ObjectOutputStream(baos);
oos2.writeObject(DataSourceFactory.vmID);
oos2.close();
byte[] id = baos.toByteArray();
BinaryRefAddr localAddr = new BinaryRefAddr("VMID", id);
dsRef.add(localAddr);
StringRefAddr jndiRef = new StringRefAddr("JndiName", bindName);
dsRef.add(jndiRef);
Util.rebind(ctx, bindName, cf);
log.info("Bound connection factory for resource adapter for ConnectionManager '" + serviceName + " to JNDI name '" + bindName + "'");
}
catch (NamingException ne)
{
throw new DeploymentException("Could not bind ConnectionFactory into jndi: " + bindName);
}
finally
{
ctx.close();
}
}
public ObjectName getJMXInvokerName()
{
return jmxInvokerName;
}
public void setJMXInvokerName(ObjectName jmxInvokerName)
{
this.jmxInvokerName = jmxInvokerName;
}
public Object invoke(Invocation invocation) throws Exception
{
if (invocation instanceof MarshalledInvocation)
{
MarshalledInvocation mi = (MarshalledInvocation) invocation;
mi.setMethodMap(marshalledInvocationMapping);
}
Method method = invocation.getMethod();
Class methodClass = method.getDeclaringClass();
Object[] args = invocation.getArguments();
Object value = null;
try
{
if( methodClass.isAssignableFrom(DataSource.class) )
{
InitialContext ctx = new InitialContext();
DataSource ds = (DataSource) ctx.lookup(bindName);
value = doDataSourceMethod(ds, method, args);
}
else if( methodClass.isAssignableFrom(Connection.class) )
{
Integer id = (Integer) invocation.getId();
Connection conn = (Connection) connectionMap.get(id);
if( conn == null )
{
throw new IllegalAccessException("Failed to find connection: "+id);
}
value = doConnectionMethod(conn, method, args);
}
else if( methodClass.isAssignableFrom(Statement.class) ||
methodClass.isAssignableFrom(PreparedStatement.class) ||
methodClass.isAssignableFrom(CallableStatement.class))
{
Integer id = (Integer) invocation.getId();
Statement stmt = (Statement) statementMap.get(id);
if( stmt == null )
{
throw new SQLException("Failed to find Statement: " + id);
}
value = doStatementMethod(stmt, method, args);
}
else if( methodClass.isAssignableFrom(ResultSet.class) )
{
Integer id = (Integer) invocation.getId();
ResultSet results = (ResultSet) resultSetMap.get(id);
if( results == null )
{
throw new IllegalAccessException("Failed to find ResultSet: "+id);
}
value = doResultSetMethod(results, method, args);
}
else if (methodClass.isAssignableFrom(DatabaseMetaData.class))
{
Integer id = (Integer) invocation.getId();
DatabaseMetaData dbMetaData = (DatabaseMetaData) databaseMetaDataMap.get(id);
if(dbMetaData == null)
{
throw new IllegalAccessException("Failed to find DatabaseMetaData: " + id);
}
value = doDatabaseMetaDataMethod(dbMetaData, method, args);
}
else
{
throw new UnsupportedOperationException("Do not know how to handle method="+method);
}
}
catch (InvocationTargetException e)
{
Throwable t = e.getTargetException();
if (t instanceof Exception)
throw (Exception) t;
else
throw new UndeclaredThrowableException(t, method.toString());
}
return value;
}
protected void createProxy() throws Exception
{
delegateInvoker = (Invoker) Registry.lookup(jmxInvokerName);
log.debug("Using delegate: " + delegateInvoker
+ " for invoker=" + jmxInvokerName);
ObjectName targetName = getServiceName();
Integer nameHash = new Integer(targetName.hashCode());
Registry.bind(nameHash, targetName);
Object cacheID = null;
String proxyBindingName = null;
String jndiName = null;
Class[] ifaces = {javax.sql.DataSource.class};
ArrayList interceptorClasses = new ArrayList();
interceptorClasses.add(ClientMethodInterceptor.class);
interceptorClasses.add(InvokerInterceptor.class);
ClassLoader loader = Thread.currentThread().getContextClassLoader();
GenericProxyFactory proxyFactory = new GenericProxyFactory();
theProxy = proxyFactory.createProxy(cacheID, targetName,
delegateInvoker, jndiName, proxyBindingName, interceptorClasses,
loader, ifaces);
log.debug("Created proxy for invoker=" + jmxInvokerName
+ ", targetName=" + targetName + ", nameHash=" + nameHash);
}
protected void destroyProxy() throws Exception
{
ObjectName name = getServiceName();
Integer nameHash = new Integer(name.hashCode());
Registry.unbind(nameHash);
}
protected void calculateMethodHases() throws Exception
{
Method[] methods = DataSource.class.getMethods();
for(int m = 0; m < methods.length; m ++)
{
Method method = methods[m];
Long hash = new Long(MarshalledInvocation.calculateHash(method));
marshalledInvocationMapping.put(hash, method);
}
Map m = MarshalledInvocation.methodToHashesMap(Connection.class);
displayHashes(m);
marshalledInvocationMapping.putAll(m);
m = MarshalledInvocation.methodToHashesMap(Statement.class);
displayHashes(m);
marshalledInvocationMapping.putAll(m);
m = MarshalledInvocation.methodToHashesMap(CallableStatement.class);
displayHashes(m);
marshalledInvocationMapping.putAll(m);
m = MarshalledInvocation.methodToHashesMap(PreparedStatement.class);
displayHashes(m);
marshalledInvocationMapping.putAll(m);
m = MarshalledInvocation.methodToHashesMap(ResultSet.class);
displayHashes(m);
marshalledInvocationMapping.putAll(m);
m = MarshalledInvocation.methodToHashesMap(DatabaseMetaData.class);
displayHashes(m);
marshalledInvocationMapping.putAll(m);
}
private Object doDataSourceMethod(DataSource ds, Method method, Object[] args)
throws InvocationTargetException, IllegalAccessException
{
Object value = method.invoke(ds, args);
if( value instanceof Connection )
{
value = createConnectionProxy(value);
}
else if( value != null && (value instanceof Serializable) == false )
{
throw new IllegalAccessException("Method="+method+" does not return Serializable");
}
return value;
}
private Object doConnectionMethod(Connection conn, Method method, Object[] args)
throws InvocationTargetException, IllegalAccessException, SQLException
{
if( trace )
{
log.trace("doConnectionMethod, conn="+conn+", method="+method);
}
Object value = method.invoke(conn, args);
if( value instanceof Statement )
{
value = createStatementProxy(value);
}
else if(value instanceof DatabaseMetaData)
{
value = createDatabaseMetaData(value);
}
else if( value != null && (value instanceof Serializable) == false )
{
throw new IllegalAccessException("Method="+method+" does not return Serializable");
}
return value;
}
private Object doStatementMethod(Statement stmt, Method method, Object[] args)
throws InvocationTargetException, IllegalAccessException, SQLException
{
if( trace )
{
log.trace("doStatementMethod, conn="+stmt+", method="+method);
}
if( method.getName().equals("close") )
{
Integer id = new Integer(stmt.hashCode());
statementMap.remove(id);
log.debug("Closed Statement="+id);
}
Object value = method.invoke(stmt, args);
if( value instanceof ResultSet )
{
value = createResultSetProxy(value);
}
else if( value instanceof ResultSetMetaData )
{
ResultSetMetaData rmd = (ResultSetMetaData) value;
value = new SerializableResultSetMetaData(rmd);
}
else if ( value instanceof ParameterMetaData )
{
ParameterMetaData pmd = (ParameterMetaData) value;
value = new SerializableParameterMetaData(pmd);
}
else if( value != null && (value instanceof Serializable) == false )
{
throw new IllegalAccessException("Method="+method+" does not return Serializable");
}
return value;
}
private Object doResultSetMethod(ResultSet results, Method method, Object[] args)
throws InvocationTargetException, IllegalAccessException, SQLException, IOException
{
if( trace )
{
log.trace("doStatementMethod, results="+results+", method="+method);
}
if( method.getName().equals("close") )
{
Integer id = new Integer(results.hashCode());
resultSetMap.remove(id);
log.debug("Closed ResultSet="+id);
}
Object value = method.invoke(results, args);
if( value instanceof ResultSetMetaData )
{
ResultSetMetaData rmd = (ResultSetMetaData) value;
value = new SerializableResultSetMetaData(rmd);
}
if("getAsciiStream".equals(method.getName()) && value instanceof InputStream)
{
InputStream ins = (InputStream)value;
value = new SerializableInputStream(ins);
}
if( value != null && (value instanceof Serializable) == false )
{
throw new IllegalAccessException("Method="+method+" does not return Serializable");
}
return value;
}
private Object doDatabaseMetaDataMethod(DatabaseMetaData dbMetaData, Method method, Object[] args)
throws InvocationTargetException, IllegalAccessException
{
if( trace )
{
log.trace("doDatabaseMetaDataMethod, dbMetaData="+dbMetaData+", method="+method);
}
Object value = method.invoke(dbMetaData, args);
if( value instanceof ResultSet )
{
value = createResultSetProxy(value);
}
else if( value instanceof Connection )
{
value = createConnectionProxy(value);
}
if( value != null && (value instanceof Serializable) == false )
{
throw new IllegalAccessException("Method="+method+" does not return Serializable");
}
return value;
}
private Object createConnectionProxy(Object conn)
{
Object cacheID = new Integer(conn.hashCode());
ObjectName targetName = getServiceName();
String proxyBindingName = null;
String jndiName = null;
Class[] ifaces = {java.sql.Connection.class};
ArrayList interceptorClasses = new ArrayList();
interceptorClasses.add(ClientMethodInterceptor.class);
interceptorClasses.add(InvokerInterceptor.class);
ClassLoader loader = Thread.currentThread().getContextClassLoader();
GenericProxyFactory proxyFactory = new GenericProxyFactory();
Object connProxy = proxyFactory.createProxy(cacheID, targetName,
delegateInvoker, jndiName, proxyBindingName, interceptorClasses,
loader, ifaces);
connectionMap.put(cacheID, conn);
log.debug("Created Connection proxy for invoker=" + jmxInvokerName
+ ", targetName=" + targetName + ", cacheID=" + cacheID);
return connProxy;
}
private Object createStatementProxy(Object stmt)
{
Object cacheID = new Integer(stmt.hashCode());
ObjectName targetName = getServiceName();
String proxyBindingName = null;
String jndiName = null;
Class[] ifaces = stmt.getClass().getInterfaces();
ArrayList tmp = new ArrayList();
for(int i = 0; i < ifaces.length; i ++)
{
Class c = ifaces[i];
if( c.getName().startsWith("java") )
tmp.add(c);
}
ifaces = new Class[tmp.size()];
tmp.toArray(ifaces);
ArrayList interceptorClasses = new ArrayList();
interceptorClasses.add(StatementInterceptor.class);
interceptorClasses.add(ClientMethodInterceptor.class);
interceptorClasses.add(InvokerInterceptor.class);
ClassLoader loader = Thread.currentThread().getContextClassLoader();
GenericProxyFactory proxyFactory = new GenericProxyFactory();
Object stmtProxy = proxyFactory.createProxy(cacheID, targetName,
delegateInvoker, jndiName, proxyBindingName, interceptorClasses,
loader, ifaces);
statementMap.put(cacheID, stmt);
log.debug("Created Statement proxy for invoker=" + jmxInvokerName
+ ", targetName=" + targetName + ", cacheID=" + cacheID);
return stmtProxy;
}
private Object createResultSetProxy(Object results)
{
Object cacheID = new Integer(results.hashCode());
ObjectName targetName = getServiceName();
String proxyBindingName = null;
String jndiName = null;
Class[] ifaces = results.getClass().getInterfaces();
ArrayList interceptorClasses = new ArrayList();
interceptorClasses.add(ClientMethodInterceptor.class);
interceptorClasses.add(InvokerInterceptor.class);
ClassLoader loader = Thread.currentThread().getContextClassLoader();
GenericProxyFactory proxyFactory = new GenericProxyFactory();
Object resultsProxy = proxyFactory.createProxy(cacheID, targetName,
delegateInvoker, jndiName, proxyBindingName, interceptorClasses,
loader, ifaces);
resultSetMap.put(cacheID, results);
log.debug("Created ResultSet proxy for invoker=" + jmxInvokerName
+ ", targetName=" + targetName + ", cacheID=" + cacheID);
return resultsProxy;
}
private Object createDatabaseMetaData(Object dbMetaData)
{
Object cacheID = new Integer(dbMetaData.hashCode());
ObjectName targetName = getServiceName();
String proxyBindingName = null;
String jndiName = null;
Class[] ifaces = {java.sql.DatabaseMetaData.class};
ArrayList interceptorClasses = new ArrayList();
interceptorClasses.add(ClientMethodInterceptor.class);
interceptorClasses.add(InvokerInterceptor.class);
ClassLoader loader = Thread.currentThread().getContextClassLoader();
GenericProxyFactory proxyFactory = new GenericProxyFactory();
Object dbMetaDataProxy = proxyFactory.createProxy(cacheID, targetName,
delegateInvoker, jndiName, proxyBindingName, interceptorClasses,
loader, ifaces);
databaseMetaDataMap.put(cacheID, dbMetaData);
log.debug("Created DatabaseMetadata proxy for invoker=" + jmxInvokerName
+ ", targetName=" + targetName + ", cacheID=" + cacheID);
return dbMetaDataProxy;
}
private void displayHashes(Map m)
{
if( trace == false )
return;
Iterator keys = m.keySet().iterator();
while( keys.hasNext() )
{
Long key = (Long) keys.next();
log.trace(key+"="+m.get(key));
}
}
}