package org.jboss.mq.il.uil2;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.net.InetAddress;
import java.net.Socket;
import java.util.Iterator;
import javax.jms.JMSException;
import org.jboss.logging.Logger;
import org.jboss.mq.il.uil2.msgs.BaseMsg;
import org.jboss.util.stream.NotifyingBufferedInputStream;
import org.jboss.util.stream.NotifyingBufferedOutputStream;
import EDU.oswego.cs.dl.util.concurrent.ConcurrentHashMap;
import EDU.oswego.cs.dl.util.concurrent.LinkedQueue;
import EDU.oswego.cs.dl.util.concurrent.PooledExecutor;
import EDU.oswego.cs.dl.util.concurrent.SynchronizedBoolean;
import EDU.oswego.cs.dl.util.concurrent.ThreadFactory;
public class SocketManager
{
private static Logger log = Logger.getLogger(SocketManager.class);
private static final int STOPPED = 0;
private static final int STARTED = 1;
private static final int STOPPING = 2;
private static int taskID = 0;
private Socket socket;
private ObjectInputStream in;
NotifyingBufferedInputStream bufferedInput;
private ObjectOutputStream out;
NotifyingBufferedOutputStream bufferedOutput;
private Thread writeThread;
private Thread readThread;
PooledExecutor pool;
private int readState = STOPPED;
private int writeState = STOPPED;
private SynchronizedBoolean running = new SynchronizedBoolean(false);
private LinkedQueue sendQueue;
private ConcurrentHashMap replyMap;
private SocketManagerHandler handler;
private int bufferSize = 1;
private int chunkSize = 0x40000000;
private boolean trace;
public SocketManager(Socket s) throws IOException
{
socket = s;
sendQueue = new LinkedQueue();
replyMap = new ConcurrentHashMap();
trace = log.isTraceEnabled();
}
public void start(ThreadGroup tg)
{
if (trace)
log.trace("start called", new Exception("Start stack trace"));
InetAddress inetAddr = socket.getInetAddress();
String ipAddress = (inetAddr != null) ? inetAddr.getHostAddress() : "<unknown>";
ipAddress += ":" + socket.getPort();
if (pool == null)
{
pool = new PooledExecutor(5);
pool.setMinimumPoolSize(1);
pool.setKeepAliveTime(1000 * 60);
pool.runWhenBlocked();
String id = "SocketManager.MsgPool@"+
Integer.toHexString(System.identityHashCode(this))
+ " client=" + ipAddress;
pool.setThreadFactory(new UILThreadFactory(id));
}
taskID++;
ReadTask readTask = new ReadTask();
readThread = new Thread(tg, readTask, "UIL2.SocketManager.ReadTask#" + taskID + " client=" + ipAddress);
readThread.setDaemon(true);
taskID++;
WriteTask writeTask = new WriteTask();
writeThread = new Thread(tg, writeTask, "UIL2.SocketManager.WriteTask#" + taskID + " client=" + ipAddress);
writeThread.setDaemon(true);
synchronized (running)
{
readState = STARTED;
writeState = STARTED;
running.set(true);
}
readThread.start();
writeThread.start();
}
public void stop()
{
synchronized (running)
{
if (readState == STARTED)
{
readState = STOPPING;
readThread.interrupt();
}
if (writeState == STARTED)
{
writeState = STOPPING;
writeThread.interrupt();
}
running.set(false);
if (pool != null)
{
pool.shutdownNow();
pool = null;
}
}
}
public void setHandler(SocketManagerHandler handler)
{
this.handler = handler;
if (bufferedInput != null)
bufferedInput.setStreamListener(handler);
if (bufferedOutput != null)
bufferedOutput.setStreamListener(handler);
}
public void setBufferSize(int size)
{
this.bufferSize = size;
}
public void setChunkSize(int size)
{
this.chunkSize = size;
}
public void sendMessage(BaseMsg msg) throws Exception
{
internalSendMessage(msg, true);
if (msg.error != null)
{
if (trace)
log.trace("sendMessage will throw error", msg.error);
throw msg.error;
}
}
public void sendReply(BaseMsg msg) throws Exception
{
msg.trimReply();
internalSendMessage(msg, false);
}
private void internalSendMessage(BaseMsg msg, boolean waitOnReply) throws Exception
{
if (running.get() == false)
throw new IOException("Client is not connected");
if (waitOnReply)
{ synchronized (msg)
{
msg.getMsgID();
if (trace)
log.trace("Begin internalSendMessage, round-trip msg=" + msg);
replyMap.put(msg, msg);
sendQueue.put(msg);
msg.wait();
}
}
else
{ if (trace)
log.trace("Begin internalSendMessage, one-way msg=" + msg);
sendQueue.put(msg);
}
if (trace)
log.trace("End internalSendMessage, msg=" + msg);
}
public class ReadTask implements Runnable
{
public void run()
{
int msgType = 0;
log.debug("Begin ReadTask.run");
try
{
bufferedInput = new NotifyingBufferedInputStream(socket.getInputStream(), bufferSize, chunkSize, handler);
in = new ObjectInputStream(bufferedInput);
log.debug("Created ObjectInputStream");
}
catch (IOException e)
{
handleStop("Failed to create ObjectInputStream", e);
return;
}
while (true)
{
try
{
msgType = in.readByte();
int msgID = in.readInt();
if (trace)
log.trace("Read msgType: " + BaseMsg.toString(msgType) + ", msgID: " + msgID);
BaseMsg key = new BaseMsg(msgType, msgID);
BaseMsg msg = (BaseMsg) replyMap.remove(key);
if (msg == null)
{
msg = BaseMsg.createMsg(msgType);
msg.setMsgID(msgID);
msg.read(in);
if (trace)
log.trace("Read new msg: " + msg);
if (pool == null)
break;
msg.setHandler(this);
pool.execute(msg);
}
else
{
if (trace)
log.trace("Found replyMap msg: " + msg);
msg.setMsgID(msgID);
try
{
msg.read(in);
if (trace)
log.trace("Read msg reply: " + msg);
}
catch (Throwable e)
{
msg.setError(e);
throw e;
}
finally
{
synchronized (msg)
{
msg.notify();
}
}
}
}
catch (ClassNotFoundException e)
{
handleStop("Failed to read msgType:" + msgType, e);
break;
}
catch (IOException e)
{
handleStop("Exiting on IOE", e);
break;
}
catch (InterruptedException e)
{
handleStop("Exiting on interrupt", e);
break;
}
catch (Throwable e)
{
handleStop("Exiting on unexpected error in read task", e);
break;
}
}
log.debug("End ReadTask.run");
}
public void handleMsg(BaseMsg msg)
{
try
{
handler.handleMsg(msg);
}
catch (Throwable e)
{
if (e instanceof JMSException)
log.trace("Failed to handle: " + msg.toString(), e);
else if (e instanceof RuntimeException || e instanceof Error)
log.error("Failed to handle: " + msg.toString(), e);
else
log.debug("Failed to handle: " + msg.toString(), e);
msg.setError(e);
try
{
internalSendMessage(msg, false);
}
catch (Exception ie)
{
log.debug("Failed to send error reply", ie);
}
}
}
private void handleStop(String error, Throwable e)
{
synchronized (running)
{
readState = STOPPING;
running.set(false);
}
if (e instanceof IOException || e instanceof InterruptedException)
{
if (trace)
log.trace(error, e);
}
else
log.debug(error, e);
replyAll(e);
if (handler != null)
{
handler.asynchFailure(error, e);
handler.close();
}
synchronized (running)
{
readState = STOPPED;
if (writeState == STARTED)
{
writeState = STOPPING;
writeThread.interrupt();
}
}
try
{
in.close();
}
catch (Exception ignored)
{
if (trace)
log.trace(ignored.getMessage(), ignored);
}
try
{
socket.close();
}
catch (Exception ignored)
{
if (trace)
log.trace(ignored.getMessage(), ignored);
}
}
private void replyAll(Throwable e)
{
Thread.interrupted();
for (Iterator iterator = replyMap.keySet().iterator(); iterator.hasNext();)
{
BaseMsg msg = (BaseMsg) iterator.next();
msg.setError(e);
synchronized (msg)
{
msg.notify();
}
iterator.remove();
}
}
}
public class WriteTask implements Runnable
{
public void run()
{
log.debug("Begin WriteTask.run");
try
{
bufferedOutput =
new NotifyingBufferedOutputStream(socket.getOutputStream(), bufferSize, chunkSize, handler);
out = new ObjectOutputStream(bufferedOutput);
log.debug("Created ObjectOutputStream");
}
catch (IOException e)
{
handleStop(null, "Failed to create ObjectOutputStream", e);
return;
}
while (true)
{
BaseMsg msg = null;
try
{
msg = (BaseMsg) sendQueue.take();
if (trace)
log.trace("Write msg: " + msg);
msg.write(out);
out.reset();
out.flush();
}
catch (InterruptedException e)
{
handleStop(msg, "WriteTask was interrupted", e);
break;
}
catch (IOException e)
{
handleStop(msg, "Exiting on IOE", e);
break;
}
catch (Throwable e)
{
handleStop(msg, "Failed to write msgType:" + msg, e);
break;
}
}
log.debug("End WriteTask.run");
}
private void handleStop(BaseMsg msg, String error, Throwable e)
{
synchronized (running)
{
writeState = STOPPING;
running.set(false);
}
if (e instanceof InterruptedException || e instanceof IOException)
{
if (trace)
log.trace(error, e);
}
else
log.debug(error, e);
if (msg != null)
{
msg.setError(e);
synchronized (msg)
{
msg.notify();
}
}
synchronized (running)
{
writeState = STOPPED;
if (readState == STARTED)
{
readState = STOPPING;
readThread.interrupt();
}
}
try
{
out.close();
}
catch (Exception ignored)
{
if (trace)
log.trace(ignored.getMessage(), ignored);
}
try
{
socket.close();
}
catch (Exception ignored)
{
if (trace)
log.trace(ignored.getMessage(), ignored);
}
}
}
static class UILThreadFactory implements ThreadFactory
{
private String id;
private int count;
UILThreadFactory(String id)
{
this.id = id;
}
public Thread newThread(Runnable command)
{
synchronized( this )
{
count ++;
}
Thread t = new Thread(command, "UIL2("+id+")#"+count);
return t;
}
}
}