1 /*
<lambda>null2  * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3  */
4 
5 package kotlinx.coroutines
6 
7 import java.lang.ref.*
8 import java.lang.reflect.*
9 import java.text.*
10 import java.util.*
11 import java.util.Collections.*
12 import java.util.concurrent.atomic.*
13 import java.util.concurrent.locks.*
14 import kotlin.collections.ArrayList
15 import kotlin.test.*
16 
17 object FieldWalker {
18     sealed class Ref {
19         object RootRef : Ref()
20         class FieldRef(val parent: Any, val name: String) : Ref()
21         class ArrayRef(val parent: Any, val index: Int) : Ref()
22     }
23 
24     private val fieldsCache = HashMap<Class<*>, List<Field>>()
25 
26     init {
27         // excluded/terminal classes (don't walk them)
28         fieldsCache += listOf(
29             Any::class, String::class, Thread::class, Throwable::class, StackTraceElement::class,
30             WeakReference::class, ReferenceQueue::class, AbstractMap::class,
31             ReentrantReadWriteLock::class, SimpleDateFormat::class
32         )
33             .map { it.java }
34             .associateWith { emptyList<Field>() }
35     }
36 
37     /*
38      * Reflectively starts to walk through object graph and returns identity set of all reachable objects.
39      * Use [walkRefs] if you need a path from root for debugging.
40      */
41     public fun walk(root: Any?): Set<Any> = walkRefs(root, false).keys
42 
43     public fun assertReachableCount(expected: Int, root: Any?, rootStatics: Boolean = false, predicate: (Any) -> Boolean) {
44         val visited = walkRefs(root, rootStatics)
45         val actual = visited.keys.filter(predicate)
46         if (actual.size != expected) {
47             val textDump = actual.joinToString("") { "\n\t" + showPath(it, visited) }
48             assertEquals(
49                 expected, actual.size,
50                 "Unexpected number objects. Expected $expected, found ${actual.size}$textDump"
51             )
52         }
53     }
54 
55     /*
56      * Reflectively starts to walk through object graph and map to all the reached object to their path
57      * in from root. Use [showPath] do display a path if needed.
58      */
59     private fun walkRefs(root: Any?, rootStatics: Boolean): Map<Any, Ref> {
60         val visited = IdentityHashMap<Any, Ref>()
61         if (root == null) return visited
62         visited[root] = Ref.RootRef
63         val stack = ArrayDeque<Any>()
64         stack.addLast(root)
65         var statics = rootStatics
66         while (stack.isNotEmpty()) {
67             val element = stack.removeLast()
68             try {
69                 visit(element, visited, stack, statics)
70                 statics = false // only scan root static when asked
71             } catch (e: Exception) {
72                 error("Failed to visit element ${showPath(element, visited)}: $e")
73             }
74         }
75         return visited
76     }
77 
78     private fun showPath(element: Any, visited: Map<Any, Ref>): String {
79         val path = ArrayList<String>()
80         var cur = element
81         while (true) {
82             val ref = visited.getValue(cur)
83             if (ref is Ref.RootRef) break
84             when (ref) {
85                 is Ref.FieldRef -> {
86                     cur = ref.parent
87                     path += "|${ref.parent.javaClass.simpleName}::${ref.name}"
88                 }
89                 is Ref.ArrayRef -> {
90                     cur = ref.parent
91                     path += "[${ref.index}]"
92                 }
93             }
94         }
95         path.reverse()
96         return path.joinToString("")
97     }
98 
99     private fun visit(element: Any, visited: IdentityHashMap<Any, Ref>, stack: ArrayDeque<Any>, statics: Boolean) {
100         val type = element.javaClass
101         when {
102             // Special code for arrays
103             type.isArray && !type.componentType.isPrimitive -> {
104                 @Suppress("UNCHECKED_CAST")
105                 val array = element as Array<Any?>
106                 array.forEachIndexed { index, value ->
107                     push(value, visited, stack) { Ref.ArrayRef(element, index) }
108                 }
109             }
110             // Special code for platform types that cannot be reflectively accessed on modern JDKs
111             type.name.startsWith("java.") && element is Collection<*> -> {
112                 element.forEachIndexed { index, value ->
113                     push(value, visited, stack) { Ref.ArrayRef(element, index) }
114                 }
115             }
116             type.name.startsWith("java.") && element is Map<*, *> -> {
117                 push(element.keys, visited, stack) { Ref.FieldRef(element, "keys") }
118                 push(element.values, visited, stack) { Ref.FieldRef(element, "values") }
119             }
120             element is AtomicReference<*> -> {
121                 push(element.get(), visited, stack) { Ref.FieldRef(element, "value") }
122             }
123             element is AtomicReferenceArray<*> -> {
124                 for (index in 0 until element.length()) {
125                     push(element[index], visited, stack) { Ref.ArrayRef(element, index) }
126                 }
127             }
128             element is AtomicLongFieldUpdater<*> -> {
129                 /* filter it out here to suppress its subclasses too */
130             }
131             // All the other classes are reflectively scanned
132             else -> fields(type, statics).forEach { field ->
133                 push(field.get(element), visited, stack) { Ref.FieldRef(element, field.name) }
134                 // special case to scan Throwable cause (cannot get it reflectively)
135                 if (element is Throwable) {
136                     push(element.cause, visited, stack) { Ref.FieldRef(element, "cause") }
137                 }
138             }
139         }
140     }
141 
142     private inline fun push(value: Any?, visited: IdentityHashMap<Any, Ref>, stack: ArrayDeque<Any>, ref: () -> Ref) {
143         if (value != null && !visited.containsKey(value)) {
144             visited[value] = ref()
145             stack.addLast(value)
146         }
147     }
148 
149     private fun fields(type0: Class<*>, rootStatics: Boolean): List<Field> {
150         fieldsCache[type0]?.let { return it }
151         val result = ArrayList<Field>()
152         var type = type0
153         var statics = rootStatics
154         while (true) {
155             val fields = type.declaredFields.filter {
156                 !it.type.isPrimitive
157                         && (statics || !Modifier.isStatic(it.modifiers))
158                         && !(it.type.isArray && it.type.componentType.isPrimitive)
159             }
160             fields.forEach { it.isAccessible = true } // make them all accessible
161             result.addAll(fields)
162             type = type.superclass
163             statics = false
164             val superFields = fieldsCache[type] // will stop at Any anyway
165             if (superFields != null) {
166                 result.addAll(superFields)
167                 break
168             }
169         }
170         fieldsCache[type0] = result
171         return result
172     }
173 }
174