1 // Copyright 2018 The Abseil Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // -----------------------------------------------------------------------------
16 // mocking_bit_gen.h
17 // -----------------------------------------------------------------------------
18 //
19 // This file includes an `absl::MockingBitGen` class to use as a mock within the
20 // Googletest testing framework. Such a mock is useful to provide deterministic
21 // values as return values within (otherwise random) Abseil distribution
22 // functions. Such determinism within a mock is useful within testing frameworks
23 // to test otherwise indeterminate APIs.
24 //
25 // More information about the Googletest testing framework is available at
26 // https://github.com/google/googletest
27 
28 #ifndef ABSL_RANDOM_MOCKING_BIT_GEN_H_
29 #define ABSL_RANDOM_MOCKING_BIT_GEN_H_
30 
31 #include <iterator>
32 #include <limits>
33 #include <memory>
34 #include <tuple>
35 #include <type_traits>
36 #include <typeindex>
37 #include <typeinfo>
38 #include <utility>
39 
40 #include "gmock/gmock.h"
41 #include "gtest/gtest.h"
42 #include "absl/container/flat_hash_map.h"
43 #include "absl/meta/type_traits.h"
44 #include "absl/random/distributions.h"
45 #include "absl/random/internal/distribution_caller.h"
46 #include "absl/random/internal/mocking_bit_gen_base.h"
47 #include "absl/strings/str_cat.h"
48 #include "absl/strings/str_join.h"
49 #include "absl/types/span.h"
50 #include "absl/types/variant.h"
51 #include "absl/utility/utility.h"
52 
53 namespace absl {
54 ABSL_NAMESPACE_BEGIN
55 
56 namespace random_internal {
57 
58 template <typename, typename>
59 struct MockSingleOverload;
60 
61 }  // namespace random_internal
62 
63 // MockingBitGen
64 //
65 // `absl::MockingBitGen` is a mock Uniform Random Bit Generator (URBG) class
66 // which can act in place of an `absl::BitGen` URBG within tests using the
67 // Googletest testing framework.
68 //
69 // Usage:
70 //
71 // Use an `absl::MockingBitGen` along with a mock distribution object (within
72 // mock_distributions.h) inside Googletest constructs such as ON_CALL(),
73 // EXPECT_TRUE(), etc. to produce deterministic results conforming to the
74 // distribution's API contract.
75 //
76 // Example:
77 //
78 //  // Mock a call to an `absl::Bernoulli` distribution using Googletest
79 //   absl::MockingBitGen bitgen;
80 //
81 //   ON_CALL(absl::MockBernoulli(), Call(bitgen, 0.5))
82 //       .WillByDefault(testing::Return(true));
83 //   EXPECT_TRUE(absl::Bernoulli(bitgen, 0.5));
84 //
85 //  // Mock a call to an `absl::Uniform` distribution within Googletest
86 //  absl::MockingBitGen bitgen;
87 //
88 //   ON_CALL(absl::MockUniform<int>(), Call(bitgen, testing::_, testing::_))
89 //       .WillByDefault([] (int low, int high) {
90 //           return (low + high) / 2;
91 //       });
92 //
93 //   EXPECT_EQ(absl::Uniform<int>(gen, 0, 10), 5);
94 //   EXPECT_EQ(absl::Uniform<int>(gen, 30, 40), 35);
95 //
96 // At this time, only mock distributions supplied within the Abseil random
97 // library are officially supported.
98 //
99 class MockingBitGen : public absl::random_internal::MockingBitGenBase {
100  public:
MockingBitGen()101   MockingBitGen() {}
102 
103   ~MockingBitGen() override;
104 
105  private:
106   template <typename DistrT, typename... Args>
107   using MockFnType =
108       ::testing::MockFunction<typename DistrT::result_type(Args...)>;
109 
110   // MockingBitGen::Register
111   //
112   // Register<DistrT, FormatT, ArgTupleT> is the main extension point for
113   // extending the MockingBitGen framework. It provides a mechanism to install a
114   // mock expectation for the distribution `distr_t` onto the MockingBitGen
115   // context.
116   //
117   // The returned MockFunction<...> type can be used to setup additional
118   // distribution parameters of the expectation.
119   template <typename DistrT, typename... Args, typename... Ms>
120   decltype(std::declval<MockFnType<DistrT, Args...>>().gmock_Call(
121       std::declval<Ms>()...))
Register(Ms &&...matchers)122   Register(Ms&&... matchers) {
123     auto& mock =
124         mocks_[std::type_index(GetTypeId<DistrT, std::tuple<Args...>>())];
125 
126     if (!mock.mock_fn) {
127       auto* mock_fn = new MockFnType<DistrT, Args...>;
128       mock.mock_fn = mock_fn;
129       mock.match_impl = &MatchImpl<DistrT, Args...>;
130       deleters_.emplace_back([mock_fn] { delete mock_fn; });
131     }
132 
133     return static_cast<MockFnType<DistrT, Args...>*>(mock.mock_fn)
134         ->gmock_Call(std::forward<Ms>(matchers)...);
135   }
136 
137   mutable std::vector<std::function<void()>> deleters_;
138 
139   using match_impl_fn = void (*)(void* mock_fn, void* t_erased_dist_args,
140                                  void* t_erased_result);
141   struct MockData {
142     void* mock_fn = nullptr;
143     match_impl_fn match_impl = nullptr;
144   };
145 
146   mutable absl::flat_hash_map<std::type_index, MockData> mocks_;
147 
148   template <typename DistrT, typename... Args>
MatchImpl(void * mock_fn,void * dist_args,void * result)149   static void MatchImpl(void* mock_fn, void* dist_args, void* result) {
150     using result_type = typename DistrT::result_type;
151     *static_cast<result_type*>(result) = absl::apply(
152         [mock_fn](Args... args) -> result_type {
153           return (*static_cast<MockFnType<DistrT, Args...>*>(mock_fn))
154               .Call(std::move(args)...);
155         },
156         *static_cast<std::tuple<Args...>*>(dist_args));
157   }
158 
159   // Looks for an appropriate mock - Returns the mocked result if one is found.
160   // Otherwise, returns a random value generated by the underlying URBG.
CallImpl(const std::type_info & key_type,void * dist_args,void * result)161   bool CallImpl(const std::type_info& key_type, void* dist_args,
162                 void* result) override {
163     // Trigger a mock, if there exists one that matches `param`.
164     auto it = mocks_.find(std::type_index(key_type));
165     if (it == mocks_.end()) return false;
166     auto* mock_data = static_cast<MockData*>(&it->second);
167     mock_data->match_impl(mock_data->mock_fn, dist_args, result);
168     return true;
169   }
170 
171   template <typename, typename>
172   friend struct ::absl::random_internal::MockSingleOverload;
173   friend struct ::absl::random_internal::DistributionCaller<
174       absl::MockingBitGen>;
175 };
176 
177 // -----------------------------------------------------------------------------
178 // Implementation Details Only Below
179 // -----------------------------------------------------------------------------
180 
181 namespace random_internal {
182 
183 template <>
184 struct DistributionCaller<absl::MockingBitGen> {
185   template <typename DistrT, typename FormatT, typename... Args>
186   static typename DistrT::result_type Call(absl::MockingBitGen* gen,
187                                            Args&&... args) {
188     return gen->template Call<DistrT, FormatT>(std::forward<Args>(args)...);
189   }
190 };
191 
192 }  // namespace random_internal
193 ABSL_NAMESPACE_END
194 }  // namespace absl
195 
196 #endif  // ABSL_RANDOM_MOCKING_BIT_GEN_H_
197