package org.jboss.webservice.server;
import org.jboss.deployment.DeploymentInfo;
import org.jboss.logging.Logger;
import org.jboss.webservice.metadata.WebserviceDescriptionMetaData;
import org.w3c.dom.Attr;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;
import javax.wsdl.Definition;
import javax.wsdl.WSDLException;
import javax.wsdl.factory.WSDLFactory;
import javax.wsdl.xml.WSDLWriter;
import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.StringTokenizer;
public class WSDLRequestHandler
{
private Logger log = Logger.getLogger(WSDLRequestHandler.class);
private WebserviceDescriptionMetaData wsdMetaData;
private DeploymentInfo di;
public WSDLRequestHandler(WebserviceDescriptionMetaData wsdMetaData, DeploymentInfo di)
{
this.wsdMetaData = wsdMetaData;
this.di = di;
}
public Document getDocumentForPath(String requestURI, String resourcePath)
{
Document wsdlDoc = null;
if (resourcePath == null)
{
wsdlDoc = getWSDLDocument(wsdMetaData.getWsdlDefinition());
}
else
{
String wsdlFile = wsdMetaData.getWsdlFile();
String rootDir = wsdlFile.substring(0, wsdlFile.lastIndexOf("/"));
URLClassLoader cl = di.localCl;
String resource = rootDir + "/" + resourcePath;
resource = canonicalize(resource);
if (resource.startsWith("WEB-INF/wsdl/") == false && resource.startsWith("META-INF/wsdl/") == false)
throw new SecurityException("Cannot access a resource below the wsdl root: " + resource);
URL resURL = cl.findResource(resource);
if (resURL == null)
throw new IllegalStateException("Cannot obtain wsdl resource from: " + resource);
try
{
DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance();
factory.setNamespaceAware(true);
factory.setValidating(false);
DocumentBuilder builder = factory.newDocumentBuilder();
wsdlDoc = builder.parse(resURL.openStream());
}
catch (Exception e)
{
throw new IllegalArgumentException("Cannot parse wsdl resource: " + resURL);
}
}
modifyImportLocations(requestURI, resourcePath, wsdlDoc.getDocumentElement());
return wsdlDoc;
}
private Document getWSDLDocument(Definition wsdlDefinition)
{
try
{
WSDLFactory factory = WSDLFactory.newInstance();
WSDLWriter wsdlWriter = factory.newWSDLWriter();
return wsdlWriter.getDocument(wsdlDefinition);
}
catch (WSDLException e)
{
throw new RuntimeException(e);
}
}
private void modifyImportLocations(String requestURI, String resourcePath, Element element)
{
NodeList nlist = element.getChildNodes();
for (int i = 0; i < nlist.getLength(); i++)
{
Node childNode = nlist.item(i);
if (childNode.getNodeType() == Node.ELEMENT_NODE)
{
Element childElement = (Element)childNode;
String nodeName = childElement.getLocalName();
if ("import".equals(nodeName) || "include".equals(nodeName))
{
Attr locationAttr = childElement.getAttributeNode("schemaLocation");
if (locationAttr == null)
locationAttr = childElement.getAttributeNode("location");
if (locationAttr != null)
{
String orgLocation = locationAttr.getNodeValue();
boolean isAbsolute = orgLocation.startsWith("http://") || orgLocation.startsWith("https://");
if (isAbsolute == false && orgLocation.startsWith(requestURI) == false)
{
String resource = orgLocation;
if (resourcePath != null && resourcePath.indexOf("/") > 0)
{
resource = resourcePath.substring(0, resourcePath.lastIndexOf("/") + 1);
resource = resource + orgLocation;
}
String newLocation = requestURI + "?wsdl&resource=" + resource;
locationAttr.setNodeValue(newLocation);
log.debug("Mapping import from '" + orgLocation + "' to '" + newLocation + "'");
}
}
}
else
{
modifyImportLocations(requestURI, resourcePath, childElement);
}
}
}
}
private String canonicalize(String path)
{
StringTokenizer tok = new StringTokenizer(path, "/");
List parts = new ArrayList();
while (tok.hasMoreTokens())
{
String t = tok.nextToken();
if (".".equals(t))
{
}
else if ("..".equals(t) && parts.size() > 0)
{
parts.remove(parts.size() - 1);
}
else
{
parts.add(t);
}
}
StringBuffer ret = new StringBuffer();
for (Iterator iter = parts.iterator(); iter.hasNext();)
{
ret.append((String)iter.next());
if (iter.hasNext())
ret.append('/');
}
return ret.toString();
}
}