1 /*
2  * Copyright (c) 2012, 2013, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.
8  *
9  * This code is distributed in the hope that it will be useful, but WITHOUT
10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12  * version 2 for more details (a copy is included in the LICENSE file that
13  * accompanied this code).
14  *
15  * You should have received a copy of the GNU General Public License version
16  * 2 along with this work; if not, write to the Free Software Foundation,
17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18  *
19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20  * or visit www.oracle.com if you need additional information or have any
21  * questions.
22  */
23 package org.openjdk.tests.java.lang.invoke;
24 
25 import java.io.ByteArrayInputStream;
26 import java.io.ByteArrayOutputStream;
27 import java.io.IOException;
28 import java.io.NotSerializableException;
29 import java.io.ObjectInputStream;
30 import java.io.ObjectOutputStream;
31 import java.io.Serializable;
32 import java.lang.invoke.CallSite;
33 import java.lang.invoke.LambdaMetafactory;
34 import java.lang.invoke.MethodHandle;
35 import java.lang.invoke.MethodHandles;
36 import java.lang.invoke.MethodType;
37 import java.util.ArrayList;
38 import java.util.List;
39 import java.util.concurrent.atomic.AtomicLong;
40 import java.util.function.BiPredicate;
41 import java.util.function.Consumer;
42 import java.util.function.LongConsumer;
43 import java.util.function.Predicate;
44 import java.util.function.Supplier;
45 
46 import org.testng.annotations.Test;
47 
48 import static org.testng.Assert.assertFalse;
49 import static org.testng.Assert.assertTrue;
50 import static org.testng.Assert.fail;
51 
52 /**
53  * SerializedLambdaTest
54  *
55  * @author Brian Goetz
56  */
57 @Test
58 public class SerializedLambdaTest {
59     public static final int REPS = 50;
60 
61     @SuppressWarnings("unchecked")
assertSerial(T p, Consumer<T> asserter)62     private<T> void assertSerial(T p, Consumer<T> asserter) throws IOException, ClassNotFoundException {
63         asserter.accept(p);
64 
65         for (int i=0; i<REPS; i++) {
66             byte[] bytes = serialize(p);
67             assertTrue(bytes.length > 0);
68 
69             asserter.accept((T) deserialize(bytes));
70         }
71     }
72 
assertNotSerial(Predicate<String> p, Consumer<Predicate<String>> asserter)73     private void assertNotSerial(Predicate<String> p, Consumer<Predicate<String>> asserter)
74             throws IOException, ClassNotFoundException {
75         asserter.accept(p);
76         try {
77             byte[] bytes = serialize(p);
78             fail("Expected serialization failure");
79         }
80         catch (NotSerializableException e) {
81             // success
82         }
83     }
84 
serialize(Object o)85     private byte[] serialize(Object o) throws IOException {
86         ByteArrayOutputStream bos = new ByteArrayOutputStream();
87         ObjectOutputStream oos = new ObjectOutputStream(bos);
88         oos.writeObject(o);
89         oos.close();
90         return bos.toByteArray();
91     }
92 
deserialize(byte[] bytes)93     private Object deserialize(byte[] bytes) throws IOException, ClassNotFoundException {
94         try(ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(bytes))) {
95             return ois.readObject();
96         }
97     }
98 
99     // Test instantiating against intersection type
testSimpleSerializedInstantiation()100     public void testSimpleSerializedInstantiation() throws IOException, ClassNotFoundException {
101         @SuppressWarnings("unchecked")
102         Predicate<String> pred = (Predicate<String> & Serializable) s -> true;
103         assertSerial(pred,
104                      p -> {
105                          assertTrue(p instanceof Predicate);
106                          assertTrue(p instanceof Serializable);
107                          assertTrue(p.test(""));
108                      });
109     }
110 
111     interface SerPredicate<T> extends Predicate<T>, Serializable { }
112 
113     // Test instantiating against derived type
testSimpleSerializedInstantiation2()114     public void testSimpleSerializedInstantiation2() throws IOException, ClassNotFoundException {
115         SerPredicate<String> serPred = (SerPredicate<String>) s -> true;
116         assertSerial(serPred,
117                      p -> {
118                          assertTrue(p instanceof Predicate);
119                          assertTrue(p instanceof Serializable);
120                          assertTrue(p instanceof SerPredicate);
121                          assertTrue(p.test(""));
122                      });
123     }
124 
125     // Negative test: non-serializable lambdas are in fact not serializable
testNonserializableInstantiation()126     public void testNonserializableInstantiation() throws IOException, ClassNotFoundException {
127         @SuppressWarnings("unchecked")
128         Predicate<String> pred = (Predicate<String>) s -> true;
129         assertNotSerial(pred,
130                         p -> {
131                             assertTrue(p instanceof Predicate);
132                             assertFalse(p instanceof Serializable);
133                             assertTrue(p.test(""));
134                         });
135     }
136 
137     // Test lambda capturing int
testSerializeCapturingInt()138     public void testSerializeCapturingInt() throws IOException, ClassNotFoundException {
139         class Moo {
140             @SuppressWarnings("unchecked")
141             Predicate<String> foo(int x) {
142                 return (Predicate<String> & Serializable) s -> s.length() >= x;
143             }
144         }
145         Predicate<String> pred = new Moo().foo(3);
146         assertSerial(pred, p -> {
147             assertTrue(p.test("yada"));
148             assertFalse(p.test("no"));
149         });
150     }
151 
152     // Test lambda capturing String
testSerializeCapturingString()153     public void testSerializeCapturingString() throws IOException, ClassNotFoundException {
154         class Moo {
155             @SuppressWarnings("unchecked")
156             Predicate<String> foo(String t) {
157                 return (Predicate<String> & Serializable) s -> s.equals(t);
158             }
159         }
160         Predicate<String> pred = new Moo().foo("goo");
161         assertSerial(pred, p -> {
162             assertTrue(p.test("goo"));
163             assertFalse(p.test("foo"));
164         });
165     }
166 
167     // Negative test: lambdas that capture a non-serializable var
testSerializeCapturingNonSerializable()168     public void testSerializeCapturingNonSerializable() throws IOException, ClassNotFoundException {
169         class Box {
170             String s;
171 
172             Box(String s) { this.s = s; }
173         }
174         class Moo {
175             @SuppressWarnings("unchecked")
176             Predicate<String> foo(Box b) {
177                 return (Predicate<String> & Serializable) s -> s.equals(b.s);
178             }
179         }
180         Predicate<String> pred = new Moo().foo(new Box("goo"));
181         assertNotSerial(pred, p -> {
182             assertTrue(p.test("goo"));
183             assertFalse(p.test("foo"));
184         });
185     }
186 
startsWithA(String s)187     static boolean startsWithA(String s) {
188         return s.startsWith("a");
189     }
190 
191     // Test static method ref
testStaticMR()192     public void testStaticMR() throws IOException, ClassNotFoundException {
193         @SuppressWarnings("unchecked")
194         Predicate<String> mh1 = (Predicate<String> & Serializable) SerializedLambdaTest::startsWithA;
195         @SuppressWarnings("unchecked")
196         Predicate<String> mh2 = (SerPredicate<String>) SerializedLambdaTest::startsWithA;
197         Consumer<Predicate<String>> b = p -> {
198             assertTrue(p instanceof Serializable);
199             assertTrue(p.test("arf"));
200             assertFalse(p.test("barf"));
201         };
202         assertSerial(mh1, b);
203         assertSerial(mh2, b);
204     }
205 
206     // Test unbound method ref of nonserializable class -- should still succeed
testUnboundMR()207     public void testUnboundMR() throws IOException, ClassNotFoundException {
208         class Moo {
209             public boolean startsWithB(String s) {
210                 return s.startsWith("b");
211             }
212         }
213         @SuppressWarnings("unchecked")
214         BiPredicate<Moo, String> mh1 = (BiPredicate<Moo, String> & Serializable) Moo::startsWithB;
215         Consumer<BiPredicate<Moo, String>> b = p -> {
216             assertTrue(p instanceof Serializable);
217             assertTrue(p.test(new Moo(), "barf"));
218             assertFalse(p.test(new Moo(), "arf"));
219         };
220         assertSerial(mh1, b);
221     }
222 
223     // Negative test: test bound MR of nonserializable class
testBoundMRNotSerReceiver()224     public void testBoundMRNotSerReceiver() throws IOException, ClassNotFoundException {
225         class Moo {
226             public boolean startsWithB(String s) {
227                 return s.startsWith("b");
228             }
229         }
230         Moo moo = new Moo();
231         @SuppressWarnings("unchecked")
232         Predicate<String> mh1 = (Predicate<String> & Serializable) moo::startsWithB;
233         @SuppressWarnings("unchecked")
234         Predicate<String> mh2 = (SerPredicate<String>) moo::startsWithB;
235         Consumer<Predicate<String>> b = p -> {
236             assertTrue(p instanceof Serializable);
237             assertTrue(p.test("barf"));
238             assertFalse(p.test("arf"));
239         };
240         assertNotSerial(mh1, b);
241         assertNotSerial(mh2, b);
242     }
243 
244     // Test bound MR of serializable class
245     @SuppressWarnings("serial")
246     static class ForBoundMRef implements Serializable {
startsWithB(String s)247         public boolean startsWithB(String s) {
248             return s.startsWith("b");
249         }
250     }
251 
testBoundMR()252     public void testBoundMR() throws IOException, ClassNotFoundException {
253         ForBoundMRef moo = new ForBoundMRef();
254         @SuppressWarnings("unchecked")
255         Predicate<String> mh1 = (Predicate<String> & Serializable) moo::startsWithB;
256         @SuppressWarnings("unchecked")
257         Predicate<String> mh2 = (SerPredicate<String>) moo::startsWithB;
258         Consumer<Predicate<String>> b = p -> {
259             assertTrue(p instanceof Serializable);
260             assertTrue(p.test("barf"));
261             assertFalse(p.test("arf"));
262         };
263         assertSerial(mh1, b);
264         assertSerial(mh2, b);
265     }
266 
267     static class ForCtorRef {
startsWithB(String s)268         public boolean startsWithB(String s) {
269             return s.startsWith("b");
270         }
271     }
272     // Test ctor ref of nonserializable class
testCtorRef()273     public void testCtorRef() throws IOException, ClassNotFoundException {
274         @SuppressWarnings("unchecked")
275         Supplier<ForCtorRef> ctor = (Supplier<ForCtorRef> & Serializable) ForCtorRef::new;
276         Consumer<Supplier<ForCtorRef>> b = s -> {
277             assertTrue(s instanceof Serializable);
278             ForCtorRef m = s.get();
279             assertTrue(m.startsWithB("barf"));
280             assertFalse(m.startsWithB("arf"));
281         };
282         assertSerial(ctor, b);
283     }
284 
285     //Test throwing away return type
testDiscardReturnBound()286     public void testDiscardReturnBound() throws IOException, ClassNotFoundException {
287         List<String> list = new ArrayList<>();
288         Consumer<String> c = (Consumer<String> & Serializable) list::add;
289         assertSerial(c, cc -> { assertTrue(cc instanceof Consumer); });
290 
291         AtomicLong a = new AtomicLong();
292         LongConsumer lc = (LongConsumer & Serializable) a::addAndGet;
293         assertSerial(lc, plc -> { plc.accept(3); });
294     }
295 
296     // Tests of direct use of metafactories
297 
foo(Object s)298     private static boolean foo(Object s) { return s != null && ((String) s).length() > 0; }
299     private static final MethodType predicateMT = MethodType.methodType(boolean.class, Object.class);
300     private static final MethodType stringPredicateMT = MethodType.methodType(boolean.class, String.class);
301     private static final Consumer<Predicate<String>> fooAsserter = x -> {
302         assertTrue(x.test("foo"));
303         assertFalse(x.test(""));
304         assertFalse(x.test(null));
305     };
306 
307     // standard MF: nonserializable supertype
testDirectStdNonser()308     public void testDirectStdNonser() throws Throwable {
309         MethodHandle fooMH = MethodHandles.lookup().findStatic(SerializedLambdaTest.class, "foo", predicateMT);
310 
311         // Standard metafactory, non-serializable target: not serializable
312         CallSite cs = LambdaMetafactory.metafactory(MethodHandles.lookup(),
313                                                     "test", MethodType.methodType(Predicate.class),
314                                                     predicateMT, fooMH, stringPredicateMT);
315         Predicate<String> p = (Predicate<String>) cs.getTarget().invokeExact();
316         assertNotSerial(p, fooAsserter);
317     }
318 
319     // standard MF: serializable supertype
testDirectStdSer()320     public void testDirectStdSer() throws Throwable {
321         MethodHandle fooMH = MethodHandles.lookup().findStatic(SerializedLambdaTest.class, "foo", predicateMT);
322 
323         // Standard metafactory, serializable target: not serializable
324         CallSite cs = LambdaMetafactory.metafactory(MethodHandles.lookup(),
325                                                     "test", MethodType.methodType(SerPredicate.class),
326                                                     predicateMT, fooMH, stringPredicateMT);
327         assertNotSerial((SerPredicate<String>) cs.getTarget().invokeExact(), fooAsserter);
328     }
329 
330     // alt MF: nonserializable supertype
testAltStdNonser()331     public void testAltStdNonser() throws Throwable {
332         MethodHandle fooMH = MethodHandles.lookup().findStatic(SerializedLambdaTest.class, "foo", predicateMT);
333 
334         // Alt metafactory, non-serializable target: not serializable
335         CallSite cs = LambdaMetafactory.altMetafactory(MethodHandles.lookup(),
336                                                        "test", MethodType.methodType(Predicate.class),
337                                                        predicateMT, fooMH, stringPredicateMT, 0);
338         assertNotSerial((Predicate<String>) cs.getTarget().invokeExact(), fooAsserter);
339     }
340 
341     // alt MF: serializable supertype
testAltStdSer()342     public void testAltStdSer() throws Throwable {
343         MethodHandle fooMH = MethodHandles.lookup().findStatic(SerializedLambdaTest.class, "foo", predicateMT);
344 
345         // Alt metafactory, serializable target, no FLAG_SERIALIZABLE: not serializable
346         CallSite cs = LambdaMetafactory.altMetafactory(MethodHandles.lookup(),
347                                                        "test", MethodType.methodType(SerPredicate.class),
348                                                        predicateMT, fooMH, stringPredicateMT, 0);
349         assertNotSerial((SerPredicate<String>) cs.getTarget().invokeExact(), fooAsserter);
350 
351         // Alt metafactory, serializable marker, no FLAG_SERIALIZABLE: not serializable
352         cs = LambdaMetafactory.altMetafactory(MethodHandles.lookup(),
353                                               "test", MethodType.methodType(Predicate.class),
354                                               predicateMT, fooMH, stringPredicateMT, LambdaMetafactory.FLAG_MARKERS, 1, Serializable.class);
355         assertNotSerial((Predicate<String>) cs.getTarget().invokeExact(), fooAsserter);
356     }
357 }
358