package org.jboss.security.srp;
import java.math.BigInteger;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Arrays;
import org.jboss.logging.Logger;
import org.jboss.security.Util;
public class SRPClientSession
{
private static Logger log = Logger.getLogger(SRPClientSession.class);
private SRPParameters params;
private BigInteger N;
private BigInteger g;
private BigInteger x;
private BigInteger v;
private byte[] s;
private BigInteger a;
private BigInteger A;
private byte[] K;
private MessageDigest clientHash;
private MessageDigest serverHash;
private static int A_LEN = 64;
public SRPClientSession(String username, char[] password, SRPParameters params)
{
this(username, password, params, null);
}
public SRPClientSession(String username, char[] password, SRPParameters params,
byte[] abytes)
{
try
{
Util.init();
}
catch(NoSuchAlgorithmException e)
{
}
this.params = params;
this.g = new BigInteger(1, params.g);
this.N = new BigInteger(1, params.N);
if( abytes != null )
{
if( 8*abytes.length != A_LEN )
throw new IllegalArgumentException("The abytes param must be "
+(A_LEN/8)+" in length, abytes.length="+abytes.length);
this.a = new BigInteger(abytes);
}
if( log.isTraceEnabled() )
log.trace("g: "+Util.tob64(params.g));
byte[] xb = Util.calculatePasswordHash(username, password, params.s);
if( log.isTraceEnabled() )
log.trace("x: "+Util.tob64(xb));
this.x = new BigInteger(1, xb);
this.v = g.modPow(x, N); if( log.isTraceEnabled() )
log.trace("v: "+Util.tob64(v.toByteArray()));
serverHash = Util.newDigest();
clientHash = Util.newDigest();
byte[] hn = Util.newDigest().digest(params.N);
if( log.isTraceEnabled() )
log.trace("H(N): "+Util.tob64(hn));
byte[] hg = Util.newDigest().digest(params.g);
if( log.isTraceEnabled() )
log.trace("H(g): "+Util.tob64(hg));
byte[] hxg = Util.xor(hn, hg, 20);
if( log.isTraceEnabled() )
log.trace("H(N) xor H(g): "+Util.tob64(hxg));
clientHash.update(hxg);
if( log.isTraceEnabled() )
{
MessageDigest tmp = Util.copy(clientHash);
log.trace("H[H(N) xor H(g)]: "+Util.tob64(tmp.digest()));
}
clientHash.update(Util.newDigest().digest(username.getBytes()));
if( log.isTraceEnabled() )
{
MessageDigest tmp = Util.copy(clientHash);
log.trace("H[H(N) xor H(g) | H(U)]: "+Util.tob64(tmp.digest()));
}
clientHash.update(params.s);
if( log.isTraceEnabled() )
{
MessageDigest tmp = Util.copy(clientHash);
log.trace("H[H(N) xor H(g) | H(U) | s]: "+Util.tob64(tmp.digest()));
}
K = null;
}
public byte[] exponential()
{
byte[] Abytes = null;
if(A == null)
{
if( a == null )
{
BigInteger one = BigInteger.ONE;
do
{
a = new BigInteger(A_LEN, Util.getPRNG());
} while(a.compareTo(one) <= 0);
}
A = g.modPow(a, N);
Abytes = Util.trim(A.toByteArray());
clientHash.update(Abytes);
if( log.isTraceEnabled() )
{
MessageDigest tmp = Util.copy(clientHash);
log.trace("H[H(N) xor H(g) | H(U) | s | A]: "+Util.tob64(tmp.digest()));
}
serverHash.update(Abytes);
}
return Abytes;
}
public byte[] response(byte[] Bbytes) throws NoSuchAlgorithmException
{
clientHash.update(Bbytes);
if( log.isTraceEnabled() )
{
MessageDigest tmp = Util.copy(clientHash);
log.trace("H[H(N) xor H(g) | H(U) | s | A | B]: "+Util.tob64(tmp.digest()));
}
byte[] hB = Util.newDigest().digest(Bbytes);
byte[] ub =
{hB[0], hB[1], hB[2], hB[3]};
BigInteger B = new BigInteger(1, Bbytes);
if( log.isTraceEnabled() )
log.trace("B: "+Util.tob64(B.toByteArray()));
if( B.compareTo(v) < 0 )
B = B.add(N);
if( log.isTraceEnabled() )
log.trace("B': "+Util.tob64(B.toByteArray()));
if( log.isTraceEnabled() )
log.trace("v: "+Util.tob64(v.toByteArray()));
BigInteger u = new BigInteger(1, ub);
if( log.isTraceEnabled() )
log.trace("u: "+Util.tob64(u.toByteArray()));
BigInteger B_v = B.subtract(v);
if( log.isTraceEnabled() )
log.trace("B - v: "+Util.tob64(B_v.toByteArray()));
BigInteger a_ux = a.add(u.multiply(x));
if( log.isTraceEnabled() )
log.trace("a + u * x: "+Util.tob64(a_ux.toByteArray()));
BigInteger S = B_v.modPow(a_ux, N);
if( log.isTraceEnabled() )
log.trace("S: "+Util.tob64(S.toByteArray()));
MessageDigest sessionDigest = MessageDigest.getInstance(params.hashAlgorithm);
K = sessionDigest.digest(S.toByteArray());
if( log.isTraceEnabled() )
log.trace("K: "+Util.tob64(K));
clientHash.update(K);
byte[] M1 = clientHash.digest();
if( log.isTraceEnabled() )
log.trace("M1: H[H(N) xor H(g) | H(U) | s | A | B | K]: "+Util.tob64(M1));
serverHash.update(M1);
serverHash.update(K);
if( log.isTraceEnabled() )
{
MessageDigest tmp = Util.copy(serverHash);
log.trace("H[A | M1 | K]: "+Util.tob64(tmp.digest()));
}
return M1;
}
public boolean verify(byte[] M2)
{
byte[] myM2 = serverHash.digest();
boolean valid = Arrays.equals(M2, myM2);
if( log.isTraceEnabled() )
{
log.trace("verify serverM2: "+Util.tob64(M2));
log.trace("verify M2: "+Util.tob64(myM2));
}
return valid;
}
public byte[] getSessionKey() throws SecurityException
{
SecurityManager sm = System.getSecurityManager();
if( sm != null )
{
SRPPermission p = new SRPPermission("getSessionKey");
sm.checkPermission(p);
}
return K;
}
}