/***************************************
 *                                     *
 *  JBoss: The OpenSource J2EE WebOS   *
 *                                     *
 *  Distributable under LGPL license.  *
 *  See terms of license at gnu.org.   *
 *                                     *
 ***************************************/
package org.jboss.remoting.transport.socket;

import org.jboss.remoting.CannotConnectException;
import org.jboss.remoting.ConnectionFailedException;
import org.jboss.remoting.InvokerLocator;
import org.jboss.remoting.RemoteClientInvoker;
import org.jboss.remoting.loading.ObjectInputStreamWithClassLoader;
import org.jboss.remoting.marshal.Marshaller;
import org.jboss.remoting.marshal.UnMarshaller;
import org.jboss.remoting.marshal.serializable.SerializableMarshaller;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.InputStream;
import java.net.InetAddress;
import java.net.Socket;
import java.rmi.ConnectException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Map;

/**
 * SocketClientInvoker uses Sockets to remotely connect to the a remote ServerInvoker, which
 * must be a SocketServerInvoker.
 *
 * @author <a href="mailto:jhaynie@vocalocity.net">Jeff Haynie</a>
 * @author <a href="mailto:telrod@e2technologies.net">Tom Elrod</a>
 * @version $Revision: 1.8.8.2 $
 */
public class SocketClientInvoker extends RemoteClientInvoker
{
   private InetAddress addr;
   private int port;

   public static final String TCP_NODELAY_FLAG = "enableTcpNoDelay";
   public static final String MAX_POOL_SIZE_FLAG = "clientMaxPoolSize";
   public static final String SO_TIMEOUT_FLAG = "socketTimeout";

   public static final int SO_TIMEOUT_DEFAULT = 60000; // 60 seconds.
   public static final boolean TCP_NODELAY_DEFAULT = false;

   // Performance measurements
   public static long getSocketTime = 0;
   public static long readTime = 0;
   public static long writeTime = 0;
   public static long serializeTime = 0;
   public static long deserializeTime = 0;

   /**
    * If the TcpNoDelay option should be used on the socket.
    */
   protected boolean enableTcpNoDelay = TCP_NODELAY_DEFAULT;

   protected int timeout = SO_TIMEOUT_DEFAULT;

   /**
    * Set number of retries in getSocket method
    */
   public static final int MAX_RETRIES = 10;
   public static long usedPooled = 0;

   protected int numberOfRetries = MAX_RETRIES;

   /**
    * Pool for this invoker.  This is shared between all
    * instances of proxies attached to a specific invoker
    */
   protected LinkedList pool = null;

   /**
    * connection information
    */
   protected ServerAddress address;

   protected static HashMap connectionPools = new HashMap();
   protected int maxPoolSize = 10;


   public SocketClientInvoker(InvokerLocator locator)
         throws IOException
   {
      super(locator);
      try
      {
         setup();
      }
      catch(Exception ex)
      {
         throw new RuntimeException(ex.getMessage());
      }
   }

   protected void setup()
         throws Exception
   {
      this.addr = InetAddress.getByName(locator.getHost());
      this.port = locator.getPort();

      configureParameters();

      address = new ServerAddress(addr.getHostAddress(), port, enableTcpNoDelay, timeout);
   }

   private void configureParameters()
   {
      Map params = locator.getParameters();
      if(params != null)
      {
         // look for enableTcpNoDelay param
         Object val = params.get(TCP_NODELAY_FLAG);
         if(val != null)
         {
            try
            {
               boolean bVal = Boolean.valueOf((String) val).booleanValue();
               enableTcpNoDelay = bVal;
               log.debug("Setting SocketClientInvoker::enableTcpNoDelay to: " + enableTcpNoDelay);
            }
            catch(Exception e)
            {
               log.warn("Could not convert " + TCP_NODELAY_FLAG + " value of " + val + " to a boolean value.");
            }
         }
         // look for maxPoolSize param
         val = params.get(MAX_POOL_SIZE_FLAG);
         if(val != null)
         {
            try
            {
               int nVal = Integer.valueOf((String) val).intValue();
               maxPoolSize = nVal;
               log.debug("Setting SocketClientInvoker::maxPoolSize to: " + maxPoolSize);
            }
            catch(Exception e)
            {
               log.warn("Could not convert " + MAX_POOL_SIZE_FLAG + " value of " + val + " to a int value.");
            }
         }
         // look for socketTimeout param
         val = params.get(SO_TIMEOUT_FLAG);
         if(val != null)
         {
            try
            {
               int nVal = Integer.valueOf((String) val).intValue();
               timeout = nVal;
               log.debug("Setting SocketClientInvoker::timeout to: " + timeout);
            }
            catch(Exception e)
            {
               log.warn("Could not convert " + SO_TIMEOUT_FLAG + " value of " + val + " to a int value.");
            }
         }
      }
   }

