1 /* 2 * Copyright (C) 2020 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #pragma once 18 19 #include <memory> 20 #include <mutex> 21 22 namespace android::mediametrics { 23 24 /** 25 * Wraps a shared-ptr for which member access through operator->() behaves 26 * as if the shared-ptr is atomically copied and then (without a lock) -> called. 27 * 28 * See related C++ 20: 29 * https://en.cppreference.com/w/cpp/memory/shared_ptr/atomic2 30 * 31 * EXAMPLE: 32 * 33 * SharedPtrWrap<T> t{}; 34 * 35 * thread1() { 36 * t->func(); // safely executes either the original t or the one created by thread2. 37 * } 38 * 39 * thread2() { 40 * t.set(std::make_shared<T>()); // overwrites the original t. 41 * } 42 */ 43 template <typename T> 44 class SharedPtrWrap { 45 mutable std::mutex mLock; 46 std::shared_ptr<T> mPtr; 47 48 public: 49 template <typename... Args> SharedPtrWrap(Args &&...args)50 explicit SharedPtrWrap(Args&&... args) 51 : mPtr(std::make_shared<T>(std::forward<Args>(args)...)) 52 {} 53 54 /** 55 * Gets the current shared pointer. This must return a value, not a reference. 56 * 57 * For compatibility with existing shared_ptr, we do not pass back a 58 * shared_ptr<const T> for the const getter. 59 */ get()60 std::shared_ptr<T> get() const { 61 std::lock_guard lock(mLock); 62 return mPtr; 63 } 64 65 /** 66 * Sets the current shared pointer, returning the previous shared pointer. 67 */ set(std::shared_ptr<T> ptr)68 std::shared_ptr<T> set(std::shared_ptr<T> ptr) { // pass by value as we use swap. 69 std::lock_guard lock(mLock); 70 std::swap(ptr, mPtr); 71 return ptr; 72 } 73 74 /** 75 * Returns a shared pointer value representing T at the instant of time when 76 * the call executes. The lifetime of the shared pointer will 77 * be extended as we are returning an instance of the shared_ptr 78 * not a reference to it. The destructor to the returned shared_ptr 79 * will be called sometime after the expression including the member function or 80 * the member variable is evaluated. Do not change to a reference! 81 */ 82 83 // For compatibility with existing shared_ptr, we do not pass back a 84 // shared_ptr<const T> for the const operator pointer access. 85 std::shared_ptr<T> operator->() const { 86 return get(); 87 } 88 /** 89 * We do not overload operator*() as the reference is not stable if the 90 * lock is not held. 91 */ 92 }; 93 94 /** 95 * Wraps member access to the class T by a lock. 96 * 97 * The object T is constructed within the LockWrap to guarantee 98 * locked access at all times. When T's methods are accessed through ->, 99 * a monitor style lock is obtained to prevent multiple threads from executing 100 * methods in the object T at the same time. 101 * Suggested by Kevin R. 102 * 103 * EXAMPLE: 104 * 105 * // Accumulator class which is very slow, requires locking for multiple threads. 106 * 107 * class Accumulator { 108 * int32_t value_ = 0; 109 * public: 110 * void add(int32_t incr) { 111 * const int32_t temp = value_; 112 * sleep(0); // yield 113 * value_ = temp + incr; 114 * } 115 * int32_t get() { return value_; } 116 * }; 117 * 118 * // We use LockWrap on Accumulator to have safe multithread access. 119 * android::mediametrics::LockWrap<Accumulator> a{}; // locked accumulator succeeds 120 * 121 * // Conversely, the following line fails: 122 * // auto a = std::make_shared<Accumulator>(); // this fails, only 50% adds atomic. 123 * 124 * constexpr size_t THREADS = 100; 125 * constexpr size_t ITERATIONS = 10; 126 * constexpr int32_t INCREMENT = 1; 127 * 128 * // Test by generating multiple threads, all adding simultaneously. 129 * std::vector<std::future<void>> threads(THREADS); 130 * for (size_t i = 0; i < THREADS; ++i) { 131 * threads.push_back(std::async(std::launch::async, [&] { 132 * for (size_t j = 0; j < ITERATIONS; ++j) { 133 * a->add(INCREMENT); // add needs locked access here. 134 * } 135 * })); 136 * } 137 * threads.clear(); 138 * 139 * // If the add operations are not atomic, value will be smaller than expected. 140 * ASSERT_EQ(INCREMENT * THREADS * ITERATIONS, (size_t)a->get()); 141 * 142 */ 143 template <typename T> 144 class LockWrap { 145 /** 146 * Holding class that keeps the pointer and the lock. 147 * 148 * We return this holding class from operator->() to keep the lock until the 149 * method function or method variable access is completed. 150 */ 151 class LockedPointer { 152 friend LockWrap; LockedPointer(T * t,std::recursive_mutex * lock,std::atomic<size_t> * recursionDepth)153 LockedPointer(T *t, std::recursive_mutex *lock, std::atomic<size_t> *recursionDepth) 154 : mT(t), mLock(*lock), mRecursionDepth(recursionDepth) { ++*mRecursionDepth; } 155 156 T* const mT; 157 std::lock_guard<std::recursive_mutex> mLock; 158 std::atomic<size_t>* mRecursionDepth; 159 public: ~LockedPointer()160 ~LockedPointer() { 161 --*mRecursionDepth; // Used for testing, we do not check underflow. 162 } 163 164 const T* operator->() const { 165 return mT; 166 } 167 T* operator->() { 168 return mT; 169 } 170 }; 171 172 // We must use a recursive mutex because the end of the full expression may 173 // involve another reference to T->. 174 // 175 // A recursive mutex allows the same thread to recursively acquire, 176 // but different thread would block. 177 // 178 // Example which fails with a normal mutex: 179 // 180 // android::mediametrics::LockWrap<std::vector<int>> v{std::initializer_list<int>{1, 2}}; 181 // const int sum = v->operator[](0) + v->operator[](1); 182 // 183 mutable std::recursive_mutex mLock; 184 mutable T mT; 185 mutable std::atomic<size_t> mRecursionDepth{}; // Used for testing. 186 187 public: 188 template <typename... Args> LockWrap(Args &&...args)189 explicit LockWrap(Args&&... args) : mT(std::forward<Args>(args)...) {} 190 191 const LockedPointer operator->() const { 192 return LockedPointer(&mT, &mLock, &mRecursionDepth); 193 } 194 LockedPointer operator->() { 195 return LockedPointer(&mT, &mLock, &mRecursionDepth); 196 } 197 198 // Returns the lock depth of the recursive mutex. 199 // @TestApi getRecursionDepth()200 size_t getRecursionDepth() const { 201 return mRecursionDepth; 202 } 203 }; 204 205 } // namespace android::mediametrics 206