View Javadoc

1   /*
2    * Copyright 2009 Red Hat, Inc.
3    *
4    * Red Hat licenses this file to you under the Apache License, version 2.0
5    * (the "License"); you may not use this file except in compliance with the
6    * License.  You may obtain a copy of the License at:
7    *
8    *    http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package org.jboss.netty.channel.socket.http;
17  
18  import java.io.EOFException;
19  import java.io.IOException;
20  import java.io.PushbackInputStream;
21  import java.net.SocketAddress;
22  
23  import javax.servlet.ServletConfig;
24  import javax.servlet.ServletException;
25  import javax.servlet.ServletOutputStream;
26  import javax.servlet.http.HttpServlet;
27  import javax.servlet.http.HttpServletRequest;
28  import javax.servlet.http.HttpServletResponse;
29  
30  import org.jboss.netty.buffer.ChannelBuffer;
31  import org.jboss.netty.buffer.ChannelBuffers;
32  import org.jboss.netty.channel.Channel;
33  import org.jboss.netty.channel.ChannelFactory;
34  import org.jboss.netty.channel.ChannelFuture;
35  import org.jboss.netty.channel.ChannelFutureListener;
36  import org.jboss.netty.channel.ChannelHandlerContext;
37  import org.jboss.netty.channel.ChannelPipeline;
38  import org.jboss.netty.channel.Channels;
39  import org.jboss.netty.channel.ExceptionEvent;
40  import org.jboss.netty.channel.MessageEvent;
41  import org.jboss.netty.channel.SimpleChannelUpstreamHandler;
42  import org.jboss.netty.channel.local.DefaultLocalClientChannelFactory;
43  import org.jboss.netty.channel.local.LocalAddress;
44  import org.jboss.netty.handler.codec.http.HttpHeaders;
45  import org.jboss.netty.logging.InternalLogger;
46  import org.jboss.netty.logging.InternalLoggerFactory;
47  
48  /**
49   * An {@link HttpServlet} that proxies an incoming data to the actual server
50   * and vice versa.  Please refer to the
51   * <a href="package-summary.html#package_description">package summary</a> for
52   * the detailed usage.
53   *
54   * @author <a href="http://www.jboss.org/netty/">The Netty Project</a>
55   * @author Andy Taylor (andy.taylor@jboss.org)
56   * @author <a href="http://gleamynode.net/">Trustin Lee</a>
57   * @version $Rev: 2119 $, $Date: 2010-02-01 20:46:09 +0900 (Mon, 01 Feb 2010) $
58   *
59   * @apiviz.landmark
60   */
61  public class HttpTunnelingServlet extends HttpServlet {
62  
63      private static final long serialVersionUID = 4259910275899756070L;
64  
65      private static final String ENDPOINT = "endpoint";
66  
67      static final InternalLogger logger = InternalLoggerFactory.getInstance(HttpTunnelingServlet.class);
68  
69      private volatile SocketAddress remoteAddress;
70      private volatile ChannelFactory channelFactory;
71  
72      @Override
73      public void init() throws ServletException {
74          ServletConfig config = getServletConfig();
75          String endpoint = config.getInitParameter(ENDPOINT);
76          if (endpoint == null) {
77              throw new ServletException("init-param '" + ENDPOINT + "' must be specified.");
78          }
79  
80          try {
81              remoteAddress = parseEndpoint(endpoint.trim());
82          } catch (ServletException e) {
83              throw e;
84          } catch (Exception e) {
85              throw new ServletException("Failed to parse an endpoint.", e);
86          }
87  
88          try {
89              channelFactory = createChannelFactory(remoteAddress);
90          } catch (ServletException e) {
91              throw e;
92          } catch (Exception e) {
93              throw new ServletException("Failed to create a channel factory.", e);
94          }
95  
96          // Stuff for testing purpose
97          //ServerBootstrap b = new ServerBootstrap(new DefaultLocalServerChannelFactory());
98          //b.getPipeline().addLast("logger", new LoggingHandler(getClass(), InternalLogLevel.INFO, true));
99          //b.getPipeline().addLast("handler", new EchoHandler());
100         //b.bind(remoteAddress);
101     }
102 
103     protected SocketAddress parseEndpoint(String endpoint) throws Exception {
104         if (endpoint.startsWith("local:")) {
105             return new LocalAddress(endpoint.substring(6).trim());
106         } else {
107             throw new ServletException(
108                     "Invalid or unknown endpoint: " + endpoint);
109         }
110     }
111 
112     protected ChannelFactory createChannelFactory(SocketAddress remoteAddress) throws Exception {
113         if (remoteAddress instanceof LocalAddress) {
114             return new DefaultLocalClientChannelFactory();
115         } else {
116             throw new ServletException(
117                     "Unsupported remote address type: " +
118                     remoteAddress.getClass().getName());
119         }
120     }
121 
122     @Override
123     public void destroy() {
124         try {
125             destroyChannelFactory(channelFactory);
126         } catch (Exception e) {
127             logger.warn("Failed to destroy a channel factory.", e);
128         }
129     }
130 
131     protected void destroyChannelFactory(ChannelFactory factory) throws Exception {
132         factory.releaseExternalResources();
133     }
134 
135     @Override
136     protected void service(HttpServletRequest req, HttpServletResponse res)
137             throws ServletException, IOException {
138         if (!"POST".equalsIgnoreCase(req.getMethod())) {
139             logger.warn("Unallowed method: " + req.getMethod());
140             res.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED);
141             return;
142         }
143 
144         final ChannelPipeline pipeline = Channels.pipeline();
145         final ServletOutputStream out = res.getOutputStream();
146         final OutboundConnectionHandler handler = new OutboundConnectionHandler(out);
147         pipeline.addLast("handler", handler);
148 
149         Channel channel = channelFactory.newChannel(pipeline);
150         ChannelFuture future = channel.connect(remoteAddress).awaitUninterruptibly();
151         if (!future.isSuccess()) {
152             Throwable cause = future.getCause();
153             logger.warn("Endpoint unavailable: " + cause.getMessage(), cause);
154             res.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
155             return;
156         }
157 
158         ChannelFuture lastWriteFuture = null;
159         try {
160             res.setStatus(HttpServletResponse.SC_OK);
161             res.setHeader(HttpHeaders.Names.CONTENT_TYPE, "application/octet-stream");
162             res.setHeader(HttpHeaders.Names.CONTENT_TRANSFER_ENCODING, HttpHeaders.Values.BINARY);
163 
164             // Initiate chunked encoding by flushing the headers.
165             out.flush();
166 
167             PushbackInputStream in =
168                     new PushbackInputStream(req.getInputStream());
169             while (channel.isConnected()) {
170                 ChannelBuffer buffer;
171                 try {
172                     buffer = read(in);
173                 } catch (EOFException e) {
174                     break;
175                 }
176                 if (buffer == null) {
177                     break;
178                 }
179                 lastWriteFuture = channel.write(buffer);
180             }
181         } finally {
182             if (lastWriteFuture == null) {
183                 channel.close();
184             } else {
185                 lastWriteFuture.addListener(ChannelFutureListener.CLOSE);
186             }
187         }
188     }
189 
190     private static ChannelBuffer read(PushbackInputStream in) throws IOException {
191         byte[] buf;
192         int readBytes;
193 
194         int bytesToRead = in.available();
195         if (bytesToRead > 0) {
196             buf = new byte[bytesToRead];
197             readBytes = in.read(buf);
198         } else if (bytesToRead == 0) {
199             int b = in.read();
200             if (b < 0 || in.available() < 0) {
201                 return null;
202             }
203             in.unread(b);
204             bytesToRead = in.available();
205             buf = new byte[bytesToRead];
206             readBytes = in.read(buf);
207         } else {
208             return null;
209         }
210 
211         assert readBytes > 0;
212 
213         ChannelBuffer buffer;
214         if (readBytes == buf.length) {
215             buffer = ChannelBuffers.wrappedBuffer(buf);
216         } else {
217             // A rare case, but it sometimes happen.
218             buffer = ChannelBuffers.wrappedBuffer(buf, 0, readBytes);
219         }
220         return buffer;
221     }
222 
223     private static final class OutboundConnectionHandler extends SimpleChannelUpstreamHandler {
224 
225         private final ServletOutputStream out;
226 
227         public OutboundConnectionHandler(ServletOutputStream out) {
228             this.out = out;
229         }
230 
231         @Override
232         public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) throws Exception {
233             ChannelBuffer buffer = (ChannelBuffer) e.getMessage();
234             synchronized (this) {
235                 buffer.readBytes(out, buffer.readableBytes());
236                 out.flush();
237             }
238         }
239 
240         @Override
241         public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e) throws Exception {
242             logger.warn("Unexpected exception while HTTP tunneling", e.getCause());
243             e.getChannel().close();
244         }
245     }
246 }