1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.net.module.util
18 
19 import java.util.concurrent.TimeUnit
20 import java.util.concurrent.locks.Condition
21 import java.util.concurrent.locks.ReentrantLock
22 import java.util.concurrent.locks.StampedLock
23 import kotlin.concurrent.withLock
24 
25 /**
26  * A List that additionally offers the ability to append via the add() method, and to retrieve
27  * an element by its index optionally waiting for it to become available.
28  */
29 interface TrackRecord<E> : List<E> {
30     /**
31      * Adds an element to this queue, waking up threads waiting for one. Returns true, as
32      * per the contract for List.
33      */
addnull34     fun add(e: E): Boolean
35 
36     /**
37      * Returns the first element after {@param pos}, possibly blocking until one is available, or
38      * null if no such element can be found within the timeout.
39      * If a predicate is given, only elements matching the predicate are returned.
40      *
41      * @param timeoutMs how long, in milliseconds, to wait at most (best effort approximation).
42      * @param pos the position at which to start polling.
43      * @param predicate an optional predicate to filter elements to be returned.
44      * @return an element matching the predicate, or null if timeout.
45      */
46     fun poll(timeoutMs: Long, pos: Int, predicate: (E) -> Boolean = { true }): E?
47 }
48 
49 /**
50  * A thread-safe implementation of TrackRecord that is backed by an ArrayList.
51  *
52  * This class also supports the creation of a read-head for easier single-thread access.
53  * Refer to the documentation of {@link ArrayTrackRecord.ReadHead}.
54  */
55 class ArrayTrackRecord<E> : TrackRecord<E> {
56     private val lock = ReentrantLock()
57     private val condition = lock.newCondition()
58     // Backing store. This stores the elements in this ArrayTrackRecord.
59     private val elements = ArrayList<E>()
60 
61     // The list iterator for RecordingQueue iterates over a snapshot of the collection at the
62     // time the operator is created. Because TrackRecord is only ever mutated by appending,
63     // that makes this iterator thread-safe as it sees an effectively immutable List.
64     class ArrayTrackRecordIterator<E>(
65         private val list: ArrayList<E>,
66         start: Int,
67         private val end: Int
68     ) : ListIterator<E> {
69         var index = start
hasNextnull70         override fun hasNext() = index < end
71         override fun next() = list[index++]
72         override fun hasPrevious() = index > 0
73         override fun nextIndex() = index + 1
74         override fun previous() = list[--index]
75         override fun previousIndex() = index - 1
76     }
77 
78     // List<E> implementation
79     override val size get() = lock.withLock { elements.size }
<lambda>null80     override fun contains(element: E) = lock.withLock { elements.contains(element) }
<lambda>null81     override fun containsAll(elements: Collection<E>) = lock.withLock {
82         this.elements.containsAll(elements)
83     }
<lambda>null84     override operator fun get(index: Int) = lock.withLock { elements[index] }
<lambda>null85     override fun indexOf(element: E): Int = lock.withLock { elements.indexOf(element) }
<lambda>null86     override fun lastIndexOf(element: E): Int = lock.withLock { elements.lastIndexOf(element) }
<lambda>null87     override fun isEmpty() = lock.withLock { elements.isEmpty() }
listIteratornull88     override fun listIterator(index: Int) = ArrayTrackRecordIterator(elements, index, size)
89     override fun listIterator() = listIterator(0)
90     override fun iterator() = listIterator()
91     override fun subList(fromIndex: Int, toIndex: Int): List<E> = lock.withLock {
92         elements.subList(fromIndex, toIndex)
93     }
94 
95     // TrackRecord<E> implementation
addnull96     override fun add(e: E): Boolean {
97         lock.withLock {
98             elements.add(e)
99             condition.signalAll()
100         }
101         return true
102     }
<lambda>null103     override fun poll(timeoutMs: Long, pos: Int, predicate: (E) -> Boolean) = lock.withLock {
104         elements.getOrNull(pollForIndexReadLocked(timeoutMs, pos, predicate))
105     }
106 
107     // For convenience
<lambda>null108     fun getOrNull(pos: Int, predicate: (E) -> Boolean) = lock.withLock {
109         if (pos < 0 || pos > size) null else elements.subList(pos, size).find(predicate)
110     }
111 
112     // Returns the index of the next element whose position is >= pos matching the predicate, if
113     // necessary waiting until such a time that such an element is available, with a timeout.
114     // If no such element is found within the timeout -1 is returned.
pollForIndexReadLockednull115     private fun pollForIndexReadLocked(timeoutMs: Long, pos: Int, predicate: (E) -> Boolean): Int {
116         val deadline = System.currentTimeMillis() + timeoutMs
117         var index = pos
118         do {
119             while (index < elements.size) {
120                 if (predicate(elements[index])) return index
121                 ++index
122             }
123         } while (condition.await(deadline - System.currentTimeMillis()))
124         return -1
125     }
126 
127     /**
128      * Returns a ReadHead over this ArrayTrackRecord. The returned ReadHead is tied to the
129      * current thread.
130      */
newReadHeadnull131     fun newReadHead() = ReadHead()
132 
133     /**
134      * ReadHead is an object that helps users of ArrayTrackRecord keep track of how far
135      * it has read this far in the ArrayTrackRecord. A ReadHead is always associated with
136      * a single instance of ArrayTrackRecord. Multiple ReadHeads can be created and used
137      * on the same instance of ArrayTrackRecord concurrently, and the ArrayTrackRecord
138      * instance can also be used concurrently. ReadHead maintains the current index that is
139      * the next to be read, and calls this the "mark".
140      *
141      * In a ReadHead, {@link poll(Long, (E) -> Boolean)} works similarly to a LinkedBlockingQueue.
142      * It can be called repeatedly and will return the elements as they arrive.
143      *
144      * Intended usage looks something like this :
145      * val TrackRecord<MyObject> record = ArrayTrackRecord().newReadHead()
146      * Thread().start {
147      *   // do stuff
148      *   record.add(something)
149      *   // do stuff
150      * }
151      *
152      * val obj1 = record.poll(timeout)
153      * // do something with obj1
154      * val obj2 = record.poll(timeout)
155      * // do something with obj2
156      *
157      * The point is that the caller does not have to track the mark like it would have to if
158      * it was using ArrayTrackRecord directly.
159      *
160      * Thread safety :
161      * A ReadHead delegates all TrackRecord methods to its associated ArrayTrackRecord, and
162      * inherits its thread-safe properties for all the TrackRecord methods.
163      *
164      * Poll() operates under its own set of rules that only allow execution on multiple threads
165      * within constrained boundaries, and never concurrently or pseudo-concurrently. This is
166      * because concurrent calls to poll() fundamentally do not make sense. poll() will move
167      * the mark according to what events remained to be read by this read head, and therefore
168      * if multiple threads were calling poll() concurrently on the same ReadHead, what
169      * happens to the mark and the return values could not be useful because there is no way to
170      * provide either a guarantee not to skip objects nor a guarantee about the mark position at
171      * the exit of poll(). This is even more true in the presence of a predicate to filter
172      * returned elements, because one thread might be filtering out the events the other is
173      * interested in. For this reason, this class will fail-fast if any concurrent access is
174      * detected with ConcurrentAccessException.
175      * It is possible to use poll() on different threads as long as the following can be
176      * guaranteed : one thread must call poll() for the last time, then execute a write barrier,
177      * then the other thread must execute a read barrier before calling poll() for the first time.
178      * This allows in particular to call poll in @Before and @After methods in JUnit unit tests,
179      * because JUnit will enforce those barriers by creating the testing thread after executing
180      * @Before and joining the thread after executing @After.
181      *
182      * peek() can be used by multiple threads concurrently, but only if no thread is calling
183      * poll() outside of the boundaries above. For simplicity, it can be considered that peek()
184      * is safe to call only when poll() is safe to call.
185      *
186      * Polling concurrently from the same ArrayTrackRecord is supported by creating multiple
187      * ReadHeads on the same instance of ArrayTrackRecord (or of course by using ArrayTrackRecord
188      * directly). Each ReadHead is then guaranteed to see all events always and
189      * guarantees are made on the value of the mark upon return. {@see poll(Long, (E) -> Boolean)}
190      * for details. Be careful to create each ReadHead on the thread it is meant to be used on, or
191      * to have a clear synchronization point between creation and use.
192      *
193      * Users of a ReadHead can ask for the current position of the mark at any time, on a thread
194      * where it's safe to call peek(). This mark can be used later to replay the history of events
195      * either on this ReadHead, on the associated ArrayTrackRecord or on another ReadHead
196      * associated with the same ArrayTrackRecord. It might look like this in the reader thread :
197      *
198      * val markAtStart = record.mark
199      * // Start processing interesting events
200      * while (val element = record.poll(timeout) { it.isInteresting() }) {
201      *   // Do something with element
202      * }
203      * // Look for stuff that happened while searching for interesting events
204      * val firstElementReceived = record.getOrNull(markAtStart)
205      * val firstSpecialElement = record.getOrNull(markAtStart) { it.isSpecial() }
206      * // Get the first special element since markAtStart, possibly blocking until one is available
207      * val specialElement = record.poll(timeout, markAtStart) { it.isSpecial() }
208      */
209     inner class ReadHead : TrackRecord<E> by this@ArrayTrackRecord {
210         // This lock only controls access to the readHead member below. The ArrayTrackRecord
211         // object has its own synchronization following different (and more usual) semantics.
212         // See the comment on the ReadHead class for details.
213         private val slock = StampedLock()
214         private var readHead = 0
215 
216         // A special mark used to track the start of the last poll() operation.
217         private var pollMark = 0
218 
219         /**
220          * @return the current value of the mark.
221          */
222         var mark
223             get() = checkThread { readHead }
224             set(v: Int) = rewind(v)
225         fun rewind(v: Int) {
226             val stamp = slock.tryWriteLock()
227             if (0L == stamp) concurrentAccessDetected()
228             readHead = v
229             pollMark = v
230             slock.unlockWrite(stamp)
231         }
232 
233         private fun <T> checkThread(r: (Long) -> T): T {
234             // tryOptimisticRead is a read barrier, guarantees writes from other threads are visible
235             // after it
236             val stamp = slock.tryOptimisticRead()
237             val result = r(stamp)
238             // validate also performs a read barrier, guaranteeing that if validate returns true,
239             // then any change either happens-before tryOptimisticRead, or happens-after validate.
240             if (!slock.validate(stamp)) concurrentAccessDetected()
241             return result
242         }
243 
244         private fun concurrentAccessDetected(): Nothing {
245             throw ConcurrentModificationException(
246                     "ReadHeads can't be used concurrently. Check your threading model.")
247         }
248 
249         /**
250          * Returns the first element after the mark, optionally blocking until one is available, or
251          * null if no such element can be found within the timeout.
252          * If a predicate is given, only elements matching the predicate are returned.
253          *
254          * Upon return the mark will be set to immediately after the returned element, or after
255          * the last element in the queue if null is returned. This means this method will always
256          * skip elements that do not match the predicate, even if it returns null.
257          *
258          * This method can only be used by the thread that created this ManagedRecordingQueue.
259          * If used on another thread, this throws IllegalStateException.
260          *
261          * @param timeoutMs how long, in milliseconds, to wait at most (best effort approximation).
262          * @param predicate an optional predicate to filter elements to be returned.
263          * @return an element matching the predicate, or null if timeout.
264          */
265         fun poll(timeoutMs: Long, predicate: (E) -> Boolean = { true }): E? {
266             val stamp = slock.tryWriteLock()
267             if (0L == stamp) concurrentAccessDetected()
268             pollMark = readHead
269             try {
270                 lock.withLock {
271                     val index = pollForIndexReadLocked(timeoutMs, readHead, predicate)
272                     readHead = if (index < 0) size else index + 1
273                     return getOrNull(index)
274                 }
275             } finally {
276                 slock.unlockWrite(stamp)
277             }
278         }
279 
280         /**
281          * Returns a list of events that were observed since the last time poll() was called on this
282          * ReadHead.
283          *
284          * @return list of events since poll() was called.
285          */
286         fun backtrace(): List<E> {
287             val stamp = slock.tryReadLock()
288             if (0L == stamp) concurrentAccessDetected()
289 
290             try {
291                 lock.withLock {
292                     return ArrayList(subList(pollMark, mark))
293                 }
294             } finally {
295                 slock.unlockRead(stamp)
296             }
297         }
298 
299         /**
300          * Returns the first element after the mark or null. This never blocks.
301          *
302          * This method is subject to threading restrictions. It can be used concurrently on
303          * multiple threads but not if any other thread might be executing poll() at the same
304          * time. See the class comment for details.
305          */
306         fun peek(): E? = checkThread { getOrNull(readHead) }
307     }
308 }
309 
310 // Private helper
Conditionnull311 private fun Condition.await(timeoutMs: Long) = this.await(timeoutMs, TimeUnit.MILLISECONDS)
312