package org.jboss.remoting.transport.socket;
import org.jboss.remoting.CannotConnectException;
import org.jboss.remoting.ConnectionFailedException;
import org.jboss.remoting.InvokerLocator;
import org.jboss.remoting.RemoteClientInvoker;
import org.jboss.remoting.loading.ObjectInputStreamWithClassLoader;
import org.jboss.remoting.marshal.Marshaller;
import org.jboss.remoting.marshal.UnMarshaller;
import org.jboss.remoting.marshal.serializable.SerializableMarshaller;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.InputStream;
import java.net.InetAddress;
import java.net.Socket;
import java.rmi.ConnectException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Map;
public class SocketClientInvoker extends RemoteClientInvoker
{
private InetAddress addr;
private int port;
public static final String TCP_NODELAY_FLAG = "enableTcpNoDelay";
public static final String MAX_POOL_SIZE_FLAG = "clientMaxPoolSize";
public static final String SO_TIMEOUT_FLAG = "socketTimeout";
public static final int SO_TIMEOUT_DEFAULT = 60000; public static final boolean TCP_NODELAY_DEFAULT = false;
public static long getSocketTime = 0;
public static long readTime = 0;
public static long writeTime = 0;
public static long serializeTime = 0;
public static long deserializeTime = 0;
protected boolean enableTcpNoDelay = TCP_NODELAY_DEFAULT;
protected int timeout = SO_TIMEOUT_DEFAULT;
public static final int MAX_RETRIES = 10;
public static long usedPooled = 0;
protected int numberOfRetries = MAX_RETRIES;
protected LinkedList pool = null;
protected ServerAddress address;
protected static HashMap connectionPools = new HashMap();
protected int maxPoolSize = 10;
public SocketClientInvoker(InvokerLocator locator)
throws IOException
{
super(locator);
try
{
setup();
}
catch(Exception ex)
{
throw new RuntimeException(ex.getMessage());
}
}
protected void setup()
throws Exception
{
this.addr = InetAddress.getByName(locator.getHost());
this.port = locator.getPort();
configureParameters();
address = new ServerAddress(addr.getHostAddress(), port, enableTcpNoDelay, timeout);
}
private void configureParameters()
{
Map params = locator.getParameters();
if(params != null)
{
Object val = params.get(TCP_NODELAY_FLAG);
if(val != null)
{
try
{
boolean bVal = Boolean.valueOf((String) val).booleanValue();
enableTcpNoDelay = bVal;
log.debug("Setting SocketClientInvoker::enableTcpNoDelay to: " + enableTcpNoDelay);
}
catch(Exception e)
{
log.warn("Could not convert " + TCP_NODELAY_FLAG + " value of " + val + " to a boolean value.");
}
}
val = params.get(MAX_POOL_SIZE_FLAG);
if(val != null)
{
try
{
int nVal = Integer.valueOf((String) val).intValue();
maxPoolSize = nVal;
log.debug("Setting SocketClientInvoker::maxPoolSize to: " + maxPoolSize);
}
catch(Exception e)
{
log.warn("Could not convert " + MAX_POOL_SIZE_FLAG + " value of " + val + " to a int value.");
}
}
val = params.get(SO_TIMEOUT_FLAG);
if(val != null)
{
try
{
int nVal = Integer.valueOf((String) val).intValue();
timeout = nVal;
log.debug("Setting SocketClientInvoker::timeout to: " + timeout);
}
catch(Exception e)
{
log.warn("Could not convert " + SO_TIMEOUT_FLAG + " value of " + val + " to a int value.");
}
}
}
}
protected void finalize() throws Throwable
{
disconnect();
super.finalize();
}
protected synchronized void handleConnect()
throws ConnectionFailedException
{
initPool();
}
protected synchronized void handleDisconnect()
{
clearPools();
}
protected String getDefaultDataType()
{
return SerializableMarshaller.DATATYPE;
}
protected Object transport(String sessionId, Object invocation, Map metadata,
Marshaller marshaller, UnMarshaller unmarshaller)
throws IOException, ConnectionFailedException, ClassNotFoundException
{
Object response = null;
long start = System.currentTimeMillis();
ClientSocket socket = null;
try
{
socket = getConnection();
}
catch(Exception e)
{
throw new CannotConnectException("Can not get connection to server. Problem establishing socket connection.", e);
}
long end = System.currentTimeMillis() - start;
getSocketTime += end;
try
{
marshaller.write(invocation, socket.oout);
end = System.currentTimeMillis() - start;
writeTime += end;
start = System.currentTimeMillis();
response = unmarshaller.read(socket.oin, null);
end = System.currentTimeMillis() - start;
readTime += end;
}
catch(Exception ex)
{
try
{
socket.socket.close();
}
catch(Exception ignored)
{
}
log.error("Got marshalling exception, exiting", ex);
if(ex instanceof ClassNotFoundException)
{
log.error("Error loading classes from remote call result.", ex);
throw (ClassNotFoundException) ex;
}
throw new ConnectException("Failed to communicate. Problem during marshalling/unmarshalling", ex);
}
synchronized(pool)
{
if(pool.size() < maxPoolSize)
{
pool.add(socket);
}
else
{
try
{
socket.socket.close();
}
catch(Exception ignored)
{
}
}
}
if(log.isDebugEnabled())
{
log.debug("Response: " + response);
}
return response;
}
public static void clearPool(ServerAddress sa)
{
try
{
LinkedList thepool = (LinkedList) connectionPools.get(sa);
if(thepool == null)
{
return;
}
synchronized(thepool)
{
int size = thepool.size();
for(int i = 0; i < size; i++)
{
ClientSocket socket = (ClientSocket) thepool.removeFirst();
try
{
socket.socket.close();
socket.socket = null;
}
catch(Exception ignored)
{
}
}
}
}
catch(Exception ex)
{
}
}
public static void clearPools()
{
synchronized(connectionPools)
{
Iterator it = connectionPools.keySet().iterator();
while(it.hasNext())
{
ServerAddress sa = (ServerAddress) it.next();
clearPool(sa);
}
}
}
protected void initPool()
{
synchronized(connectionPools)
{
pool = (LinkedList) connectionPools.get(address);
if(pool == null)
{
pool = new LinkedList();
connectionPools.put(address, pool);
}
}
}
public void setNumberOfRetries(int numberOfRetries)
{
if(numberOfRetries < 1)
{
this.numberOfRetries = MAX_RETRIES;
}
else
{
this.numberOfRetries = numberOfRetries;
}
}
public int getNumberOfRetries()
{
return numberOfRetries;
}
protected ClientSocket getConnection() throws Exception
{
Exception failed = null;
Socket socket = null;
for(int i = 0; i < numberOfRetries; i++)
{
synchronized(pool)
{
if(pool.size() > 0)
{
ClientSocket pooled = getPooledConnection();
if(pooled != null)
{
usedPooled++;
return pooled;
}
}
}
try
{
socket = new Socket(address.address, address.port);
break;
}
catch(Exception ex)
{
if(i + 1 < MAX_RETRIES)
{
Thread.sleep(1);
continue;
}
throw ex;
}
}
socket.setTcpNoDelay(address.enableTcpNoDelay);
return new ClientSocket(socket, address.timeout);
}
protected ClientSocket getPooledConnection()
{
ClientSocket socket = null;
while(pool.size() > 0)
{
socket = (ClientSocket) pool.removeFirst();
try
{
final byte ACK = 1;
socket.oout.reset();
socket.oout.writeByte(ACK);
socket.oout.flush();
socket.oin.readByte();
return socket;
}
catch(Exception ex)
{
try
{
socket.socket.close();
}
catch(Exception ignored)
{
}
}
}
return null;
}
protected static class ClientSocket
{
public BufferedOutputStream out;
public BufferedInputStream in;
public ObjectOutputStream oout;
public ObjectInputStream oin;
public Socket socket;
public int timeout;
public ClientSocket(Socket socket, int timeout) throws Exception
{
this.socket = socket;
socket.setSoTimeout(timeout);
this.timeout = timeout;
out = new BufferedOutputStream(socket.getOutputStream());
in = new BufferedInputStream(socket.getInputStream());
oout = new ObjectOutputStream(out);
oin = new ObjectInputStreamWithClassLoader(in, null);
}
protected void finalize()
{
if(socket != null)
{
try
{
socket.close();
}
catch(Exception ignored)
{
}
}
}
}
public String getServerHostName() throws Exception
{
return address.address;
}
}