1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package org.jboss.netty.handler.execution;
17
18 import java.lang.reflect.Method;
19 import java.util.concurrent.ConcurrentMap;
20 import java.util.concurrent.Executor;
21 import java.util.concurrent.Executors;
22 import java.util.concurrent.RejectedExecutionException;
23 import java.util.concurrent.RejectedExecutionHandler;
24 import java.util.concurrent.ThreadFactory;
25 import java.util.concurrent.ThreadPoolExecutor;
26 import java.util.concurrent.TimeUnit;
27 import java.util.concurrent.atomic.AtomicLong;
28
29 import org.jboss.netty.buffer.ChannelBuffer;
30 import org.jboss.netty.channel.Channel;
31 import org.jboss.netty.channel.ChannelEvent;
32 import org.jboss.netty.channel.ChannelHandlerContext;
33 import org.jboss.netty.channel.ChannelState;
34 import org.jboss.netty.channel.ChannelStateEvent;
35 import org.jboss.netty.channel.MessageEvent;
36 import org.jboss.netty.channel.WriteCompletionEvent;
37 import org.jboss.netty.logging.InternalLogger;
38 import org.jboss.netty.logging.InternalLoggerFactory;
39 import org.jboss.netty.util.DefaultObjectSizeEstimator;
40 import org.jboss.netty.util.ObjectSizeEstimator;
41 import org.jboss.netty.util.internal.ConcurrentIdentityHashMap;
42 import org.jboss.netty.util.internal.LinkedTransferQueue;
43 import org.jboss.netty.util.internal.SharedResourceMisuseDetector;
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138 public class MemoryAwareThreadPoolExecutor extends ThreadPoolExecutor {
139
140 private static final InternalLogger logger =
141 InternalLoggerFactory.getInstance(MemoryAwareThreadPoolExecutor.class);
142
143 private static final SharedResourceMisuseDetector misuseDetector =
144 new SharedResourceMisuseDetector(MemoryAwareThreadPoolExecutor.class);
145
146 private volatile Settings settings;
147
148 private final ConcurrentMap<Channel, AtomicLong> channelCounters =
149 new ConcurrentIdentityHashMap<Channel, AtomicLong>();
150 private final Limiter totalLimiter;
151
152
153
154
155
156
157
158
159
160
161 public MemoryAwareThreadPoolExecutor(
162 int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize) {
163
164 this(corePoolSize, maxChannelMemorySize, maxTotalMemorySize, 30, TimeUnit.SECONDS);
165 }
166
167
168
169
170
171
172
173
174
175
176
177
178 public MemoryAwareThreadPoolExecutor(
179 int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
180 long keepAliveTime, TimeUnit unit) {
181
182 this(corePoolSize, maxChannelMemorySize, maxTotalMemorySize, keepAliveTime, unit, Executors.defaultThreadFactory());
183 }
184
185
186
187
188
189
190
191
192
193
194
195
196
197 public MemoryAwareThreadPoolExecutor(
198 int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
199 long keepAliveTime, TimeUnit unit, ThreadFactory threadFactory) {
200
201 this(corePoolSize, maxChannelMemorySize, maxTotalMemorySize, keepAliveTime, unit, new DefaultObjectSizeEstimator(), threadFactory);
202 }
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217 public MemoryAwareThreadPoolExecutor(
218 int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
219 long keepAliveTime, TimeUnit unit, ObjectSizeEstimator objectSizeEstimator,
220 ThreadFactory threadFactory) {
221
222 super(corePoolSize, corePoolSize, keepAliveTime, unit,
223 new LinkedTransferQueue<Runnable>(), threadFactory, new NewThreadRunsPolicy());
224
225 if (objectSizeEstimator == null) {
226 throw new NullPointerException("objectSizeEstimator");
227 }
228 if (maxChannelMemorySize < 0) {
229 throw new IllegalArgumentException(
230 "maxChannelMemorySize: " + maxChannelMemorySize);
231 }
232 if (maxTotalMemorySize < 0) {
233 throw new IllegalArgumentException(
234 "maxTotalMemorySize: " + maxTotalMemorySize);
235 }
236
237
238
239 try {
240 Method m = getClass().getMethod("allowCoreThreadTimeOut", new Class[] { boolean.class });
241 m.invoke(this, Boolean.TRUE);
242 } catch (Throwable t) {
243
244 logger.debug(
245 "ThreadPoolExecutor.allowCoreThreadTimeOut() is not " +
246 "supported in this platform.");
247 }
248
249 settings = new Settings(
250 objectSizeEstimator, maxChannelMemorySize);
251
252 if (maxTotalMemorySize == 0) {
253 totalLimiter = null;
254 } else {
255 totalLimiter = new Limiter(maxTotalMemorySize);
256 }
257
258
259 misuseDetector.increase();
260 }
261
262 @Override
263 protected void terminated() {
264 super.terminated();
265 misuseDetector.decrease();
266 }
267
268
269
270
271 public ObjectSizeEstimator getObjectSizeEstimator() {
272 return settings.objectSizeEstimator;
273 }
274
275
276
277
278 public void setObjectSizeEstimator(ObjectSizeEstimator objectSizeEstimator) {
279 if (objectSizeEstimator == null) {
280 throw new NullPointerException("objectSizeEstimator");
281 }
282
283 settings = new Settings(
284 objectSizeEstimator,
285 settings.maxChannelMemorySize);
286 }
287
288
289
290
291 public long getMaxChannelMemorySize() {
292 return settings.maxChannelMemorySize;
293 }
294
295
296
297
298
299 public void setMaxChannelMemorySize(long maxChannelMemorySize) {
300 if (maxChannelMemorySize < 0) {
301 throw new IllegalArgumentException(
302 "maxChannelMemorySize: " + maxChannelMemorySize);
303 }
304
305 if (getTaskCount() > 0) {
306 throw new IllegalStateException(
307 "can't be changed after a task is executed");
308 }
309
310 settings = new Settings(
311 settings.objectSizeEstimator,
312 maxChannelMemorySize);
313 }
314
315
316
317
318 public long getMaxTotalMemorySize() {
319 return totalLimiter.limit;
320 }
321
322
323
324
325 @Deprecated
326 public void setMaxTotalMemorySize(long maxTotalMemorySize) {
327 if (maxTotalMemorySize < 0) {
328 throw new IllegalArgumentException(
329 "maxTotalMemorySize: " + maxTotalMemorySize);
330 }
331
332 if (getTaskCount() > 0) {
333 throw new IllegalStateException(
334 "can't be changed after a task is executed");
335 }
336 }
337
338 @Override
339 public void execute(Runnable command) {
340 if (!(command instanceof ChannelEventRunnable)) {
341 command = new MemoryAwareRunnable(command);
342 }
343
344 increaseCounter(command);
345 doExecute(command);
346 }
347
348
349
350
351
352 protected void doExecute(Runnable task) {
353 doUnorderedExecute(task);
354 }
355
356
357
358
359 protected final void doUnorderedExecute(Runnable task) {
360 super.execute(task);
361 }
362
363 @Override
364 public boolean remove(Runnable task) {
365 boolean removed = super.remove(task);
366 if (removed) {
367 decreaseCounter(task);
368 }
369 return removed;
370 }
371
372 @Override
373 protected void beforeExecute(Thread t, Runnable r) {
374 super.beforeExecute(t, r);
375 decreaseCounter(r);
376 }
377
378 protected void increaseCounter(Runnable task) {
379 if (!shouldCount(task)) {
380 return;
381 }
382
383 Settings settings = this.settings;
384 long maxChannelMemorySize = settings.maxChannelMemorySize;
385
386 int increment = settings.objectSizeEstimator.estimateSize(task);
387
388 if (task instanceof ChannelEventRunnable) {
389 ChannelEventRunnable eventTask = (ChannelEventRunnable) task;
390 eventTask.estimatedSize = increment;
391 Channel channel = eventTask.getEvent().getChannel();
392 long channelCounter = getChannelCounter(channel).addAndGet(increment);
393
394 if (maxChannelMemorySize != 0 && channelCounter >= maxChannelMemorySize && channel.isOpen()) {
395 if (channel.isReadable()) {
396
397 ChannelHandlerContext ctx = eventTask.getContext();
398 if (ctx.getHandler() instanceof ExecutionHandler) {
399
400 ctx.setAttachment(Boolean.TRUE);
401 }
402 channel.setReadable(false);
403 }
404 }
405 } else {
406 ((MemoryAwareRunnable) task).estimatedSize = increment;
407 }
408
409 if (totalLimiter != null) {
410 totalLimiter.increase(increment);
411 }
412 }
413
414 protected void decreaseCounter(Runnable task) {
415 if (!shouldCount(task)) {
416 return;
417 }
418
419 Settings settings = this.settings;
420 long maxChannelMemorySize = settings.maxChannelMemorySize;
421
422 int increment;
423 if (task instanceof ChannelEventRunnable) {
424 increment = ((ChannelEventRunnable) task).estimatedSize;
425 } else {
426 increment = ((MemoryAwareRunnable) task).estimatedSize;
427 }
428
429 if (totalLimiter != null) {
430 totalLimiter.decrease(increment);
431 }
432
433 if (task instanceof ChannelEventRunnable) {
434 ChannelEventRunnable eventTask = (ChannelEventRunnable) task;
435 Channel channel = eventTask.getEvent().getChannel();
436 long channelCounter = getChannelCounter(channel).addAndGet(-increment);
437
438 if (maxChannelMemorySize != 0 && channelCounter < maxChannelMemorySize && channel.isOpen()) {
439 if (!channel.isReadable()) {
440
441 ChannelHandlerContext ctx = eventTask.getContext();
442 if (ctx.getHandler() instanceof ExecutionHandler) {
443
444 ctx.setAttachment(null);
445 }
446 channel.setReadable(true);
447 }
448 }
449 }
450 }
451
452 private AtomicLong getChannelCounter(Channel channel) {
453 AtomicLong counter = channelCounters.get(channel);
454 if (counter == null) {
455 counter = new AtomicLong();
456 AtomicLong oldCounter = channelCounters.putIfAbsent(channel, counter);
457 if (oldCounter != null) {
458 counter = oldCounter;
459 }
460 }
461
462
463 if (!channel.isOpen()) {
464 channelCounters.remove(channel);
465 }
466 return counter;
467 }
468
469
470
471
472
473
474
475 protected boolean shouldCount(Runnable task) {
476 if (task instanceof ChannelEventRunnable) {
477 ChannelEventRunnable r = (ChannelEventRunnable) task;
478 ChannelEvent e = r.getEvent();
479 if (e instanceof WriteCompletionEvent) {
480 return false;
481 } else if (e instanceof ChannelStateEvent) {
482 if (((ChannelStateEvent) e).getState() == ChannelState.INTEREST_OPS) {
483 return false;
484 }
485 }
486 }
487 return true;
488 }
489
490 private static final class Settings {
491 final ObjectSizeEstimator objectSizeEstimator;
492 final long maxChannelMemorySize;
493
494 Settings(ObjectSizeEstimator objectSizeEstimator,
495 long maxChannelMemorySize) {
496 this.objectSizeEstimator = objectSizeEstimator;
497 this.maxChannelMemorySize = maxChannelMemorySize;
498 }
499 }
500
501 private static final class NewThreadRunsPolicy implements RejectedExecutionHandler {
502 NewThreadRunsPolicy() {
503 super();
504 }
505
506 public void rejectedExecution(Runnable r, ThreadPoolExecutor executor) {
507 try {
508 final Thread t = new Thread(r, "Temporary task executor");
509 t.start();
510 } catch (Throwable e) {
511 throw new RejectedExecutionException(
512 "Failed to start a new thread", e);
513 }
514 }
515 }
516
517 private static final class MemoryAwareRunnable implements Runnable {
518 final Runnable task;
519 int estimatedSize;
520
521 MemoryAwareRunnable(Runnable task) {
522 this.task = task;
523 }
524
525 public void run() {
526 task.run();
527 }
528 }
529
530
531 private static class Limiter {
532
533 final long limit;
534 private long counter;
535 private int waiters;
536
537 Limiter(long limit) {
538 super();
539 this.limit = limit;
540 }
541
542 synchronized void increase(long amount) {
543 while (counter >= limit) {
544 waiters ++;
545 try {
546 wait();
547 } catch (InterruptedException e) {
548
549 } finally {
550 waiters --;
551 }
552 }
553 counter += amount;
554 }
555
556 synchronized void decrease(long amount) {
557 counter -= amount;
558 if (counter < limit && waiters > 0) {
559 notifyAll();
560 }
561 }
562 }
563 }