1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.internal.util.function.pooled;
18 
19 import android.annotation.Nullable;
20 import android.os.Message;
21 import android.text.TextUtils;
22 import android.util.Log;
23 import android.util.Pools;
24 
25 import com.android.internal.util.ArrayUtils;
26 import com.android.internal.util.BitUtils;
27 import com.android.internal.util.function.HexConsumer;
28 import com.android.internal.util.function.HexFunction;
29 import com.android.internal.util.function.HexPredicate;
30 import com.android.internal.util.function.QuadConsumer;
31 import com.android.internal.util.function.QuadFunction;
32 import com.android.internal.util.function.QuadPredicate;
33 import com.android.internal.util.function.QuintConsumer;
34 import com.android.internal.util.function.QuintFunction;
35 import com.android.internal.util.function.QuintPredicate;
36 import com.android.internal.util.function.TriConsumer;
37 import com.android.internal.util.function.TriFunction;
38 import com.android.internal.util.function.TriPredicate;
39 
40 import java.util.Arrays;
41 import java.util.function.BiConsumer;
42 import java.util.function.BiFunction;
43 import java.util.function.BiPredicate;
44 import java.util.function.Consumer;
45 import java.util.function.Function;
46 import java.util.function.Predicate;
47 import java.util.function.Supplier;
48 
49 /**
50  * @see PooledLambda
51  * @hide
52  */
53 final class PooledLambdaImpl<R> extends OmniFunction<Object,
54         Object, Object, Object, Object, Object, R> {
55 
56     private static final boolean DEBUG = false;
57     private static final String LOG_TAG = "PooledLambdaImpl";
58 
59     private static final int MAX_ARGS = 5;
60 
61     private static final int MAX_POOL_SIZE = 50;
62 
63     static class Pool extends Pools.SynchronizedPool<PooledLambdaImpl> {
64 
Pool(Object lock)65         public Pool(Object lock) {
66             super(MAX_POOL_SIZE, lock);
67         }
68     }
69 
70     static final Pool sPool = new Pool(new Object());
71     static final Pool sMessageCallbacksPool = new Pool(Message.sPoolSync);
72 
PooledLambdaImpl()73     private PooledLambdaImpl() {}
74 
75     /**
76      * The function reference to be invoked
77      *
78      * May be the return value itself in case when an immediate result constant is provided instead
79      */
80     Object mFunc;
81 
82     /**
83      * A primitive result value to be immediately returned on invocation instead of calling
84      * {@link #mFunc}
85      */
86     long mConstValue;
87 
88     /**
89      * Arguments for {@link #mFunc}
90      */
91     @Nullable Object[] mArgs = null;
92 
93     /**
94      * Flag for {@link #mFlags}
95      *
96      * Indicates whether this instance is recycled
97      */
98     private static final int FLAG_RECYCLED = 1 << MAX_ARGS;
99 
100     /**
101      * Flag for {@link #mFlags}
102      *
103      * Indicates whether this instance should be immediately recycled on invocation
104      * (as requested via {@link PooledLambda#recycleOnUse()}) or not(default)
105      */
106     private static final int FLAG_RECYCLE_ON_USE = 1 << (MAX_ARGS + 1);
107 
108     /**
109      * Flag for {@link #mFlags}
110      *
111      * Indicates that this instance was acquired from {@link #sMessageCallbacksPool} as opposed to
112      * {@link #sPool}
113      */
114     private static final int FLAG_ACQUIRED_FROM_MESSAGE_CALLBACKS_POOL = 1 << (MAX_ARGS + 2);
115 
116     /** @see #mFlags */
117     static final int MASK_EXPOSED_AS = LambdaType.MASK << (MAX_ARGS + 3);
118 
119     /** @see #mFlags */
120     static final int MASK_FUNC_TYPE = LambdaType.MASK <<
121             (MAX_ARGS + 3 + LambdaType.MASK_BIT_COUNT);
122 
123     /**
124      * Bit schema:
125      * AAAABCDEEEEEEFFFFFF
126      *
127      * Where:
128      * A - whether {@link #mArgs arg} at corresponding index was specified at
129      * {@link #acquire creation time} (0) or {@link #invoke invocation time} (1)
130      * B - {@link #FLAG_RECYCLED}
131      * C - {@link #FLAG_RECYCLE_ON_USE}
132      * D - {@link #FLAG_ACQUIRED_FROM_MESSAGE_CALLBACKS_POOL}
133      * E - {@link LambdaType} representing the type of the lambda returned to the caller from a
134      * factory method
135      * F - {@link LambdaType} of {@link #mFunc} as resolved when calling a factory method
136      */
137     int mFlags = 0;
138 
139 
140     @Override
recycle()141     public void recycle() {
142         if (DEBUG) Log.i(LOG_TAG, this + ".recycle()");
143         if (!isRecycled()) doRecycle();
144     }
145 
doRecycle()146     private void doRecycle() {
147         if (DEBUG) Log.i(LOG_TAG, this + ".doRecycle()");
148         Pool pool = (mFlags & FLAG_ACQUIRED_FROM_MESSAGE_CALLBACKS_POOL) != 0
149                 ? PooledLambdaImpl.sMessageCallbacksPool
150                 : PooledLambdaImpl.sPool;
151 
152         mFunc = null;
153         if (mArgs != null) Arrays.fill(mArgs, null);
154         mFlags = FLAG_RECYCLED;
155         mConstValue = 0L;
156 
157         pool.release(this);
158     }
159 
160     @Override
invoke(Object a1, Object a2, Object a3, Object a4, Object a5, Object a6)161     R invoke(Object a1, Object a2, Object a3, Object a4, Object a5, Object a6) {
162         checkNotRecycled();
163         if (DEBUG) {
164             Log.i(LOG_TAG, this + ".invoke("
165                     + commaSeparateFirstN(
166                             new Object[] { a1, a2, a3, a4, a5, a6 },
167                             LambdaType.decodeArgCount(getFlags(MASK_EXPOSED_AS)))
168                     + ")");
169         }
170         final boolean notUsed = fillInArg(a1) && fillInArg(a2) && fillInArg(a3)
171                 && fillInArg(a4) && fillInArg(a5) && fillInArg(a6);
172         int argCount = LambdaType.decodeArgCount(getFlags(MASK_FUNC_TYPE));
173         if (argCount != LambdaType.MASK_ARG_COUNT) {
174             for (int i = 0; i < argCount; i++) {
175                 if (mArgs[i] == ArgumentPlaceholder.INSTANCE) {
176                     throw new IllegalStateException("Missing argument #" + i + " among "
177                             + Arrays.toString(mArgs));
178                 }
179             }
180         }
181         try {
182             return doInvoke();
183         } finally {
184             if (isRecycleOnUse()) doRecycle();
185             if (!isRecycled()) {
186                 int argsSize = ArrayUtils.size(mArgs);
187                 for (int i = 0; i < argsSize; i++) {
188                     popArg(i);
189                 }
190             }
191         }
192     }
193 
fillInArg(Object invocationArg)194     private boolean fillInArg(Object invocationArg) {
195         int argsSize = ArrayUtils.size(mArgs);
196         for (int i = 0; i < argsSize; i++) {
197             if (mArgs[i] == ArgumentPlaceholder.INSTANCE) {
198                 mArgs[i] = invocationArg;
199                 mFlags |= BitUtils.bitAt(i);
200                 return true;
201             }
202         }
203         if (invocationArg != null && invocationArg != ArgumentPlaceholder.INSTANCE) {
204             throw new IllegalStateException("No more arguments expected for provided arg "
205                     + invocationArg + " among " + Arrays.toString(mArgs));
206         }
207         return false;
208     }
209 
checkNotRecycled()210     private void checkNotRecycled() {
211         if (isRecycled()) throw new IllegalStateException("Instance is recycled: " + this);
212     }
213 
214     @SuppressWarnings("unchecked")
doInvoke()215     private R doInvoke() {
216         final int funcType = getFlags(MASK_FUNC_TYPE);
217         final int argCount = LambdaType.decodeArgCount(funcType);
218         final int returnType = LambdaType.decodeReturnType(funcType);
219 
220         switch (argCount) {
221             case LambdaType.MASK_ARG_COUNT: {
222                 switch (returnType) {
223                     case LambdaType.ReturnType.INT: return (R) (Integer) getAsInt();
224                     case LambdaType.ReturnType.LONG: return (R) (Long) getAsLong();
225                     case LambdaType.ReturnType.DOUBLE: return (R) (Double) getAsDouble();
226                     default: return (R) mFunc;
227                 }
228             }
229             case 0: {
230                 switch (returnType) {
231                     case LambdaType.ReturnType.VOID: {
232                         ((Runnable) mFunc).run();
233                         return null;
234                     }
235                     case LambdaType.ReturnType.BOOLEAN:
236                     case LambdaType.ReturnType.OBJECT: {
237                         return (R) ((Supplier) mFunc).get();
238                     }
239                 }
240             } break;
241             case 1: {
242                 switch (returnType) {
243                     case LambdaType.ReturnType.VOID: {
244                         ((Consumer) mFunc).accept(popArg(0));
245                         return null;
246                     }
247                     case LambdaType.ReturnType.BOOLEAN: {
248                         return (R) (Object) ((Predicate) mFunc).test(popArg(0));
249                     }
250                     case LambdaType.ReturnType.OBJECT: {
251                         return (R) ((Function) mFunc).apply(popArg(0));
252                     }
253                 }
254             } break;
255             case 2: {
256                 switch (returnType) {
257                     case LambdaType.ReturnType.VOID: {
258                         ((BiConsumer) mFunc).accept(popArg(0), popArg(1));
259                         return null;
260                     }
261                     case LambdaType.ReturnType.BOOLEAN: {
262                         return (R) (Object) ((BiPredicate) mFunc).test(popArg(0), popArg(1));
263                     }
264                     case LambdaType.ReturnType.OBJECT: {
265                         return (R) ((BiFunction) mFunc).apply(popArg(0), popArg(1));
266                     }
267                 }
268             } break;
269             case 3: {
270                 switch (returnType) {
271                     case LambdaType.ReturnType.VOID: {
272                         ((TriConsumer) mFunc).accept(popArg(0), popArg(1), popArg(2));
273                         return null;
274                     }
275                     case LambdaType.ReturnType.BOOLEAN: {
276                         return (R) (Object) ((TriPredicate) mFunc).test(
277                                 popArg(0), popArg(1), popArg(2));
278                     }
279                     case LambdaType.ReturnType.OBJECT: {
280                         return (R) ((TriFunction) mFunc).apply(popArg(0), popArg(1), popArg(2));
281                     }
282                 }
283             } break;
284             case 4: {
285                 switch (returnType) {
286                     case LambdaType.ReturnType.VOID: {
287                         ((QuadConsumer) mFunc).accept(popArg(0), popArg(1), popArg(2), popArg(3));
288                         return null;
289                     }
290                     case LambdaType.ReturnType.BOOLEAN: {
291                         return (R) (Object) ((QuadPredicate) mFunc).test(
292                                 popArg(0), popArg(1), popArg(2), popArg(3));
293                     }
294                     case LambdaType.ReturnType.OBJECT: {
295                         return (R) ((QuadFunction) mFunc).apply(
296                                 popArg(0), popArg(1), popArg(2), popArg(3));
297                     }
298                 }
299             } break;
300 
301             case 5: {
302                 switch (returnType) {
303                     case LambdaType.ReturnType.VOID: {
304                         ((QuintConsumer) mFunc).accept(popArg(0), popArg(1),
305                                 popArg(2), popArg(3), popArg(4));
306                         return null;
307                     }
308                     case LambdaType.ReturnType.BOOLEAN: {
309                         return (R) (Object) ((QuintPredicate) mFunc).test(
310                                 popArg(0), popArg(1), popArg(2), popArg(3), popArg(4));
311                     }
312                     case LambdaType.ReturnType.OBJECT: {
313                         return (R) ((QuintFunction) mFunc).apply(
314                                 popArg(0), popArg(1), popArg(2), popArg(3),  popArg(4));
315                     }
316                 }
317             } break;
318 
319             case 6: {
320                 switch (returnType) {
321                     case LambdaType.ReturnType.VOID: {
322                         ((HexConsumer) mFunc).accept(popArg(0), popArg(1),
323                                 popArg(2), popArg(3), popArg(4), popArg(5));
324                         return null;
325                     }
326                     case LambdaType.ReturnType.BOOLEAN: {
327                         return (R) (Object) ((HexPredicate) mFunc).test(popArg(0),
328                                 popArg(1), popArg(2), popArg(3), popArg(4), popArg(5));
329                     }
330                     case LambdaType.ReturnType.OBJECT: {
331                         return (R) ((HexFunction) mFunc).apply(popArg(0), popArg(1),
332                                 popArg(2), popArg(3), popArg(4), popArg(5));
333                     }
334                 }
335             }
336         }
337         throw new IllegalStateException("Unknown function type: " + LambdaType.toString(funcType));
338     }
339 
isConstSupplier()340     private boolean isConstSupplier() {
341         return LambdaType.decodeArgCount(getFlags(MASK_FUNC_TYPE)) == LambdaType.MASK_ARG_COUNT;
342     }
343 
popArg(int index)344     private Object popArg(int index) {
345         Object result = mArgs[index];
346         if (isInvocationArgAtIndex(index)) {
347             mArgs[index] = ArgumentPlaceholder.INSTANCE;
348             mFlags &= ~BitUtils.bitAt(index);
349         }
350         return result;
351     }
352 
353     @Override
toString()354     public String toString() {
355         if (isRecycled()) return "<recycled PooledLambda@" + hashCodeHex(this) + ">";
356 
357         StringBuilder sb = new StringBuilder();
358         if (isConstSupplier()) {
359             sb.append(getFuncTypeAsString()).append("(").append(doInvoke()).append(")");
360         } else {
361             if (mFunc instanceof PooledLambdaImpl) {
362                 sb.append(mFunc);
363             } else {
364                 sb.append(getFuncTypeAsString()).append("@").append(hashCodeHex(mFunc));
365             }
366             sb.append("(");
367             sb.append(commaSeparateFirstN(mArgs, LambdaType.decodeArgCount(getFlags(MASK_FUNC_TYPE))));
368             sb.append(")");
369         }
370         return sb.toString();
371     }
372 
commaSeparateFirstN(@ullable Object[] arr, int n)373     private String commaSeparateFirstN(@Nullable Object[] arr, int n) {
374         if (arr == null) return "";
375         return TextUtils.join(",", Arrays.copyOf(arr, n));
376     }
377 
hashCodeHex(Object o)378     private static String hashCodeHex(Object o) {
379         return Integer.toHexString(o.hashCode());
380     }
381 
getFuncTypeAsString()382     private String getFuncTypeAsString() {
383         if (isRecycled()) throw new IllegalStateException();
384         if (isConstSupplier()) return "supplier";
385         String name = LambdaType.toString(getFlags(MASK_EXPOSED_AS));
386         if (name.endsWith("Consumer")) return "consumer";
387         if (name.endsWith("Function")) return "function";
388         if (name.endsWith("Predicate")) return "predicate";
389         if (name.endsWith("Supplier")) return "supplier";
390         if (name.endsWith("Runnable")) return "runnable";
391         throw new IllegalStateException("Don't know the string representation of " + name);
392     }
393 
394     /**
395      * Internal non-typesafe factory method for {@link PooledLambdaImpl}
396      */
acquire(Pool pool, Object func, int fNumArgs, int numPlaceholders, int fReturnType, Object a, Object b, Object c, Object d, Object e, Object f)397     static <E extends PooledLambda> E acquire(Pool pool, Object func,
398             int fNumArgs, int numPlaceholders, int fReturnType,
399             Object a, Object b, Object c, Object d, Object e, Object f) {
400         PooledLambdaImpl r = acquire(pool);
401         if (DEBUG) {
402             Log.i(LOG_TAG,
403                     "acquire(this = @" + hashCodeHex(r)
404                             + ", func = " + func
405                             + ", fNumArgs = " + fNumArgs
406                             + ", numPlaceholders = " + numPlaceholders
407                             + ", fReturnType = " + LambdaType.ReturnType.toString(fReturnType)
408                             + ", a = " + a
409                             + ", b = " + b
410                             + ", c = " + c
411                             + ", d = " + d
412                             + ", e = " + e
413                             + ", f = " + f
414                             + ")");
415         }
416         r.mFunc = func;
417         r.setFlags(MASK_FUNC_TYPE, LambdaType.encode(fNumArgs, fReturnType));
418         r.setFlags(MASK_EXPOSED_AS, LambdaType.encode(numPlaceholders, fReturnType));
419         if (ArrayUtils.size(r.mArgs) < fNumArgs) r.mArgs = new Object[fNumArgs];
420         setIfInBounds(r.mArgs, 0, a);
421         setIfInBounds(r.mArgs, 1, b);
422         setIfInBounds(r.mArgs, 2, c);
423         setIfInBounds(r.mArgs, 3, d);
424         setIfInBounds(r.mArgs, 4, e);
425         setIfInBounds(r.mArgs, 5, f);
426         return (E) r;
427     }
428 
acquireConstSupplier(int type)429     static PooledLambdaImpl acquireConstSupplier(int type) {
430         PooledLambdaImpl r = acquire(PooledLambdaImpl.sPool);
431         int lambdaType = LambdaType.encode(LambdaType.MASK_ARG_COUNT, type);
432         r.setFlags(PooledLambdaImpl.MASK_FUNC_TYPE, lambdaType);
433         r.setFlags(PooledLambdaImpl.MASK_EXPOSED_AS, lambdaType);
434         return r;
435     }
436 
acquire(Pool pool)437     static PooledLambdaImpl acquire(Pool pool) {
438         PooledLambdaImpl r = pool.acquire();
439         if (r == null) r = new PooledLambdaImpl();
440         r.mFlags &= ~FLAG_RECYCLED;
441         r.setFlags(FLAG_ACQUIRED_FROM_MESSAGE_CALLBACKS_POOL,
442                 pool == sMessageCallbacksPool ? 1 : 0);
443         return r;
444     }
445 
setIfInBounds(Object[] array, int i, Object a)446     private static void setIfInBounds(Object[] array, int i, Object a) {
447         if (i < ArrayUtils.size(array)) array[i] = a;
448     }
449 
450     @Override
negate()451     public OmniFunction<Object, Object, Object, Object, Object, Object, R> negate() {
452         throw new UnsupportedOperationException();
453     }
454 
455     @Override
andThen( Function<? super R, ? extends V> after)456     public <V> OmniFunction<Object, Object, Object, Object, Object, Object, V> andThen(
457             Function<? super R, ? extends V> after) {
458         throw new UnsupportedOperationException();
459     }
460 
461     @Override
getAsDouble()462     public double getAsDouble() {
463         return Double.longBitsToDouble(mConstValue);
464     }
465 
466     @Override
getAsInt()467     public int getAsInt() {
468         return (int) mConstValue;
469     }
470 
471     @Override
getAsLong()472     public long getAsLong() {
473         return mConstValue;
474     }
475 
476     @Override
recycleOnUse()477     public OmniFunction<Object, Object, Object, Object, Object, Object, R> recycleOnUse() {
478         if (DEBUG) Log.i(LOG_TAG, this + ".recycleOnUse()");
479         mFlags |= FLAG_RECYCLE_ON_USE;
480         return this;
481     }
482 
isRecycled()483     private boolean isRecycled() {
484         return (mFlags & FLAG_RECYCLED) != 0;
485     }
486 
isRecycleOnUse()487     private boolean isRecycleOnUse() {
488         return (mFlags & FLAG_RECYCLE_ON_USE) != 0;
489     }
490 
isInvocationArgAtIndex(int argIndex)491     private boolean isInvocationArgAtIndex(int argIndex) {
492         return (mFlags & (1 << argIndex)) != 0;
493     }
494 
getFlags(int mask)495     int getFlags(int mask) {
496         return unmask(mask, mFlags);
497     }
498 
setFlags(int mask, int value)499     void setFlags(int mask, int value) {
500         mFlags &= ~mask;
501         mFlags |= mask(mask, value);
502     }
503 
504     /**
505      * 0xFF000, 0xAB -> 0xAB000
506      */
mask(int mask, int value)507     private static int mask(int mask, int value) {
508         return (value << Integer.numberOfTrailingZeros(mask)) & mask;
509     }
510 
511     /**
512      * 0xFF000, 0xAB123 -> 0xAB
513      */
unmask(int mask, int bits)514     private static int unmask(int mask, int bits) {
515         return (bits & mask) / (1 << Integer.numberOfTrailingZeros(mask));
516     }
517 
518     /**
519      * Contract for encoding a supported lambda type in {@link #MASK_BIT_COUNT} bits
520      */
521     static class LambdaType {
522         public static final int MASK_ARG_COUNT = 0b111;
523         public static final int MASK_RETURN_TYPE = 0b111000;
524         public static final int MASK = MASK_ARG_COUNT | MASK_RETURN_TYPE;
525         public static final int MASK_BIT_COUNT = 6;
526 
encode(int argCount, int returnType)527         static int encode(int argCount, int returnType) {
528             return mask(MASK_ARG_COUNT, argCount) | mask(MASK_RETURN_TYPE, returnType);
529         }
530 
decodeArgCount(int type)531         static int decodeArgCount(int type) {
532             return type & MASK_ARG_COUNT;
533         }
534 
decodeReturnType(int type)535         static int decodeReturnType(int type) {
536             return unmask(MASK_RETURN_TYPE, type);
537         }
538 
toString(int type)539         static String toString(int type) {
540             int argCount = decodeArgCount(type);
541             int returnType = decodeReturnType(type);
542             if (argCount == 0) {
543                 if (returnType == ReturnType.VOID) return "Runnable";
544                 if (returnType == ReturnType.OBJECT || returnType == ReturnType.BOOLEAN) {
545                     return "Supplier";
546                 }
547             }
548             return argCountPrefix(argCount) + ReturnType.lambdaSuffix(returnType);
549         }
550 
argCountPrefix(int argCount)551         private static String argCountPrefix(int argCount) {
552             switch (argCount) {
553                 case MASK_ARG_COUNT: return "";
554                 case 1: return "";
555                 case 2: return "Bi";
556                 case 3: return "Tri";
557                 case 4: return "Quad";
558                 case 5: return "Quint";
559                 case 6: return "Hex";
560                 default: throw new IllegalArgumentException("" + argCount);
561             }
562         }
563 
564         static class ReturnType {
565             public static final int VOID = 1;
566             public static final int BOOLEAN = 2;
567             public static final int OBJECT = 3;
568             public static final int INT = 4;
569             public static final int LONG = 5;
570             public static final int DOUBLE = 6;
571 
toString(int returnType)572             static String toString(int returnType) {
573                 switch (returnType) {
574                     case VOID: return "VOID";
575                     case BOOLEAN: return "BOOLEAN";
576                     case OBJECT: return "OBJECT";
577                     case INT: return "INT";
578                     case LONG: return "LONG";
579                     case DOUBLE: return "DOUBLE";
580                     default: return "" + returnType;
581                 }
582             }
583 
lambdaSuffix(int type)584             static String lambdaSuffix(int type) {
585                 return prefix(type) + suffix(type);
586             }
587 
prefix(int type)588             private static String prefix(int type) {
589                 switch (type) {
590                     case INT: return "Int";
591                     case LONG: return "Long";
592                     case DOUBLE: return "Double";
593                     default: return "";
594                 }
595             }
596 
suffix(int type)597             private static String suffix(int type) {
598                 switch (type) {
599                     case VOID: return "Consumer";
600                     case BOOLEAN: return "Predicate";
601                     case OBJECT: return "Function";
602                     default: return "Supplier";
603                 }
604             }
605         }
606     }
607 }
608