package org.jboss.security.srp.jaas;
import java.rmi.Naming;
import java.security.Principal;
import java.util.Hashtable;
import java.util.Map;
import java.util.Set;
import java.util.ArrayList;
import java.io.Serializable;
import javax.naming.InitialContext;
import javax.security.auth.Subject;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.PasswordCallback;
import javax.security.auth.callback.TextInputCallback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.auth.login.LoginException;
import javax.security.auth.spi.LoginModule;
import org.jboss.logging.Logger;
import org.jboss.security.Util;
import org.jboss.security.auth.callback.ByteArrayCallback;
import org.jboss.security.srp.SRPClientSession;
import org.jboss.security.srp.SRPParameters;
import org.jboss.security.srp.SRPServerInterface;
public class SRPLoginModule implements LoginModule
{
private Subject subject;
private CallbackHandler handler;
private Map sharedState;
private Hashtable jndiEnv;
private String principalClassName;
private String srpServerRmiUrl;
private String srpServerJndiName;
private String username;
private char[] password;
private SRPServerInterface srpServer;
private SRPParameters params;
private Principal userPrincipal;
private Integer sessionID;
private byte[] sessionKey;
private byte[] abytes;
private Object auxChallenge;
private boolean externalRandomA;
private boolean hasAuxChallenge;
private boolean multipleSessions;
private boolean loginFailed;
private Logger log;
public SRPLoginModule()
{
}
public void initialize(Subject subject, CallbackHandler handler, Map sharedState,
Map options)
{
log = Logger.getLogger(getClass());
this.jndiEnv = new Hashtable(options);
this.subject = subject;
this.handler = handler;
this.sharedState = sharedState;
principalClassName = (String) options.get("principalClassName");
if( principalClassName != null )
log.warn("The principalClassName is no longer used, its always SRPPrincipal");
srpServerJndiName = (String) options.get("srpServerJndiName");
srpServerRmiUrl = (String) options.get("srpServerRmiUrl");
String tmp = (String) options.get("externalRandomA");
if( tmp != null )
externalRandomA = Boolean.valueOf(tmp).booleanValue();
multipleSessions = false;
tmp = (String) options.get("multipleSessions");
if( tmp != null )
multipleSessions = Boolean.valueOf(tmp).booleanValue();
tmp = (String) options.get("hasAuxChallenge");
if( tmp != null )
hasAuxChallenge = Boolean.valueOf(tmp).booleanValue();
jndiEnv.remove("principalClassName");
jndiEnv.remove("srpServerJndiName");
jndiEnv.remove("srpServerRmiUrl");
jndiEnv.remove("externalRandomA");
jndiEnv.remove("multipleSessions");
jndiEnv.remove("hasAuxChallenge");
}
public boolean login() throws LoginException
{
boolean trace = log.isTraceEnabled();
loginFailed = true;
getUserInfo();
if( srpServerJndiName != null )
{
srpServer = loadServerFromJndi(srpServerJndiName);
}
else if( srpServerRmiUrl != null )
{
srpServer = loadServer(srpServerRmiUrl);
}
else
{
throw new LoginException("No option specified to access a SRPServerInterface instance");
}
if( srpServer == null )
throw new LoginException("Failed to access a SRPServerInterface instance");
byte[] M1, M2;
SRPClientSession client = null;
try
{ if( trace )
log.trace("Getting SRP parameters for username: "+username);
Util.init();
Object[] sessionInfo = srpServer.getSRPParameters(username, multipleSessions);
params = (SRPParameters) sessionInfo[0];
sessionID = (Integer) sessionInfo[1];
if( sessionID == null )
sessionID = new Integer(0);
if( trace )
{
log.trace("SessionID: "+sessionID);
log.trace("N: "+Util.tob64(params.N));
log.trace("g: "+Util.tob64(params.g));
log.trace("s: "+Util.tob64(params.s));
log.trace("cipherAlgorithm: "+params.cipherAlgorithm);
log.trace("hashAlgorithm: "+params.hashAlgorithm);
}
byte[] hn = Util.newDigest().digest(params.N);
if( trace )
log.trace("H(N): "+Util.tob64(hn));
byte[] hg = Util.newDigest().digest(params.g);
if( trace )
{
log.trace("H(g): "+Util.tob64(hg));
log.trace("Creating SRPClientSession");
}
if( abytes != null )
client = new SRPClientSession(username, password, params, abytes);
else
client = new SRPClientSession(username, password, params);
if( trace )
log.trace("Generating client public key");
byte[] A = client.exponential();
if( trace )
log.trace("Exchanging public keys");
byte[] B = srpServer.init(username, A, sessionID.intValue());
if( trace )
log.trace("Generating server challenge");
M1 = client.response(B);
if( trace )
log.trace("Exchanging challenges");
sessionKey = client.getSessionKey();
if( auxChallenge != null )
{
auxChallenge = encryptAuxChallenge(auxChallenge, params.cipherAlgorithm,
params.cipherIV, sessionKey);
M2 = srpServer.verify(username, M1, auxChallenge, sessionID.intValue());
}
else
{
M2 = srpServer.verify(username, M1, sessionID.intValue());
}
}
catch(Exception e)
{
log.warn("Failed to complete SRP login", e);
throw new LoginException("Failed to complete SRP login, msg="+e.getMessage());
}
if( trace )
log.trace("Verifying server response");
if( client.verify(M2) == false )
throw new LoginException("Failed to validate server reply");
if( trace )
log.trace("Login succeeded");
userPrincipal = new SRPPrincipal(username, sessionID);
sharedState.put("javax.security.auth.login.name", userPrincipal);
sharedState.put("javax.security.auth.login.password", M1);
loginFailed = false;
return true;
}
public boolean commit() throws LoginException
{
if( loginFailed == true )
return false;
subject.getPrincipals().add(userPrincipal);
Set privateCredentials = subject.getPrivateCredentials();
privateCredentials.add(sessionKey);
if( sessionID != null )
privateCredentials.add(sessionID);
if( params.cipherAlgorithm != null )
{
Object secretKey = createSecretKey(params.cipherAlgorithm, sessionKey);
privateCredentials.add(secretKey);
}
privateCredentials.add(params);
return true;
}
public boolean abort() throws LoginException
{
username = null;
password = null;
return true;
}
public boolean logout() throws LoginException
{
try
{
if( subject.isReadOnly() == false )
{ Set s = subject.getPrincipals(userPrincipal.getClass());
s.remove(userPrincipal);
subject.getPrivateCredentials().remove(sessionKey);
}
if( srpServer != null )
{
srpServer.close(username, sessionID.intValue());
}
}
catch(Exception e)
{
throw new LoginException("Failed to remove user principal, "+e.getMessage());
}
return true;
}
private void getUserInfo() throws LoginException
{
String _username = (String) sharedState.get("javax.security.auth.login.name");
char[] _password = null;
if( _username != null )
{
Object pw = sharedState.get("javax.security.auth.login.password");
if( pw instanceof char[] )
_password = (char[]) pw;
else if( pw != null )
_password = pw.toString().toCharArray();
}
if( _username != null && _password != null )
{
username = _username;
password = _password;
return;
}
if( handler == null )
throw new LoginException("No CallbackHandler provied to SRPLoginModule");
NameCallback nc = new NameCallback("Username: ", "guest");
PasswordCallback pc = new PasswordCallback("Password: ", false);
ByteArrayCallback bac = new ByteArrayCallback("Public key random number: ");
TextInputCallback tic = new TextInputCallback("Auxillary challenge token: ");
ArrayList tmpList = new ArrayList();
tmpList.add(nc);
tmpList.add(pc);
if( externalRandomA == true )
tmpList.add(bac);
if( hasAuxChallenge == true )
tmpList.add(tic);
Callback[] callbacks = new Callback[tmpList.size()];
tmpList.toArray(callbacks);
try
{
handler.handle(callbacks);
username = nc.getName();
_password = pc.getPassword();
if( _password != null )
password = _password;
pc.clearPassword();
if( externalRandomA == true )
abytes = bac.getByteArray();
if( hasAuxChallenge == true )
this.auxChallenge = tic.getText();
}
catch(java.io.IOException e)
{
throw new LoginException(e.toString());
}
catch(UnsupportedCallbackException uce)
{
throw new LoginException("UnsupportedCallback: " + uce.getCallback().toString());
}
}
private SRPServerInterface loadServerFromJndi(String jndiName)
{
SRPServerInterface server = null;
try
{
InitialContext ctx = new InitialContext(jndiEnv);
server = (SRPServerInterface) ctx.lookup(jndiName);
}
catch(Exception e)
{
log.error("Failed to lookup("+jndiName+")", e);
}
return server;
}
private SRPServerInterface loadServer(String rmiUrl)
{
SRPServerInterface server = null;
try
{
server = (SRPServerInterface) Naming.lookup(rmiUrl);
}
catch(Exception e)
{
log.error("Failed to lookup("+rmiUrl+")", e);
}
return server;
}
private Object encryptAuxChallenge(Object challenge, String cipherAlgorithm,
byte[] cipherIV, Object key)
throws LoginException
{
if( cipherAlgorithm == null )
return challenge;
Object sealedObject = null;
try
{
Serializable data = (Serializable) challenge;
Object tmpKey = Util.createSecretKey(cipherAlgorithm, key);
sealedObject = Util.createSealedObject(cipherAlgorithm, tmpKey, cipherIV, data);
}
catch(Exception e)
{
log.error("Failed to encrypt aux challenge", e);
throw new LoginException("Failed to encrypt aux challenge");
}
return sealedObject;
}
private Object createSecretKey(String cipherAlgorithm, Object key) throws LoginException
{
Object secretKey = null;
try
{
secretKey = Util.createSecretKey(cipherAlgorithm, key);
}
catch(Exception e)
{
log.error("Failed to create SecretKey", e);
throw new LoginException("Failed to create SecretKey");
}
return secretKey;
}
}