package org.jboss.webservice.transport.jms;
import org.jboss.axis.AxisFault;
import org.jboss.axis.ConfigurationException;
import org.jboss.axis.EngineConfiguration;
import org.jboss.axis.MessageContext;
import org.jboss.axis.description.OperationDesc;
import org.jboss.axis.description.ServiceDesc;
import org.jboss.axis.server.AxisServer;
import org.jboss.logging.Logger;
import org.jboss.mx.util.MBeanServerLocator;
import org.jboss.util.NestedRuntimeException;
import org.jboss.webservice.AxisServiceMBean;
import javax.ejb.EJBException;
import javax.ejb.MessageDrivenBean;
import javax.ejb.MessageDrivenContext;
import javax.jms.BytesMessage;
import javax.jms.JMSException;
import javax.jms.Message;
import javax.jms.MessageListener;
import javax.jms.Queue;
import javax.jms.QueueConnection;
import javax.jms.QueueConnectionFactory;
import javax.jms.QueueSender;
import javax.jms.QueueSession;
import javax.jms.Session;
import javax.management.MBeanServer;
import javax.naming.InitialContext;
import javax.xml.namespace.QName;
import javax.xml.soap.SOAPElement;
import javax.xml.soap.SOAPException;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.Iterator;
public class DelegatingMessageDrivenBean implements MessageDrivenBean, MessageListener
{
static final long serialVersionUID = -2841152046009542652L;
protected Logger log = Logger.getLogger(DelegatingMessageDrivenBean.class);
private MessageDrivenContext mdbCtx;
private QueueConnectionFactory queueFactory;
protected AxisServer getAxisServer()
{
AxisServer axisServer;
try
{
MBeanServer mbeanServer = MBeanServerLocator.locateJBoss();
axisServer = (AxisServer)mbeanServer.getAttribute(AxisServiceMBean.OBJECT_NAME, "AxisServer");
log.debug("got AxisServer: " + axisServer);
}
catch (Exception e)
{
throw new RuntimeException("Cannot obtain axis server", e);
}
return axisServer;
}
public void onMessage(Message message)
{
try
{
if (message instanceof BytesMessage)
{
processSOAPMessage((BytesMessage)message);
}
else
{
log.warn("Ingnore message, because it is not a javax.jms.BytesMessage: " + message);
}
}
catch (Exception e)
{
throw new EJBException(e);
}
}
protected void processSOAPMessage(BytesMessage message) throws Exception
{
InputStream in = null;
byte[] buffer = new byte[8 * 1024];
ByteArrayOutputStream out = new ByteArrayOutputStream(buffer.length);
try
{
int read = message.readBytes(buffer);
while (read != -1)
{
out.write(buffer, 0, read);
read = message.readBytes(buffer);
}
in = new ByteArrayInputStream(out.toByteArray());
}
catch (Exception e)
{
log.error("Cannot get bytes from message", e);
return;
}
log.debug("onMessage: " + new String(out.toByteArray()));
AxisServer axisServer = getAxisServer();
org.jboss.axis.Message axisRequest = new org.jboss.axis.Message(in);
MessageContext msgContext = new MessageContext(axisServer);
msgContext.setRequestMessage(axisRequest);
HashMap serviceOperations = getServiceOperationsMap(axisServer);
String targetService = null;
Iterator it = axisRequest.getSOAPEnvelope().getBody().getChildElements();
while (it.hasNext())
{
SOAPElement soapElement = (SOAPElement)it.next();
String namespace = soapElement.getElementName().getURI();
String localName = soapElement.getElementName().getLocalName();
QName qname = new QName(namespace, localName);
log.debug("maybe operation: " + qname);
targetService = (String)serviceOperations.get(qname);
}
if (targetService != null)
{
log.debug("setTargetService: " + targetService);
msgContext.setTargetService(targetService);
}
org.jboss.axis.Message axisResponse = null;
try
{
axisServer.invoke(msgContext);
axisResponse = msgContext.getResponseMessage();
}
catch (AxisFault af)
{
axisResponse = new org.jboss.axis.Message(af);
axisResponse.setMessageContext(msgContext);
}
catch (Exception e)
{
axisResponse = new org.jboss.axis.Message(new AxisFault(e.toString()));
axisResponse.setMessageContext(msgContext);
}
Queue replyQueue = getReplyQueue(message);
if (replyQueue != null)
sendResponse(replyQueue, axisResponse);
}
private HashMap getServiceOperationsMap(AxisServer server)
{
HashMap serviceOperations = new HashMap();
try
{
EngineConfiguration config = server.getConfig();
Iterator it = config.getDeployedServices();
while (it.hasNext())
{
ServiceDesc service = (ServiceDesc)it.next();
log.debug("service: [name=" + service.getName() + ",ns=" + service.getDefaultNamespace() + "]");
Iterator opit = service.getOperations().iterator();
while (opit.hasNext())
{
OperationDesc operation = (OperationDesc)opit.next();
QName qname = operation.getElementQName();
log.debug(" operation: [qname=" + qname + "]");
serviceOperations.put(qname, service.getName());
}
}
}
catch (ConfigurationException e)
{
log.error("Cannot map service operations", e);
}
return serviceOperations;
}
protected Queue getReplyQueue(BytesMessage message)
throws JMSException
{
Queue replyQueue = (Queue)message.getJMSReplyTo();
return replyQueue;
}
protected void sendResponse(Queue replyQueue, org.jboss.axis.Message axisResponse)
throws SOAPException, IOException, JMSException
{
ByteArrayOutputStream out = new ByteArrayOutputStream(8 * 1024);
axisResponse.writeTo(out);
QueueConnection qc = queueFactory.createQueueConnection();
QueueSession session = qc.createQueueSession(false, Session.AUTO_ACKNOWLEDGE);
QueueSender sender = null;
try
{
sender = session.createSender(replyQueue);
BytesMessage responseMessage = session.createBytesMessage();
responseMessage.writeBytes(out.toByteArray());
sender.send(responseMessage);
log.info("Sent response");
}
finally
{
try
{
sender.close();
}
catch (JMSException ignored)
{
}
try
{
session.close();
}
catch (JMSException ignored)
{
}
try
{
qc.close();
}
catch (JMSException ignored)
{
}
}
}
public void ejbCreate()
{
try
{
InitialContext ctx = new InitialContext();
queueFactory = (QueueConnectionFactory)ctx.lookup("java:/ConnectionFactory");
}
catch (Exception e)
{
throw new NestedRuntimeException(e);
}
}
public void ejbRemove() throws EJBException
{
}
public void setMessageDrivenContext(MessageDrivenContext ctx) throws EJBException
{
this.mdbCtx = ctx;
}
}