   protected void finalize() throws Throwable
   {
      disconnect();
      super.finalize();
   }

   protected synchronized void handleConnect()
         throws ConnectionFailedException
   {
      initPool();
   }

   protected synchronized void handleDisconnect()
   {
      clearPools();
   }

   /**
    * Each implementation of the remote client invoker should have
    * a default data type that is uses in the case it is not specified
    * in the invoker locator uri.
    *
    * @return
    */
   protected String getDefaultDataType()
   {
      return SerializableMarshaller.DATATYPE;
   }

   /**
    * @param sessionId
    * @param invocation
    * @param marshaller
    * @return
    * @throws java.io.IOException
    * @throws org.jboss.remoting.ConnectionFailedException
    *
    */
   protected Object transport(String sessionId, Object invocation, Map metadata,
                              Marshaller marshaller, UnMarshaller unmarshaller)
         throws IOException, ConnectionFailedException, ClassNotFoundException
   {

      /**
       * //TODO: -TME Need to fully javadoc.
       * //TODO: -TME Need to think more on what the signature should be for this.  Have to pass the marshaller
       * to the actual transport implementation as will be up to the transport to figure out how
       * it should get the data from the marshaller (streamed or one chunck).  Am passing the invocation
       * at this point, because don't know if will need to do anything with it later on down the call stack
       * and currently, the marshaller does not have a reference to it (so HAVE to pass to marshaller at
       * some point).  Could change to pass invocation to MarshalFactory when getting Marshaller and
       * add it to Marshaller constructor, but then not part of it's interface.
       */

      Object response = null;
      long start = System.currentTimeMillis();
      ClientSocket socket = null;
      try
      {
         socket = getConnection();
      }
      catch(Exception e)
      {
         throw new CannotConnectException("Can not get connection to server.  Problem establishing socket connection.", e);
      }
      long end = System.currentTimeMillis() - start;
      getSocketTime += end;
      try
      {
         marshaller.write(invocation, socket.oout);

         end = System.currentTimeMillis() - start;
         writeTime += end;
         start = System.currentTimeMillis();

         response = unmarshaller.read(socket.oin, null);

         end = System.currentTimeMillis() - start;
         readTime += end;
      }
      catch(Exception ex)
      {
         try
         {
            socket.socket.close();
         }
         catch(Exception ignored)
         {
         }
         log.error("Got marshalling exception, exiting", ex);
         if(ex instanceof ClassNotFoundException)
         {
            //TODO: -TME Add better exception handling for class not found exception
            log.error("Error loading classes from remote call result.", ex);
            throw (ClassNotFoundException) ex;
         }

         throw new ConnectException("Failed to communicate.  Problem during marshalling/unmarshalling", ex);
      }

      // Put socket back in pool for reuse
      synchronized(pool)
      {
         if(pool.size() < maxPoolSize)
         {
            pool.add(socket);
         }
         else
         {
            try
            {
               socket.socket.close();
            }
            catch(Exception ignored)
            {
            }
         }
      }

      // Return response
      if(log.isDebugEnabled())
      {
         log.debug("Response: " + response);
      }

      return response;

//      try
//      {
//         if (response instanceof Exception)
//         {
//            throw ((Exception)response);
//         }
//         if (response instanceof MarshalledObject)
//         {
//            return ((MarshalledObject)response).get();
//         }
//         return response;
//      }
//      catch (ServerException ex)
//      {
      //TODO: -TME Important to replicate this behavior if required for J2EE.  Need to ask Bill about this.
//         // Suns RMI implementation wraps NoSuchObjectException in
//         // a ServerException. We cannot have that if we want
//         // to comply with the spec, so we unwrap here.
//         if (ex.detail instanceof NoSuchObjectException)
//         {
//            throw (NoSuchObjectException) ex.detail;
//         }
//         //likewise
//         if (ex.detail instanceof TransactionRolledbackException)
//         {
//            throw (TransactionRolledbackException) ex.detail;
//         }
//         throw ex;
//      }


   }

