1 /*
2  * Copyright (c) 2019, 2020, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.
8  *
9  * This code is distributed in the hope that it will be useful, but WITHOUT
10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12  * version 2 for more details (a copy is included in the LICENSE file that
13  * accompanied this code).
14  *
15  * You should have received a copy of the GNU General Public License version
16  * 2 along with this work; if not, write to the Free Software Foundation,
17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18  *
19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20  * or visit www.oracle.com if you need additional information or have any
21  * questions.
22  */
23 
24 /*
25  * @test
26  * @bug 8246774
27  * @summary Tests for stream references
28  * @run testng StreamRefTest
29  */
30 package test.java.io.Serializable.records;
31 
32 import java.io.ByteArrayInputStream;
33 import java.io.ByteArrayOutputStream;
34 import java.io.DataInputStream;
35 import java.io.IOException;
36 import java.io.InvalidObjectException;
37 import java.io.ObjectInputStream;
38 import java.io.ObjectOutputStream;
39 import java.io.Serializable;
40 
41 import org.testng.annotations.Test;
42 import static java.lang.System.out;
43 import static org.testng.Assert.assertEquals;
44 import static org.testng.Assert.assertTrue;
45 import static org.testng.Assert.expectThrows;
46 
47 /**
48  * Tests for stream references.
49  */
50 public class StreamRefTest {
51 
52     record A (int x) implements Serializable {
A(int x)53         public A(int x) {
54             if (x < 0)
55                 throw new IllegalArgumentException("negative value for x:" + x);
56             this.x = x;
57         }
58     }
59 
60     static class B implements Serializable {
61         final A a ;
B(A a)62         B(A a) { this.a = a; }
63     }
64 
65     record C (B b) implements Serializable {
C(B b)66         public C(B b) { this.b = b; }
67     }
68 
69     static class D implements Serializable {
70         final C c ;
D(C c)71         D(C c) { this.c = c; }
72     }
73 
74     @Test
basicRef()75     public void basicRef() throws Exception {
76         out.println("\n---");
77         var a = new A(6);
78         var b = new B(a);
79         var c = new C(b);
80         var d = new D(c);
81 
82         var bytes = serialize(a, b, c, d);
83 
84         A a1 = (A)deserializeOne(bytes);
85         B b1 = (B)deserializeOne(bytes);
86         C c1 = (C)deserializeOne(bytes);
87         D d1 = (D)deserializeOne(bytes);
88 
89         assertTrue(a1.x == a.x);
90         assertTrue(a1 == b1.a);
91         assertTrue(b1 == c1.b);
92         assertTrue(c1 == d1.c);
93     }
94 
95     @Test
reverseBasicRef()96     public void reverseBasicRef() throws Exception {
97         out.println("\n---");
98         var a = new A(7);
99         var b = new B(a);
100         var c = new C(b);
101         var d = new D(c);
102 
103         var bytes = serialize(d, c, b, a);
104 
105         D d1 = (D)deserializeOne(bytes);
106         C c1 = (C)deserializeOne(bytes);
107         B b1 = (B)deserializeOne(bytes);
108         A a1 = (A)deserializeOne(bytes);
109 
110         assertTrue(a1 == b1.a);
111         assertTrue(b1 == c1.b);
112         assertTrue(c1 == d1.c);
113     }
114 
115     static final Class<InvalidObjectException> IOE = InvalidObjectException.class;
116 
117     @Test
basicRefWithInvalidA()118     public void basicRefWithInvalidA() throws Exception {
119         out.println("\n---");
120         var a = new A(3);
121         var b = new B(a);
122 
123         var bytes = serializeToBytes(a, b);
124         // injects a bad (negative) value for field x (of record A), in the stream
125         // Android-changed: package name of this class adds additional characters to the bytes.
126         // updateIntValue(3, -3, bytes, 40);
127         updateIntValue(3, -3, bytes, 74);
128         var byteStream = new ObjectInputStream(new ByteArrayInputStream(bytes));
129 
130         InvalidObjectException ioe = expectThrows(IOE, () -> deserializeOne(byteStream));
131         out.println("caught expected IOE: " + ioe);
132         Throwable t = ioe.getCause();
133         assertTrue(t instanceof IllegalArgumentException, "Expected IAE, got:" + t);
134         out.println("expected cause IAE: " + t);
135 
136         B b1 = (B)deserializeOne(byteStream);
137         assertEquals(b1.a, null);
138     }
139 
140     @Test
reverseBasicRefWithInvalidA()141     public void reverseBasicRefWithInvalidA() throws Exception {
142         out.println("\n---");
143         var a = new A(3);
144         var b = new B(a);
145 
146         var bytes = serializeToBytes(b, a);
147         // injects a bad (negative) value for field x (of record A), in the stream
148         // Android-changed: package name of this class adds additional characters to the bytes.
149         // updateIntValue(3, -3, bytes, 96);
150         updateIntValue(3, -3, bytes, 198);
151         var byteStream = new ObjectInputStream(new ByteArrayInputStream(bytes));
152 
153         InvalidObjectException ioe = expectThrows(IOE, () -> deserializeOne(byteStream));
154         out.println("caught expected IOE: " + ioe);
155         Throwable t = ioe.getCause();
156         assertTrue(t instanceof IllegalArgumentException, "Expected IAE, got:" + t);
157         out.println("expected cause IAE: " + t);
158 
159         A a1 = (A)deserializeOne(byteStream);
160         assertEquals(a1, null);
161     }
162 
163     // ---
164 
165 //    static class Y implements Serializable {
166 //        final int i = 10;
167 //        private void readObject(ObjectInputStream in)
168 //            throws IOException, ClassNotFoundException
169 //        {
170 //            in.defaultReadObject();
171 //            throw new IllegalArgumentException("dunno");
172 //        }
173 //    }
174 //
175 //    static class Z implements Serializable {
176 //        final Y y ;
177 //        Z(Y y) { this.y = y; }
178 //    }
179 //
180 //    static final Class<IllegalArgumentException> IAE = IllegalArgumentException.class;
181 //
182 //    @Test
183 //    public void whatDoesPlainDeserializationDo() throws Exception {
184 //        out.println("\n---");
185 //        var y = new Y();
186 //        var z = new Z(y);
187 //
188 //        var byteStream = serialize(z, y);
189 //
190 //        IllegalArgumentException iae = expectThrows(IAE, () -> deserializeOne(byteStream));
191 //        out.println("caught expected IAE: " + iae);
192 //        iae.printStackTrace();
193 //
194 //        Y y1 = (Y)deserializeOne(byteStream);
195 //        assertEquals(y1.i, 0);
196 //    }
197 //
198 //    @Test
199 //    public void reverseWhatDoesPlainDeserializationDo() throws Exception {
200 //        out.println("\n---");
201 //        var y = new Y();
202 //        var z = new Z(y);
203 //
204 //        var byteStream = serialize(y, z);
205 //
206 //        IllegalArgumentException iae = expectThrows(IAE, () -> deserializeOne(byteStream));
207 //        out.println("caught expected IAE: " + iae);
208 //        //iae.printStackTrace();
209 //
210 //        Z z1 = (Z)deserializeOne(byteStream);
211 //        assertEquals(z1.y, null);
212 //    }
213 
214     // ---
215 
assertExpectedIntValue(int expectedValue, byte[] bytes, int offset)216     static void assertExpectedIntValue(int expectedValue, byte[] bytes, int offset)
217         throws IOException {
218         ByteArrayInputStream bais = new ByteArrayInputStream(bytes, offset, 4);
219         DataInputStream dis = new DataInputStream(bais);
220         assertEquals(dis.readInt(), expectedValue);
221     }
222 
updateIntValue(int expectedValue, int newValue, byte[] bytes, int offset)223     static void updateIntValue(int expectedValue, int newValue, byte[] bytes, int offset)
224         throws IOException
225     {
226         assertExpectedIntValue(expectedValue, bytes, offset);
227         bytes[offset + 0] = (byte)((newValue >>> 24) & 0xFF);
228         bytes[offset + 1] = (byte)((newValue >>> 16) & 0xFF);
229         bytes[offset + 2] = (byte)((newValue >>>  8) & 0xFF);
230         bytes[offset + 3] = (byte)((newValue >>>  0) & 0xFF);
231         assertExpectedIntValue(newValue, bytes, offset);
232     }
233 
serialize(Object... objs)234     static ObjectInputStream serialize(Object... objs) throws IOException {
235         return new ObjectInputStream(new ByteArrayInputStream(serializeToBytes(objs)));
236     }
237 
serializeToBytes(Object... objs)238     static byte[] serializeToBytes(Object... objs) throws IOException {
239         ByteArrayOutputStream baos = new ByteArrayOutputStream();
240         ObjectOutputStream oos = new ObjectOutputStream(baos);
241         for (Object obj : objs)
242             oos.writeObject(obj);
243         oos.close();
244         return baos.toByteArray();
245     }
246 
247     @SuppressWarnings("unchecked")
deserializeOne(ObjectInputStream ois)248     static Object deserializeOne(ObjectInputStream ois)
249         throws IOException, ClassNotFoundException
250     {
251         return ois.readObject();
252     }
253 }
254