package org.jboss.mx.loading;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLClassLoader;
import java.security.CodeSource;
import java.security.PermissionCollection;
import java.security.Policy;
import java.security.ProtectionDomain;
import java.security.cert.Certificate;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.Vector;
import java.util.Collections;
import java.util.Set;
import javax.management.MalformedObjectNameException;
import javax.management.ObjectName;
import org.jboss.logging.Logger;
import org.jboss.util.loading.Translator;
import EDU.oswego.cs.dl.util.concurrent.ReentrantLock;
import EDU.oswego.cs.dl.util.concurrent.ConcurrentReaderHashMap;
public abstract class RepositoryClassLoader extends URLClassLoader
{
private static final Logger log = Logger.getLogger(RepositoryClassLoader.class);
private static final URL[] EMPTY_URL_ARRAY = {};
protected LoaderRepository repository = null;
protected Exception unregisterTrace;
private int addedOrder;
protected ClassLoader parent = null;
private Set classBlackList = Collections.synchronizedSet(new HashSet());
private Set resourceBlackList = Collections.synchronizedSet(new HashSet());
private ConcurrentReaderHashMap resourceCache = new ConcurrentReaderHashMap();
protected ReentrantLock loadLock = new ReentrantLock();
private int loadClassDepth;
protected RepositoryClassLoader(URL[] urls, ClassLoader parent)
{
super(urls, parent);
this.parent = parent;
}
public abstract ObjectName getObjectName() throws MalformedObjectNameException;
public LoaderRepository getLoaderRepository()
{
return repository;
}
public void setRepository(LoaderRepository repository)
{
log.debug("setRepository, repository="+repository+", cl=" + this);
this.repository = repository;
}
public int getAddedOrder()
{
return addedOrder;
}
public void setAddedOrder(int addedOrder)
{
this.addedOrder = addedOrder;
}
public Class loadClassLocally(String name, boolean resolve)
throws ClassNotFoundException
{
boolean trace = log.isTraceEnabled();
if( trace )
log.trace("loadClassLocally, " + this + " name=" + name);
Class result = null;
try
{
if (isClassBlackListed(name))
{
if( trace )
log.trace("Class in blacklist, name="+name);
throw new ClassNotFoundException("Class Not Found(blacklist): " + name);
}
try
{
result = super.loadClass(name, resolve);
return result;
}
catch (ClassNotFoundException cnfe)
{
addToClassBlackList(name);
if( name.charAt(0) == '[' )
{
result = Class.forName(name, true, this);
removeFromClassBlackList(name);
return result;
}
if( trace )
log.trace("CFNE: Adding to blacklist: "+name);
throw cnfe;
}
}
finally
{
if (trace)
{
if (result != null)
log.trace("loadClassLocally, " + this + " name=" + name + " class=" + result + " cl=" + result.getClassLoader());
else
log.trace("loadClassLocally, " + this + " name=" + name + " not found");
}
}
}
public URL getResourceLocally(String name)
{
URL resURL = (URL) resourceCache.get(name);
if (resURL != null)
return resURL;
if (isResourceBlackListed(name))
return null;
resURL = super.getResource(name);
if( log.isTraceEnabled() == true )
log.trace("getResourceLocally("+this+"), name="+name+", resURL:"+resURL);
if (resURL == null)
addToResourceBlackList(name);
else
resourceCache.put(name, resURL);
return resURL;
}
public URL getURL()
{
URL[] urls = super.getURLs();
if (urls.length > 0)
return urls[0];
else
return null;
}
public void unregister()
{
log.debug("Unregistering cl=" + this);
if (repository != null)
repository.removeClassLoader(this);
clearBlackLists();
resourceCache.clear();
repository = null;
this.unregisterTrace = new Exception();
}
public URL[] getClasspath()
{
return super.getURLs();
}
public URL[] getAllURLs()
{
return repository.getURLs();
}
public void addToClassBlackList(String name)
{
classBlackList.add(name);
}
public void removeFromClassBlackList(String name)
{
classBlackList.remove(name);
}
public boolean isClassBlackListed(String name)
{
return classBlackList.contains(name);
}
public void clearClassBlackList()
{
classBlackList.clear();
}
public void addToResourceBlackList(String name)
{
resourceBlackList.add(name);
}
public void removeFromResourceBlackList(String name)
{
resourceBlackList.remove(name);
}
public boolean isResourceBlackListed(String name)
{
return resourceBlackList.contains(name);
}
public void clearResourceBlackList()
{
resourceBlackList.clear();
}
public void clearBlackLists()
{
clearClassBlackList();
clearResourceBlackList();
}
public Class loadClass(String name, boolean resolve)
throws ClassNotFoundException
{
boolean trace = log.isTraceEnabled();
if (trace)
log.trace("loadClass " + this + " name=" + name+", loadClassDepth="+loadClassDepth);
Class clazz = null;
try
{
if (repository != null)
{
clazz = repository.getCachedClass(name);
if (clazz != null)
{
if( log.isTraceEnabled() )
{
StringBuffer buffer = new StringBuffer("Loaded class from cache, ");
ClassToStringAction.toString(clazz, buffer);
log.trace(buffer.toString());
}
return clazz;
}
}
clazz = loadClassImpl(name, resolve, Integer.MAX_VALUE);
return clazz;
}
finally
{
if (trace)
{
if (clazz != null)
log.trace("loadClass " + this + " name=" + name + " class=" + clazz + " cl=" + clazz.getClassLoader());
else
log.trace("loadClass " + this + " name=" + name + " not found");
}
}
}
public Class loadClassBefore(String name)
throws ClassNotFoundException
{
boolean trace = log.isTraceEnabled();
if (trace)
log.trace("loadClassBefore " + this + " name=" + name);
Class clazz = null;
try
{
clazz = loadClassImpl(name, false, addedOrder);
return clazz;
}
finally
{
if (trace)
{
if (clazz != null)
log.trace("loadClassBefore " + this + " name=" + name + " class=" + clazz + " cl=" + clazz.getClassLoader());
else
log.trace("loadClassBefore " + this + " name=" + name + " not found");
}
}
}
public synchronized Class loadClassImpl(String name, boolean resolve, int stopAt)
throws ClassNotFoundException
{
loadClassDepth ++;
boolean trace = log.isTraceEnabled();
if( trace )
log.trace("loadClassImpl, name="+name+", resolve="+resolve);
if( repository == null )
{
String msg = "Invalid use of destroyed classloader, UCL destroyed at:";
throw new ClassNotFoundException(msg, this.unregisterTrace);
}
boolean acquired = attempt(1);
while( acquired == false )
{
try
{
if( trace )
log.trace("Waiting for loadClass lock");
this.wait();
}
catch(InterruptedException ignore)
{
}
acquired = attempt(1);
}
ClassLoadingTask task = null;
try
{
Thread t = Thread.currentThread();
if( loadLock.holds() == 1 )
LoadMgr3.registerLoaderThread(this, t);
task = new ClassLoadingTask(name, this, t, stopAt);
UnifiedLoaderRepository3 ulr3 = (UnifiedLoaderRepository3) repository;
if( LoadMgr3.beginLoadTask(task, ulr3) == false )
{
while( task.threadTaskCount != 0 )
{
try
{
LoadMgr3.nextTask(t, task, ulr3);
}
catch(InterruptedException e)
{
break;
}
}
}
}
finally
{
if( loadLock.holds() == 1 )
LoadMgr3.endLoadTask(task);
this.release();
this.notifyAll();
loadClassDepth --;
}
if( task.loadedClass == null )
{
if( task.loadException instanceof ClassNotFoundException )
throw (ClassNotFoundException) task.loadException;
else if( task.loadException != null )
{
if( log.isTraceEnabled() )
log.trace("Unexpected error during load of:"+name, task.loadException);
String msg = "Unexpected error during load of: "+name
+ ", msg="+task.loadException.getMessage();
throw new ClassNotFoundException(msg);
}
else
throw new IllegalStateException("ClassLoadingTask.loadedTask is null, name: "+name);
}
return task.loadedClass;
}
public URL getResource(String name)
{
if (repository != null)
return repository.getResource(name, this);
return null;
}
public Enumeration findResources(String name) throws IOException
{
Vector resURLs = new Vector();
repository.getResources(name, this, resURLs);
return resURLs.elements();
}
public Enumeration findResourcesLocally(String name) throws IOException
{
return super.findResources(name);
}
protected Class findClass(String name) throws ClassNotFoundException
{
boolean trace = log.isTraceEnabled();
if( trace )
log.trace("findClass, name="+name);
if (isClassBlackListed(name))
{
if( trace )
log.trace("Class in blacklist, name="+name);
throw new ClassNotFoundException("Class Not Found(blacklist): " + name);
}
Translator translator = repository.getTranslator();
if (translator != null)
{
try
{
URL classUrl = getClassURL(name);
byte[] rawcode = loadByteCode(classUrl);
URL codeSourceUrl = getCodeSourceURL(name, classUrl);
ProtectionDomain pd = getProtectionDomain(codeSourceUrl);
byte[] bytecode = translator.transform(this, name, null, pd, rawcode);
if( bytecode == null )
bytecode = rawcode;
definePackage(name);
return defineClass(name, bytecode, 0, bytecode.length, pd);
}
catch(ClassNotFoundException e)
{
throw e;
}
catch (Throwable ex)
{
throw new ClassNotFoundException(name, ex);
}
}
Class clazz = null;
try
{
clazz = findClassLocally(name);
}
catch(ClassNotFoundException e)
{
if( trace )
log.trace("CFNE: Adding to blacklist: "+name);
addToClassBlackList(name);
throw e;
}
return clazz;
}
protected Class findClassLocally(String name) throws ClassNotFoundException
{
return super.findClass(name);
}
protected void definePackage(String className)
{
int i = className.lastIndexOf('.');
if (i == -1)
return;
try
{
definePackage(className.substring(0, i), null, null, null, null, null, null, null);
}
catch (IllegalArgumentException alreadyDone)
{
}
}
public void addURL(URL url)
{
if( url == null )
throw new IllegalArgumentException("url cannot be null");
if( repository.addClassLoaderURL(this, url) == true )
{
log.debug("Added url: "+url+", to ucl: "+this);
String query = url.getQuery();
if( query != null )
{
String ext = url.toExternalForm();
String ext2 = ext.substring(0, ext.length() - query.length() - 1);
try
{
url = new URL (ext2);
}
catch(MalformedURLException e)
{
log.warn("Failed to strip query from: "+url, e);
}
}
super.addURL(url);
clearBlackLists();
}
else if( log.isTraceEnabled() )
{
log.trace("Ignoring duplicate url: "+url+", for ucl: "+this);
}
}
public URL[] getURLs()
{
return EMPTY_URL_ARRAY;
}
public Package getPackage(String name)
{
return super.getPackage(name);
}
public Package[] getPackages()
{
return super.getPackages();
}
public final boolean equals(Object other)
{
return super.equals(other);
}
public final int hashCode()
{
return super.hashCode();
}
public String toString()
{
return super.toString() + "{ url=" + getURL() + " }";
}
protected boolean attempt(long waitMS)
{
boolean acquired = false;
boolean trace = log.isTraceEnabled();
boolean threadWasInterrupted = Thread.interrupted();
try
{
acquired = loadLock.attempt(waitMS);
}
catch(InterruptedException e)
{
}
finally
{
if( threadWasInterrupted )
Thread.currentThread().interrupt();
}
if( trace )
log.trace("attempt("+loadLock.holds()+") was: "+acquired+" for :"+this);
return acquired;
}
protected void acquire()
{
boolean threadWasInterrupted = Thread.interrupted();
try
{
loadLock.acquire();
}
catch(InterruptedException e)
{
}
finally
{
if( threadWasInterrupted )
Thread.currentThread().interrupt();
}
if( log.isTraceEnabled() )
log.trace("acquired("+loadLock.holds()+") for :"+this);
}
protected void release()
{
if( log.isTraceEnabled() )
log.trace("release("+loadLock.holds()+") for :"+this);
loadLock.release();
if( log.isTraceEnabled() )
log.trace("released, holds: "+loadLock.holds());
}
protected byte[] loadByteCode(String classname)
throws ClassNotFoundException, IOException
{
byte[] bytecode = null;
URL classURL = getClassURL(classname);
InputStream is = null;
try
{
is = classURL.openStream();
ByteArrayOutputStream baos = new ByteArrayOutputStream();
byte[] tmp = new byte[1024];
int read = 0;
while( (read = is.read(tmp)) > 0 )
{
baos.write(tmp, 0, read);
}
bytecode = baos.toByteArray();
}
finally
{
if( is != null )
is.close();
}
return bytecode;
}
protected byte[] loadByteCode(URL classURL)
throws ClassNotFoundException, IOException
{
byte[] bytecode = null;
InputStream is = null;
try
{
is = classURL.openStream();
ByteArrayOutputStream baos = new ByteArrayOutputStream();
byte[] tmp = new byte[1024];
int read = 0;
while( (read = is.read(tmp)) > 0 )
{
baos.write(tmp, 0, read);
}
bytecode = baos.toByteArray();
}
finally
{
if( is != null )
is.close();
}
return bytecode;
}
protected ProtectionDomain getProtectionDomain(URL codesourceUrl)
{
Certificate certs[] = null;
CodeSource cs = new CodeSource(codesourceUrl, certs);
PermissionCollection permissions = Policy.getPolicy().getPermissions(cs);
if (log.isTraceEnabled())
log.trace("getProtectionDomain, url=" + codesourceUrl +
" codeSource=" + cs + " permissions=" + permissions);
return new ProtectionDomain(cs, permissions);
}
private URL getCodeSourceURL(String classname, URL classURL) throws java.net.MalformedURLException
{
String classRsrcName = classname.replace('.', '/') + ".class";
String urlAsString = classURL.toString();
int idx = urlAsString.indexOf(classRsrcName);
if (idx == -1) return classURL;
urlAsString = urlAsString.substring(0, idx);
return new URL(urlAsString);
}
private URL getClassURL(String classname) throws ClassNotFoundException
{
String classRsrcName = classname.replace('.', '/') + ".class";
URL classURL = this.getResourceLocally(classRsrcName);
if( classURL == null )
{
String msg = "Failed to find: "+classname+" as resource: "+classRsrcName;
throw new ClassNotFoundException(msg);
}
return classURL;
}
}