package org.jboss.security.plugins;
import java.lang.reflect.Method;
import java.lang.reflect.UndeclaredThrowableException;
import java.security.Principal;
import java.security.acl.Group;
import java.util.Arrays;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import javax.security.auth.Subject;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.login.LoginContext;
import javax.security.auth.login.LoginException;
import org.jboss.logging.Logger;
import org.jboss.security.AnybodyPrincipal;
import org.jboss.security.NobodyPrincipal;
import org.jboss.security.RealmMapping;
import org.jboss.security.SecurityAssociation;
import org.jboss.security.SubjectSecurityManager;
import org.jboss.security.auth.callback.SecurityAssociationHandler;
import org.jboss.system.ServiceMBeanSupport;
import org.jboss.util.CachePolicy;
import org.jboss.util.TimedCachePolicy;
public class JaasSecurityManager extends ServiceMBeanSupport
implements SubjectSecurityManager, RealmMapping
{
public static class DomainInfo implements TimedCachePolicy.TimedEntry
{
private static Logger log = Logger.getLogger(DomainInfo.class);
private static boolean trace = log.isTraceEnabled();
private LoginContext loginCtx;
private Subject subject;
private Object credential;
private Principal callerPrincipal;
private long expirationTime;
public DomainInfo(int lifetime)
{
expirationTime = 1000 * lifetime;
}
public void init(long now)
{
expirationTime += now;
}
public boolean isCurrent(long now)
{
return expirationTime > now;
}
public boolean refresh()
{
return false;
}
public void destroy()
{
if( trace )
log.trace("destroy, subject="+subject+", this="+this);
try
{
loginCtx.logout();
}
catch(Exception e)
{
if( trace )
log.trace("Cache entry logout failed", e);
}
}
public Object getValue()
{
return this;
}
public String toString()
{
StringBuffer tmp = new StringBuffer(super.toString());
tmp.append('[');
tmp.append("Subject(");
tmp.append(System.identityHashCode(subject));
tmp.append(").principals=");
tmp.append(subject.getPrincipals());
tmp.append(",credential.class=");
if( credential != null )
{
Class c = credential.getClass();
tmp.append(c.getName());
tmp.append('@');
tmp.append(System.identityHashCode(c));
}
else
{
tmp.append("null");
}
tmp.append(",expirationTime=");
tmp.append(expirationTime);
tmp.append(']');
return tmp.toString();
}
}
private String securityDomain;
private CachePolicy domainCache;
private CallbackHandler handler;
private Method setSecurityInfo;
protected Logger log;
protected boolean trace;
public JaasSecurityManager()
{
this("other", new SecurityAssociationHandler());
}
public JaasSecurityManager(String securityDomain, CallbackHandler handler)
{
this.securityDomain = securityDomain;
this.handler = handler;
String categoryName = getClass().getName()+'.'+securityDomain;
this.log = Logger.getLogger(categoryName);
this.trace = log.isTraceEnabled();
Class[] sig = {Principal.class, Object.class};
try
{
setSecurityInfo = handler.getClass().getMethod("setSecurityInfo", sig);
}
catch (Exception e)
{
String msg = "Failed to find setSecurityInfo(Princpal, Object) method in handler";
throw new UndeclaredThrowableException(e, msg);
}
}
public void setCachePolicy(CachePolicy domainCache)
{
this.domainCache = domainCache;
log.debug("CachePolicy set to: "+domainCache);
}
public void flushCache()
{
if( domainCache != null )
domainCache.flush();
}
public String getSecurityDomain()
{
return securityDomain;
}
public Subject getActiveSubject()
{
return SecurityAssociation.getSubject();
}
public boolean isValid(Principal principal, Object credential)
{
return isValid(principal, credential, null);
}
public synchronized boolean isValid(Principal principal, Object credential,
Subject activeSubject)
{
DomainInfo cacheInfo = getCacheInfo(principal, true);
if( trace )
log.trace("Begin isValid, cache info: "+cacheInfo);
boolean isValid = false;
if( cacheInfo != null )
isValid = validateCache(cacheInfo, credential, activeSubject);
if( isValid == false )
isValid = authenticate(principal, credential, activeSubject);
if( trace )
log.trace("End isValid, "+isValid);
return isValid;
}
public Principal getPrincipal(Principal principal)
{
Principal result = principal;
synchronized( domainCache )
{
DomainInfo info = getCacheInfo(principal, false);
if( trace )
log.trace("getPrincipal, cache info: "+info);
if( info != null )
{
result = info.callerPrincipal;
if( result == null )
result = principal;
}
}
return result;
}
public boolean doesUserHaveRole(Principal principal, Set rolePrincipals)
{
boolean hasRole = false;
Subject subject = SubjectActions.getActiveSubject();
if( subject != null )
{
if( trace )
log.trace("doesUserHaveRole(Set), subject: "+subject);
Group roles = getSubjectRoles(subject);
if( trace )
log.trace("roles="+roles);
if( roles != null )
{
Iterator iter = rolePrincipals.iterator();
while( hasRole == false && iter.hasNext() )
{
Principal role = (Principal) iter.next();
hasRole = doesRoleGroupHaveRole(role, roles);
}
}
if( trace )
log.trace("hasRole="+hasRole);
}
return hasRole;
}
public boolean doesUserHaveRole(Principal principal, Principal role)
{
boolean hasRole = false;
Subject subject = SubjectActions.getActiveSubject();
if( subject != null )
{
if( trace )
log.trace("doesUserHaveRole(Principal), subject: "+subject);
Group roles = getSubjectRoles(subject);
if( roles != null )
{
hasRole = doesRoleGroupHaveRole(role, roles);
}
}
return hasRole;
}
public Set getUserRoles(Principal principal)
{
HashSet userRoles = null;
Subject subject = SubjectActions.getActiveSubject();
if( subject != null )
{
if( trace )
log.trace("getUserRoles, subject: "+subject);
Group roles = getSubjectRoles(subject);
if( roles != null )
{
userRoles = new HashSet();
Enumeration members = roles.members();
while( members.hasMoreElements() )
{
Principal role = (Principal) members.nextElement();
userRoles.add(role);
}
}
}
return userRoles;
}
protected boolean doesRoleGroupHaveRole(Principal role, Group userRoles)
{
if (role instanceof NobodyPrincipal)
return false;
boolean isMember = userRoles.isMember(role);
if (isMember == false)
{ isMember = (role instanceof AnybodyPrincipal);
}
return isMember;
}
private boolean authenticate(Principal principal, Object credential,
Subject theSubject)
{
Subject subject = null;
boolean authenticated = false;
LoginException authException = null;
try
{
LoginContext lc = defaultLogin(principal, credential);
subject = lc.getSubject();
if( subject != null )
{
if( theSubject != null )
{
SubjectActions.copySubject(subject, theSubject);
}
else
{
theSubject = subject;
}
authenticated = true;
updateCache(lc, subject, principal, credential);
}
}
catch(LoginException e)
{
if( principal != null && principal.getName() != null || trace )
log.trace("Login failure", e);
authException = e;
}
SubjectActions.setContextInfo("org.jboss.security.exception", authException);
return authenticated;
}
private LoginContext defaultLogin(Principal principal, Object credential)
throws LoginException
{
Object[] securityInfo = {principal, credential};
try
{
setSecurityInfo.invoke(handler, securityInfo);
}
catch (Exception e)
{
if( trace )
log.trace("Failed to setSecurityInfo on handler", e);
throw new LoginException("Failed to setSecurityInfo on handler, msg="
+ e.getMessage());
}
Subject subject = new Subject();
LoginContext lc = null;
lc = SubjectActions.createLoginContext(securityDomain, subject, handler);
lc.login();
return lc;
}
private boolean validateCache(DomainInfo info, Object credential,
Subject theSubject)
{
if( trace )
{
StringBuffer tmp = new StringBuffer("Begin validateCache, info=");
tmp.append(info.toString());
tmp.append(";credential.class=");
if( credential != null )
{
Class c = credential.getClass();
tmp.append(c.getName());
tmp.append('@');
tmp.append(System.identityHashCode(c));
}
else
{
tmp.append("null");
}
log.trace(tmp.toString());
}
Object subjectCredential = info.credential;
boolean isValid = false;
if( credential == null || subjectCredential == null )
{
isValid = (credential == null) && (subjectCredential == null);
}
else if( subjectCredential.getClass().isAssignableFrom(credential.getClass()) )
{
if( subjectCredential instanceof Comparable )
{
Comparable c = (Comparable) subjectCredential;
isValid = c.compareTo(credential) == 0;
}
else if( subjectCredential instanceof char[] )
{
char[] a1 = (char[]) subjectCredential;
char[] a2 = (char[]) credential;
isValid = Arrays.equals(a1, a2);
}
else if( subjectCredential instanceof byte[] )
{
byte[] a1 = (byte[]) subjectCredential;
byte[] a2 = (byte[]) credential;
isValid = Arrays.equals(a1, a2);
}
else if( subjectCredential.getClass().isArray() )
{
Object[] a1 = (Object[]) subjectCredential;
Object[] a2 = (Object[]) credential;
isValid = Arrays.equals(a1, a2);
}
else
{
isValid = subjectCredential.equals(credential);
}
}
if( isValid )
{
if( theSubject != null )
{
SubjectActions.copySubject(info.subject, theSubject);
}
}
if( trace )
log.trace("End validateCache, isValid="+isValid);
return isValid;
}
private DomainInfo getCacheInfo(Principal principal, boolean allowRefresh)
{
if( domainCache == null )
return null;
DomainInfo cacheInfo = null;
synchronized( domainCache )
{
if( allowRefresh == true )
cacheInfo = (DomainInfo) domainCache.get(principal);
else
cacheInfo = (DomainInfo) domainCache.peek(principal);
}
return cacheInfo;
}
private Subject updateCache(LoginContext lc, Subject subject,
Principal principal, Object credential)
{
if( domainCache == null )
return subject;
int lifetime = 0;
if( domainCache instanceof TimedCachePolicy )
{
TimedCachePolicy cache = (TimedCachePolicy) domainCache;
lifetime = cache.getDefaultLifetime();
}
DomainInfo info = new DomainInfo(lifetime);
info.loginCtx = lc;
info.subject = new Subject();
SubjectActions.copySubject(subject, info.subject, true);
info.credential = credential;
if( trace )
log.trace("updateCache, subject="+subject);
Set subjectGroups = subject.getPrincipals(Group.class);
Iterator iter = subjectGroups.iterator();
while( iter.hasNext() )
{
Group grp = (Group) iter.next();
String name = grp.getName();
if( name.equals("CallerPrincipal") )
{
Enumeration members = grp.members();
if( members.hasMoreElements() )
info.callerPrincipal = (Principal) members.nextElement();
}
}
if( principal == null && info.callerPrincipal == null )
{
Set subjectPrincipals = subject.getPrincipals(Principal.class);
iter = subjectPrincipals.iterator();
while( iter.hasNext() )
{
Principal p = (Principal) iter.next();
if( (p instanceof Group) == false )
info.callerPrincipal = p;
}
}
synchronized( domainCache )
{
if( domainCache.peek(principal) != null )
domainCache.remove(principal);
domainCache.insert(principal, info);
if( trace )
log.trace("Inserted cache info: "+info);
}
return info.subject;
}
private Group getSubjectRoles(Subject theSubject)
{
Set subjectGroups = theSubject.getPrincipals(Group.class);
Iterator iter = subjectGroups.iterator();
Group roles = null;
while( iter.hasNext() )
{
Group grp = (Group) iter.next();
String name = grp.getName();
if( name.equals("Roles") )
roles = grp;
}
return roles;
}
}