package org.jboss.web.tomcat.tc5.sso;
import java.io.Serializable;
import java.security.Principal;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.Set;
import javax.management.MBeanServer;
import javax.management.ObjectName;
import javax.naming.InitialContext;
import javax.naming.NamingException;
import javax.transaction.UserTransaction;
import org.apache.catalina.LifecycleException;
import org.apache.catalina.LifecycleListener;
import org.apache.catalina.Session;
import org.apache.catalina.util.LifecycleSupport;
import org.jboss.cache.Fqn;
import org.jboss.cache.TreeCache;
import org.jboss.cache.TreeCacheListener;
import org.jboss.logging.Logger;
import org.jboss.mx.util.MBeanServerLocator;
import org.jboss.web.tomcat.tc5.Tomcat5;
import org.jgroups.View;
public final class TreeCacheSSOClusterManager
implements SSOClusterManager, TreeCacheListener
{
private static final String CREDENTIALS = "credentials";
private static final String SSO = "SSO";
private static final String SESSIONS = "sessions";
private static final String KEY = "key";
public static final String DEFAULT_GLOBAL_CACHE_NAME =
Tomcat5.DEFAULT_CACHE_NAME;
private static final String[] GET_SIGNATURE =
{Fqn.class.getName(), Object.class.getName()};
private static final String[] PUT_SIGNATURE =
{Fqn.class.getName(), Object.class.getName(), Object.class.getName()};
private static final String[] REMOVE_SIGNATURE = {Fqn.class.getName()};
private LinkedList beingLocallyAdded = new LinkedList();
private LinkedList beingLocallyRemoved = new LinkedList();
private LinkedList beingRemotelyRemoved = new LinkedList();
private ObjectName cacheObjectName = null;
private String cacheName = null;
private CredentialUpdater credentialUpdater = null;
private InitialContext initialContext = null;
private LifecycleSupport lifecycle = new LifecycleSupport(this);
private Logger log = Logger.getLogger(getClass().getName());;
private boolean registeredAsListener = false;
private MBeanServer server = null;
private ClusteredSingleSignOn ssoValve = null;
private boolean started = false;
private boolean treeCacheAvailable = false;
private boolean missingCacheErrorLogged = false;
public TreeCacheSSOClusterManager()
{
server = MBeanServerLocator.locate();
}
public String getCacheName()
{
return cacheName;
}
public void setCacheName(String objectName)
throws Exception
{
if (objectName == null)
{
setCacheObjectName(null);
}
else if (objectName.equals(cacheName) == false)
{
setCacheObjectName(new ObjectName(objectName));
}
}
public ObjectName getCacheObjectName()
{
return cacheObjectName;
}
public void setCacheObjectName(ObjectName objectName)
throws Exception
{
if ((objectName != null && objectName.equals(cacheObjectName))
|| (cacheObjectName != null && cacheObjectName.equals(objectName))
|| (objectName == null && cacheObjectName == null))
{
return;
}
removeAsTreeCacheListener(cacheObjectName);
this.cacheObjectName = objectName;
this.cacheName = (objectName == null
? null
: objectName.getCanonicalName());
if (false == isTreeCacheAvailable(true))
{
if (started)
{
logMissingCacheError();
}
else
{
log.info("Cannot find TreeCache using " + cacheName + " -- tree" +
"CacheName must be set to point to a running TreeCache " +
"before ClusteredSingleSignOn can handle requests");
}
}
}
public void addSession(String ssoId, Session session)
{
if (ssoId == null || session == null)
{
return;
}
if (false == isTreeCacheAvailable(false))
{
logMissingCacheError();
return;
}
if (log.isTraceEnabled())
{
log.trace("addSession(): adding Session " + session.getId() +
" to cached session set for SSO " + ssoId);
}
Fqn fqn = getSessionsFqn(ssoId);
UserTransaction tx = null;
try
{
tx = getNewTransaction();
tx.begin();
Set sessions = getSessionSet(fqn, true);
sessions.add(session.getId());
putInTreeCache(fqn, sessions);
tx.commit();
}
catch (Exception e)
{
if (tx != null)
{
try
{
tx.rollback();
}
catch (Exception x)
{
}
}
String sessId = (session == null ? "NULL" : session.getId());
log.error("caught exception adding session " + sessId +
" to SSO id " + ssoId, e);
}
}
public ClusteredSingleSignOn getSingleSignOnValve()
{
return ssoValve;
}
public void setSingleSignOnValve(ClusteredSingleSignOn valve)
{
ssoValve = valve;
}
public void logout(String ssoId)
{
if (false == isTreeCacheAvailable(false))
{
logMissingCacheError();
return;
}
{
if (beingLocallyRemoved.contains(ssoId))
{
return;
}
beingLocallyRemoved.add(ssoId);
}
if (log.isTraceEnabled())
{
log.trace("Registering logout of SSO " + ssoId +
" in clustered cache");
}
Fqn fqn = getSingleSignOnFqn(ssoId);
try
{
removeFromTreeCache(fqn);
}
catch (Exception e)
{
log.error("Exception attempting to remove node " +
fqn.toString() + " from TreeCache", e);
}
finally
{
{
beingLocallyRemoved.remove(ssoId);
}
}
}
public SingleSignOnEntry lookup(String ssoId)
{
if (false == isTreeCacheAvailable(false))
{
logMissingCacheError();
return null;
}
SingleSignOnEntry entry = null;
Fqn fqn = getCredentialsFqn(ssoId);
try
{
SSOCredentials data = (SSOCredentials) getFromTreeCache(fqn);
if (data != null)
{
entry = new SingleSignOnEntry(null,
data.getAuthType(),
data.getUsername(),
data.getPassword());
}
}
catch (Exception e)
{
log.error("caught exception looking up SSOCredentials for SSO id " +
ssoId, e);
}
return entry;
}
public void register(String ssoId, String authType,
String username, String password)
{
if (false == isTreeCacheAvailable(false))
{
logMissingCacheError();
return;
}
if (log.isTraceEnabled())
{
log.trace("Registering SSO " + ssoId + " in clustered cache");
}
storeSSOData(ssoId, authType, username, password);
}
public void removeSession(String ssoId, Session session)
{
if (false == isTreeCacheAvailable(false))
{
logMissingCacheError();
return;
}
{
if (beingRemotelyRemoved.contains(ssoId))
{
return;
}
}
if (log.isTraceEnabled())
{
log.trace("removeSession(): removing Session " + session.getId() +
" from cached session set for SSO " + ssoId);
}
Fqn fqn = getSessionsFqn(ssoId);
UserTransaction tx = null;
boolean removing = false;
try
{
tx = getNewTransaction();
tx.begin();
Set sessions = getSessionSet(fqn, false);
if (sessions != null)
{
sessions.remove(session.getId());
if (sessions.size() == 0)
{
{
beingLocallyRemoved.add(ssoId);
}
removing = true;
removeFromTreeCache(getSingleSignOnFqn(ssoId));
}
else
{
putInTreeCache(fqn, sessions);
}
}
tx.commit();
}
catch (Exception e)
{
if (tx != null)
{
try
{
tx.rollback();
}
catch (Exception x)
{
}
}
String sessId = (session == null ? "NULL" : session.getId());
log.error("caught exception removing session " + sessId +
" from SSO id " + ssoId, e);
}
finally
{
if (removing)
{
{
beingLocallyRemoved.remove(ssoId);
}
}
}
}
public void updateCredentials(String ssoId, String authType,
String username, String password)
{
if (false == isTreeCacheAvailable(false))
{
logMissingCacheError();
return;
}
if (log.isTraceEnabled())
{
log.trace("Updating credentials for SSO " + ssoId +
" in clustered cache");
}
storeSSOData(ssoId, authType, username, password);
}
public void nodeCreated(Fqn fqn)
{
; }
public void nodeLoaded(Fqn fqn)
{
; }
public void nodeVisited(Fqn fqn)
{
; }
public void cacheStarted(TreeCache cache)
{
; }
public void cacheStopped(TreeCache cache)
{
; }
public void nodeRemoved(Fqn fqn)
{
String ssoId = getIdFromFqn(fqn);
{
if (beingLocallyRemoved.contains(ssoId))
{
return;
}
}
{
beingRemotelyRemoved.add(ssoId);
}
try
{
if (log.isTraceEnabled())
{
log.trace("received a node removed message for SSO " + ssoId);
}
ssoValve.deregister(ssoId);
}
finally
{
{
beingRemotelyRemoved.remove(ssoId);
}
}
}
public void nodeModified(Fqn fqn)
{
if (CREDENTIALS.equals(getTypeFromFqn(fqn)) == false)
{
return;
}
String ssoId = getIdFromFqn(fqn);
{
if (beingLocallyAdded.contains(ssoId))
{
return;
}
}
SingleSignOnEntry sso = ssoValve.localLookup(ssoId);
if (sso == null || sso.getCanReauthenticate())
{
return;
}
if (log.isTraceEnabled())
{
log.trace("received a credentials modified message for SSO " + ssoId);
}
credentialUpdater.enqueue(sso, ssoId);
}
public void viewChange(View new_view)
{
; }
public void nodeEvicted(Fqn fqn)
{
; }
public void addLifecycleListener(LifecycleListener listener)
{
lifecycle.addLifecycleListener(listener);
}
public LifecycleListener[] findLifecycleListeners()
{
return lifecycle.findLifecycleListeners();
}
public void removeLifecycleListener(LifecycleListener listener)
{
lifecycle.removeLifecycleListener(listener);
}
public void start() throws LifecycleException
{
if (started)
{
throw new LifecycleException
("TreeCacheSSOClusterManager already Started");
}
credentialUpdater = new CredentialUpdater();
started = true;
lifecycle.fireLifecycleEvent(START_EVENT, null);
}
public void stop() throws LifecycleException
{
if (!started)
{
throw new LifecycleException
("TreeCacheSSOClusterManager not Started");
}
credentialUpdater.stop();
started = false;
lifecycle.fireLifecycleEvent(STOP_EVENT, null);
}
private Object getFromTreeCache(Fqn fqn) throws Exception
{
Object[] args = new Object[]{fqn, KEY};
return server.invoke(getCacheObjectName(), "get", args, GET_SIGNATURE);
}
private Fqn getCredentialsFqn(String ssoid)
{
Object[] objs = new Object[]{SSO, ssoid, CREDENTIALS};
return new Fqn(objs);
}
private Fqn getSessionsFqn(String ssoid)
{
Object[] objs = new Object[]{SSO, ssoid, SESSIONS};
return new Fqn(objs);
}
private Fqn getSingleSignOnFqn(String ssoid)
{
Object[] objs = new Object[]{SSO, ssoid};
return new Fqn(objs);
}
private String getIdFromFqn(Fqn fqn)
{
return (String) fqn.get(1);
}
private InitialContext getInitialContext() throws NamingException
{
if (initialContext == null)
{
initialContext = new InitialContext();
}
return initialContext;
}
private Set getSessionSet(Fqn fqn, boolean create)
throws Exception
{
Set sessions = (Set) getFromTreeCache(fqn);
if (create && sessions == null)
{
sessions = new HashSet();
}
return sessions;
}
private String getTypeFromFqn(Fqn fqn)
{
return (String) fqn.get(fqn.size() - 1);
}
private UserTransaction getNewTransaction() throws NamingException
{
try
{
UserTransaction t =
(UserTransaction) getInitialContext().lookup("UserTransaction");
return t;
}
catch (NamingException n)
{
initialContext = null;
throw n;
}
}
private synchronized boolean isTreeCacheAvailable(boolean forceCheck)
{
if (forceCheck || treeCacheAvailable == false)
{
boolean available = (cacheObjectName != null);
if (available)
{
Set s = server.queryMBeans(cacheObjectName, null);
available = s.size() > 0;
if (available)
{
try
{
registerAsTreeCacheListener(cacheObjectName);
setMissingCacheErrorLogged(false);
}
catch (Exception e)
{
log.error("Caught exception registering as listener to " +
cacheObjectName, e);
available = false;
}
}
}
treeCacheAvailable = available;
}
return treeCacheAvailable;
}
private void putInTreeCache(Fqn fqn, Object data) throws Exception
{
Object[] args = new Object[]{fqn, KEY, data};
server.invoke(getCacheObjectName(), "put", args, PUT_SIGNATURE);
}
private void registerAsTreeCacheListener(ObjectName listenTo)
throws Exception
{
server.invoke(listenTo, "addTreeCacheListener",
new Object[]{this},
new String[]{TreeCacheListener.class.getName()});
registeredAsListener = true;
}
private void removeAsTreeCacheListener(ObjectName removeFrom)
throws Exception
{
if (registeredAsListener && removeFrom != null)
{
server.invoke(removeFrom, "removeTreeCacheListener",
new Object[]{this},
new String[]{TreeCacheListener.class.getName()});
}
}
private void removeFromTreeCache(Fqn fqn) throws Exception
{
server.invoke(getCacheObjectName(), "remove",
new Object[]{fqn},
REMOVE_SIGNATURE);
}
private void storeSSOData(String ssoId, String authType, String username,
String password)
{
SSOCredentials data = new SSOCredentials(authType, username, password);
{
beingLocallyAdded.add(ssoId);
}
try
{
putInTreeCache(getCredentialsFqn(ssoId), data);
}
catch (Exception e)
{
log.error("Exception attempting to add TreeCache nodes for SSO " +
ssoId, e);
}
finally
{
{
beingLocallyAdded.remove(ssoId);
}
}
}
private boolean isMissingCacheErrorLogged()
{
return missingCacheErrorLogged;
}
private void setMissingCacheErrorLogged(boolean missingCacheErrorLogged)
{
this.missingCacheErrorLogged = missingCacheErrorLogged;
}
private void logMissingCacheError()
{
StringBuffer msg = new StringBuffer("Cannot find TreeCache using ");
msg.append(getCacheName());
msg.append(" -- TreeCache must be started before ClusteredSingleSignOn ");
msg.append("can handle requests");
if (isMissingCacheErrorLogged())
{
log.warn(msg);
}
else
{
log.error(msg);
setMissingCacheErrorLogged(true);
}
}
private class CredentialUpdater
implements Runnable
{
private HashSet awaitingUpdate = new HashSet();
private Thread updateThread;
private boolean updateThreadSleeping = false;
private boolean queueEmpty = true;
private boolean stopped = false;
private CredentialUpdater()
{
updateThread =
new Thread(this, "SSOClusterManager.CredentialUpdater");
updateThread.setDaemon(true);
updateThread.start();
}
public void run()
{
while (!stopped)
{
try
{
updateThreadSleeping = false;
SSOWrapper[] ssos = null;
synchronized (awaitingUpdate)
{
ssos = new SSOWrapper[awaitingUpdate.size()];
ssos = (SSOWrapper[]) awaitingUpdate.toArray(ssos);
awaitingUpdate.clear();
queueEmpty = true;
}
for (int i = 0; i < ssos.length; i++)
{
processUpdate(ssos[i]);
}
if (queueEmpty)
{
try
{
updateThreadSleeping = true;
updateThread.sleep(30000);
}
catch (InterruptedException e)
{
if (log.isTraceEnabled())
{
log.trace("CredentialUpdater: interrupted");
}
}
}
else if (log.isTraceEnabled())
{
log.trace("CredentialUpdater: more updates added while " +
"handling existing updates");
}
}
catch (Exception e)
{
log.error("CredentialUpdater thread caught an exception", e);
}
}
}
private void enqueue(SingleSignOnEntry sso, String ssoId)
{
synchronized (awaitingUpdate)
{
awaitingUpdate.add(new SSOWrapper(sso, ssoId));
queueEmpty = false;
}
if (updateThreadSleeping)
{
updateThread.interrupt();
}
}
private void processUpdate(SSOWrapper wrapper)
{
if (wrapper.sso.getCanReauthenticate())
{
return;
}
Fqn fqn = getCredentialsFqn(wrapper.id);
try
{
SSOCredentials data = (SSOCredentials) getFromTreeCache(fqn);
if (data != null)
{
String authType = data.getAuthType();
String username = data.getUsername();
String password = data.getPassword();
if (log.isTraceEnabled())
{
log.trace("CredentialUpdater: Updating credentials for SSO " +
wrapper.sso);
}
synchronized (wrapper.sso)
{
Principal p = wrapper.sso.getPrincipal();
wrapper.sso.updateCredentials(p, authType, username, password);
}
}
}
catch (Exception e)
{
log.error("Exception attempting to get SSOCredentials from " +
"TreeCache node " + fqn.toString(), e);
}
}
private void stop()
{
stopped = true;
}
}
private class SSOWrapper
{
private SingleSignOnEntry sso = null;
private String id = null;
private SSOWrapper(SingleSignOnEntry entry, String ssoId)
{
this.sso = entry;
this.id = ssoId;
}
}
public static class SSOCredentials
implements Serializable
{
static final long serialVersionUID = 5704877226920571663L;
private String authType = null;
private String password = null;
private String username = null;
private SSOCredentials(String authType, String username, String password)
{
this.authType = authType;
this.username = username;
this.password = password;
}
public String getUsername()
{
return username;
}
public String getAuthType()
{
return authType;
}
private String getPassword()
{
return password;
}
}
}