View Javadoc

1   /*
2    * Copyright 2009 Red Hat, Inc.
3    *
4    * Red Hat licenses this file to you under the Apache License, version 2.0
5    * (the "License"); you may not use this file except in compliance with the
6    * License.  You may obtain a copy of the License at:
7    *
8    *    http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package org.jboss.netty.handler.execution;
17  
18  import java.lang.reflect.Method;
19  import java.util.concurrent.ConcurrentMap;
20  import java.util.concurrent.Executor;
21  import java.util.concurrent.Executors;
22  import java.util.concurrent.RejectedExecutionException;
23  import java.util.concurrent.RejectedExecutionHandler;
24  import java.util.concurrent.ThreadFactory;
25  import java.util.concurrent.ThreadPoolExecutor;
26  import java.util.concurrent.TimeUnit;
27  import java.util.concurrent.atomic.AtomicLong;
28  
29  import org.jboss.netty.buffer.ChannelBuffer;
30  import org.jboss.netty.channel.Channel;
31  import org.jboss.netty.channel.ChannelEvent;
32  import org.jboss.netty.channel.ChannelHandlerContext;
33  import org.jboss.netty.channel.ChannelState;
34  import org.jboss.netty.channel.ChannelStateEvent;
35  import org.jboss.netty.channel.MessageEvent;
36  import org.jboss.netty.channel.WriteCompletionEvent;
37  import org.jboss.netty.logging.InternalLogger;
38  import org.jboss.netty.logging.InternalLoggerFactory;
39  import org.jboss.netty.util.DefaultObjectSizeEstimator;
40  import org.jboss.netty.util.ObjectSizeEstimator;
41  import org.jboss.netty.util.internal.ConcurrentIdentityHashMap;
42  import org.jboss.netty.util.internal.LinkedTransferQueue;
43  import org.jboss.netty.util.internal.SharedResourceMisuseDetector;
44  
45  /**
46   * A {@link ThreadPoolExecutor} which blocks the task submission when there's
47   * too many tasks in the queue.  Both per-{@link Channel} and per-{@link Executor}
48   * limitation can be applied.
49   * <p>
50   * When a task (i.e. {@link Runnable}) is submitted,
51   * {@link MemoryAwareThreadPoolExecutor} calls {@link ObjectSizeEstimator#estimateSize(Object)}
52   * to get the estimated size of the task in bytes to calculate the amount of
53   * memory occupied by the unprocessed tasks.
54   * <p>
55   * If the total size of the unprocessed tasks exceeds either per-{@link Channel}
56   * or per-{@link Executor} threshold, any further {@link #execute(Runnable)}
57   * call will block until the tasks in the queue are processed so that the total
58   * size goes under the threshold.
59   *
60   * <h3>Using an alternative task size estimation strategy</h3>
61   *
62   * Although the default implementation does its best to guess the size of an
63   * object of unknown type, it is always good idea to to use an alternative
64   * {@link ObjectSizeEstimator} implementation instead of the
65   * {@link DefaultObjectSizeEstimator} to avoid incorrect task size calculation,
66   * especially when:
67   * <ul>
68   *   <li>you are using {@link MemoryAwareThreadPoolExecutor} independently from
69   *       {@link ExecutionHandler},</li>
70   *   <li>you are submitting a task whose type is not {@link ChannelEventRunnable}, or</li>
71   *   <li>the message type of the {@link MessageEvent} in the {@link ChannelEventRunnable}
72   *       is not {@link ChannelBuffer}.</li>
73   * </ul>
74   * Here is an example that demonstrates how to implement an {@link ObjectSizeEstimator}
75   * which understands a user-defined object:
76   * <pre>
77   * public class MyRunnable implements {@link Runnable} {
78   *
79   *     <b>private final byte[] data;</b>
80   *
81   *     public MyRunnable(byte[] data) {
82   *         this.data = data;
83   *     }
84   *
85   *     public void run() {
86   *         // Process 'data' ..
87   *     }
88   * }
89   *
90   * public class MyObjectSizeEstimator extends {@link DefaultObjectSizeEstimator} {
91   *
92   *     {@literal @Override}
93   *     public int estimateSize(Object o) {
94   *         if (<b>o instanceof MyRunnable</b>) {
95   *             <b>return ((MyRunnable) o).data.length + 8;</b>
96   *         }
97   *         return super.estimateSize(o);
98   *     }
99   * }
100  *
101  * {@link ThreadPoolExecutor} pool = new {@link MemoryAwareThreadPoolExecutor}(
102  *         16, 65536, 1048576, 30, {@link TimeUnit}.SECONDS,
103  *         <b>new MyObjectSizeEstimator()</b>,
104  *         {@link Executors}.defaultThreadFactory());
105  *
106  * <b>pool.execute(new MyRunnable(data));</b>
107  * </pre>
108  *
109  * <h3>Event execution order</h3>
110  *
111  * Please note that this executor does not maintain the order of the
112  * {@link ChannelEvent}s for the same {@link Channel}.  For example,
113  * you can even receive a {@code "channelClosed"} event before a
114  * {@code "messageReceived"} event, as depicted by the following diagram.
115  *
116  * For example, the events can be processed as depicted below:
117  *
118  * <pre>
119  *           --------------------------------&gt; Timeline --------------------------------&gt;
120  *
121  * Thread X: --- Channel A (Event 2) --- Channel A (Event 1) ---------------------------&gt;
122  *
123  * Thread Y: --- Channel A (Event 3) --- Channel B (Event 2) --- Channel B (Event 3) ---&gt;
124  *
125  * Thread Z: --- Channel B (Event 1) --- Channel B (Event 4) --- Channel A (Event 4) ---&gt;
126  * </pre>
127  *
128  * To maintain the event order, you must use {@link OrderedMemoryAwareThreadPoolExecutor}.
129  *
130  * @author <a href="http://www.jboss.org/netty/">The Netty Project</a>
131  * @author <a href="http://gleamynode.net/">Trustin Lee</a>
132  *
133  * @version $Rev: 2351 $, $Date: 2010-08-26 11:55:10 +0900 (Thu, 26 Aug 2010) $
134  *
135  * @apiviz.has org.jboss.netty.util.ObjectSizeEstimator oneway - -
136  * @apiviz.has org.jboss.netty.handler.execution.ChannelEventRunnable oneway - - executes
137  */
138 public class MemoryAwareThreadPoolExecutor extends ThreadPoolExecutor {
139 
140     private static final InternalLogger logger =
141         InternalLoggerFactory.getInstance(MemoryAwareThreadPoolExecutor.class);
142 
143     private static final SharedResourceMisuseDetector misuseDetector =
144         new SharedResourceMisuseDetector(MemoryAwareThreadPoolExecutor.class);
145 
146     private volatile Settings settings;
147 
148     private final ConcurrentMap<Channel, AtomicLong> channelCounters =
149         new ConcurrentIdentityHashMap<Channel, AtomicLong>();
150     private final Limiter totalLimiter;
151 
152     /**
153      * Creates a new instance.
154      *
155      * @param corePoolSize          the maximum number of active threads
156      * @param maxChannelMemorySize  the maximum total size of the queued events per channel.
157      *                              Specify {@code 0} to disable.
158      * @param maxTotalMemorySize    the maximum total size of the queued events for this pool
159      *                              Specify {@code 0} to disable.
160      */
161     public MemoryAwareThreadPoolExecutor(
162             int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize) {
163 
164         this(corePoolSize, maxChannelMemorySize, maxTotalMemorySize, 30, TimeUnit.SECONDS);
165     }
166 
167     /**
168      * Creates a new instance.
169      *
170      * @param corePoolSize          the maximum number of active threads
171      * @param maxChannelMemorySize  the maximum total size of the queued events per channel.
172      *                              Specify {@code 0} to disable.
173      * @param maxTotalMemorySize    the maximum total size of the queued events for this pool
174      *                              Specify {@code 0} to disable.
175      * @param keepAliveTime         the amount of time for an inactive thread to shut itself down
176      * @param unit                  the {@link TimeUnit} of {@code keepAliveTime}
177      */
178     public MemoryAwareThreadPoolExecutor(
179             int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
180             long keepAliveTime, TimeUnit unit) {
181 
182         this(corePoolSize, maxChannelMemorySize, maxTotalMemorySize, keepAliveTime, unit, Executors.defaultThreadFactory());
183     }
184 
185     /**
186      * Creates a new instance.
187      *
188      * @param corePoolSize          the maximum number of active threads
189      * @param maxChannelMemorySize  the maximum total size of the queued events per channel.
190      *                              Specify {@code 0} to disable.
191      * @param maxTotalMemorySize    the maximum total size of the queued events for this pool
192      *                              Specify {@code 0} to disable.
193      * @param keepAliveTime         the amount of time for an inactive thread to shut itself down
194      * @param unit                  the {@link TimeUnit} of {@code keepAliveTime}
195      * @param threadFactory         the {@link ThreadFactory} of this pool
196      */
197     public MemoryAwareThreadPoolExecutor(
198             int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
199             long keepAliveTime, TimeUnit unit, ThreadFactory threadFactory) {
200 
201         this(corePoolSize, maxChannelMemorySize, maxTotalMemorySize, keepAliveTime, unit, new DefaultObjectSizeEstimator(), threadFactory);
202     }
203 
204     /**
205      * Creates a new instance.
206      *
207      * @param corePoolSize          the maximum number of active threads
208      * @param maxChannelMemorySize  the maximum total size of the queued events per channel.
209      *                              Specify {@code 0} to disable.
210      * @param maxTotalMemorySize    the maximum total size of the queued events for this pool
211      *                              Specify {@code 0} to disable.
212      * @param keepAliveTime         the amount of time for an inactive thread to shut itself down
213      * @param unit                  the {@link TimeUnit} of {@code keepAliveTime}
214      * @param threadFactory         the {@link ThreadFactory} of this pool
215      * @param objectSizeEstimator   the {@link ObjectSizeEstimator} of this pool
216      */
217     public MemoryAwareThreadPoolExecutor(
218             int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
219             long keepAliveTime, TimeUnit unit, ObjectSizeEstimator objectSizeEstimator,
220             ThreadFactory threadFactory) {
221 
222         super(corePoolSize, corePoolSize, keepAliveTime, unit,
223               new LinkedTransferQueue<Runnable>(), threadFactory, new NewThreadRunsPolicy());
224 
225         if (objectSizeEstimator == null) {
226             throw new NullPointerException("objectSizeEstimator");
227         }
228         if (maxChannelMemorySize < 0) {
229             throw new IllegalArgumentException(
230                     "maxChannelMemorySize: " + maxChannelMemorySize);
231         }
232         if (maxTotalMemorySize < 0) {
233             throw new IllegalArgumentException(
234                     "maxTotalMemorySize: " + maxTotalMemorySize);
235         }
236 
237         // Call allowCoreThreadTimeOut(true) using reflection
238         // because it is not supported in Java 5.
239         try {
240             Method m = getClass().getMethod("allowCoreThreadTimeOut", new Class[] { boolean.class });
241             m.invoke(this, Boolean.TRUE);
242         } catch (Throwable t) {
243             // Java 5
244             logger.debug(
245                     "ThreadPoolExecutor.allowCoreThreadTimeOut() is not " +
246                     "supported in this platform.");
247         }
248 
249         settings = new Settings(
250                 objectSizeEstimator, maxChannelMemorySize);
251 
252         if (maxTotalMemorySize == 0) {
253             totalLimiter = null;
254         } else {
255             totalLimiter = new Limiter(maxTotalMemorySize);
256         }
257 
258         // Misuse check
259         misuseDetector.increase();
260     }
261 
262     @Override
263     protected void terminated() {
264         super.terminated();
265         misuseDetector.decrease();
266     }
267 
268     /**
269      * Returns the {@link ObjectSizeEstimator} of this pool.
270      */
271     public ObjectSizeEstimator getObjectSizeEstimator() {
272         return settings.objectSizeEstimator;
273     }
274 
275     /**
276      * Sets the {@link ObjectSizeEstimator} of this pool.
277      */
278     public void setObjectSizeEstimator(ObjectSizeEstimator objectSizeEstimator) {
279         if (objectSizeEstimator == null) {
280             throw new NullPointerException("objectSizeEstimator");
281         }
282 
283         settings = new Settings(
284                 objectSizeEstimator,
285                 settings.maxChannelMemorySize);
286     }
287 
288     /**
289      * Returns the maximum total size of the queued events per channel.
290      */
291     public long getMaxChannelMemorySize() {
292         return settings.maxChannelMemorySize;
293     }
294 
295     /**
296      * Sets the maximum total size of the queued events per channel.
297      * Specify {@code 0} to disable.
298      */
299     public void setMaxChannelMemorySize(long maxChannelMemorySize) {
300         if (maxChannelMemorySize < 0) {
301             throw new IllegalArgumentException(
302                     "maxChannelMemorySize: " + maxChannelMemorySize);
303         }
304 
305         if (getTaskCount() > 0) {
306             throw new IllegalStateException(
307                     "can't be changed after a task is executed");
308         }
309 
310         settings = new Settings(
311                 settings.objectSizeEstimator,
312                 maxChannelMemorySize);
313     }
314 
315     /**
316      * Returns the maximum total size of the queued events for this pool.
317      */
318     public long getMaxTotalMemorySize() {
319         return totalLimiter.limit;
320     }
321 
322     /**
323      * @deprecated <tt>maxTotalMemorySize</tt> is not modifiable anymore.
324      */
325     @Deprecated
326     public void setMaxTotalMemorySize(long maxTotalMemorySize) {
327         if (maxTotalMemorySize < 0) {
328             throw new IllegalArgumentException(
329                     "maxTotalMemorySize: " + maxTotalMemorySize);
330         }
331 
332         if (getTaskCount() > 0) {
333             throw new IllegalStateException(
334                     "can't be changed after a task is executed");
335         }
336     }
337 
338     @Override
339     public void execute(Runnable command) {
340         if (!(command instanceof ChannelEventRunnable)) {
341             command = new MemoryAwareRunnable(command);
342         }
343 
344         increaseCounter(command);
345         doExecute(command);
346     }
347 
348     /**
349      * Put the actual execution logic here.  The default implementation simply
350      * calls {@link #doUnorderedExecute(Runnable)}.
351      */
352     protected void doExecute(Runnable task) {
353         doUnorderedExecute(task);
354     }
355 
356     /**
357      * Executes the specified task without maintaining the event order.
358      */
359     protected final void doUnorderedExecute(Runnable task) {
360         super.execute(task);
361     }
362 
363     @Override
364     public boolean remove(Runnable task) {
365         boolean removed = super.remove(task);
366         if (removed) {
367             decreaseCounter(task);
368         }
369         return removed;
370     }
371 
372     @Override
373     protected void beforeExecute(Thread t, Runnable r) {
374         super.beforeExecute(t, r);
375         decreaseCounter(r);
376     }
377 
378     protected void increaseCounter(Runnable task) {
379         if (!shouldCount(task)) {
380             return;
381         }
382 
383         Settings settings = this.settings;
384         long maxChannelMemorySize = settings.maxChannelMemorySize;
385 
386         int increment = settings.objectSizeEstimator.estimateSize(task);
387 
388         if (task instanceof ChannelEventRunnable) {
389             ChannelEventRunnable eventTask = (ChannelEventRunnable) task;
390             eventTask.estimatedSize = increment;
391             Channel channel = eventTask.getEvent().getChannel();
392             long channelCounter = getChannelCounter(channel).addAndGet(increment);
393             //System.out.println("IC: " + channelCounter + ", " + increment);
394             if (maxChannelMemorySize != 0 && channelCounter >= maxChannelMemorySize && channel.isOpen()) {
395                 if (channel.isReadable()) {
396                     //System.out.println("UNREADABLE");
397                     ChannelHandlerContext ctx = eventTask.getContext();
398                     if (ctx.getHandler() instanceof ExecutionHandler) {
399                         // readSuspended = true;
400                         ctx.setAttachment(Boolean.TRUE);
401                     }
402                     channel.setReadable(false);
403                 }
404             }
405         } else {
406             ((MemoryAwareRunnable) task).estimatedSize = increment;
407         }
408 
409         if (totalLimiter != null) {
410             totalLimiter.increase(increment);
411         }
412     }
413 
414     protected void decreaseCounter(Runnable task) {
415         if (!shouldCount(task)) {
416             return;
417         }
418 
419         Settings settings = this.settings;
420         long maxChannelMemorySize = settings.maxChannelMemorySize;
421 
422         int increment;
423         if (task instanceof ChannelEventRunnable) {
424             increment = ((ChannelEventRunnable) task).estimatedSize;
425         } else {
426             increment = ((MemoryAwareRunnable) task).estimatedSize;
427         }
428 
429         if (totalLimiter != null) {
430             totalLimiter.decrease(increment);
431         }
432 
433         if (task instanceof ChannelEventRunnable) {
434             ChannelEventRunnable eventTask = (ChannelEventRunnable) task;
435             Channel channel = eventTask.getEvent().getChannel();
436             long channelCounter = getChannelCounter(channel).addAndGet(-increment);
437             //System.out.println("DC: " + channelCounter + ", " + increment);
438             if (maxChannelMemorySize != 0 && channelCounter < maxChannelMemorySize && channel.isOpen()) {
439                 if (!channel.isReadable()) {
440                     //System.out.println("READABLE");
441                     ChannelHandlerContext ctx = eventTask.getContext();
442                     if (ctx.getHandler() instanceof ExecutionHandler) {
443                         // readSuspended = false;
444                         ctx.setAttachment(null);
445                     }
446                     channel.setReadable(true);
447                 }
448             }
449         }
450     }
451 
452     private AtomicLong getChannelCounter(Channel channel) {
453         AtomicLong counter = channelCounters.get(channel);
454         if (counter == null) {
455             counter = new AtomicLong();
456             AtomicLong oldCounter = channelCounters.putIfAbsent(channel, counter);
457             if (oldCounter != null) {
458                 counter = oldCounter;
459             }
460         }
461 
462         // Remove the entry when the channel closes.
463         if (!channel.isOpen()) {
464             channelCounters.remove(channel);
465         }
466         return counter;
467     }
468 
469     /**
470      * Returns {@code true} if and only if the specified {@code task} should
471      * be counted to limit the global and per-channel memory consumption.
472      * To override this method, you must call {@code super.shouldCount()} to
473      * make sure important tasks are not counted.
474      */
475     protected boolean shouldCount(Runnable task) {
476         if (task instanceof ChannelEventRunnable) {
477             ChannelEventRunnable r = (ChannelEventRunnable) task;
478             ChannelEvent e = r.getEvent();
479             if (e instanceof WriteCompletionEvent) {
480                 return false;
481             } else if (e instanceof ChannelStateEvent) {
482                 if (((ChannelStateEvent) e).getState() == ChannelState.INTEREST_OPS) {
483                     return false;
484                 }
485             }
486         }
487         return true;
488     }
489 
490     private static final class Settings {
491         final ObjectSizeEstimator objectSizeEstimator;
492         final long maxChannelMemorySize;
493 
494         Settings(ObjectSizeEstimator objectSizeEstimator,
495                  long maxChannelMemorySize) {
496             this.objectSizeEstimator = objectSizeEstimator;
497             this.maxChannelMemorySize = maxChannelMemorySize;
498         }
499     }
500 
501     private static final class NewThreadRunsPolicy implements RejectedExecutionHandler {
502         NewThreadRunsPolicy() {
503             super();
504         }
505 
506         public void rejectedExecution(Runnable r, ThreadPoolExecutor executor) {
507             try {
508                 final Thread t = new Thread(r, "Temporary task executor");
509                 t.start();
510             } catch (Throwable e) {
511                 throw new RejectedExecutionException(
512                         "Failed to start a new thread", e);
513             }
514         }
515     }
516 
517     private static final class MemoryAwareRunnable implements Runnable {
518         final Runnable task;
519         int estimatedSize;
520 
521         MemoryAwareRunnable(Runnable task) {
522             this.task = task;
523         }
524 
525         public void run() {
526             task.run();
527         }
528     }
529 
530 
531     private static class Limiter {
532 
533         final long limit;
534         private long counter;
535         private int waiters;
536 
537         Limiter(long limit) {
538             super();
539             this.limit = limit;
540         }
541 
542         synchronized void increase(long amount) {
543             while (counter >= limit) {
544                 waiters ++;
545                 try {
546                     wait();
547                 } catch (InterruptedException e) {
548                     // Ignore
549                 } finally {
550                     waiters --;
551                 }
552             }
553             counter += amount;
554         }
555 
556         synchronized void decrease(long amount) {
557             counter -= amount;
558             if (counter < limit && waiters > 0) {
559                 notifyAll();
560             }
561         }
562     }
563 }