package org.jboss.security.srp;
import java.lang.reflect.Method;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Proxy;
import java.lang.reflect.UndeclaredThrowableException;
import java.rmi.server.RMIClientSocketFactory;
import java.rmi.server.RMIServerSocketFactory;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import javax.naming.InitialContext;
import javax.naming.Name;
import javax.naming.NamingException;
import org.jboss.invocation.Invocation;
import org.jboss.invocation.MarshalledInvocation;
import org.jboss.naming.NonSerializableFactory;
import org.jboss.security.srp.SRPRemoteServer;
import org.jboss.security.srp.SRPServerListener;
import org.jboss.security.srp.SRPServerInterface;
import org.jboss.security.srp.SRPServerSession;
import org.jboss.security.srp.SRPVerifierStore;
import org.jboss.system.ServiceMBeanSupport;
import org.jboss.util.CachePolicy;
import org.jboss.util.TimedCachePolicy;
public class SRPService extends ServiceMBeanSupport
implements SRPServiceMBean, SRPServerListener
{
private SRPRemoteServer server;
private int serverPort = 10099;
private SRPVerifierStore verifierStore;
private String verifierSourceJndiName = "srp/DefaultVerifierSource";
private String serverJndiName = "srp/SRPServerInterface";
private String cacheJndiName = "srp/AuthenticationCache";
private CachePolicy cachePolicy;
private int cacheTimeout = 1800;
private int cacheResolution = 60;
private boolean overwriteSessions;
private boolean requireAuxChallenge;
private RMIClientSocketFactory clientSocketFactory;
private RMIServerSocketFactory serverSocketFactory;
private String clientSocketFactoryName;
private String serverSocketFactoryName;
private Map marshalledInvocationMapping = new HashMap();
public String getVerifierSourceJndiName()
{
return verifierSourceJndiName;
}
public void setVerifierSourceJndiName(String jndiName)
{
this.verifierSourceJndiName = jndiName;
}
public String getJndiName()
{
return serverJndiName;
}
public void setJndiName(String jndiName)
{
this.serverJndiName = jndiName;
}
public String getAuthenticationCacheJndiName()
{
return cacheJndiName;
}
public void setAuthenticationCacheJndiName(String jndiName)
{
this.cacheJndiName = jndiName;
}
public int getAuthenticationCacheTimeout()
{
return cacheTimeout;
}
public void setAuthenticationCacheTimeout(int timeoutInSecs)
{
this.cacheTimeout = timeoutInSecs;
}
public int getAuthenticationCacheResolution()
{
return cacheResolution;
}
public void setAuthenticationCacheResolution(int resInSecs)
{
this.cacheResolution = resInSecs;
}
public boolean getRequireAuxChallenge()
{
return this.requireAuxChallenge;
}
public void setRequireAuxChallenge(boolean flag)
{
this.requireAuxChallenge = flag;
}
public boolean getOverwriteSessions()
{
return this.overwriteSessions;
}
public void setOverwriteSessions(boolean flag)
{
this.overwriteSessions = flag;
}
public String getClientSocketFactory()
{
return serverSocketFactoryName;
}
public void setClientSocketFactory(String factoryClassName)
throws ClassNotFoundException, InstantiationException, IllegalAccessException
{
this.clientSocketFactoryName = factoryClassName;
ClassLoader loader = Thread.currentThread().getContextClassLoader();
Class clazz = loader.loadClass(clientSocketFactoryName);
clientSocketFactory = (RMIClientSocketFactory) clazz.newInstance();
}
public String getServerSocketFactory()
{
return serverSocketFactoryName;
}
public void setServerSocketFactory(String factoryClassName)
throws ClassNotFoundException, InstantiationException, IllegalAccessException
{
this.serverSocketFactoryName = factoryClassName;
ClassLoader loader = Thread.currentThread().getContextClassLoader();
Class clazz = loader.loadClass(serverSocketFactoryName);
serverSocketFactory = (RMIServerSocketFactory) clazz.newInstance();
}
public int getServerPort()
{
return serverPort;
}
public void setServerPort(int serverPort)
{
this.serverPort = serverPort;
}
public void verifiedUser(SRPSessionKey key, SRPServerSession session)
{
try
{
synchronized( cachePolicy )
{
if( cachePolicy.peek(key) == null )
{
cachePolicy.insert(key, session);
log.trace("Cached SRP session for user="+key);
}
else if( overwriteSessions )
{
cachePolicy.remove(key);
cachePolicy.insert(key, session);
log.trace("Replaced SRP session for user="+key);
}
else
{
log.debug("Ignoring SRP session due to existing session for user="+key);
}
}
}
catch(Exception e)
{
log.error("Failed to update SRP cache for user="+key, e);
}
}
public void closedUserSession(SRPSessionKey key)
{
try
{
synchronized( cachePolicy )
{
if( cachePolicy.peek(key) == null )
{
log.warn("No SRP session found for user="+key);
}
cachePolicy.remove(key);
}
}
catch(Exception e)
{
log.error("Failed to update SRP cache for user="+key, e);
}
}
public String getName()
{
return "SRPService";
}
public Object invoke(Invocation invocation) throws Exception
{
if (invocation instanceof MarshalledInvocation)
{
MarshalledInvocation mi = (MarshalledInvocation) invocation;
mi.setMethodMap(marshalledInvocationMapping);
}
Method method = invocation.getMethod();
Object[] args = invocation.getArguments();
Object value = null;
try
{
value = method.invoke(server, args);
}
catch(InvocationTargetException e)
{
Throwable t = e.getTargetException();
if( t instanceof Exception )
throw (Exception) t;
else
throw new UndeclaredThrowableException(t, method.toString());
}
return value;
}
protected void startService() throws Exception
{
loadStore();
server = new SRPRemoteServer(verifierStore, serverPort,
clientSocketFactory, serverSocketFactory);
server.addSRPServerListener(this);
server.setRequireAuxChallenge(this.requireAuxChallenge);
InitialContext ctx = new InitialContext();
if( serverJndiName != null && serverJndiName.length() > 0 )
{
SRPServerProxy proxyHandler = new SRPServerProxy(server);
ClassLoader loader = Thread.currentThread().getContextClassLoader();
Class[] interfaces = {SRPServerInterface.class};
Object proxy = Proxy.newProxyInstance(loader, interfaces, proxyHandler);
org.jboss.naming.Util.rebind(ctx, serverJndiName, proxy);
log.debug("Bound SRPServerProxy at "+serverJndiName);
}
try
{
cachePolicy = (CachePolicy) ctx.lookup(cacheJndiName);
log.debug("Found AuthenticationCache at: "+cacheJndiName);
}
catch(Exception e)
{
log.trace("Failed to find existing cache at: "+cacheJndiName, e);
cachePolicy = new TimedCachePolicy(cacheTimeout, true, cacheResolution);
cachePolicy.create();
cachePolicy.start();
Name name = ctx.getNameParser("").parse(cacheJndiName);
NonSerializableFactory.rebind(name, cachePolicy, true);
log.debug("Bound AuthenticationCache at "+cacheJndiName);
}
HashMap tmpMap = new HashMap(13);
Method[] methods = SRPRemoteServerInterface.class.getMethods();
for(int m = 0; m < methods.length; m ++)
{
Method method = methods[m];
Long hash = new Long(MarshalledInvocation.calculateHash(method));
tmpMap.put(hash, method);
}
marshalledInvocationMapping = Collections.unmodifiableMap(tmpMap);
}
protected void stopService() throws Exception
{
InitialContext ctx = new InitialContext();
ctx.unbind(serverJndiName);
log.debug("Unbound SRPServerProxy at "+serverJndiName);
NonSerializableFactory.unbind(cacheJndiName);
ctx.unbind(cacheJndiName);
log.debug("Unbound AuthenticationCache at "+cacheJndiName);
}
private void loadStore() throws NamingException
{
InitialContext ctx = new InitialContext();
verifierStore = (SRPVerifierStore) ctx.lookup(verifierSourceJndiName);
if( server != null )
{
server.setVerifierStore(verifierStore);
}
}
}