package org.jboss.web;
import java.io.BufferedReader;
import java.io.BufferedInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.IOException;
import java.net.InetAddress;
import java.net.MalformedURLException;
import java.net.Socket;
import java.net.ServerSocket;
import java.net.UnknownHostException;
import java.net.URL;
import java.util.Properties;
import org.jboss.logging.Logger;
import org.jboss.util.StringPropertyReplacer;
import org.jboss.util.threadpool.BasicThreadPoolMBean;
import org.jboss.util.threadpool.BasicThreadPool;
import EDU.oswego.cs.dl.util.concurrent.ConcurrentReaderHashMap;
public class WebServer
implements Runnable
{
private static Logger log = Logger.getLogger(WebServer.class);
private int port = 8083;
private InetAddress bindAddress;
private int backlog = 50;
private final ConcurrentReaderHashMap loaderMap = new ConcurrentReaderHashMap();
private ServerSocket server = null;
private boolean downloadServerClasses = true;
private static final Properties mimeTypes = new Properties();
private BasicThreadPoolMBean threadPool;
public void setPort(int p)
{
port = p;
}
public int getPort()
{
return port;
}
public String getBindAddress()
{
String address = null;
if (bindAddress != null)
address = bindAddress.getHostAddress();
return address;
}
public String getBindHostname()
{
return bindAddress.getHostName();
}
public void setBindAddress(String host)
{
try
{
if (host != null)
{
String h = StringPropertyReplacer.replaceProperties(host);
bindAddress = InetAddress.getByName(h);
}
}
catch (UnknownHostException e)
{
String msg = "Invalid host address specified: " + host;
log.error(msg, e);
}
}
public int getBacklog()
{
return backlog;
}
public void setBacklog(int backlog)
{
if (backlog <= 0)
backlog = 50;
this.backlog = backlog;
}
public boolean getDownloadServerClasses()
{
return downloadServerClasses;
}
public void setDownloadServerClasses(boolean flag)
{
downloadServerClasses = flag;
}
public BasicThreadPoolMBean getThreadPool()
{
return threadPool;
}
public void setThreadPool(BasicThreadPoolMBean threadPool)
{
this.threadPool = threadPool;
}
public void addMimeType(String extension, String type)
{
mimeTypes.put(extension, type);
}
public void start() throws Exception
{
if (threadPool == null)
threadPool = new BasicThreadPool("ClassLoadingPool");
try
{
server = new ServerSocket(port, backlog, bindAddress);
if (log.isDebugEnabled())
log.debug("Started server: " + server);
listen();
}catch(java.net.BindException be)
{
throw new Exception("Port "+port+" already in use.",be);
}
catch (IOException e)
{
throw e;
}
}
public void stop()
{
try
{
ServerSocket srv = server;
server = null;
srv.close();
}
catch (Exception e)
{
}
}
public URL addClassLoader(ClassLoader cl)
{
String key = (cl instanceof WebClassLoader) ?
((WebClassLoader) cl).getKey() :
getClassLoaderKey(cl);
loaderMap.put(key, cl);
URL loaderURL = null;
String codebase = System.getProperty("java.rmi.server.codebase");
if (codebase != null)
{
if (codebase.endsWith("/") == false)
codebase += '/';
codebase += key;
codebase += '/';
try
{
loaderURL = new URL(codebase);
}
catch (MalformedURLException e)
{
log.error("invalid url", e);
}
}
log.trace("Added ClassLoader: " + cl + " URL: " + loaderURL);
return loaderURL;
}
public void removeClassLoader(ClassLoader cl)
{
String key = getClassLoaderKey(cl);
loaderMap.remove(key);
}
public void run()
{
if (server == null)
return;
Socket socket = null;
try
{
socket = server.accept();
}
catch (IOException e)
{
if (server != null)
log.error("Failed to accept connection", e);
return;
}
listen();
try
{
DataOutputStream out = new DataOutputStream(socket.getOutputStream());
try
{
String httpCode = "200 OK";
BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream()));
String rawPath = getPath(in);
int endOfKey = rawPath.indexOf(']');
String filePath = rawPath.substring(endOfKey + 2);
String loaderKey = rawPath.substring(0, endOfKey + 1);
log.trace("loaderKey = " + loaderKey);
log.trace("filePath = " + filePath);
ClassLoader loader = (ClassLoader) loaderMap.get(loaderKey);
if (loader == null && rawPath.indexOf('[') < 0 && downloadServerClasses)
{
filePath = rawPath;
log.trace("No loader, reset filePath = " + filePath);
loader = Thread.currentThread().getContextClassLoader();
}
log.trace("loader = " + loader);
byte[] bytes = {};
if (loader != null && filePath.endsWith(".class"))
{
String className = filePath.substring(0, filePath.length() - 6).replace('/', '.');
log.trace("loading className = " + className);
Class clazz = loader.loadClass(className);
URL clazzUrl = clazz.getProtectionDomain().getCodeSource().getLocation();
log.trace("clazzUrl = " + clazzUrl);
if (clazzUrl == null)
{
bytes = ((WebClassLoader) clazz.getClassLoader()).getBytes(clazz);
if (bytes == null)
throw new Exception("Class not found: " + className);
}
else
{
if (clazzUrl.getFile().endsWith(".jar"))
{
clazzUrl = new URL("jar:" + clazzUrl + "!/" + filePath);
}
else if (clazzUrl.getFile().indexOf("/org_jboss_aop_proxy$") < 0)
{
clazzUrl = new URL(clazzUrl, filePath);
}
log.trace("new clazzUrl: " + clazzUrl);
bytes = getBytes(clazzUrl);
}
}
else if (loader != null && filePath.length() > 0 && downloadServerClasses)
{
log.trace("loading resource = " + filePath);
URL resourceURL = loader.getResource(filePath);
if (resourceURL == null)
httpCode = "404 Resource not found:" + filePath;
else
{
log.trace("resourceURL = " + resourceURL);
bytes = getBytes(resourceURL);
}
}
else
{
httpCode = "404 Not Found";
}
try
{
log.trace("HTTP code=" + httpCode + ", Content-Length: " + bytes.length);
out.writeBytes("HTTP/1.0 " + httpCode + "\r\n");
out.writeBytes("Content-Length: " + bytes.length + "\r\n");
out.writeBytes("Content-Type: " + getMimeType(filePath));
out.writeBytes("\r\n\r\n");
out.write(bytes);
out.flush();
}
catch (IOException ie)
{
return;
}
}
catch (Throwable e)
{
try
{
log.trace("HTTP code=404 " + e.getMessage());
out.writeBytes("HTTP/1.0 400 " + e.getMessage() + "\r\n");
out.writeBytes("Content-Type: text/html\r\n\r\n");
out.flush();
}
catch (IOException ex)
{
}
}
}
catch (IOException ex)
{
log.error("error writting response", ex);
}
finally
{
try
{
socket.close();
}
catch (IOException e)
{
}
}
}
protected String getClassLoaderKey(ClassLoader cl)
{
String className = cl.getClass().getName();
int dot = className.lastIndexOf('.');
if (dot >= 0)
className = className.substring(dot + 1);
String key = className + '[' + cl.hashCode() + ']';
return key;
}
protected void listen()
{
threadPool.getInstance().run(this);
}
protected String getPath(BufferedReader in) throws IOException
{
String line = in.readLine();
log.trace("raw request=" + line);
int start = line.indexOf(' ') + 1;
int end = line.indexOf(' ', start + 1);
String filePath = line.substring(start + 1, end);
return filePath;
}
protected byte[] getBytes(URL url) throws IOException
{
InputStream in = new BufferedInputStream(url.openStream());
if (log.isDebugEnabled())
log.debug("Retrieving " + url);
ByteArrayOutputStream out = new ByteArrayOutputStream();
byte[] tmp = new byte[1024];
int bytes;
while ((bytes = in.read(tmp)) != -1)
{
out.write(tmp, 0, bytes);
}
in.close();
return out.toByteArray();
}
protected String getMimeType(String path)
{
int dot = path.lastIndexOf(".");
String type = "text/html";
if (dot >= 0)
{
String suffix = path.substring(dot + 1);
String mimeType = mimeTypes.getProperty(suffix);
if (mimeType != null)
type = mimeType;
}
return type;
}
}