package org.jboss.security.srp.jaas;
import java.security.Principal;
import java.util.Arrays;
import java.util.Map;
import java.util.Set;
import javax.crypto.spec.SecretKeySpec;
import javax.naming.InitialContext;
import javax.naming.NamingException;
import javax.security.auth.Subject;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.auth.login.LoginException;
import javax.security.auth.spi.LoginModule;
import org.jboss.logging.Logger;
import org.jboss.security.auth.callback.SecurityAssociationCallback;
import org.jboss.security.srp.SRPParameters;
import org.jboss.security.srp.SRPServerSession;
import org.jboss.security.srp.SRPSessionKey;
import org.jboss.util.CachePolicy;
public class SRPCacheLoginModule implements LoginModule
{
private static Logger log = Logger.getLogger(SRPCacheLoginModule.class);
private Subject subject;
private CallbackHandler handler;
private Map sharedState;
private String domainName;
private String cacheJndiName;
private byte[] clientChallenge;
private SRPServerSession session;
private Principal userPrincipal;
private boolean loginFailed;
public SRPCacheLoginModule()
{
}
public void initialize(Subject subject, CallbackHandler handler, Map sharedState, Map options)
{
this.subject = subject;
this.handler = handler;
this.sharedState = sharedState;
cacheJndiName = (String) options.get("cacheJndiName");
log.trace("cacheJndiName="+cacheJndiName);
domainName = (String) options.get("domainName");
}
public boolean login() throws LoginException
{
loginFailed = true;
getUserInfo();
String username = userPrincipal.getName();
try
{
if( cacheJndiName == null )
throw new LoginException("Required cacheJndiName option not set");
InitialContext iniCtx = new InitialContext();
CachePolicy cache = (CachePolicy) iniCtx.lookup(cacheJndiName);
SRPSessionKey key;
if( userPrincipal instanceof SRPPrincipal )
{
SRPPrincipal srpPrincpal = (SRPPrincipal) userPrincipal;
key = new SRPSessionKey(username, srpPrincpal.getSessionID());
}
else
{
key = new SRPSessionKey(username);
}
Object cacheCredential = cache.get(key);
if( cacheCredential == null )
{
throw new LoginException("No SRP session found for: "+key);
}
log.trace("Found SRP cache credential: "+cacheCredential);
if( cacheCredential instanceof SRPServerSession )
{
session = (SRPServerSession) cacheCredential;
if( validateCache(session) == false )
throw new LoginException("Failed to validate SRP session key for: "+key);
}
else
{
throw new LoginException("Unknown type of cache credential: "+cacheCredential.getClass());
}
}
catch(NamingException e)
{
log.error("Failed to load SRP auth cache", e);
throw new LoginException("Failed to load SRP auth cache: "+e.toString(true));
}
log.trace("Login succeeded");
sharedState.put("javax.security.auth.login.name", username);
sharedState.put("javax.security.auth.login.password", clientChallenge);
loginFailed = false;
return true;
}
public boolean commit() throws LoginException
{
if( loginFailed == true )
return false;
Set principals = subject.getPrincipals();
principals.add(userPrincipal);
subject.getPublicCredentials().add(clientChallenge);
byte[] sessionKey = session.getSessionKey();
SRPParameters params = session.getParameters();
Set privateCredentials = subject.getPrivateCredentials();
privateCredentials.add(params);
if( params.cipherAlgorithm != null )
{
SecretKeySpec secretKey = new SecretKeySpec(sessionKey, params.cipherAlgorithm);
privateCredentials.add(secretKey);
}
else
{
privateCredentials.add(sessionKey);
}
return true;
}
public boolean abort() throws LoginException
{
userPrincipal = null;
clientChallenge = null;
return true;
}
public boolean logout() throws LoginException
{
try
{
if( subject.isReadOnly() == false )
{ Set s = subject.getPrincipals(userPrincipal.getClass());
s.remove(userPrincipal);
subject.getPublicCredentials().remove(clientChallenge);
byte[] sessionKey = session.getSessionKey();
SRPParameters params = session.getParameters();
Set privateCredentials = subject.getPrivateCredentials();
if( params.cipherAlgorithm != null )
{
SecretKeySpec secretKey = new SecretKeySpec(sessionKey, params.cipherAlgorithm);
privateCredentials.remove(secretKey);
}
else
{
privateCredentials.remove(sessionKey);
}
privateCredentials.remove(params);
}
}
catch(Exception e)
{
throw new LoginException("Failed to remove commit information, "+e.getMessage());
}
return true;
}
private void getUserInfo() throws LoginException
{
if( handler == null )
throw new LoginException("No CallbackHandler provied");
SecurityAssociationCallback sac = new SecurityAssociationCallback();
Callback[] callbacks = { sac };
try
{
handler.handle(callbacks);
userPrincipal = sac.getPrincipal();
clientChallenge = (byte[]) sac.getCredential();
sac.clearCredential();
}
catch(java.io.IOException e)
{
throw new LoginException(e.toString());
}
catch(UnsupportedCallbackException uce)
{
throw new LoginException("UnsupportedCallback: " + uce.getCallback().toString());
}
catch(ClassCastException e)
{
throw new LoginException("Credential info is not of type byte[], "+ e.getMessage());
}
}
private boolean validateCache(SRPServerSession session)
{
byte[] challenge = session.getClientResponse();
boolean isValid = Arrays.equals(challenge, clientChallenge);
return isValid;
}
}