package org.jboss.security.auth.spi;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.Principal;
import java.security.acl.Group;
import java.security.cert.X509Certificate;
import java.util.Map;
import java.util.Enumeration;
import java.util.ArrayList;
import java.io.IOException;
import javax.naming.InitialContext;
import javax.naming.NamingException;
import javax.security.auth.Subject;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.auth.login.FailedLoginException;
import javax.security.auth.login.LoginException;
import org.jboss.security.SecurityDomain;
import org.jboss.security.auth.callback.ObjectCallback;
import org.jboss.security.auth.certs.X509CertificateVerifier;
public class BaseCertLoginModule extends AbstractServerLoginModule
{
private Principal identity;
private X509Certificate credential;
private SecurityDomain domain = null;
private X509CertificateVerifier verifier;
private boolean trace;
public void initialize(Subject subject, CallbackHandler callbackHandler,
Map sharedState, Map options)
{
super.initialize(subject, callbackHandler, sharedState, options);
trace = log.isTraceEnabled();
String sd = (String) options.get("securityDomain");
if (sd == null)
sd = "java:/jaas/other";
if( trace )
log.trace("securityDomain=" + sd);
try
{
Object tempDomain = new InitialContext().lookup(sd);
if (tempDomain instanceof SecurityDomain)
{
domain = (SecurityDomain) tempDomain;
if( trace )
{
if (domain != null)
log.trace("found domain: " + domain.getClass().getName());
else
log.trace("the domain " + sd + " is null!");
}
}
else
{
log.error("The domain " + sd + " is not a SecurityDomain. All authentication using this module will fail!");
}
}
catch (NamingException e)
{
log.error("Unable to find the securityDomain named: " + sd, e);
}
String option = (String) options.get("verifier");
if( option != null )
{
try
{
ClassLoader loader = Thread.currentThread().getContextClassLoader();
Class verifierClass = loader.loadClass(option);
verifier = (X509CertificateVerifier) verifierClass.newInstance();
}
catch(Throwable e)
{
if( trace )
log.trace("Failed to create X509CertificateVerifier", e);
IllegalArgumentException ex = new IllegalArgumentException("Invalid verifier: "+option);
ex.initCause(e);
}
}
if( trace )
log.trace("exit: initialize(Subject, CallbackHandler, Map, Map)");
}
public boolean login() throws LoginException
{
if( trace )
log.trace("enter: login()");
if (super.login() == true)
{
Object username = sharedState.get("javax.security.auth.login.name");
if( username instanceof Principal )
identity = (Principal) username;
else
{
String name = username.toString();
try
{
identity = createIdentity(name);
}
catch(Exception e)
{
log.debug("Failed to create principal", e);
throw new LoginException("Failed to create principal: "+ e.getMessage());
}
}
Object password = sharedState.get("javax.security.auth.login.password");
if (password instanceof X509Certificate)
credential = (X509Certificate) password;
else if (password != null)
{
log.debug("javax.security.auth.login.password is not X509Certificate");
super.loginOk = false;
return false;
}
return true;
}
super.loginOk = false;
Object[] info = getAliasAndCert();
String alias = (String) info[0];
credential = (X509Certificate) info[1];
if (alias == null && credential == null)
{
identity = unauthenticatedIdentity;
super.log.trace("Authenticating as unauthenticatedIdentity=" + identity);
}
if (identity == null)
{
try
{
identity = createIdentity(alias);
}
catch(Exception e)
{
log.debug("Failed to create identity for alias:"+alias, e);
}
if (!validateCredential(alias, credential))
{
log.debug("Bad credential for alias=" + alias);
throw new FailedLoginException("Supplied Credential did not match existing credential for " + alias);
}
}
if (getUseFirstPass() == true)
{
sharedState.put("javax.security.auth.login.name", alias);
sharedState.put("javax.security.auth.login.password", credential);
}
super.loginOk = true;
if( trace )
{
log.trace("User '" + identity + "' authenticated, loginOk=" + loginOk);
log.debug("exit: login()");
}
return true;
}
public boolean commit() throws LoginException
{
boolean ok = super.commit();
if( ok == true )
{
subject.getPublicCredentials().add(credential);
}
return ok;
}
protected Group[] getRoleSets() throws LoginException
{
return new Group[0];
}
protected Principal getIdentity()
{
return identity;
}
protected Object getCredentials()
{
return credential;
}
protected String getUsername()
{
String username = null;
if (getIdentity() != null)
username = getIdentity().getName();
return username;
}
protected Object[] getAliasAndCert() throws LoginException
{
if( trace )
log.trace("enter: getAliasAndCert()");
Object[] info = { null, null };
if (callbackHandler == null)
{
throw new LoginException("Error: no CallbackHandler available to collect authentication information");
}
NameCallback nc = new NameCallback("Alias: ");
ObjectCallback oc = new ObjectCallback("Certificate: ");
Callback[] callbacks = { nc, oc };
String alias = null;
X509Certificate cert = null;
X509Certificate[] certChain;
try
{
callbackHandler.handle(callbacks);
alias = nc.getName();
Object tmpCert = oc.getCredential();
if (tmpCert != null)
{
if (tmpCert instanceof X509Certificate)
{
cert = (X509Certificate) tmpCert;
if( trace )
log.trace("found cert " + cert.getSerialNumber().toString(16) + ":" + cert.getSubjectDN().getName());
}
else if( tmpCert instanceof X509Certificate[] )
{
certChain = (X509Certificate[]) tmpCert;
if( certChain.length > 0 )
cert = certChain[0];
}
else
{
String msg = "Don't know how to obtain X509Certificate from: "
+tmpCert.getClass();
log.warn(msg);
throw new LoginException(msg);
}
}
}
catch (IOException e)
{
log.debug("Failed to invoke callback", e);
throw new LoginException("Failed to invoke callback: "+e.toString());
}
catch (UnsupportedCallbackException uce)
{
throw new LoginException("CallbackHandler does not support: "
+ uce.getCallback());
}
info[0] = alias;
info[1] = cert;
if( trace )
log.trace("exit: getAliasAndCert()");
return info;
}
protected boolean validateCredential(String alias, X509Certificate cert)
{
if( trace )
log.trace("enter: validateCredentail(String, X509Certificate)");
boolean isValid = false;
KeyStore keyStore = null;
KeyStore trustStore = null;
if( domain != null )
{
keyStore = domain.getKeyStore();
trustStore = domain.getTrustStore();
}
if( trustStore == null )
trustStore = keyStore;
if( verifier != null )
{
if( trace )
log.trace("Validating cert using: "+verifier);
isValid = verifier.verify(cert, alias, keyStore, trustStore);
}
else if (keyStore != null && cert != null)
{
X509Certificate storeCert = null;
try
{
storeCert = (X509Certificate) keyStore.getCertificate(alias);
if( trace )
{
StringBuffer buf = new StringBuffer("\n\tSupplied Credential: ");
buf.append(cert.getSerialNumber().toString(16));
buf.append("\n\t\t");
buf.append(cert.getSubjectDN().getName());
buf.append("\n\n\tExisting Credential: ");
if( storeCert != null )
{
buf.append(storeCert.getSerialNumber().toString(16));
buf.append("\n\t\t");
buf.append(storeCert.getSubjectDN().getName());
buf.append("\n");
}
else
{
ArrayList aliases = new ArrayList();
Enumeration en = keyStore.aliases();
while (en.hasMoreElements())
{
aliases.add(en.nextElement());
}
buf.append("No match for alias: "+alias+", we have aliases " + aliases);
}
log.trace(buf.toString());
}
}
catch (KeyStoreException e)
{
log.warn("failed to find the certificate for " + alias, e);
}
if (cert.equals(storeCert))
isValid = true;
}
else
{
log.warn("Domain, KeyStore, or cert is null. Unable to validate the certificate.");
}
if( trace )
{
log.trace("The supplied certificate "
+ (isValid ? "matched" : "DID NOT match")
+ " the certificate in the keystore.");
log.trace("exit: validateCredentail(String, X509Certificate)");
}
return isValid;
}
}