   /**
    * Close all sockets in a specific pool.
    */
   public static void clearPool(ServerAddress sa)
   {
      try
      {
         LinkedList thepool = (LinkedList) connectionPools.get(sa);
         if(thepool == null)
         {
            return;
         }
         synchronized(thepool)
         {
            int size = thepool.size();
            for(int i = 0; i < size; i++)
            {
               ClientSocket socket = (ClientSocket) thepool.removeFirst();
               try
               {
                  socket.socket.close();
                  socket.socket = null;
               }
               catch(Exception ignored)
               {
               }
            }
         }
      }
      catch(Exception ex)
      {
         // ignored
      }
   }

   /**
    * Close all sockets in all pools
    */
   public static void clearPools()
   {
      synchronized(connectionPools)
      {
         Iterator it = connectionPools.keySet().iterator();
         while(it.hasNext())
         {
            ServerAddress sa = (ServerAddress) it.next();
            clearPool(sa);
         }
      }
   }

   protected void initPool()
   {
      synchronized(connectionPools)
      {
         pool = (LinkedList) connectionPools.get(address);
         if(pool == null)
         {
            pool = new LinkedList();
            connectionPools.put(address, pool);
         }
      }
   }

   /**
    * Sets the number of retries to get a socket connection.
    *
    * @param numberOfRetries Must be a number greater than 0
    */
   public void setNumberOfRetries(int numberOfRetries)
   {
      if(numberOfRetries < 1)
      {
         this.numberOfRetries = MAX_RETRIES;
      }
      else
      {
         this.numberOfRetries = numberOfRetries;
      }
   }

   public int getNumberOfRetries()
   {
      return numberOfRetries;
   }

   protected ClientSocket getConnection() throws Exception
   {
      Exception failed = null;
      Socket socket = null;


      //
      // Need to retry a few times
      // on socket connection because, at least on Windoze,
      // if too many concurrent threads try to connect
      // at same time, you get ConnectionRefused
      //
      // Retrying seems to be the most performant.
      //
      // This problem always happens with RMI and seems to
      // have nothing to do with backlog or number of threads
      // waiting in accept() on the server.
      //
      for(int i = 0; i < numberOfRetries; i++)
      {
         synchronized(pool)
         {
            if(pool.size() > 0)
            {
               ClientSocket pooled = getPooledConnection();
               if(pooled != null)
               {
                  usedPooled++;
                  return pooled;
               }
            }
         }

         try
         {
            socket = new Socket(address.address, address.port);
            break;
         }
         catch(Exception ex)
         {
            if(i + 1 < MAX_RETRIES)
            {
               Thread.sleep(1);
               continue;
            }
            throw ex;
         }
      }
      socket.setTcpNoDelay(address.enableTcpNoDelay);
      return new ClientSocket(socket, address.timeout);
   }

   protected ClientSocket getPooledConnection()
   {
      ClientSocket socket = null;
      while(pool.size() > 0)
      {
         socket = (ClientSocket) pool.removeFirst();
         try
         {
            // Test to see if socket is alive by send ACK message
            final byte ACK = 1;

            socket.oout.reset();
            socket.oout.writeByte(ACK);
            socket.oout.flush();
            socket.oin.readByte();

            return socket;
         }
         catch(Exception ex)
         {
            try
            {
               socket.socket.close();
            }
            catch(Exception ignored)
            {
            }
         }
      }
      return null;
   }


   protected static class ClientSocket
   {
      public BufferedOutputStream out;
      public BufferedInputStream in;
      public ObjectOutputStream oout;
      public ObjectInputStream oin;
      public Socket socket;
      public int timeout;

      public ClientSocket(Socket socket, int timeout) throws Exception
      {
         this.socket = socket;
         socket.setSoTimeout(timeout);
         this.timeout = timeout;
         out = new BufferedOutputStream(socket.getOutputStream());
         in = new BufferedInputStream(socket.getInputStream());
         oout = new ObjectOutputStream(out);
//         oin = new ObjectInputStream(in);
         oin = new ObjectInputStreamWithClassLoader(in, null);
      }

      protected void finalize()
      {
         if(socket != null)
         {
            try
            {
               socket.close();
            }
            catch(Exception ignored)
            {
            }
         }
      }
   }

   /**
    * The name of of the server.
    */
   public String getServerHostName() throws Exception
   {
      return address.address;
   }


}