package org.jboss.remoting.transport.socket;
import org.jboss.remoting.InvokerLocator;
import org.jboss.remoting.ServerInvoker;
import org.jboss.remoting.marshal.serializable.SerializableMarshaller;
import org.jboss.remoting.transport.PortUtil;
import org.jboss.util.propertyeditor.PropertyEditors;
import java.io.IOException;
import java.net.InetAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.util.LinkedList;
import java.util.Map;
import java.util.Properties;
public class SocketServerInvoker extends ServerInvoker implements Runnable, SocketServerInvokerMBean
{
private InetAddress addr;
private int port;
static int clientCount = 0;
private static int BACKLOG_DEFAULT = 200;
private static int MAX_POOL_SIZE_DEFAULT = 300;
protected ServerSocket serverSocket = null;
protected boolean running = false;
protected int backlog = BACKLOG_DEFAULT;
protected Thread[] acceptThreads;
protected int numAcceptThreads = 1;
protected int maxPoolSize = MAX_POOL_SIZE_DEFAULT;
protected LRUPool clientpool;
protected LinkedList threadpool;
protected int timeout = 60000;
protected boolean trace = false;
protected int clientConnectPort = 0;
protected String clientConnectAddress = null;
protected String serverBindAddress = null;
protected int serverBindPort = 0;
public SocketServerInvoker(InvokerLocator locator)
{
super(locator);
try
{
setup();
}
catch(Exception ex)
{
throw new RuntimeException(ex.getMessage());
}
}
public SocketServerInvoker(InvokerLocator locator, Map configuration)
{
super(locator, configuration);
try
{
setup();
}
catch(Exception ex)
{
throw new RuntimeException(ex.getMessage());
}
}
public InetAddress getAddress()
{
return addr;
}
public int getPort()
{
return port;
}
protected void setup()
throws Exception
{
Properties props = new Properties();
props.putAll(getConfiguration());
PropertyEditors.mapJavaBeanProperties(this, props, false);
this.addr = InetAddress.getByName(locator.getHost());
this.port = locator.getPort();
if(this.port <= 0)
{
this.port = PortUtil.findFreePort();
this.locator = new InvokerLocator(locator.getProtocol(), locator.getHost(), this.port, locator.getPath(), locator.getParameters());
}
if(props.getProperty("serverBindAddress") != null)
{
serverBindAddress = props.getProperty("serverBindAddress");
}
else
{
if(props.getProperty("clientConnectAddress") != null)
{
serverBindAddress = InetAddress.getLocalHost().getHostAddress();
}
else
{
serverBindAddress = addr.getHostAddress();
}
}
if(props.getProperty("serverBindPort") != null)
{
serverBindPort = Integer.parseInt(props.getProperty("serverBindPort"));
}
else
{
if(props.getProperty("clientConnectPort") != null)
{
serverBindPort = PortUtil.findFreePort();
}
else
{
serverBindPort = port;
}
}
}
protected void finalize() throws Throwable
{
stop();
super.finalize();
}
public synchronized void start() throws IOException
{
trace = log.isTraceEnabled();
if(!running)
{
running = true;
InetAddress bindAddress =
(serverBindAddress == null || serverBindAddress.length() == 0)
? null
: InetAddress.getByName(serverBindAddress);
clientConnectAddress =
(clientConnectAddress == null || clientConnectAddress.length() == 0)
? InetAddress.getLocalHost().getHostName()
: clientConnectAddress;
if(maxPoolSize <= 0)
{
maxPoolSize = MAX_POOL_SIZE_DEFAULT;
}
clientpool = new LRUPool(2, maxPoolSize);
clientpool.create();
threadpool = new LinkedList();
try
{
serverSocket = new ServerSocket(serverBindPort, backlog, bindAddress);
}
catch(IOException e)
{
log.error("Error starting ServerSocket. Bind port: " + serverBindPort + ", bind address: " + bindAddress);
throw e;
}
serverBindPort = serverSocket.getLocalPort();
clientConnectPort = (clientConnectPort == 0) ? serverSocket.getLocalPort() : clientConnectPort;
acceptThreads = new Thread[numAcceptThreads];
for(int i = 0; i < numAcceptThreads; i++)
{
String name = "SocketServerInvoker#" + i + "-" + serverBindPort;
acceptThreads[i] = new Thread(this, name);
acceptThreads[i].start();
}
}
super.start();
}
public synchronized void stop()
{
if(running)
{
running = false;
maxPoolSize = 0; for(int i = 0; i < acceptThreads.length; i++)
{
try
{
acceptThreads[i].interrupt();
}
catch(Exception ignored)
{
}
}
clientpool.flush();
for(int i = 0; i < threadpool.size(); i++)
{
ServerThread thread = (ServerThread) threadpool.removeFirst();
thread.shutdown();
}
try
{
serverSocket.close();
}
catch(Exception e)
{
}
}
super.stop();
}
public String getMBeanObjectName()
{
return "jboss.remoting:service=invoker,transport=socket";
}
public int getSocketTimeout()
{
return timeout;
}
public void setSocketTimeout(int time)
{
this.timeout = time;
}
public int getCurrentThreadPoolSize()
{
return threadpool.size();
}
public int getCurrentClientPoolSize()
{
return clientpool.size();
}
public String getClientConnectAddress()
{
return clientConnectAddress;
}
public void setClientConnectAddress(String clientConnectAddress)
{
this.clientConnectAddress = clientConnectAddress;
}
public int getNumAcceptThreads()
{
return numAcceptThreads;
}
public void setNumAcceptThreads(int size)
{
this.numAcceptThreads = size;
}
public int getMaxPoolSize()
{
return maxPoolSize;
}
public void setMaxPoolSize(int maxPoolSize)
{
this.maxPoolSize = maxPoolSize;
}
public String getServerBindAddress()
{
return serverBindAddress;
}
public void setServerBindAddress(String serverBindAddress)
{
this.serverBindAddress = serverBindAddress;
}
public int getServerBindPort()
{
return serverBindPort;
}
public void setServerBindPort(int serverBindPort)
{
this.serverBindPort = serverBindPort;
}
public int getBacklog()
{
return backlog;
}
public void setBacklog(int backlog)
{
if(backlog < 0)
{
this.backlog = BACKLOG_DEFAULT;
}
else
{
this.backlog = backlog;
}
}
public void run()
{
while(running)
{
try
{
Socket socket = serverSocket.accept();
if(trace)
{
log.trace("Accepted: " + socket);
}
ServerThread thread = null;
boolean newThread = false;
while(thread == null)
{
synchronized(threadpool)
{
if(threadpool.size() > 0)
{
thread = (ServerThread) threadpool.removeFirst();
}
}
if(thread == null)
{
synchronized(clientpool)
{
if(clientpool.size() < maxPoolSize)
{
thread = new ServerThread(socket, this, clientpool, threadpool, timeout);
newThread = true;
}
if(thread == null)
{
clientpool.evict();
if(trace)
{
log.trace("Waiting for a thread...");
}
clientpool.wait();
if(trace)
{
log.trace("Notified of available thread");
}
}
}
}
}
synchronized(clientpool)
{
clientpool.insert(thread, thread);
}
if(newThread)
{
if(trace)
{
log.trace("Created a new thread, t=" + thread);
}
thread.start();
}
else
{
if(trace)
{
log.trace("Reusing thread t=" + thread);
}
thread.wakeup(socket, timeout);
}
}
catch(Throwable ex)
{
if(running)
{
log.error("Failed to accept socket connection", ex);
}
}
}
}
public boolean isTransportBiDirectional()
{
return true;
}
protected String getDefaultDataType()
{
return SerializableMarshaller.DATATYPE;
}
}