/*
 * JBoss, the OpenSource J2EE webOS
 *
 * Distributable under LGPL license.
 * See terms of license at gnu.org.
 */

package org.jboss.web.tomcat.tc5.session;

import java.util.Collection;
import java.util.Iterator;
import java.util.Random;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;

import org.jboss.logging.Logger;

/**
 * Unique session id generator
 *
 * @author Ben Wang
 */
public class SessionIDGenerator
{
   protected final static int SESSION_ID_BYTES = 16; // We want 16 Bytes for the session-id
   protected final static String SESSION_ID_HASH_ALGORITHM = "MD5";
   protected final static String SESSION_ID_RANDOM_ALGORITHM = "SHA1PRNG";
   protected final static String SESSION_ID_RANDOM_ALGORITHM_ALT = "IBMSecureRandom";
   protected Logger log = Logger.getLogger(SessionIDGenerator.class);

   protected MessageDigest digest = null;
   protected Random random = null;
   protected static SessionIDGenerator s_;
   
   protected String sessionIdAlphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+-*";

   public static SessionIDGenerator getInstance()
   {
      if (s_ == null) s_ = new SessionIDGenerator();
      return s_;
   }

   /**
    * The SessionIdAlphabet is the set of characters used to create a session Id
    */
   public void setSessionIdAlphabet(String sessionIdAlphabet) 
   {
      if (sessionIdAlphabet.length() != 65) {
         throw new IllegalArgumentException("SessionIdAlphabet must be exactly 65 characters long");
      }

      checkDuplicateChars(sessionIdAlphabet);

      this.sessionIdAlphabet = sessionIdAlphabet;
   }

   protected void checkDuplicateChars(String sessionIdAlphabet) {
      char[] alphabet = sessionIdAlphabet.toCharArray();
      for (int i=0; i < alphabet.length; i++) {
          if (!uniqueChar(alphabet[i], sessionIdAlphabet)) {
              throw new IllegalArgumentException("All chars in SessionIdAlphabet must be unique");
          }
      }
   }
      
   // does a character appear in the String once and only once?
   protected boolean uniqueChar(char c, String s) {
       int firstIndex = s.indexOf(c);
       if (firstIndex == -1) return false;
       return s.indexOf(c, firstIndex + 1) == -1;
   }

   /**
    * The SessionIdAlphabet is the set of characters used to create a session Id
    */
   public String getSessionIdAlphabet() {
      return this.sessionIdAlphabet;
   }
   
   public synchronized String getSessionId()
   {
      String id = generateSessionId();
      if (log.isDebugEnabled())
         log.debug("getSessionId called: " + id);
      return id;
   }


   /**
    * Generate a session-id that is not guessable
    *
    * @return generated session-id
    */
   protected synchronized String generateSessionId()
   {
      if (this.digest == null)
      {
         this.digest = getDigest();
      }

      if (this.random == null)
      {
         this.random = getRandom();
      }

      byte[] bytes = new byte[SESSION_ID_BYTES];

      // get random bytes
      this.random.nextBytes(bytes);

      // Hash the random bytes
      bytes = this.digest.digest(bytes);

      // Render the result as a String of hexadecimal digits
      return encode(bytes);
   }

   /**
    * Encode the bytes into a String with a slightly modified Base64-algorithm
    * This code was written by Kevin Kelley <kelley@ruralnet.net>
    * and adapted by Thomas Peuss <jboss@peuss.de>
    *
    * @param data The bytes you want to encode
    * @return the encoded String
    */
   protected String encode(byte[] data)
   {
      char[] out = new char[((data.length + 2) / 3) * 4];
      char[] alphabet = this.sessionIdAlphabet.toCharArray();

      //
      // 3 bytes encode to 4 chars.  Output is always an even
      // multiple of 4 characters.
      //
      for (int i = 0, index = 0; i < data.length; i += 3, index += 4)
      {
         boolean quad = false;
         boolean trip = false;

         int val = (0xFF & (int) data[i]);
         val <<= 8;
         if ((i + 1) < data.length)
         {
            val |= (0xFF & (int) data[i + 1]);
            trip = true;
         }
         val <<= 8;
         if ((i + 2) < data.length)
         {
            val |= (0xFF & (int) data[i + 2]);
            quad = true;
         }
         out[index + 3] = alphabet[(quad ? (val & 0x3F) : 64)];
         val >>= 6;
         out[index + 2] = alphabet[(trip ? (val & 0x3F) : 64)];
         val >>= 6;
         out[index + 1] = alphabet[val & 0x3F];
         val >>= 6;
         out[index + 0] = alphabet[val & 0x3F];
      }
      return new String(out);
   }

   /**
    * get a random-number generator
    *
    * @return a random-number generator
    */
   protected synchronized Random getRandom()
   {
      long seed;
      Random random = null;

      // Mix up the seed a bit
      seed = System.currentTimeMillis();
      seed ^= Runtime.getRuntime().freeMemory();

      try
      {
         random = SecureRandom.getInstance(SESSION_ID_RANDOM_ALGORITHM);
      }
      catch (NoSuchAlgorithmException e)
      {
         try
         {
            random = SecureRandom.getInstance(SESSION_ID_RANDOM_ALGORITHM_ALT);
         }
         catch (NoSuchAlgorithmException e_alt)
         {
            log.error("Could not generate SecureRandom for session-id randomness", e);
            log.error("Could not generate SecureRandom for session-id randomness", e_alt);
            return null;
         }
      }

      // set the generated seed for this PRNG
      random.setSeed(seed);

      return random;
   }

   /**
    * get a MessageDigest hash-generator
    *
    * @return a hash generator
    */
   protected synchronized MessageDigest getDigest()
   {
      MessageDigest digest = null;

      try
      {
         digest = MessageDigest.getInstance(SESSION_ID_HASH_ALGORITHM);
      }
      catch (NoSuchAlgorithmException e)
      {
         log.error("Could not generate MessageDigest for session-id hashing", e);
         return null;
      }

      return digest;
   }

}