1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <memory>
20 #include <mutex>
21 
22 namespace android::mediametrics {
23 
24 /**
25  * Wraps a shared-ptr for which member access through operator->() behaves
26  * as if the shared-ptr is atomically copied and then (without a lock) -> called.
27  *
28  * See related C++ 20:
29  * https://en.cppreference.com/w/cpp/memory/shared_ptr/atomic2
30  *
31  * EXAMPLE:
32  *
33  * SharedPtrWrap<T> t{};
34  *
35  * thread1() {
36  *   t->func();  // safely executes either the original t or the one created by thread2.
37  * }
38  *
39  * thread2() {
40  *  t.set(std::make_shared<T>()); // overwrites the original t.
41  * }
42  */
43 template <typename T>
44 class SharedPtrWrap {
45     mutable std::mutex mLock;
46     std::shared_ptr<T> mPtr;
47 
48 public:
49     template <typename... Args>
SharedPtrWrap(Args &&...args)50     explicit SharedPtrWrap(Args&&... args)
51         : mPtr(std::make_shared<T>(std::forward<Args>(args)...))
52     {}
53 
54     /**
55      * Gets the current shared pointer.  This must return a value, not a reference.
56      *
57      * For compatibility with existing shared_ptr, we do not pass back a
58      * shared_ptr<const T> for the const getter.
59      */
get()60     std::shared_ptr<T> get() const {
61         std::lock_guard lock(mLock);
62         return mPtr;
63     }
64 
65     /**
66      * Sets the current shared pointer, returning the previous shared pointer.
67      */
set(std::shared_ptr<T> ptr)68     std::shared_ptr<T> set(std::shared_ptr<T> ptr) { // pass by value as we use swap.
69         std::lock_guard lock(mLock);
70         std::swap(ptr, mPtr);
71         return ptr;
72     }
73 
74     /**
75      * Returns a shared pointer value representing T at the instant of time when
76      * the call executes. The lifetime of the shared pointer will
77      * be extended as we are returning an instance of the shared_ptr
78      * not a reference to it.  The destructor to the returned shared_ptr
79      * will be called sometime after the expression including the member function or
80      * the member variable is evaluated. Do not change to a reference!
81      */
82 
83     // For compatibility with existing shared_ptr, we do not pass back a
84     // shared_ptr<const T> for the const operator pointer access.
85     std::shared_ptr<T> operator->() const {
86         return get();
87     }
88     /**
89      * We do not overload operator*() as the reference is not stable if the
90      * lock is not held.
91      */
92 };
93 
94 /**
95  * Wraps member access to the class T by a lock.
96  *
97  * The object T is constructed within the LockWrap to guarantee
98  * locked access at all times.  When T's methods are accessed through ->,
99  * a monitor style lock is obtained to prevent multiple threads from executing
100  * methods in the object T at the same time.
101  * Suggested by Kevin R.
102  *
103  * EXAMPLE:
104  *
105  * // Accumulator class which is very slow, requires locking for multiple threads.
106  *
107  * class Accumulator {
108  *   int32_t value_ = 0;
109  * public:
110  *   void add(int32_t incr) {
111  *     const int32_t temp = value_;
112  *     sleep(0);  // yield
113  *     value_ = temp + incr;
114  *   }
115  *   int32_t get() { return value_; }
116  * };
117  *
118  * // We use LockWrap on Accumulator to have safe multithread access.
119  * android::mediametrics::LockWrap<Accumulator> a{}; // locked accumulator succeeds
120  *
121  * // Conversely, the following line fails:
122  * // auto a = std::make_shared<Accumulator>(); // this fails, only 50% adds atomic.
123  *
124  * constexpr size_t THREADS = 100;
125  * constexpr size_t ITERATIONS = 10;
126  * constexpr int32_t INCREMENT = 1;
127  *
128  * // Test by generating multiple threads, all adding simultaneously.
129  * std::vector<std::future<void>> threads(THREADS);
130  * for (size_t i = 0; i < THREADS; ++i) {
131  *     threads.push_back(std::async(std::launch::async, [&] {
132  *         for (size_t j = 0; j < ITERATIONS; ++j) {
133  *             a->add(INCREMENT);  // add needs locked access here.
134  *         }
135  *     }));
136  * }
137  * threads.clear();
138  *
139  * // If the add operations are not atomic, value will be smaller than expected.
140  * ASSERT_EQ(INCREMENT * THREADS * ITERATIONS, (size_t)a->get());
141  *
142  */
143 template <typename T>
144 class LockWrap {
145     /**
146       * Holding class that keeps the pointer and the lock.
147       *
148       * We return this holding class from operator->() to keep the lock until the
149       * method function or method variable access is completed.
150       */
151     class LockedPointer {
152         friend LockWrap;
LockedPointer(T * t,std::recursive_mutex * lock,std::atomic<size_t> * recursionDepth)153         LockedPointer(T *t, std::recursive_mutex *lock, std::atomic<size_t> *recursionDepth)
154             : mT(t), mLock(*lock), mRecursionDepth(recursionDepth) { ++*mRecursionDepth; }
155 
156         T* const mT;
157         std::lock_guard<std::recursive_mutex> mLock;
158         std::atomic<size_t>* mRecursionDepth;
159     public:
~LockedPointer()160         ~LockedPointer() {
161             --*mRecursionDepth; // Used for testing, we do not check underflow.
162         }
163 
164         const T* operator->() const {
165             return mT;
166         }
167         T* operator->() {
168             return mT;
169         }
170     };
171 
172     // We must use a recursive mutex because the end of the full expression may
173     // involve another reference to T->.
174     //
175     // A recursive mutex allows the same thread to recursively acquire,
176     // but different thread would block.
177     //
178     // Example which fails with a normal mutex:
179     //
180     // android::mediametrics::LockWrap<std::vector<int>> v{std::initializer_list<int>{1, 2}};
181     // const int sum = v->operator[](0) + v->operator[](1);
182     //
183     mutable std::recursive_mutex mLock;
184     mutable T mT;
185     mutable std::atomic<size_t> mRecursionDepth{};  // Used for testing.
186 
187 public:
188     template <typename... Args>
LockWrap(Args &&...args)189     explicit LockWrap(Args&&... args) : mT(std::forward<Args>(args)...) {}
190 
191     const LockedPointer operator->() const {
192         return LockedPointer(&mT, &mLock, &mRecursionDepth);
193     }
194     LockedPointer operator->() {
195         return LockedPointer(&mT, &mLock, &mRecursionDepth);
196     }
197 
198     // Returns the lock depth of the recursive mutex.
199     // @TestApi
getRecursionDepth()200     size_t getRecursionDepth() const {
201         return mRecursionDepth;
202     }
203 };
204 
205 } // namespace android::mediametrics
206