package org.jboss.security.srp;
import java.io.Serializable;
import java.math.BigInteger;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Arrays;
import org.jboss.logging.Logger;
import org.jboss.security.Util;
public class SRPServerSession implements Serializable
{
static final long serialVersionUID = -2448005747721323704L;
private static int B_LEN = 64; private static Logger log = Logger.getLogger(SRPServerSession.class);
private SRPParameters params;
private BigInteger N;
private BigInteger g;
private BigInteger v;
private BigInteger b;
private BigInteger B;
private byte[] K;
private transient MessageDigest clientHash;
private byte[] M1;
private transient MessageDigest serverHash;
private byte[] M2;
public SRPServerSession(String username, byte[] vb, SRPParameters params)
{
this.params = params;
this.v = new BigInteger(1, vb);
this.g = new BigInteger(1, params.g);
this.N = new BigInteger(1, params.N);
if( log.isTraceEnabled() )
log.trace("g: "+Util.tob64(params.g));
if( log.isTraceEnabled() )
log.trace("v: "+Util.tob64(vb));
serverHash = Util.newDigest();
clientHash = Util.newDigest();
byte[] hn = Util.newDigest().digest(params.N);
if( log.isTraceEnabled() )
log.trace("H(N): "+Util.tob64(hn));
byte[] hg = Util.newDigest().digest(params.g);
if( log.isTraceEnabled() )
log.trace("H(g): "+Util.tob64(hg));
byte[] hxg = Util.xor(hn, hg, 20);
if( log.isTraceEnabled() )
log.trace("H(N) xor H(g): "+Util.tob64(hxg));
clientHash.update(hxg);
if( log.isTraceEnabled() )
{
MessageDigest tmp = Util.copy(clientHash);
log.trace("H[H(N) xor H(g)]: "+Util.tob64(tmp.digest()));
}
clientHash.update(Util.newDigest().digest(username.getBytes()));
if( log.isTraceEnabled() )
{
MessageDigest tmp = Util.copy(clientHash);
log.trace("H[H(N) xor H(g) | H(U)]: "+Util.tob64(tmp.digest()));
}
clientHash.update(params.s);
if( log.isTraceEnabled() )
{
MessageDigest tmp = Util.copy(clientHash);
log.trace("H[H(N) xor H(g) | H(U) | s]: "+Util.tob64(tmp.digest()));
}
K = null;
}
public SRPParameters getParameters()
{
return params;
}
public byte[] exponential()
{
if(B == null)
{
BigInteger one = BigInteger.valueOf(1);
do
{
b = new BigInteger(B_LEN, Util.getPRNG());
} while(b.compareTo(one) <= 0);
B = v.add(g.modPow(b, N));
if(B.compareTo(N) >= 0)
B = B.subtract(N);
}
return Util.trim(B.toByteArray());
}
public void buildSessionKey(byte[] ab) throws NoSuchAlgorithmException
{
if( log.isTraceEnabled() )
log.trace("A: "+Util.tob64(ab));
byte[] nb = Util.trim(B.toByteArray());
clientHash.update(ab);
if( log.isTraceEnabled() )
{
MessageDigest tmp = Util.copy(clientHash);
log.trace("H[H(N) xor H(g) | H(U) | s | A]: "+Util.tob64(tmp.digest()));
}
clientHash.update(nb);
if( log.isTraceEnabled() )
{
MessageDigest tmp = Util.copy(clientHash);
log.trace("H[H(N) xor H(g) | H(U) | s | A | B]: "+Util.tob64(tmp.digest()));
}
serverHash.update(ab);
byte[] hB = Util.newDigest().digest(nb);
byte[] ub =
{hB[0], hB[1], hB[2], hB[3]};
BigInteger A = new BigInteger(1, ab);
if( log.isTraceEnabled() )
log.trace("A: "+Util.tob64(A.toByteArray()));
if( log.isTraceEnabled() )
log.trace("B: "+Util.tob64(B.toByteArray()));
if( log.isTraceEnabled() )
log.trace("v: "+Util.tob64(v.toByteArray()));
BigInteger u = new BigInteger(1, ub);
if( log.isTraceEnabled() )
log.trace("u: "+Util.tob64(u.toByteArray()));
BigInteger A_v2u = A.multiply(v.modPow(u, N)).mod(N);
if( log.isTraceEnabled() )
log.trace("A * v^u: "+Util.tob64(A_v2u.toByteArray()));
BigInteger S = A_v2u.modPow(b, N);
if( log.isTraceEnabled() )
log.trace("S: "+Util.tob64(S.toByteArray()));
MessageDigest sessionDigest = MessageDigest.getInstance(params.hashAlgorithm);
K = sessionDigest.digest(S.toByteArray());
if( log.isTraceEnabled() )
log.trace("K: "+Util.tob64(K));
clientHash.update(K);
if( log.isTraceEnabled() )
{
MessageDigest tmp = Util.copy(clientHash);
log.trace("H[H(N) xor H(g) | H(U) | s | A | B | K]: "+Util.tob64(tmp.digest()));
}
}
public byte[] getSessionKey() throws SecurityException
{
SecurityManager sm = System.getSecurityManager();
if( sm != null )
{
SRPPermission p = new SRPPermission("getSessionKey");
sm.checkPermission(p);
}
return K;
}
public byte[] getServerResponse()
{
if( M2 == null )
M2 = serverHash.digest();
return M2;
}
public byte[] getClientResponse()
{
return M1;
}
public boolean verify(byte[] clientM1)
{
boolean valid = false;
M1 = clientHash.digest();
if( log.isTraceEnabled() )
{
log.trace("verify M1: "+Util.tob64(M1));
log.trace("verify clientM1: "+Util.tob64(clientM1));
}
if( Arrays.equals(clientM1, M1) )
{
serverHash.update(M1);
serverHash.update(K);
if( log.isTraceEnabled() )
{
MessageDigest tmp = Util.copy(serverHash);
log.trace("H(A | M1 | K)"+Util.tob64(tmp.digest()));
}
valid = true;
}
return valid;
}
}