package org.jnp.server;
import java.io.InputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.net.InetAddress;
import java.net.Socket;
import java.net.ServerSocket;
import java.net.UnknownHostException;
import java.rmi.Remote;
import java.rmi.MarshalledObject;
import java.rmi.server.RMIClientSocketFactory;
import java.rmi.server.RMIServerSocketFactory;
import java.rmi.server.UnicastRemoteObject;
import java.lang.reflect.Method;
import javax.net.ServerSocketFactory;
import org.jnp.interfaces.Naming;
import org.jnp.interfaces.NamingContext;
import org.jboss.logging.Logger;
import org.jboss.net.sockets.DefaultSocketFactory;
import org.jboss.util.threadpool.ThreadPool;
import org.jboss.util.threadpool.BasicThreadPool;
public class Main implements MainMBean
{
protected NamingServer theServer;
protected MarshalledObject serverStub;
protected boolean isStubExported;
protected ServerSocket serverSocket;
protected RMIClientSocketFactory clientSocketFactory;
protected RMIServerSocketFactory serverSocketFactory;
protected ServerSocketFactory jnpServerSocketFactory;
protected String clientSocketFactoryName;
protected String serverSocketFactoryName;
protected String jnpServerSocketFactoryName;
protected InetAddress bindAddress;
protected InetAddress rmiBindAddress;
protected int backlog = 50;
protected int port = 1099;
protected int rmiPort = 0;
protected boolean InstallGlobalService = true;
protected Logger log;
protected ThreadPool lookupPool;
public static void main(String[] args)
throws Exception
{
new Main().start();
}
public Main()
{
this("org.jboss.naming.Naming");
}
public Main(String categoryName)
{
try
{
ClassLoader loader = getClass().getClassLoader();
InputStream is = loader.getResourceAsStream("jnp.properties");
System.getProperties().load(is);
}
catch (Exception e)
{
}
setPort(Integer.getInteger("jnp.port",getPort()).intValue());
setRmiPort(Integer.getInteger("jnp.rmiPort",getRmiPort()).intValue());
log = Logger.getLogger(categoryName);
}
public Naming getServer()
{
return theServer;
}
public ThreadPool getLookupPool()
{
return lookupPool;
}
public void setLookupPool(ThreadPool lookupPool)
{
this.lookupPool = lookupPool;
}
public void setNamingProxy(Object proxy)
throws IOException
{
serverStub = new MarshalledObject(proxy);
}
public void setRmiPort(int p)
{
rmiPort = p;
}
public int getRmiPort()
{
return rmiPort;
}
public void setPort(int p)
{
port = p;
}
public int getPort()
{
return port;
}
public String getBindAddress()
{
String address = null;
if( bindAddress != null )
address = bindAddress.getHostAddress();
return address;
}
public void setBindAddress(String host) throws UnknownHostException
{
if( host == null || host.length() == 0 )
bindAddress = null;
else
bindAddress = InetAddress.getByName(host);
}
public String getRmiBindAddress()
{
String address = null;
if( rmiBindAddress != null )
address = rmiBindAddress.getHostAddress();
return address;
}
public void setRmiBindAddress(String host) throws UnknownHostException
{
if( host == null || host.length() == 0 )
rmiBindAddress = null;
else
rmiBindAddress = InetAddress.getByName(host);
}
public int getBacklog()
{
return backlog;
}
public void setBacklog(int backlog)
{
if( backlog <= 0 )
backlog = 50;
this.backlog = backlog;
}
public boolean getInstallGlobalService()
{
return InstallGlobalService;
}
public void setInstallGlobalService(boolean flag)
{
this.InstallGlobalService = flag;
}
public String getClientSocketFactory()
{
return clientSocketFactoryName;
}
public void setClientSocketFactory(String factoryClassName)
throws ClassNotFoundException, InstantiationException, IllegalAccessException
{
this.clientSocketFactoryName = factoryClassName;
ClassLoader loader = Thread.currentThread().getContextClassLoader();
Class clazz = loader.loadClass(clientSocketFactoryName);
clientSocketFactory = (RMIClientSocketFactory) clazz.newInstance();
}
public String getServerSocketFactory()
{
return serverSocketFactoryName;
}
public void setServerSocketFactory(String factoryClassName)
throws ClassNotFoundException, InstantiationException, IllegalAccessException
{
this.serverSocketFactoryName = factoryClassName;
ClassLoader loader = Thread.currentThread().getContextClassLoader();
Class clazz = loader.loadClass(serverSocketFactoryName);
serverSocketFactory = (RMIServerSocketFactory) clazz.newInstance();
}
public void setJNPServerSocketFactory(String factoryClassName)
throws ClassNotFoundException, InstantiationException, IllegalAccessException
{
this.jnpServerSocketFactoryName = factoryClassName;
ClassLoader loader = Thread.currentThread().getContextClassLoader();
Class clazz = loader.loadClass(jnpServerSocketFactoryName);
jnpServerSocketFactory = (ServerSocketFactory) clazz.newInstance();
}
public void start()
throws Exception
{
if( theServer == null )
{
theServer = new NamingServer();
if( InstallGlobalService == true )
{
NamingContext.setLocal(theServer);
}
}
initCustomSocketFactories();
if( this.serverStub == null && this.port >= 0 )
{
initJnpInvoker();
}
if( this.serverStub != null )
{
initBootstrapListener();
}
}
public void stop()
{
try
{
if( serverSocket != null )
{
ServerSocket s = serverSocket;
serverSocket = null;
s.close();
}
if( isStubExported == true )
UnicastRemoteObject.unexportObject(theServer, false);
}
catch (Exception e)
{
log.error("Exception during shutdown", e);
}
}
protected void initJnpInvoker() throws IOException
{
log.debug("Creating NamingServer stub, theServer="+theServer
+",rmiPort="+rmiPort+",clientSocketFactory="+clientSocketFactory
+",serverSocketFactory="+serverSocketFactory);
Remote stub = UnicastRemoteObject.exportObject(theServer, rmiPort,
clientSocketFactory, serverSocketFactory);
log.debug("NamingServer stub: "+stub);
serverStub = new MarshalledObject(stub);
}
protected void initBootstrapListener()
{
try
{
if( jnpServerSocketFactory == null )
jnpServerSocketFactory = ServerSocketFactory.getDefault();
serverSocket = jnpServerSocketFactory.createServerSocket(port, backlog, bindAddress);
if( port == 0 )
port = serverSocket.getLocalPort();
String msg = "Started jndi bootstrap jnpPort=" + port +", rmiPort=" + rmiPort
+ ", backlog="+backlog+", bindAddress="+bindAddress
+ ", Client SocketFactory="+clientSocketFactory
+ ", Server SocketFactory="+serverSocketFactory;
log.info(msg);
}
catch (IOException e)
{
log.error("Could not start on port " + port, e);
}
if( lookupPool == null )
lookupPool = new BasicThreadPool("NamingBootstrap Pool");
AcceptHandler handler = new AcceptHandler();
lookupPool.run(handler);
}
protected void initCustomSocketFactories()
{
InetAddress addr = rmiBindAddress;
if( addr == null )
addr = bindAddress;
if( clientSocketFactory != null )
{
try
{
Class csfClass = clientSocketFactory.getClass();
Class[] parameterTypes = {String.class};
Method m = csfClass.getMethod("setBindAddress", parameterTypes);
Object[] args = {addr.getHostAddress()};
m.invoke(serverSocketFactory, args);
}
catch (NoSuchMethodException e)
{
log.warn("Socket factory does not support setBindAddress(String)");
}
catch (Exception e)
{
log.warn("Failed to setBindAddress="+addr+" on socket factory", e);
}
}
try
{
if( serverSocketFactory == null )
serverSocketFactory = new DefaultSocketFactory(addr);
else
{
if( addr != null )
{
try
{
Class ssfClass = serverSocketFactory.getClass();
Class[] parameterTypes = {String.class};
Method m = ssfClass.getMethod("setBindAddress", parameterTypes);
Object[] args = {addr.getHostAddress()};
m.invoke(serverSocketFactory, args);
}
catch (NoSuchMethodException e)
{
log.warn("Socket factory does not support setBindAddress(String)");
}
catch (Exception e)
{
log.warn("Failed to setBindAddress="+addr+" on socket factory", e);
}
}
}
}
catch (Exception e)
{
log.error("operation failed", e);
serverSocketFactory = null;
}
}
private class AcceptHandler implements Runnable
{
public void run()
{
boolean trace = log.isTraceEnabled();
while( serverSocket != null )
{
Socket socket = null;
try
{
socket = serverSocket.accept();
if( trace )
log.trace("Accepted bootstrap client: "+socket);
BootstrapRequestHandler handler = new BootstrapRequestHandler(socket);
lookupPool.run(handler);
}
catch (IOException e)
{
if (serverSocket == null)
return;
log.error("Naming accept handler stopping", e);
}
catch(Throwable e)
{
log.error("Unexpected exception during accept", e);
}
}
}
}
private class BootstrapRequestHandler implements Runnable
{
private Socket socket;
BootstrapRequestHandler(Socket socket)
{
this.socket = socket;
}
public void run()
{
try
{
OutputStream os = socket.getOutputStream();
ObjectOutputStream out = new ObjectOutputStream(os);
out.writeObject(serverStub);
out.close();
}
catch (IOException ex)
{
if (log.isDebugEnabled())
{
log.debug("Error writing response to " + socket.getInetAddress(), ex);
}
}
finally
{
try
{
socket.close();
} catch (IOException e)
{
}
}
}
}
}