1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package org.jboss.netty.channel.socket.http;
17
18 import java.io.EOFException;
19 import java.io.IOException;
20 import java.io.PushbackInputStream;
21 import java.net.SocketAddress;
22
23 import javax.servlet.ServletConfig;
24 import javax.servlet.ServletException;
25 import javax.servlet.ServletOutputStream;
26 import javax.servlet.http.HttpServlet;
27 import javax.servlet.http.HttpServletRequest;
28 import javax.servlet.http.HttpServletResponse;
29
30 import org.jboss.netty.buffer.ChannelBuffer;
31 import org.jboss.netty.buffer.ChannelBuffers;
32 import org.jboss.netty.channel.Channel;
33 import org.jboss.netty.channel.ChannelFactory;
34 import org.jboss.netty.channel.ChannelFuture;
35 import org.jboss.netty.channel.ChannelFutureListener;
36 import org.jboss.netty.channel.ChannelHandlerContext;
37 import org.jboss.netty.channel.ChannelPipeline;
38 import org.jboss.netty.channel.Channels;
39 import org.jboss.netty.channel.ExceptionEvent;
40 import org.jboss.netty.channel.MessageEvent;
41 import org.jboss.netty.channel.SimpleChannelUpstreamHandler;
42 import org.jboss.netty.channel.local.DefaultLocalClientChannelFactory;
43 import org.jboss.netty.channel.local.LocalAddress;
44 import org.jboss.netty.handler.codec.http.HttpHeaders;
45 import org.jboss.netty.logging.InternalLogger;
46 import org.jboss.netty.logging.InternalLoggerFactory;
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61 public class HttpTunnelingServlet extends HttpServlet {
62
63 private static final long serialVersionUID = 4259910275899756070L;
64
65 private static final String ENDPOINT = "endpoint";
66
67 static final InternalLogger logger = InternalLoggerFactory.getInstance(HttpTunnelingServlet.class);
68
69 private volatile SocketAddress remoteAddress;
70 private volatile ChannelFactory channelFactory;
71
72 @Override
73 public void init() throws ServletException {
74 ServletConfig config = getServletConfig();
75 String endpoint = config.getInitParameter(ENDPOINT);
76 if (endpoint == null) {
77 throw new ServletException("init-param '" + ENDPOINT + "' must be specified.");
78 }
79
80 try {
81 remoteAddress = parseEndpoint(endpoint.trim());
82 } catch (ServletException e) {
83 throw e;
84 } catch (Exception e) {
85 throw new ServletException("Failed to parse an endpoint.", e);
86 }
87
88 try {
89 channelFactory = createChannelFactory(remoteAddress);
90 } catch (ServletException e) {
91 throw e;
92 } catch (Exception e) {
93 throw new ServletException("Failed to create a channel factory.", e);
94 }
95
96
97
98
99
100
101 }
102
103 protected SocketAddress parseEndpoint(String endpoint) throws Exception {
104 if (endpoint.startsWith("local:")) {
105 return new LocalAddress(endpoint.substring(6).trim());
106 } else {
107 throw new ServletException(
108 "Invalid or unknown endpoint: " + endpoint);
109 }
110 }
111
112 protected ChannelFactory createChannelFactory(SocketAddress remoteAddress) throws Exception {
113 if (remoteAddress instanceof LocalAddress) {
114 return new DefaultLocalClientChannelFactory();
115 } else {
116 throw new ServletException(
117 "Unsupported remote address type: " +
118 remoteAddress.getClass().getName());
119 }
120 }
121
122 @Override
123 public void destroy() {
124 try {
125 destroyChannelFactory(channelFactory);
126 } catch (Exception e) {
127 logger.warn("Failed to destroy a channel factory.", e);
128 }
129 }
130
131 protected void destroyChannelFactory(ChannelFactory factory) throws Exception {
132 factory.releaseExternalResources();
133 }
134
135 @Override
136 protected void service(HttpServletRequest req, HttpServletResponse res)
137 throws ServletException, IOException {
138 if (!"POST".equalsIgnoreCase(req.getMethod())) {
139 logger.warn("Unallowed method: " + req.getMethod());
140 res.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED);
141 return;
142 }
143
144 final ChannelPipeline pipeline = Channels.pipeline();
145 final ServletOutputStream out = res.getOutputStream();
146 final OutboundConnectionHandler handler = new OutboundConnectionHandler(out);
147 pipeline.addLast("handler", handler);
148
149 Channel channel = channelFactory.newChannel(pipeline);
150 ChannelFuture future = channel.connect(remoteAddress).awaitUninterruptibly();
151 if (!future.isSuccess()) {
152 Throwable cause = future.getCause();
153 logger.warn("Endpoint unavailable: " + cause.getMessage(), cause);
154 res.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
155 return;
156 }
157
158 ChannelFuture lastWriteFuture = null;
159 try {
160 res.setStatus(HttpServletResponse.SC_OK);
161 res.setHeader(HttpHeaders.Names.CONTENT_TYPE, "application/octet-stream");
162 res.setHeader(HttpHeaders.Names.CONTENT_TRANSFER_ENCODING, HttpHeaders.Values.BINARY);
163
164
165 out.flush();
166
167 PushbackInputStream in =
168 new PushbackInputStream(req.getInputStream());
169 while (channel.isConnected()) {
170 ChannelBuffer buffer;
171 try {
172 buffer = read(in);
173 } catch (EOFException e) {
174 break;
175 }
176 if (buffer == null) {
177 break;
178 }
179 lastWriteFuture = channel.write(buffer);
180 }
181 } finally {
182 if (lastWriteFuture == null) {
183 channel.close();
184 } else {
185 lastWriteFuture.addListener(ChannelFutureListener.CLOSE);
186 }
187 }
188 }
189
190 private static ChannelBuffer read(PushbackInputStream in) throws IOException {
191 byte[] buf;
192 int readBytes;
193
194 int bytesToRead = in.available();
195 if (bytesToRead > 0) {
196 buf = new byte[bytesToRead];
197 readBytes = in.read(buf);
198 } else if (bytesToRead == 0) {
199 int b = in.read();
200 if (b < 0 || in.available() < 0) {
201 return null;
202 }
203 in.unread(b);
204 bytesToRead = in.available();
205 buf = new byte[bytesToRead];
206 readBytes = in.read(buf);
207 } else {
208 return null;
209 }
210
211 assert readBytes > 0;
212
213 ChannelBuffer buffer;
214 if (readBytes == buf.length) {
215 buffer = ChannelBuffers.wrappedBuffer(buf);
216 } else {
217
218 buffer = ChannelBuffers.wrappedBuffer(buf, 0, readBytes);
219 }
220 return buffer;
221 }
222
223 private static final class OutboundConnectionHandler extends SimpleChannelUpstreamHandler {
224
225 private final ServletOutputStream out;
226
227 public OutboundConnectionHandler(ServletOutputStream out) {
228 this.out = out;
229 }
230
231 @Override
232 public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) throws Exception {
233 ChannelBuffer buffer = (ChannelBuffer) e.getMessage();
234 synchronized (this) {
235 buffer.readBytes(out, buffer.readableBytes());
236 out.flush();
237 }
238 }
239
240 @Override
241 public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e) throws Exception {
242 logger.warn("Unexpected exception while HTTP tunneling", e.getCause());
243 e.getChannel().close();
244 }
245 }
246 }