1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_UTILS_TRANSFER_VALUE_H
18 #define ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_UTILS_TRANSFER_VALUE_H
19 
20 #include <android-base/logging.h>
21 #include <android-base/thread_annotations.h>
22 
23 #include <condition_variable>
24 #include <functional>
25 #include <mutex>
26 #include <optional>
27 #include <type_traits>
28 
29 namespace android::hardware::neuralnetworks::utils {
30 
31 // This class adapts a function pointer and offers two affordances:
32 // 1) This class object can be used to generate a callback (via the implicit conversion operator)
33 //    that can be used to send the result to `CallbackValue` when called.
34 // 2) This class object can be used to retrieve the result of the callback with `take`.
35 //
36 // This class is thread compatible.
37 template <typename ReturnType, typename... ArgTypes>
38 class CallbackValue final {
39   public:
40     using FunctionType = std::add_pointer_t<ReturnType(ArgTypes...)>;
41     using CallbackType = std::function<void(ArgTypes...)>;
42 
43     explicit CallbackValue(FunctionType fn);
44 
45     // Creates a callback that forwards its arguments to `mFunction` and stores the result in
46     // `mReturnValue`.
47     /*implicit*/ operator CallbackType();  // NOLINT(google-explicit-constructor)
48 
49     // Take the result of calling `mFunction`.
50     // Precondition: mReturnValue.has_value()
51     // Postcondition: !mReturnValue.has_value()
52     [[nodiscard]] ReturnType take();
53 
54   private:
55     std::optional<ReturnType> mReturnValue;
56     FunctionType mFunction;
57 };
58 
59 // Deduction guidelines for CallbackValue when constructed with a function pointer.
60 template <typename ReturnType, typename... ArgTypes>
61 CallbackValue(ReturnType (*)(ArgTypes...))->CallbackValue<ReturnType, ArgTypes...>;
62 
63 // Thread-safe container to pass a value between threads.
64 template <typename Type>
65 class TransferValue final {
66   public:
67     // Put the value in `TransferValue`. If `TransferValue` already has a value, this function is a
68     // no-op.
69     void put(Type object) const;
70 
71     // Take the value stored in `TransferValue`. If no value is available, this function will block
72     // until the value becomes available.
73     // Postcondition: !mObject.has_value()
74     [[nodiscard]] Type take() const;
75 
76   private:
77     mutable std::mutex mMutex;
78     mutable std::condition_variable mCondition;
79     mutable std::optional<Type> mObject GUARDED_BY(mMutex);
80 };
81 
82 // template implementations
83 
84 template <typename ReturnType, typename... ArgTypes>
CallbackValue(FunctionType fn)85 CallbackValue<ReturnType, ArgTypes...>::CallbackValue(FunctionType fn) : mFunction(fn) {}
86 
87 template <typename ReturnType, typename... ArgTypes>
CallbackType()88 CallbackValue<ReturnType, ArgTypes...>::operator CallbackType() {
89     return [this](ArgTypes... args) { mReturnValue = mFunction(args...); };
90 }
91 
92 template <typename ReturnType, typename... ArgTypes>
take()93 ReturnType CallbackValue<ReturnType, ArgTypes...>::take() {
94     CHECK(mReturnValue.has_value());
95     std::optional<ReturnType> object;
96     std::swap(object, mReturnValue);
97     return std::move(object).value();
98 }
99 
100 template <typename Type>
put(Type object)101 void TransferValue<Type>::put(Type object) const {
102     {
103         std::lock_guard guard(mMutex);
104         // Immediately return if value already exists.
105         if (mObject.has_value()) return;
106         mObject.emplace(std::move(object));
107     }
108     mCondition.notify_all();
109 }
110 
111 template <typename Type>
take()112 Type TransferValue<Type>::take() const {
113     std::unique_lock lock(mMutex);
114     base::ScopedLockAssertion lockAssertion(mMutex);
115     mCondition.wait(lock, [this]() REQUIRES(mMutex) { return mObject.has_value(); });
116     CHECK(mObject.has_value());
117     std::optional<Type> object;
118     std::swap(object, mObject);
119     return std::move(object).value();
120 }
121 
122 }  // namespace android::hardware::neuralnetworks::utils
123 
124 #endif  // ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_UTILS_TRANSFER_VALUE_H
125