package org.jboss.security.ssl;
import java.io.IOException;
import java.net.InetAddress;
import java.net.ServerSocket;
import java.net.UnknownHostException;
import java.util.Arrays;
import javax.naming.InitialContext;
import javax.net.ServerSocketFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLServerSocketFactory;
import javax.net.ssl.SSLServerSocket;
import org.jboss.logging.Logger;
import org.jboss.security.SecurityDomain;
public class DomainServerSocketFactory extends SSLServerSocketFactory
{
private static Logger log = Logger.getLogger(DomainServerSocketFactory.class);
private transient SecurityDomain securityDomain;
private transient InetAddress bindAddress;
private transient SSLContext sslCtx = null;
private boolean wantsClientAuth = true;
private boolean needsClientAuth = false;
public DomainServerSocketFactory()
{
}
public DomainServerSocketFactory(SecurityDomain securityDomain) throws IOException
{
if( securityDomain == null )
throw new IOException("The securityDomain may not be null");
this.securityDomain = securityDomain;
}
public String getBindAddress()
{
String address = null;
if( bindAddress != null )
address = bindAddress.getHostAddress();
return address;
}
public void setBindAddress(String host) throws UnknownHostException
{
bindAddress = InetAddress.getByName(host);
}
public SecurityDomain getSecurityDomain()
{
return securityDomain;
}
public void setSecurityDomain(SecurityDomain securityDomain)
{
this.securityDomain = securityDomain;
}
public boolean isWantsClientAuth()
{
return wantsClientAuth;
}
public void setWantsClientAuth(boolean wantsClientAuth)
{
this.wantsClientAuth = wantsClientAuth;
}
public boolean isNeedsClientAuth()
{
return needsClientAuth;
}
public void setNeedsClientAuth(boolean needsClientAuth)
{
this.needsClientAuth = needsClientAuth;
}
public ServerSocket createServerSocket(int port) throws IOException
{
return createServerSocket(port, 50, bindAddress);
}
public ServerSocket createServerSocket(int port, int backlog)
throws IOException
{
return createServerSocket(port, backlog, bindAddress);
}
public ServerSocket createServerSocket(int port, int backlog, InetAddress ifAddress)
throws IOException
{
initSSLContext();
SSLServerSocketFactory factory = sslCtx.getServerSocketFactory();
SSLServerSocket socket = (SSLServerSocket) factory.createServerSocket(port, backlog, ifAddress);
String[] supportedProtocols = socket.getSupportedProtocols();
log.debug("Supported protocols: " + Arrays.asList(supportedProtocols));
String[] protocols = supportedProtocols; socket.setEnabledProtocols(protocols);
socket.setNeedClientAuth(needsClientAuth);
socket.setWantClientAuth(wantsClientAuth);
return socket;
}
public static ServerSocketFactory getDefault()
{
DomainServerSocketFactory ssf = null;
try
{
InitialContext iniCtx = new InitialContext();
SecurityDomain sd = (SecurityDomain) iniCtx.lookup("java:/jaas/other");
ssf = new DomainServerSocketFactory(sd);
}
catch(Exception e)
{
log.error("Failed to create default ServerSocketFactory", e);
}
return ssf;
}
public String[] getDefaultCipherSuites()
{
String[] cipherSuites = {};
try
{
initSSLContext();
SSLServerSocketFactory factory = sslCtx.getServerSocketFactory();
cipherSuites = factory.getDefaultCipherSuites();
}
catch(IOException e)
{
log.error("Failed to get default SSLServerSocketFactory", e);
}
return cipherSuites;
}
public String[] getSupportedCipherSuites()
{
String[] cipherSuites = {};
try
{
initSSLContext();
SSLServerSocketFactory factory = sslCtx.getServerSocketFactory();
cipherSuites = factory.getSupportedCipherSuites();
}
catch(IOException e)
{
log.error("Failed to get default SSLServerSocketFactory", e);
}
return cipherSuites;
}
private void initSSLContext()
throws IOException
{
if( sslCtx != null )
return;
sslCtx = Context.forDomain(securityDomain);
}
}