package org.jboss.invocation;
import java.io.DataOutputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.lang.reflect.Method;
import java.security.DigestOutputStream;
import java.security.MessageDigest;
import java.security.Principal;
import java.security.PrivilegedAction;
import java.security.AccessController;
import java.util.Map;
import java.util.Iterator;
import java.util.HashMap;
import java.util.WeakHashMap;
import javax.transaction.Transaction;
public class MarshalledInvocation
extends Invocation
implements java.io.Externalizable
{
static final long serialVersionUID = -718723094688127810L;
static boolean useFullHashMode = true;
static Map hashMap = new WeakHashMap();
protected Object tpc;
protected transient Map methodMap;
protected transient long methodHash = 0;
protected transient MarshalledValue marshalledArgs = null;
public static boolean getUseFullHashMode()
{
return useFullHashMode;
}
public static void setUseFullHashMode(boolean flag)
{
useFullHashMode = flag;
}
public static Map getInterfaceHashes(Class intf)
{
Method[] methods = null;
if( System.getSecurityManager() != null )
{
DeclaredMethodsAction action = new DeclaredMethodsAction(intf);
methods = (Method[]) AccessController.doPrivileged(action);
}
else
{
methods = intf.getDeclaredMethods();
}
HashMap map = new HashMap();
for (int i = 0; i < methods.length; i++)
{
Method method = methods[i];
Class[] parameterTypes = method.getParameterTypes();
String methodDesc = method.getName() + "(";
for (int j = 0; j < parameterTypes.length; j++)
{
methodDesc += getTypeString(parameterTypes[j]);
}
methodDesc += ")" + getTypeString(method.getReturnType());
try
{
long hash = 0;
ByteArrayOutputStream bytearrayoutputstream = new ByteArrayOutputStream(512);
MessageDigest messagedigest = MessageDigest.getInstance("SHA");
DataOutputStream dataoutputstream = new DataOutputStream(new DigestOutputStream(bytearrayoutputstream, messagedigest));
dataoutputstream.writeUTF(methodDesc);
dataoutputstream.flush();
byte abyte0[] = messagedigest.digest();
for (int j = 0; j < Math.min(8, abyte0.length); j++)
hash += (long) (abyte0[j] & 0xff) << j * 8;
map.put(method.toString(), new Long(hash));
}
catch (Exception e)
{
e.printStackTrace();
}
}
return map;
}
public static Map getFullInterfaceHashes(Class intf)
{
Method[] methods = null;
if( System.getSecurityManager() != null )
{
DeclaredMethodsAction action = new DeclaredMethodsAction(intf);
methods = (Method[]) AccessController.doPrivileged(action);
}
else
{
methods = intf.getDeclaredMethods();
}
HashMap map = new HashMap();
for (int i = 0; i < methods.length; i++)
{
Method method = methods[i];
String methodDesc = method.toString();
try
{
long hash = 0;
ByteArrayOutputStream bytearrayoutputstream = new ByteArrayOutputStream(512);
MessageDigest messagedigest = MessageDigest.getInstance("SHA");
DataOutputStream dataoutputstream = new DataOutputStream(new DigestOutputStream(bytearrayoutputstream, messagedigest));
dataoutputstream.writeUTF(methodDesc);
dataoutputstream.flush();
byte abyte0[] = messagedigest.digest();
for (int j = 0; j < Math.min(8, abyte0.length); j++)
hash += (long) (abyte0[j] & 0xff) << j * 8;
map.put(method.toString(), new Long(hash));
}
catch (Exception e)
{
e.printStackTrace();
}
}
return map;
}
public static Map methodToHashesMap(Class c)
{
Method[] methods = null;
if( System.getSecurityManager() != null )
{
DeclaredMethodsAction action = new DeclaredMethodsAction(c);
methods = (Method[]) AccessController.doPrivileged(action);
}
else
{
methods = c.getDeclaredMethods();
}
HashMap map = new HashMap();
for (int i = 0; i < methods.length; i++)
{
Method method = methods[i];
String methodDesc = method.toString();
try
{
long hash = 0;
ByteArrayOutputStream bytearrayoutputstream = new ByteArrayOutputStream(512);
MessageDigest messagedigest = MessageDigest.getInstance("SHA");
DataOutputStream dataoutputstream = new DataOutputStream(new DigestOutputStream(bytearrayoutputstream, messagedigest));
dataoutputstream.writeUTF(methodDesc);
dataoutputstream.flush();
byte abyte0[] = messagedigest.digest();
for (int j = 0; j < Math.min(8, abyte0.length); j++)
hash += (long) (abyte0[j] & 0xff) << j * 8;
map.put(new Long(hash), method);
}
catch (Exception e)
{
e.printStackTrace();
}
}
return map;
}
static String getTypeString(Class cl)
{
if (cl == Byte.TYPE)
{
return "B";
}
else if (cl == Character.TYPE)
{
return "C";
}
else if (cl == Double.TYPE)
{
return "D";
}
else if (cl == Float.TYPE)
{
return "F";
}
else if (cl == Integer.TYPE)
{
return "I";
}
else if (cl == Long.TYPE)
{
return "J";
}
else if (cl == Short.TYPE)
{
return "S";
}
else if (cl == Boolean.TYPE)
{
return "Z";
}
else if (cl == Void.TYPE)
{
return "V";
}
else if (cl.isArray())
{
return "[" + getTypeString(cl.getComponentType());
}
else
{
return "L" + cl.getName().replace('.', '/') + ";";
}
}
public static long calculateHash(Method method)
{
Map methodHashes = (Map) hashMap.get(method.getDeclaringClass());
if (methodHashes == null)
{
if( useFullHashMode == true )
methodHashes = getFullInterfaceHashes(method.getDeclaringClass());
else
methodHashes = getInterfaceHashes(method.getDeclaringClass());
synchronized (hashMap)
{
hashMap.put(method.getDeclaringClass(), methodHashes);
}
}
Long hash = (Long) methodHashes.get(method.toString());
return hash.longValue();
}
public static void removeHashes(Class declaringClass)
{
synchronized (hashMap)
{
hashMap.remove(declaringClass);
}
}
public MarshalledInvocation()
{
}
public MarshalledInvocation(Invocation invocation)
{
this.payload = invocation.payload;
this.as_is_payload = invocation.as_is_payload;
this.method = invocation.getMethod();
this.objectName = invocation.getObjectName();
this.args = invocation.getArguments();
this.invocationType = invocation.getType();
}
public MarshalledInvocation(
Object id,
Method m,
Object[] args,
Transaction tx,
Principal identity,
Object credential)
{
super(id, m, args, tx, identity, credential);
}
public Method getMethod()
{
if (this.method != null)
return this.method;
this.method = (Method) methodMap.get(new Long(methodHash));
if (this.method == null)
{
throw new IllegalStateException("Failed to find method for hash:" + methodHash + " available=" + methodMap);
}
return this.method;
}
public void setMethodMap(Map methods)
{
methodMap = methods;
}
public void setTransactionPropagationContext(Object tpc)
{
this.tpc = tpc;
}
public Object getTransactionPropagationContext()
{
return tpc;
}
public Object getValue(Object key)
{
Object value = super.getValue(key);
if (value instanceof MarshalledValue)
{
try
{
MarshalledValue mv = (MarshalledValue) value;
value = mv.get();
}
catch (Exception e)
{
e.printStackTrace();
value = null;
}
}
return value;
}
public Object getPayloadValue(Object key)
{
Object value = getPayload().get(key);
if (value instanceof MarshalledValue)
{
try
{
MarshalledValue mv = (MarshalledValue) value;
value = mv.get();
}
catch (Exception e)
{
e.printStackTrace();
value = null;
}
}
return value;
}
public Object[] getArguments()
{
if (this.args == null)
{
try
{
this.args = (Object[]) marshalledArgs.get();
}
catch (Exception e)
{
e.printStackTrace();
}
}
return args;
}
public void writeExternal(java.io.ObjectOutput out)
throws IOException
{
getAsIsPayload().put(InvocationKey.TYPE, invocationType);
out.writeObject(tpc);
long methodHash = calculateHash(this.method);
out.writeLong(methodHash);
out.writeObject(this.objectName);
out.writeObject(new MarshalledValue(this.args));
if (payload == null)
out.writeInt(0);
else
{
out.writeInt(payload.size());
Iterator keys = payload.keySet().iterator();
while (keys.hasNext())
{
Object currentKey = keys.next();
out.writeObject(currentKey);
out.writeObject(new MarshalledValue(payload.get(currentKey)));
}
}
if (as_is_payload == null)
out.writeInt(0);
else
{
out.writeInt(as_is_payload.size());
Iterator keys = as_is_payload.keySet().iterator();
while (keys.hasNext())
{
Object currentKey = keys.next();
out.writeObject(currentKey);
out.writeObject(as_is_payload.get(currentKey));
}
}
}
public void readExternal(java.io.ObjectInput in)
throws IOException, ClassNotFoundException
{
tpc = in.readObject();
this.methodHash = in.readLong();
this.objectName = in.readObject();
marshalledArgs = (MarshalledValue) in.readObject();
int payloadSize = in.readInt();
if (payloadSize > 0)
{
payload = new HashMap();
for (int i = 0; i < payloadSize; i++)
{
Object key = in.readObject();
Object value = in.readObject();
payload.put(key, value);
}
}
int as_is_payloadSize = in.readInt();
if (as_is_payloadSize > 0)
{
as_is_payload = new HashMap();
for (int i = 0; i < as_is_payloadSize; i++)
{
Object key = in.readObject();
Object value = in.readObject();
as_is_payload.put(key, value);
}
}
invocationType = (InvocationType)getAsIsValue(InvocationKey.TYPE);
}
private static class DeclaredMethodsAction implements PrivilegedAction
{
Class c;
DeclaredMethodsAction(Class c)
{
this.c = c;
}
public Object run()
{
Method[] methods = c.getDeclaredMethods();
c = null;
return methods;
}
}
}