package org.jboss.invocation.http.interfaces;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectStreamException;
import java.io.OutputStream;
import java.io.ObjectOutputStream;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.Authenticator;
import java.net.HttpURLConnection;
import java.net.MalformedURLException;
import java.net.URL;
import java.security.PrivilegedAction;
import java.security.AccessController;
import javax.net.ssl.HttpsURLConnection;
import javax.net.ssl.SSLSocketFactory;
import org.jboss.invocation.Invocation;
import org.jboss.invocation.InvocationException;
import org.jboss.invocation.MarshalledValue;
import org.jboss.logging.Logger;
import org.jboss.security.SecurityAssociationAuthenticator;
import org.jboss.net.ssl.SSLSocketFactoryBuilder;
public class Util
{
public static final String IGNORE_HTTPS_HOST = "org.jboss.security.ignoreHttpsHost";
public static final String SSL_FACTORY_BUILDER = "org.jboss.security.httpInvoker.sslSocketFactoryBuilder";
private static String REQUEST_CONTENT_TYPE =
"application/x-java-serialized-object; class=org.jboss.invocation.MarshalledInvocation";
private static Logger log = Logger.getLogger(Util.class);
private static SSLSocketFactoryBuilder sslSocketFactoryBuilder;
static class SetAuthenticator implements PrivilegedAction
{
public Object run()
{
Authenticator.setDefault(new SecurityAssociationAuthenticator());
return null;
}
}
static class ReadSSLBuilder implements PrivilegedAction
{
public Object run()
{
String value = System.getProperty(SSL_FACTORY_BUILDER);
return value;
}
}
static
{
try
{
SetAuthenticator action = new SetAuthenticator();
AccessController.doPrivileged(action);
}
catch(Exception e)
{
log.warn("Failed to install SecurityAssociationAuthenticator", e);
}
ClassLoader loader = Thread.currentThread().getContextClassLoader();
String factoryFactoryFQCN = null;
try
{
ReadSSLBuilder action = new ReadSSLBuilder();
factoryFactoryFQCN = (String) AccessController.doPrivileged(action);
}
catch(Exception e)
{
log.warn("Failed to read "+SSL_FACTORY_BUILDER, e);
}
if (factoryFactoryFQCN != null)
{
try
{
Class clazz = loader.loadClass(factoryFactoryFQCN);
sslSocketFactoryBuilder = (SSLSocketFactoryBuilder) clazz.newInstance();
}
catch (Exception e)
{
log.warn("Could not instantiate SSLSocketFactoryFactory", e);
}
}
}
public static void init()
{
try
{
SetAuthenticator action = new SetAuthenticator();
AccessController.doPrivileged(action);
}
catch(Exception e)
{
log.warn("Failed to install SecurityAssociationAuthenticator", e);
}
}
public static Object invoke(URL externalURL, Invocation mi)
throws Exception
{
if( log.isTraceEnabled() )
log.trace("invoke, externalURL="+externalURL);
HttpURLConnection conn = (HttpURLConnection) externalURL.openConnection();
configureHttpsHostVerifier(conn);
conn.setDoInput(true);
conn.setDoOutput(true);
conn.setRequestProperty("ContentType", REQUEST_CONTENT_TYPE);
conn.setRequestMethod("POST");
OutputStream os = conn.getOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(os);
try
{
oos.writeObject(mi);
oos.flush();
}
catch (ObjectStreamException e)
{
throw new InvocationException(e);
}
InputStream is = conn.getInputStream();
ObjectInputStream ois = new ObjectInputStream(is);
MarshalledValue mv = (MarshalledValue) ois.readObject();
ois.read();
ois.close();
oos.close();
Object value = mv.get();
if( value instanceof Exception )
{
throw (Exception) value;
}
return value;
}
public static void configureHttpsHostVerifier(HttpURLConnection conn)
{
if ( conn instanceof HttpsURLConnection )
{
if (Boolean.getBoolean(IGNORE_HTTPS_HOST) == true)
{
AnyhostVerifier.setHostnameVerifier(conn);
}
}
}
public static void configureSSLSocketFactory(HttpURLConnection conn)
throws InvocationTargetException
{
Class connClass = conn.getClass();
if ( conn instanceof HttpsURLConnection && sslSocketFactoryBuilder != null)
{
try
{
SSLSocketFactory socketFactory = sslSocketFactoryBuilder.getSocketFactory();
Class[] sig = {SSLSocketFactory.class};
Method method = connClass.getMethod("setSSLSocketFactory", sig);
Object[] args = {socketFactory};
method.invoke(conn, args);
log.trace("Socket factory set on connection");
}
catch(Exception e)
{
throw new InvocationTargetException(e);
}
}
}
public static URL resolveURL(String urlValue) throws MalformedURLException
{
if( urlValue == null )
return null;
URL externalURL = null;
try
{
externalURL = new URL(urlValue);
}
catch(MalformedURLException e)
{
String urlProperty = System.getProperty(urlValue);
if( urlProperty == null )
throw e;
externalURL = new URL(urlProperty);
}
return externalURL;
}
}