package org.jboss.ha.jndi;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.UndeclaredThrowableException;
import java.net.DatagramPacket;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.MulticastSocket;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.UnknownHostException;
import java.rmi.MarshalledObject;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import javax.management.ObjectInstance;
import javax.management.ObjectName;
import javax.management.Query;
import javax.management.QueryExp;
import javax.net.ServerSocketFactory;
import org.jboss.ha.framework.interfaces.HAPartition;
import org.jboss.ha.framework.server.ClusterPartition;
import org.jboss.ha.framework.server.ClusterPartitionMBean;
import org.jboss.invocation.Invocation;
import org.jboss.invocation.MarshalledInvocation;
import org.jboss.logging.Logger;
import org.jboss.mx.util.MBeanProxy;
import org.jboss.system.ServiceMBeanSupport;
import org.jboss.system.server.ServerConfigUtil;
import org.jboss.util.threadpool.BasicThreadPool;
import org.jboss.util.threadpool.BasicThreadPoolMBean;
import org.jboss.util.threadpool.ThreadPool;
import org.jnp.interfaces.Naming;
import org.jnp.interfaces.NamingContext;
public class DetachedHANamingService
extends ServiceMBeanSupport
implements DetachedHANamingServiceMBean
{
protected ServerSocket bootstrapSocket;
protected HAJNDI theServer;
protected Map marshalledInvocationMapping;
protected Naming stub;
protected HAPartition partition;
protected String partitionName = ServerConfigUtil.getDefaultPartitionName();
private ObjectName proxyFactory;
protected InetAddress bindAddress;
protected int backlog = 50;
protected int port = 1100;
protected String adGroupAddress = NamingContext.DEFAULT_DISCOVERY_GROUP_ADDRESS;
protected int adGroupPort = NamingContext.DEFAULT_DISCOVERY_GROUP_PORT;
protected InetAddress discoveryBindAddress;
protected AutomaticDiscovery autoDiscovery = null;
protected boolean discoveryDisabled = false;
protected int autoDiscoveryTTL = 16;
protected ServerSocketFactory jnpServerSocketFactory;
protected String jnpServerSocketFactoryName;
protected ThreadPool lookupPool;
public DetachedHANamingService()
{
}
public Map getMethodMap()
{
return marshalledInvocationMapping;
}
public String getPartitionName()
{
return partitionName;
}
public void setPartitionName(final String partitionName)
{
this.partitionName = partitionName;
}
public ObjectName getProxyFactoryObjectName()
{
return proxyFactory;
}
public void setProxyFactoryObjectName(ObjectName proxyFactory)
{
this.proxyFactory = proxyFactory;
}
public void setPort(int p)
{
port = p;
}
public int getPort()
{
return port;
}
public String getBindAddress()
{
String address = null;
if (bindAddress != null)
address = bindAddress.getHostAddress();
return address;
}
public void setBindAddress(String host) throws java.net.UnknownHostException
{
bindAddress = InetAddress.getByName(host);
}
public int getBacklog()
{
return backlog;
}
public void setBacklog(int backlog)
{
if (backlog <= 0)
backlog = 50;
this.backlog = backlog;
}
public void setDiscoveryDisabled(boolean disable)
{
this.discoveryDisabled = disable;
}
public boolean getDiscoveryDisabled()
{
return this.discoveryDisabled;
}
public String getAutoDiscoveryAddress()
{
return this.adGroupAddress;
}
public void setAutoDiscoveryAddress(String adAddress)
{
this.adGroupAddress = adAddress;
}
public int getAutoDiscoveryGroup()
{
return this.adGroupPort;
}
public void setAutoDiscoveryGroup(int adGroup)
{
this.adGroupPort = adGroup;
}
public String getAutoDiscoveryBindAddress()
{
String address = null;
if (discoveryBindAddress != null)
address = discoveryBindAddress.getHostAddress();
return address;
}
public void setAutoDiscoveryBindAddress(String address)
throws UnknownHostException
{
discoveryBindAddress = InetAddress.getByName(address);
}
public int getAutoDiscoveryTTL()
{
return autoDiscoveryTTL;
}
public void setAutoDiscoveryTTL(int ttl)
{
autoDiscoveryTTL = ttl;
}
public void setJNPServerSocketFactory(String factoryClassName)
throws ClassNotFoundException, InstantiationException, IllegalAccessException
{
this.jnpServerSocketFactoryName = factoryClassName;
ClassLoader loader = Thread.currentThread().getContextClassLoader();
Class clazz = loader.loadClass(jnpServerSocketFactoryName);
jnpServerSocketFactory = (ServerSocketFactory) clazz.newInstance();
}
public void setLookupPool(BasicThreadPoolMBean poolMBean)
{
lookupPool = poolMBean.getInstance();
}
public void startService(HAPartition haPartition)
throws Exception
{
this.partition = haPartition;
this.startService();
}
protected void createService()
throws Exception
{
boolean debug = log.isDebugEnabled();
if (debug)
log.debug("Initializing HAJNDI server on partition: " + partitionName);
partition = findHAPartitionWithName(partitionName);
theServer = new HAJNDI(partition);
log.debug("initialize HAJNDI");
theServer.init();
HashMap tmpMap = new HashMap(13);
Method[] methods = Naming.class.getMethods();
for (int m = 0; m < methods.length; m++)
{
Method method = methods[m];
Long hash = new Long(MarshalledInvocation.calculateHash(method));
tmpMap.put(hash, method);
}
marshalledInvocationMapping = Collections.unmodifiableMap(tmpMap);
NamingContext.setHANamingServerForPartition(partitionName, theServer);
}
protected void startService()
throws Exception
{
log.debug("Obtaining the transport proxy");
stub = this.getNamingProxy();
this.theServer.setHAStub(stub);
if (port >= 0)
{
log.debug("Starting HAJNDI bootstrap listener");
initBootstrapListener();
}
if (adGroupAddress != null && discoveryDisabled == false)
{
try
{
autoDiscovery = new AutomaticDiscovery();
autoDiscovery.start();
lookupPool.run(autoDiscovery);
}
catch (Exception e)
{
log.warn("Failed to start AutomaticDiscovery", e);
}
}
}
protected void stopService() throws Exception
{
NamingContext.removeHANamingServerForPartition(partitionName);
ServerSocket s = bootstrapSocket;
bootstrapSocket = null;
if (s != null)
{
log.debug("Closing the bootstrap listener");
s.close();
}
log.debug("Stopping the HAJNDI service");
theServer.stop();
log.debug("Stopping AutomaticDiscovery");
if (autoDiscovery != null && discoveryDisabled == false)
autoDiscovery.stop();
}
public Object invoke(Invocation invocation) throws Exception
{
if (invocation instanceof MarshalledInvocation)
{
MarshalledInvocation mi = (MarshalledInvocation) invocation;
mi.setMethodMap(marshalledInvocationMapping);
}
Method method = invocation.getMethod();
Object[] args = invocation.getArguments();
Object value = null;
try
{
value = method.invoke(theServer, args);
}
catch (InvocationTargetException e)
{
Throwable t = e.getTargetException();
if (t instanceof Exception)
throw (Exception) t;
else
throw new UndeclaredThrowableException(t, method.toString());
}
return value;
}
protected void initBootstrapListener()
{
try
{
if (jnpServerSocketFactory == null)
jnpServerSocketFactory = ServerSocketFactory.getDefault();
bootstrapSocket = jnpServerSocketFactory.createServerSocket(port, backlog, bindAddress);
if (port == 0)
port = bootstrapSocket.getLocalPort();
String msg = "Started ha-jndi bootstrap jnpPort=" + port
+ ", backlog=" + backlog + ", bindAddress=" + bindAddress;
log.info(msg);
}
catch (IOException e)
{
log.error("Could not start on port " + port, e);
}
if (lookupPool == null)
lookupPool = new BasicThreadPool("HANamingBootstrap Pool");
AcceptHandler handler = new AcceptHandler();
lookupPool.run(handler);
}
protected HAPartition findHAPartitionWithName(String name) throws Exception
{
HAPartition result = null;
QueryExp exp = Query.and(Query.eq(Query.classattr(),
Query.value(ClusterPartition.class.getName())),
Query.match(Query.attr("PartitionName"),
Query.value(name)));
Set mbeans = this.getServer().queryMBeans(null, exp);
if (mbeans != null && mbeans.size() > 0)
{
ObjectInstance inst = (ObjectInstance) (mbeans.iterator().next());
ClusterPartitionMBean cp = (ClusterPartitionMBean) MBeanProxy.get(ClusterPartitionMBean.class,
inst.getObjectName(),
this.getServer());
result = cp.getHAPartition();
}
return result;
}
protected Naming getNamingProxy() throws Exception
{
Naming proxy = (Naming) server.getAttribute(proxyFactory, "Proxy");
return proxy;
}
private class AutomaticDiscovery
implements Runnable
{
protected Logger log = Logger.getLogger(AutomaticDiscovery.class);
protected MulticastSocket socket = null;
protected byte[] ipAddress = null;
protected InetAddress group = null;
protected boolean stopping = false;
public AutomaticDiscovery() throws Exception
{
}
public void start() throws Exception
{
stopping = false;
if (discoveryBindAddress != null)
discoveryBindAddress = bindAddress;
InetSocketAddress bindAddr = new InetSocketAddress(discoveryBindAddress,
adGroupPort);
socket = new MulticastSocket(bindAddr);
socket.setTimeToLive(autoDiscoveryTTL);
group = InetAddress.getByName(adGroupAddress);
socket.joinGroup(group);
String address = getBindAddress();
if (address == null || address.equals("0.0.0.0"))
{
address = InetAddress.getLocalHost().getHostAddress();
}
ipAddress = (address + ":" + port).getBytes();
log.info("Listening on " + socket.getInterface() + ":" + socket.getLocalPort()
+ ", group=" + adGroupAddress
+ ", HA-JNDI address=" + new String(ipAddress));
}
public void stop()
{
try
{
stopping = true;
socket.leaveGroup(group);
socket.close();
}
catch (Exception ex)
{
log.error("Stopping AutomaticDiscovery failed", ex);
}
}
public void run()
{
boolean trace = log.isTraceEnabled();
log.debug("Discovery request thread begin");
while (true)
{
if (stopping)
break;
try
{
if (trace)
log.trace("HA-JNDI AutomaticDiscovery waiting for queries...");
byte[] buf = new byte[256];
DatagramPacket packet = new DatagramPacket(buf, buf.length);
socket.receive(packet);
if (trace)
log.trace("HA-JNDI AutomaticDiscovery Packet received.");
DiscoveryRequestHandler handler = new DiscoveryRequestHandler(log,
packet, socket, ipAddress);
lookupPool.run(handler);
if (trace)
log.trace("Queued DiscoveryRequestHandler");
}
catch (Throwable t)
{
if (stopping == false)
log.warn("Ignored error while processing HAJNDI discovery request:", t);
}
}
log.debug("Discovery request thread end");
}
}
private class DiscoveryRequestHandler implements Runnable
{
private Logger log;
private MulticastSocket socket;
private DatagramPacket packet;
private byte[] ipAddress;
DiscoveryRequestHandler(Logger log, DatagramPacket packet,
MulticastSocket socket, byte[] ipAddress)
{
this.log = log;
this.packet = packet;
this.socket = socket;
this.ipAddress = ipAddress;
}
public void run()
{
boolean trace = log.isTraceEnabled();
if( trace )
log.trace("DiscoveryRequestHandler begin");
try
{
String requestData = new String(packet.getData()).trim();
if( trace )
log.trace("RequestData: "+requestData);
int colon = requestData.indexOf(':');
if (colon > 0)
{
String name = requestData.substring(colon + 1);
if (name.equals(partitionName) == false)
{
log.debug("Ignoring discovery request for partition: " + name);
if( trace )
log.trace("DiscoveryRequestHandler end");
return;
}
}
DatagramPacket p = new DatagramPacket(ipAddress, ipAddress.length,
packet.getAddress(), packet.getPort());
if (trace)
log.trace("Sending AutomaticDiscovery answer: " + new String(ipAddress));
socket.send(p);
if (trace)
log.trace("AutomaticDiscovery answer sent.");
}
catch (IOException ex)
{
log.error("Error writing response", ex);
}
if( trace )
log.trace("DiscoveryRequestHandler end");
}
}
private class AcceptHandler implements Runnable
{
public void run()
{
boolean trace = log.isTraceEnabled();
while (bootstrapSocket != null)
{
Socket socket = null;
try
{
socket = bootstrapSocket.accept();
if( trace )
log.trace("Accepted bootstrap client: "+socket);
BootstrapRequestHandler handler = new BootstrapRequestHandler(socket);
lookupPool.run(handler);
}
catch (IOException e)
{
if (bootstrapSocket == null)
return;
log.error("Naming accept handler stopping", e);
}
catch(Throwable e)
{
log.error("Unexpected exception during accept", e);
}
}
}
}
private class BootstrapRequestHandler implements Runnable
{
private Socket socket;
BootstrapRequestHandler(Socket socket)
{
this.socket = socket;
}
public void run()
{
try
{
OutputStream os = socket.getOutputStream();
ObjectOutputStream out = new ObjectOutputStream(os);
MarshalledObject replyStub = new MarshalledObject(stub);
out.writeObject(replyStub);
out.close();
}
catch (IOException ex)
{
log.debug("Error writing response to " + socket, ex);
}
finally
{
try
{
socket.close();
}
catch (IOException e)
{
}
}
}
}
}