1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package org.jboss.netty.handler.codec.serialization;
17
18 import java.io.EOFException;
19 import java.io.IOException;
20 import java.io.InputStream;
21 import java.io.ObjectInputStream;
22 import java.io.ObjectStreamClass;
23 import java.io.StreamCorruptedException;
24 import java.util.HashMap;
25 import java.util.Map;
26
27
28
29
30
31
32
33
34 class CompactObjectInputStream extends ObjectInputStream {
35
36 private final Map<String, Class<?>> classCache = new HashMap<String, Class<?>>();
37 private final ClassLoader classLoader;
38
39 CompactObjectInputStream(InputStream in) throws IOException {
40 this(in, null);
41 }
42
43 CompactObjectInputStream(InputStream in, ClassLoader classLoader) throws IOException {
44 super(in);
45 this.classLoader = classLoader;
46 }
47
48 @Override
49 protected void readStreamHeader() throws IOException,
50 StreamCorruptedException {
51 int version = readByte() & 0xFF;
52 if (version != STREAM_VERSION) {
53 throw new StreamCorruptedException(
54 "Unsupported version: " + version);
55 }
56 }
57
58 @Override
59 protected ObjectStreamClass readClassDescriptor()
60 throws IOException, ClassNotFoundException {
61 int type = read();
62 if (type < 0) {
63 throw new EOFException();
64 }
65 switch (type) {
66 case CompactObjectOutputStream.TYPE_FAT_DESCRIPTOR:
67 return super.readClassDescriptor();
68 case CompactObjectOutputStream.TYPE_THIN_DESCRIPTOR:
69 String className = readUTF();
70 Class<?> clazz = loadClass(className);
71 return ObjectStreamClass.lookupAny(clazz);
72 default:
73 throw new StreamCorruptedException(
74 "Unexpected class descriptor type: " + type);
75 }
76 }
77
78 @Override
79 protected Class<?> resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException {
80
81 String className = desc.getName();
82 Class<?> clazz = classCache.get(className);
83 if (clazz != null) {
84 return clazz;
85 }
86
87
88 try {
89 clazz = loadClass(className);
90 } catch (ClassNotFoundException ex) {
91 clazz = super.resolveClass(desc);
92 }
93
94 classCache.put(className, clazz);
95 return clazz;
96 }
97
98 protected Class<?> loadClass(String className) throws ClassNotFoundException {
99
100 Class<?> clazz;
101 clazz = classCache.get(className);
102 if (clazz != null) {
103 return clazz;
104 }
105
106
107 ClassLoader classLoader = this.classLoader;
108 if (classLoader == null) {
109 classLoader = Thread.currentThread().getContextClassLoader();
110 }
111
112 if (classLoader != null) {
113 clazz = classLoader.loadClass(className);
114 } else {
115 clazz = Class.forName(className);
116 }
117
118 classCache.put(className, clazz);
119 return clazz;
120 }
121 }