package org.jboss.test;
import java.io.Serializable;
import java.math.BigInteger;
import java.rmi.RemoteException;
import java.security.KeyException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import org.apache.log4j.Category;
import org.apache.log4j.ConsoleAppender;
import org.apache.log4j.NDC;
import org.apache.log4j.PatternLayout;
import org.jboss.logging.XLevel;
import org.jboss.logging.Logger;
import org.jboss.security.Util;
import org.jboss.security.srp.SRPConf;
import org.jboss.security.srp.SRPServerInterface;
import org.jboss.security.srp.SRPClientSession;
import org.jboss.security.srp.SRPParameters;
import org.jboss.security.srp.SRPServerSession;
public class TestProtocol extends junit.framework.TestCase
{
static Logger log = Logger.getLogger(TestProtocol.class);
String username = "jduke";
char[] password = "theduke".toCharArray();
SRPServerInterface server;
static class TstImpl implements SRPServerInterface
{
SRPParameters params;
SRPServerSession session;
char[] password;
public Object[] getSRPParameters(String username, boolean mutipleSessions)
throws KeyException, RemoteException
{
return new Object[0];
}
public byte[] init(String username, byte[] A, int sessionID) throws SecurityException,
NoSuchAlgorithmException, RemoteException
{
return new byte[0];
}
public byte[] verify(String username, byte[] M1, int sessionID)
throws SecurityException, RemoteException
{
return new byte[0];
}
public byte[] verify(String username, byte[] M1, Object auxChallenge)
throws SecurityException, RemoteException
{
return new byte[0];
}
public byte[] verify(String username, byte[] M1, Object auxChallenge, int sessionID)
throws SecurityException, RemoteException
{
return new byte[0];
}
public void close(String username, int sessionID) throws SecurityException, RemoteException
{
}
TstImpl(char[] password, String salt)
{
BigInteger N = SRPConf.getDefaultParams().N();
log.trace("N: "+Util.tob64(N.toByteArray()));
BigInteger g = SRPConf.getDefaultParams().g();
log.trace("g: "+Util.tob64(g.toByteArray()));
byte[] Nb = SRPConf.getDefaultParams().Nbytes();
log.trace("N': "+Util.tob64(params.N));
byte[] gb = SRPConf.getDefaultParams().gbytes();
log.trace("g': "+Util.tob64(params.g));
byte[] hn = Util.newDigest().digest(params.N);
log.trace("H(N): "+Util.tob64(hn));
byte[] hg = Util.newDigest().digest(params.g);
log.trace("H(g): "+Util.tob64(hg));
byte[] sb = Util.fromb64(salt);
this.password = password;
params = new SRPParameters(Nb, gb, sb);
}
public SRPParameters getSRPParameters(String username) throws KeyException, RemoteException
{
return params;
}
public byte[] init(String username,byte[] A) throws SecurityException,
NoSuchAlgorithmException, RemoteException
{
byte[] v = Util.calculateVerifier(username, password, params.s, params.N, params.g);
session = new SRPServerSession(username, v, params);
byte[] B = session.exponential();
session.buildSessionKey(A);
return B;
}
public byte[] verify(String username, byte[] M1) throws SecurityException, RemoteException
{
if( session.verify(M1) == false )
throw new SecurityException("Failed to verify M1");
return session.getServerResponse();
}
public void close(String username) throws SecurityException, RemoteException
{
}
}
public TestProtocol(String name)
{
super(name);
}
protected void setUp() throws Exception
{
Category root = Category.getRoot();
root.setLevel(XLevel.TRACE);
root.addAppender(new ConsoleAppender(new PatternLayout("%x%m%n")));
Util.init();
NDC.push("S,");
server = new TstImpl(password, "123456");
NDC.pop();
NDC.remove();
}
public void testProtocol() throws Exception
{
SRPParameters params = server.getSRPParameters(username);
NDC.push("C,");
SRPClientSession client = new SRPClientSession(username, password, params);
byte[] A = client.exponential();
NDC.pop();
NDC.push("S,");
byte[] B = server.init(username, A);
NDC.pop();
NDC.push("C,");
byte[] M1 = client.response(B);
NDC.pop();
NDC.push("S,");
byte[] M2 = server.verify(username, M1);
NDC.pop();
NDC.push("C,");
if( client.verify(M2) == false )
throw new SecurityException("Failed to validate server reply");
NDC.pop();
NDC.remove();
}
public static void main(String args[])
{
long start = System.currentTimeMillis();
try
{
TestProtocol tst = new TestProtocol("main");
tst.setUp();
tst.testProtocol();
}
catch(Exception e)
{
e.printStackTrace(System.out);
}
finally
{
long end = System.currentTimeMillis();
System.out.println("Elapsed time = "+(end - start));
}
}
}