package org.jboss.util.threadpool;
import org.jboss.logging.Logger;
public class BasicTaskWrapper implements TaskWrapper
{
private static final Logger log = Logger.getLogger(BasicTaskWrapper.class);
public static final int TASK_NOT_ACCEPTED = 0;
public static final int TASK_ACCEPTED = 1;
public static final int TASK_STARTED = 2;
public static final int TASK_COMPLETED = 3;
public static final int TASK_REJECTED = -1;
public static final int TASK_STOPPED = -2;
private int state = TASK_NOT_ACCEPTED;
private Object stateLock = new Object();
private Task task;
private String taskString;
private long startTime;
private long startTimeout;
private long completionTimeout;
private int priority;
private int waitType;
private Thread runThread;
protected BasicTaskWrapper()
{
}
public BasicTaskWrapper(Task task)
{
setTask(task);
}
public int getTaskWaitType()
{
return waitType;
}
public int getTaskPriority()
{
return priority;
}
public long getTaskStartTimeout()
{
return startTimeout;
}
public long getTaskCompletionTimeout()
{
return completionTimeout;
}
public void acceptTask()
{
synchronized (stateLock)
{
if (state != TASK_NOT_ACCEPTED)
return;
}
if (taskAccepted())
state = TASK_ACCEPTED;
else
state = TASK_REJECTED;
synchronized (stateLock)
{
stateLock.notifyAll();
}
}
public void rejectTask(RuntimeException e)
{
synchronized (stateLock)
{
state = TASK_REJECTED;
stateLock.notifyAll();
}
taskRejected(e);
}
public boolean isComplete()
{
return state == TASK_COMPLETED;
}
public void stopTask()
{
boolean started;
synchronized (stateLock)
{
started = (state == TASK_STARTED);
state = TASK_STOPPED;
}
if (started)
{
if( runThread != null )
{
runThread.interrupt();
}
taskStop();
}
else if( runThread != null && runThread.isInterrupted() )
{
runThread.stop();
}
}
public void waitForTask()
{
switch (waitType)
{
case Task.WAIT_FOR_START:
{
boolean interrupted = false;
synchronized (stateLock)
{
while (state == TASK_NOT_ACCEPTED || state == TASK_ACCEPTED)
{
try
{
stateLock.wait();
}
catch (InterruptedException e)
{
interrupted = true;
}
}
if (interrupted)
Thread.currentThread().interrupt();
return;
}
}
default:
{
return;
}
}
}
public void run()
{
this.runThread = Thread.currentThread();
long runTime = getElapsedTime();
if (startTimeout > 0l && runTime >= startTimeout)
{
taskRejected(new StartTimeoutException("Start Timeout exceeded for task " + taskString));
return;
}
boolean stopped = false;
synchronized (stateLock)
{
if (state == TASK_STOPPED)
{
stopped = true;
}
else
{
state = TASK_STARTED;
taskStarted();
if (waitType == Task.WAIT_FOR_START)
stateLock.notifyAll();
}
}
if (stopped)
{
taskRejected(new TaskStoppedException("Task stopped for task " + taskString));
return;
}
Throwable throwable = null;
try
{
task.execute();
}
catch (Throwable t)
{
throwable = t;
}
taskCompleted(throwable);
synchronized (stateLock)
{
state = TASK_COMPLETED;
if (waitType == Task.WAIT_FOR_COMPLETE)
stateLock.notifyAll();
}
}
protected void setTask(Task task)
{
if (task == null)
throw new IllegalArgumentException("Null task");
this.task = task;
this.taskString = task.toString();
this.startTime = System.currentTimeMillis();
this.waitType = task.getWaitType();
this.priority = task.getPriority();
this.startTimeout = task.getStartTimeout();
this.completionTimeout = task.getCompletionTimeout();
}
protected boolean taskAccepted()
{
try
{
task.accepted(getElapsedTime());
return true;
}
catch (Throwable t)
{
log.warn("Unexpected error during 'accepted' for task: " + taskString, t);
return false;
}
}
protected boolean taskRejected(RuntimeException e)
{
try
{
task.rejected(getElapsedTime(), e);
return true;
}
catch (Throwable t)
{
log.warn("Unexpected error during 'rejected' for task: " + taskString, t);
if (e != null)
log.warn("Original reason for rejection of task: " + taskString, e);
return false;
}
}
protected boolean taskStarted()
{
try
{
task.started(getElapsedTime());
return true;
}
catch (Throwable t)
{
log.warn("Unexpected error during 'started' for task: " + taskString, t);
return false;
}
}
protected boolean taskCompleted(Throwable throwable)
{
try
{
task.completed(getElapsedTime(), throwable);
return true;
}
catch (Throwable t)
{
log.warn("Unexpected error during 'completed' for task: " + taskString, t);
if (throwable != null)
log.warn("Original error during 'run' for task: " + taskString, throwable);
return false;
}
}
protected boolean taskStop()
{
try
{
task.stop();
return true;
}
catch (Throwable t)
{
log.warn("Unexpected error during 'stop' for task: " + taskString, t);
return false;
}
}
protected long getElapsedTime()
{
return System.currentTimeMillis() - startTime;
}
}