package org.jboss.security.plugins;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.security.Principal;
import java.util.Enumeration;
import java.util.Hashtable;
import java.util.Set;
import java.util.List;
import java.beans.PropertyEditorManager;
import javax.management.MBeanServer;
import javax.naming.CommunicationException;
import javax.naming.Context;
import javax.naming.InitialContext;
import javax.naming.Name;
import javax.naming.NameClassPair;
import javax.naming.NameParser;
import javax.naming.NamingEnumeration;
import javax.naming.NamingException;
import javax.naming.OperationNotSupportedException;
import javax.naming.RefAddr;
import javax.naming.Reference;
import javax.naming.StringRefAddr;
import javax.naming.spi.ObjectFactory;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.Subject;
import javax.security.jacc.PolicyContext;
import org.jboss.logging.Logger;
import org.jboss.security.AuthenticationManager;
import org.jboss.security.SecurityAssociation;
import org.jboss.security.SecurityProxyFactory;
import org.jboss.security.SecurityDomain;
import org.jboss.security.jacc.SubjectPolicyContextHandler;
import org.jboss.security.propertyeditor.PrincipalEditor;
import org.jboss.security.propertyeditor.SecurityDomainEditor;
import org.jboss.system.ServiceMBeanSupport;
import org.jboss.util.CachePolicy;
import org.jboss.util.TimedCachePolicy;
public class JaasSecurityManagerService
extends ServiceMBeanSupport
implements JaasSecurityManagerServiceMBean
{
private static final String SECURITY_MGR_PATH = "java:/jaas";
private static final String DEFAULT_CACHE_POLICY_PATH = "java:/timedCacheFactory";
private static Logger log;
private static String securityMgrClassName = "org.jboss.security.plugins.JaasSecurityManager";
private static Class securityMgrClass;
private static String callbackHandlerClassName = "org.jboss.security.auth.callback.SecurityAssociationHandler";
private static Class callbackHandlerClass = org.jboss.security.auth.callback.SecurityAssociationHandler.class;
private static String cacheJndiName = DEFAULT_CACHE_POLICY_PATH;
private static int defaultCacheTimeout = 30*60;
private static int defaultCacheResolution = 60;
private static String securityProxyFactoryClassName = "org.jboss.security.SubjectSecurityProxyFactory";
private static Class securityProxyFactoryClass = org.jboss.security.SubjectSecurityProxyFactory.class;
private static Hashtable securityDomainCtxMap = new Hashtable();
private static NameParser parser;
private static MBeanServer server;
private static String defaultUnauthenticatedPrincipal = "Unauthenticated Principal";
static
{
log = Logger.getLogger(JaasSecurityManagerService.class);
}
public JaasSecurityManagerService()
{
}
public String getSecurityManagerClassName()
{
return securityMgrClassName;
}
public void setSecurityManagerClassName(String className)
throws ClassNotFoundException, ClassCastException
{
securityMgrClassName = className;
ClassLoader loader = Thread.currentThread().getContextClassLoader();
securityMgrClass = loader.loadClass(securityMgrClassName);
if( AuthenticationManager.class.isAssignableFrom(securityMgrClass) == false )
throw new ClassCastException(securityMgrClass+" does not implement "+AuthenticationManager.class);
}
public String getSecurityProxyFactoryClassName()
{
return securityProxyFactoryClassName;
}
public void setSecurityProxyFactoryClassName(String className)
throws ClassNotFoundException
{
securityProxyFactoryClassName = className;
ClassLoader loader = Thread.currentThread().getContextClassLoader();
securityProxyFactoryClass = loader.loadClass(securityProxyFactoryClassName);
}
public String getCallbackHandlerClassName()
{
return JaasSecurityManagerService.callbackHandlerClassName;
}
public void setCallbackHandlerClassName(String className)
throws ClassNotFoundException
{
callbackHandlerClassName = className;
ClassLoader loader = Thread.currentThread().getContextClassLoader();
callbackHandlerClass = loader.loadClass(callbackHandlerClassName);
}
public String getAuthenticationCacheJndiName()
{
return cacheJndiName;
}
public void setAuthenticationCacheJndiName(String jndiName)
{
cacheJndiName = jndiName;
}
public int getDefaultCacheTimeout()
{
return defaultCacheTimeout;
}
public void setDefaultCacheTimeout(int timeoutInSecs)
{
defaultCacheTimeout = timeoutInSecs;
}
public int getDefaultCacheResolution()
{
return defaultCacheResolution;
}
public void setDefaultCacheResolution(int resInSecs)
{
defaultCacheResolution = resInSecs;
}
public void setCacheTimeout(String securityDomain, int timeoutInSecs, int resInSecs)
{
CachePolicy cache = getCachePolicy(securityDomain);
if( cache != null && cache instanceof TimedCachePolicy )
{
TimedCachePolicy tcp = (TimedCachePolicy) cache;
synchronized( tcp )
{
tcp.setDefaultLifetime(timeoutInSecs);
tcp.setResolution(resInSecs);
}
}
else
{
log.warn("Failed to find cache policy for securityDomain='"
+ securityDomain + "'");
}
}
public void flushAuthenticationCache(String securityDomain)
{
CachePolicy cache = getCachePolicy(securityDomain);
if( cache != null )
{
cache.flush();
}
else
{
log.warn("Failed to find cache policy for securityDomain='"
+ securityDomain + "'");
}
}
public void flushAuthenticationCache(String securityDomain, Principal user)
{
CachePolicy cache = getCachePolicy(securityDomain);
if( cache != null )
{
cache.remove(user);
}
else
{
log.warn("Failed to find cache policy for securityDomain='"
+ securityDomain + "'");
}
}
public List getAuthenticationCachePrincipals(String securityDomain)
{
CachePolicy cache = getCachePolicy(securityDomain);
List validPrincipals = null;
if( cache instanceof TimedCachePolicy )
{
TimedCachePolicy tcache = (TimedCachePolicy) cache;
validPrincipals = tcache.getValidKeys();
}
return validPrincipals;
}
public boolean isValid(String securityDomain, Principal principal, Object credential)
{
boolean isValid = false;
try
{
SecurityDomainContext sdc = lookupSecurityDomain(securityDomain);
isValid = sdc.getSecurityManager().isValid(principal, credential);
}
catch(NamingException e)
{
log.debug("isValid("+securityDomain+") failed", e);
}
return isValid;
}
public Principal getPrincipal(String securityDomain, Principal principal)
{
Principal realmPrincipal = null;
try
{
SecurityDomainContext sdc = lookupSecurityDomain(securityDomain);
realmPrincipal = sdc.getRealmMapping().getPrincipal(principal);
}
catch(NamingException e)
{
log.debug("getPrincipal("+securityDomain+") failed", e);
}
return realmPrincipal;
}
public boolean doesUserHaveRole(String securityDomain, Principal principal,
Object credential, Set roles)
{
boolean doesUserHaveRole = false;
try
{
SecurityDomainContext sdc = lookupSecurityDomain(securityDomain);
Subject subject = new Subject();
sdc.getSecurityManager().isValid(principal, credential, subject);
SubjectActions.pushSubjectContext(principal, credential, subject);
doesUserHaveRole = sdc.getRealmMapping().doesUserHaveRole(principal, roles);
SubjectActions.popSubjectContext();
}
catch(NamingException e)
{
log.debug("doesUserHaveRole("+securityDomain+") failed", e);
}
return doesUserHaveRole;
}
public Set getUserRoles(String securityDomain, Principal principal, Object credential)
{
Set userRoles = null;
try
{
SecurityDomainContext sdc = lookupSecurityDomain(securityDomain);
sdc.getSecurityManager().isValid(principal, credential);
userRoles = sdc.getRealmMapping().getUserRoles(principal);
}
catch(NamingException e)
{
log.debug("getUserRoles("+securityDomain+") failed", e);
}
return userRoles;
}
protected void startService() throws Exception
{
SecurityAssociation.setServer();
SubjectPolicyContextHandler handler = new SubjectPolicyContextHandler();
PolicyContext.registerHandler(SubjectPolicyContextHandler.SUBJECT_CONTEXT_KEY,
handler, true);
Context ctx = new InitialContext();
parser = ctx.getNameParser("");
RefAddr refAddr = new StringRefAddr("nns", "JSM");
String factoryName = SecurityDomainObjectFactory.class.getName();
Reference ref = new Reference("javax.naming.Context", refAddr, factoryName, null);
ctx.rebind(SECURITY_MGR_PATH, ref);
log.debug("securityMgrCtxPath="+SECURITY_MGR_PATH);
refAddr = new StringRefAddr("nns", "JSMCachePolicy");
factoryName = DefaultCacheObjectFactory.class.getName();
ref = new Reference("javax.naming.Context", refAddr, factoryName, null);
ctx.rebind(DEFAULT_CACHE_POLICY_PATH, ref);
log.debug("cachePolicyCtxPath="+cacheJndiName);
SecurityProxyFactory proxyFactory = (SecurityProxyFactory) securityProxyFactoryClass.newInstance();
ctx.bind("java:/SecurityProxyFactory", proxyFactory);
log.debug("SecurityProxyFactory="+proxyFactory);
PropertyEditorManager.registerEditor(Principal.class, PrincipalEditor.class);
PropertyEditorManager.registerEditor(SecurityDomain.class, SecurityDomainEditor.class);
}
protected void stopService() throws Exception
{
InitialContext ic = new InitialContext();
try
{
ic.unbind(SECURITY_MGR_PATH);
}
catch(CommunicationException e)
{
}
finally
{
ic.close();
}
}
public void registerSecurityDomain(String securityDomain, SecurityDomain instance)
{
log.debug("Added "+securityDomain+", "+instance+" to map");
CachePolicy authCache = lookupCachePolicy(securityDomain);
SecurityDomainContext sdc = new SecurityDomainContext(instance, authCache);
securityDomainCtxMap.put(securityDomain, sdc);
setSecurityDomainCache(instance, authCache);
}
private static CachePolicy getCachePolicy(String securityDomain)
{
if( securityDomain.startsWith(SECURITY_MGR_PATH) )
securityDomain = securityDomain.substring(SECURITY_MGR_PATH.length()+1);
CachePolicy cache = null;
try
{
SecurityDomainContext sdc = lookupSecurityDomain(securityDomain);
if( sdc != null )
cache = sdc.getAuthenticationCache();
}
catch(NamingException e)
{
log.debug("getCachePolicy("+securityDomain+") failure", e);
}
return cache;
}
private static CachePolicy lookupCachePolicy(String securityDomain)
{
CachePolicy authCache = null;
String domainCachePath = cacheJndiName + '/' + securityDomain;
try
{
InitialContext iniCtx = new InitialContext();
authCache = (CachePolicy) iniCtx.lookup(domainCachePath);
}
catch(Exception e)
{
try
{
InitialContext iniCtx = new InitialContext();
authCache = (CachePolicy) iniCtx.lookup(cacheJndiName);
}
catch(Exception e2)
{
log.warn("Failed to locate auth CachePolicy at: "+cacheJndiName
+ " for securityDomain="+securityDomain);
}
}
return authCache;
}
private static void setSecurityDomainCache(AuthenticationManager securityMgr,
CachePolicy cachePolicy)
{
try
{
Class[] setCachePolicyTypes = {CachePolicy.class};
Method m = securityMgrClass.getMethod("setCachePolicy", setCachePolicyTypes);
Object[] setCachePolicyArgs = {cachePolicy};
m.invoke(securityMgr, setCachePolicyArgs);
log.debug("setCachePolicy, c="+setCachePolicyArgs[0]);
}
catch(Exception e2)
{ log.debug("setCachePolicy failed", e2);
}
}
private static SecurityDomainContext lookupSecurityDomain(String securityDomain)
throws NamingException
{
SecurityDomainContext securityDomainCtx = (SecurityDomainContext) securityDomainCtxMap.get(securityDomain);
if( securityDomainCtx == null )
{
securityDomainCtx = newSecurityDomainCtx(securityDomain);
securityDomainCtxMap.put(securityDomain, securityDomainCtx);
log.debug("Added "+securityDomain+", "+securityDomainCtx+" to map");
}
return securityDomainCtx;
}
private static SecurityDomainContext newSecurityDomainCtx(String securityDomain)
throws NamingException
{
SecurityDomainContext sdc = null;
try
{
Class[] parameterTypes = {String.class, CallbackHandler.class};
Constructor ctor = securityMgrClass.getConstructor(parameterTypes);
CallbackHandler handler = (CallbackHandler) callbackHandlerClass.newInstance();
Object[] args = {securityDomain, handler};
AuthenticationManager securityMgr = (AuthenticationManager) ctor.newInstance(args);
log.debug("Created securityMgr="+securityMgr);
CachePolicy cachePolicy = lookupCachePolicy(securityDomain);
sdc = new SecurityDomainContext(securityMgr, cachePolicy);
setSecurityDomainCache(securityMgr, cachePolicy);
}
catch(Exception e2)
{
log.error("Failed to create sec mgr", e2);
throw new NamingException("Failed to create sec mgr:"+e2.getMessage());
}
return sdc;
}
public String getDefaultUnauthenticatedPrincipal()
{
return defaultUnauthenticatedPrincipal;
}
public void setDefaultUnauthenticatedPrincipal(String principal)
{
defaultUnauthenticatedPrincipal = principal;
}
public static class SecurityDomainObjectFactory
implements InvocationHandler, ObjectFactory
{
public Object getObjectInstance(Object obj, Name name, Context nameCtx,
Hashtable environment)
throws Exception
{
ClassLoader loader = SubjectActions.getContextClassLoader();
Class[] interfaces = {Context.class};
Context ctx = (Context) Proxy.newProxyInstance(loader, interfaces, this);
return ctx;
}
public Object invoke(Object obj, Method method, Object[] args) throws Throwable
{
String methodName = method.getName();
if( methodName.equals("toString") == true )
return SECURITY_MGR_PATH + " Context proxy";
if( methodName.equals("list") == true )
return new DomainEnumeration(securityDomainCtxMap.keys(), securityDomainCtxMap);
if( methodName.equals("lookup") == false )
throw new OperationNotSupportedException("Only lookup is supported, op="+method);
String securityDomain = null;
Name name = null;
if( args[0] instanceof String )
name = parser.parse((String) args[0]);
else
name = (Name)args[0];
securityDomain = name.get(0);
SecurityDomainContext securityDomainCtx = lookupSecurityDomain(securityDomain);
Object binding = securityDomainCtx.getSecurityManager();
if( name.size() == 2 )
{
String request = name.get(1);
binding = securityDomainCtx.lookup(request);
}
return binding;
}
}
static class DomainEnumeration implements NamingEnumeration
{
Enumeration domains;
Hashtable ctxMap;
DomainEnumeration(Enumeration domains, Hashtable ctxMap)
{
this.domains = domains;
this.ctxMap = ctxMap;
}
public void close()
{
}
public boolean hasMoreElements()
{
return domains.hasMoreElements();
}
public boolean hasMore()
{
return domains.hasMoreElements();
}
public Object next()
{
String name = (String) domains.nextElement();
Object value = ctxMap.get(name);
String className = value.getClass().getName();
NameClassPair pair = new NameClassPair(name, className);
return pair;
}
public Object nextElement()
{
return domains.nextElement();
}
}
public static class DefaultCacheObjectFactory implements InvocationHandler, ObjectFactory
{
public Object getObjectInstance(Object obj, Name name, Context nameCtx, Hashtable environment)
throws Exception
{
ClassLoader loader = Thread.currentThread().getContextClassLoader();
Class[] interfaces = {Context.class};
Context ctx = (Context) Proxy.newProxyInstance(loader, interfaces, this);
return ctx;
}
public Object invoke(Object obj, Method method, Object[] args) throws Throwable
{
TimedCachePolicy cachePolicy = new TimedCachePolicy(defaultCacheTimeout,
true, defaultCacheResolution);
cachePolicy.create();
cachePolicy.start();
return cachePolicy;
}
}
}