1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef __VTS_HAL_HIDL_TARGET_CALLBACK_BASE_H
18 #define __VTS_HAL_HIDL_TARGET_CALLBACK_BASE_H
19 
20 #include <chrono>
21 #include <condition_variable>
22 #include <iostream>
23 #include <mutex>
24 #include <queue>
25 #include <unordered_map>
26 #include <utility>
27 
28 using namespace ::std;
29 using namespace ::std::chrono;
30 
31 constexpr char kVtsHalHidlTargetCallbackDefaultName[] =
32     "VtsHalHidlTargetCallbackDefaultName";
33 constexpr milliseconds DEFAULT_CALLBACK_WAIT_TIMEOUT_INITIAL = minutes(1);
34 
35 namespace testing {
36 
37 /*
38  * VTS target side test template for callback.
39  *
40  * Providing wait and notify for callback functionality.
41  *
42  * A typical usage looks like this:
43  *
44  * class CallbackArgs {
45  *   ArgType1 arg1;
46  *   ArgType2 arg2;
47  * }
48  *
49  * class MyCallback
50  *     : public ::testing::VtsHalHidlTargetCallbackBase<>,
51  *       public CallbackInterface {
52  *  public:
53  *   CallbackApi1(ArgType1 arg1) {
54  *     CallbackArgs data;
55  *     data.arg1 = arg1;
56  *     NotifyFromCallback("CallbackApi1", data);
57  *   }
58  *
59  *   CallbackApi2(ArgType2 arg2) {
60  *     CallbackArgs data;
61  *     data.arg1 = arg1;
62  *     NotifyFromCallback("CallbackApi2", data);
63  *   }
64  * }
65  *
66  * Test(MyTest) {
67  *   CallApi1();
68  *   CallApi2();
69  *   auto result = cb_.WaitForCallback("CallbackApi1");
70  *   // cb_ as an instance of MyCallback, result is an instance of
71  *   // ::testing::VtsHalHidlTargetCallbackBase::WaitForCallbackResult
72  *   EXPECT_TRUE(result.no_timeout); // Check wait did not time out
73  *   EXPECT_TRUE(result.args); // Check CallbackArgs is received (not
74  *                                  nullptr). This is optional.
75  *   // Here check value of args using the pointer result.args;
76  *   result = cb_.WaitForCallback("CallbackApi2");
77  *   EXPECT_TRUE(result.no_timeout);
78  *   // Here check value of args using the pointer result.args;
79  *
80  *   // Additionally. a test can wait for one of multiple callbacks.
81  *   // In this case, wait will return when any of the callbacks in the provided
82  *   // name list is called.
83  *   result = cb_.WaitForCallbackAny(<vector_of_string>)
84  *   // When vector_of_string is not provided, all callback functions will
85  *   // be monitored. The name of callback function that was invoked
86  *   // is stored in result.name
87  * }
88  *
89  * Note type of CallbackArgsTemplateClass is same across the class, which means
90  * all WaitForCallback method will return the same data type.
91  */
92 template <class CallbackArgsTemplateClass>
93 class VtsHalHidlTargetCallbackBase {
94  public:
95   struct WaitForCallbackResult {
WaitForCallbackResultWaitForCallbackResult96     WaitForCallbackResult()
97         : no_timeout(false),
98           args(shared_ptr<CallbackArgsTemplateClass>(nullptr)),
99           name("") {}
100 
101     // Whether the wait timed out
102     bool no_timeout;
103     // Arguments data from callback functions. Defaults to nullptr.
104     shared_ptr<CallbackArgsTemplateClass> args;
105     // Name of the callback. Defaults to empty string.
106     string name;
107   };
108 
VtsHalHidlTargetCallbackBase()109   VtsHalHidlTargetCallbackBase()
110       : cb_default_wait_timeout_(DEFAULT_CALLBACK_WAIT_TIMEOUT_INITIAL) {}
111 
~VtsHalHidlTargetCallbackBase()112   virtual ~VtsHalHidlTargetCallbackBase() {
113     for (auto it : cb_lock_map_) {
114       delete it.second;
115     }
116   }
117 
118   /*
119    * Wait for a callback function in a test.
120    * Returns a WaitForCallbackResult object containing wait results.
121    * If callback_function_name is not provided, a default name will be used.
122    * Timeout defaults to -1 milliseconds. Negative timeout means use to
123    * use the time out set for the callback or default callback wait time out.
124    */
125   WaitForCallbackResult WaitForCallback(
126       const string& callback_function_name =
127           kVtsHalHidlTargetCallbackDefaultName,
128       milliseconds timeout = milliseconds(-1)) {
129     return GetCallbackLock(callback_function_name)->WaitForCallback(timeout);
130   }
131 
132   /*
133    * Wait for any of the callback functions specified.
134    * Returns a WaitForCallbackResult object containing wait results.
135    * If callback_function_names is not provided, all callback functions will
136    * be monitored, and the list of callback functions will be updated
137    * dynamically during run time.
138    * If timeout_any is not provided, the shortest timeout from the function
139    * list will be used.
140    */
141   WaitForCallbackResult WaitForCallbackAny(
142       const vector<string>& callback_function_names = vector<string>(),
143       milliseconds timeout_any = milliseconds(-1)) {
144     unique_lock<mutex> lock(cb_wait_any_mtx_);
145 
146     auto start_time = system_clock::now();
147 
148     WaitForCallbackResult res = PeekCallbackLocks(callback_function_names);
149     while (!res.no_timeout) {
150       auto expiration =
151           GetWaitAnyTimeout(callback_function_names, start_time, timeout_any);
152       auto status = cb_wait_any_cv_.wait_until(lock, expiration);
153       if (status == cv_status::timeout) {
154         cerr << "Timed out waiting for callback functions." << endl;
155         break;
156       }
157       res = PeekCallbackLocks(callback_function_names);
158     }
159     return res;
160   }
161 
162   /*
163    * Notify a waiting test when a callback is invoked.
164    * If callback_function_name is not provided, a default name will be used.
165    */
166   void NotifyFromCallback(const string& callback_function_name =
167                               kVtsHalHidlTargetCallbackDefaultName) {
168     unique_lock<mutex> lock(cb_wait_any_mtx_);
169     GetCallbackLock(callback_function_name)->NotifyFromCallback();
170     cb_wait_any_cv_.notify_one();
171   }
172 
173   /*
174    * Notify a waiting test with data when a callback is invoked.
175    */
NotifyFromCallback(const CallbackArgsTemplateClass & data)176   void NotifyFromCallback(const CallbackArgsTemplateClass& data) {
177     NotifyFromCallback(kVtsHalHidlTargetCallbackDefaultName, data);
178   }
179 
180   /*
181    * Notify a waiting test with data when a callback is invoked.
182    * If callback_function_name is not provided, a default name will be used.
183    */
NotifyFromCallback(const string & callback_function_name,const CallbackArgsTemplateClass & data)184   void NotifyFromCallback(const string& callback_function_name,
185                           const CallbackArgsTemplateClass& data) {
186     unique_lock<mutex> lock(cb_wait_any_mtx_);
187     GetCallbackLock(callback_function_name)->NotifyFromCallback(data);
188     cb_wait_any_cv_.notify_one();
189   }
190 
191   /*
192    * Clear lock and data for a callback function.
193    * This function is optional.
194    */
195   void ClearForCallback(const string& callback_function_name =
196                             kVtsHalHidlTargetCallbackDefaultName) {
197     GetCallbackLock(callback_function_name, true);
198   }
199 
200   /*
201    * Get wait timeout for a specific callback function.
202    * If callback_function_name is not provided, a default name will be used.
203    */
204   milliseconds GetWaitTimeout(const string& callback_function_name =
205                                   kVtsHalHidlTargetCallbackDefaultName) {
206     return GetCallbackLock(callback_function_name)->GetWaitTimeout();
207   }
208 
209   /*
210    * Set wait timeout for a specific callback function.
211    * To set a default timeout (not for the default function name),
212    * use SetWaitTimeoutDefault. default function name callback timeout will
213    * also be set by SetWaitTimeoutDefault.
214    */
SetWaitTimeout(const string & callback_function_name,milliseconds timeout)215   void SetWaitTimeout(const string& callback_function_name,
216                       milliseconds timeout) {
217     GetCallbackLock(callback_function_name)->SetWaitTimeout(timeout);
218   }
219 
220   /*
221    * Get default wait timeout for a callback function.
222    * The default timeout is valid for all callback function names that
223    * have not been specified a timeout value, including default function name.
224    */
GetWaitTimeoutDefault()225   milliseconds GetWaitTimeoutDefault() { return cb_default_wait_timeout_; }
226 
227   /*
228    * Set default wait timeout for a callback function.
229    * The default timeout is valid for all callback function names that
230    * have not been specified a timeout value, including default function name.
231    */
SetWaitTimeoutDefault(milliseconds timeout)232   void SetWaitTimeoutDefault(milliseconds timeout) {
233     cb_default_wait_timeout_ = timeout;
234   }
235 
236  private:
237   /*
238    * A utility class to store semaphore and data for a callback name.
239    */
240   class CallbackLock {
241    public:
CallbackLock(VtsHalHidlTargetCallbackBase & parent,const string & name)242     CallbackLock(VtsHalHidlTargetCallbackBase& parent, const string& name)
243         : wait_count_(0),
244           parent_(parent),
245           timeout_(milliseconds(-1)),
246           name_(name) {}
247 
248     /*
249      * Wait for represented callback function.
250      * Timeout defaults to -1 milliseconds. Negative timeout means use to
251      * use the time out set for the callback or default callback wait time out.
252      */
253     WaitForCallbackResult WaitForCallback(
254         milliseconds timeout = milliseconds(-1),
255         bool no_wait_blocking = false) {
256       return Wait(timeout, no_wait_blocking);
257     }
258 
259     /*
260      * Wait for represented callback function.
261      * Timeout defaults to -1 milliseconds. Negative timeout means use to
262      * use the time out set for the callback or default callback wait time out.
263      */
WaitForCallback(bool no_wait_blocking)264     WaitForCallbackResult WaitForCallback(bool no_wait_blocking) {
265       return Wait(milliseconds(-1), no_wait_blocking);
266     }
267 
268     /* Notify from represented callback function. */
NotifyFromCallback()269     void NotifyFromCallback() {
270       unique_lock<mutex> lock(wait_mtx_);
271       Notify();
272     }
273 
274     /* Notify from represented callback function with data. */
NotifyFromCallback(const CallbackArgsTemplateClass & data)275     void NotifyFromCallback(const CallbackArgsTemplateClass& data) {
276       unique_lock<mutex> wait_lock(wait_mtx_);
277       arg_data_.push(make_shared<CallbackArgsTemplateClass>(data));
278       Notify();
279     }
280 
281     /* Set wait timeout for represented callback function. */
SetWaitTimeout(milliseconds timeout)282     void SetWaitTimeout(milliseconds timeout) { timeout_ = timeout; }
283 
284     /* Get wait timeout for represented callback function. */
GetWaitTimeout()285     milliseconds GetWaitTimeout() {
286       if (timeout_ < milliseconds(0)) {
287         return parent_.GetWaitTimeoutDefault();
288       }
289       return timeout_;
290     }
291 
292    private:
293     /*
294      * Wait for represented callback function in a test.
295      * Returns a WaitForCallbackResult object containing wait results.
296      * Timeout defaults to -1 milliseconds. Negative timeout means use to
297      * use the time out set for the callback or default callback wait time out.
298      */
Wait(milliseconds timeout,bool no_wait_blocking)299     WaitForCallbackResult Wait(milliseconds timeout, bool no_wait_blocking) {
300       unique_lock<mutex> lock(wait_mtx_);
301       WaitForCallbackResult res;
302       res.name = name_;
303       if (!no_wait_blocking) {
304         if (timeout < milliseconds(0)) {
305           timeout = GetWaitTimeout();
306         }
307         auto expiration = system_clock::now() + timeout;
308         while (wait_count_ == 0) {
309           auto status = wait_cv_.wait_until(lock, expiration);
310           if (status == cv_status::timeout) {
311             cerr << "Timed out waiting for callback" << endl;
312             return res;
313           }
314         }
315       } else if (!wait_count_) {
316         return res;
317       }
318 
319       wait_count_--;
320       res.no_timeout = true;
321       if (!arg_data_.empty()) {
322         res.args = arg_data_.front();
323         arg_data_.pop();
324       }
325       return res;
326     }
327 
328     /* Notify from represented callback function. */
Notify()329     void Notify() {
330       wait_count_++;
331       wait_cv_.notify_one();
332     }
333 
334     // Mutex for protecting operations on wait count and conditional variable
335     mutex wait_mtx_;
336     // Conditional variable for callback wait and notify
337     condition_variable wait_cv_;
338     // Count for callback conditional variable
339     unsigned int wait_count_;
340     // A queue of callback arg data
341     queue<shared_ptr<CallbackArgsTemplateClass>> arg_data_;
342     // Pointer to parent class
343     VtsHalHidlTargetCallbackBase& parent_;
344     // Wait time out
345     milliseconds timeout_;
346     // Name of the represented callback function
347     string name_;
348   };
349 
350   /*
351    * Get CallbackLock object using callback function name.
352    * If callback_function_name is not provided, a default name will be used.
353    * If callback_function_name does not exists in map yet, a new CallbackLock
354    * object will be created.
355    * If auto_clear is true, the old CallbackLock will be deleted.
356    */
357   CallbackLock* GetCallbackLock(const string& callback_function_name,
358                                 bool auto_clear = false) {
359     unique_lock<mutex> lock(cb_lock_map_mtx_);
360     auto found = cb_lock_map_.find(callback_function_name);
361     if (found == cb_lock_map_.end()) {
362       CallbackLock* result = new CallbackLock(*this, callback_function_name);
363       cb_lock_map_.insert({callback_function_name, result});
364       return result;
365     } else {
366       if (auto_clear) {
367         delete (found->second);
368         found->second = new CallbackLock(*this, callback_function_name);
369       }
370       return found->second;
371     }
372   }
373 
374   /*
375    * Get wait timeout for a list of function names.
376    * If timeout_any is not negative, start_time + timeout_any will be returned.
377    * Otherwise, the shortest timeout from the list will be returned.
378    */
GetWaitAnyTimeout(const vector<string> & callback_function_names,system_clock::time_point start_time,milliseconds timeout_any)379   system_clock::time_point GetWaitAnyTimeout(
380       const vector<string>& callback_function_names,
381       system_clock::time_point start_time, milliseconds timeout_any) {
382     if (timeout_any >= milliseconds(0)) {
383       return start_time + timeout_any;
384     }
385 
386     auto locks = GetWaitAnyCallbackLocks(callback_function_names);
387 
388     auto timeout_min = system_clock::duration::max();
389     for (auto lock : locks) {
390       auto timeout = lock->GetWaitTimeout();
391       if (timeout < timeout_min) {
392         timeout_min = timeout;
393       }
394     }
395 
396     return start_time + timeout_min;
397   }
398 
399   /*
400    * Get a list of CallbackLock pointers from provided function name list.
401    */
GetWaitAnyCallbackLocks(const vector<string> & callback_function_names)402   vector<CallbackLock*> GetWaitAnyCallbackLocks(
403       const vector<string>& callback_function_names) {
404     vector<CallbackLock*> res;
405     if (callback_function_names.empty()) {
406       for (auto const& it : cb_lock_map_) {
407         res.push_back(it.second);
408       }
409     } else {
410       for (auto const& name : callback_function_names) {
411         res.push_back(GetCallbackLock(name));
412       }
413     }
414     return res;
415   }
416 
417   /*
418    * Peek into the list of callback locks to check whether any of the
419    * callback functions has been called.
420    */
PeekCallbackLocks(const vector<string> & callback_function_names)421   WaitForCallbackResult PeekCallbackLocks(
422       const vector<string>& callback_function_names) {
423     auto locks = GetWaitAnyCallbackLocks(callback_function_names);
424     for (auto lock : locks) {
425       auto test = lock->WaitForCallback(true);
426       if (test.no_timeout) {
427         return test;
428       }
429     }
430     WaitForCallbackResult res;
431     return res;
432   }
433 
434   // A map of function name and CallbackLock object pointers
435   unordered_map<string, CallbackLock*> cb_lock_map_;
436   // Mutex for protecting operations on lock map
437   mutex cb_lock_map_mtx_;
438   // Mutex for protecting waiting any callback
439   mutex cb_wait_any_mtx_;
440   // Default wait timeout
441   milliseconds cb_default_wait_timeout_;
442   // Conditional variable for any callback notify
443   condition_variable cb_wait_any_cv_;
444 };
445 
446 }  // namespace testing
447 
448 #endif  // __VTS_HAL_HIDL_TARGET_CALLBACK_BASE_H
449