1 //
2 // Copyright 2018 The Abseil Authors.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      https://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 #ifndef ABSL_RANDOM_DISTRIBUTION_FORMAT_TRAITS_H_
17 #define ABSL_RANDOM_DISTRIBUTION_FORMAT_TRAITS_H_
18 
19 #include <string>
20 #include <tuple>
21 #include <typeinfo>
22 
23 #include "absl/meta/type_traits.h"
24 #include "absl/random/bernoulli_distribution.h"
25 #include "absl/random/beta_distribution.h"
26 #include "absl/random/exponential_distribution.h"
27 #include "absl/random/gaussian_distribution.h"
28 #include "absl/random/log_uniform_int_distribution.h"
29 #include "absl/random/poisson_distribution.h"
30 #include "absl/random/uniform_int_distribution.h"
31 #include "absl/random/uniform_real_distribution.h"
32 #include "absl/random/zipf_distribution.h"
33 #include "absl/strings/str_cat.h"
34 #include "absl/strings/str_join.h"
35 #include "absl/strings/string_view.h"
36 #include "absl/types/span.h"
37 
38 namespace absl {
39 ABSL_NAMESPACE_BEGIN
40 
41 struct IntervalClosedClosedTag;
42 struct IntervalClosedOpenTag;
43 struct IntervalOpenClosedTag;
44 struct IntervalOpenOpenTag;
45 
46 namespace random_internal {
47 
48 // ScalarTypeName defines a preferred hierarchy of preferred type names for
49 // scalars, and is evaluated at compile time for the specific type
50 // specialization.
51 template <typename T>
ScalarTypeName()52 constexpr const char* ScalarTypeName() {
53   static_assert(std::is_integral<T>() || std::is_floating_point<T>(), "");
54   // clang-format off
55     return
56         std::is_same<T, float>::value ? "float" :
57         std::is_same<T, double>::value ? "double" :
58         std::is_same<T, long double>::value ? "long double" :
59         std::is_same<T, bool>::value ? "bool" :
60         std::is_signed<T>::value && sizeof(T) == 1 ? "int8_t" :
61         std::is_signed<T>::value && sizeof(T) == 2 ? "int16_t" :
62         std::is_signed<T>::value && sizeof(T) == 4 ? "int32_t" :
63         std::is_signed<T>::value && sizeof(T) == 8 ? "int64_t" :
64         std::is_unsigned<T>::value && sizeof(T) == 1 ? "uint8_t" :
65         std::is_unsigned<T>::value && sizeof(T) == 2 ? "uint16_t" :
66         std::is_unsigned<T>::value && sizeof(T) == 4 ? "uint32_t" :
67         std::is_unsigned<T>::value && sizeof(T) == 8 ? "uint64_t" :
68             "undefined";
69   // clang-format on
70 
71   // NOTE: It would be nice to use typeid(T).name(), but that's an
72   // implementation-defined attribute which does not necessarily
73   // correspond to a name. We could potentially demangle it
74   // using, e.g. abi::__cxa_demangle.
75 }
76 
77 // Distribution traits used by DistributionCaller and internal implementation
78 // details of the mocking framework.
79 /*
80 struct DistributionFormatTraits {
81    // Returns the parameterized name of the distribution function.
82    static constexpr const char* FunctionName()
83    // Format DistrT parameters.
84    static std::string FormatArgs(DistrT& dist);
85    // Format DistrT::result_type results.
86    static std::string FormatResults(DistrT& dist);
87 };
88 */
89 template <typename DistrT>
90 struct DistributionFormatTraits;
91 
92 template <typename R>
93 struct DistributionFormatTraits<absl::uniform_int_distribution<R>> {
94   using distribution_t = absl::uniform_int_distribution<R>;
95   using result_t = typename distribution_t::result_type;
96 
97   static constexpr const char* Name() { return "Uniform"; }
98 
99   static std::string FunctionName() {
100     return absl::StrCat(Name(), "<", ScalarTypeName<R>(), ">");
101   }
102   static std::string FormatArgs(const distribution_t& d) {
103     return absl::StrCat("absl::IntervalClosedClosed, ", (d.min)(), ", ",
104                         (d.max)());
105   }
106   static std::string FormatResults(absl::Span<const result_t> results) {
107     return absl::StrJoin(results, ", ");
108   }
109 };
110 
111 template <typename R>
112 struct DistributionFormatTraits<absl::uniform_real_distribution<R>> {
113   using distribution_t = absl::uniform_real_distribution<R>;
114   using result_t = typename distribution_t::result_type;
115 
116   static constexpr const char* Name() { return "Uniform"; }
117 
118   static std::string FunctionName() {
119     return absl::StrCat(Name(), "<", ScalarTypeName<R>(), ">");
120   }
121   static std::string FormatArgs(const distribution_t& d) {
122     return absl::StrCat((d.min)(), ", ", (d.max)());
123   }
124   static std::string FormatResults(absl::Span<const result_t> results) {
125     return absl::StrJoin(results, ", ");
126   }
127 };
128 
129 template <typename R>
130 struct DistributionFormatTraits<absl::exponential_distribution<R>> {
131   using distribution_t = absl::exponential_distribution<R>;
132   using result_t = typename distribution_t::result_type;
133 
134   static constexpr const char* Name() { return "Exponential"; }
135 
136   static std::string FunctionName() {
137     return absl::StrCat(Name(), "<", ScalarTypeName<R>(), ">");
138   }
139   static std::string FormatArgs(const distribution_t& d) {
140     return absl::StrCat(d.lambda());
141   }
142   static std::string FormatResults(absl::Span<const result_t> results) {
143     return absl::StrJoin(results, ", ");
144   }
145 };
146 
147 template <typename R>
148 struct DistributionFormatTraits<absl::poisson_distribution<R>> {
149   using distribution_t = absl::poisson_distribution<R>;
150   using result_t = typename distribution_t::result_type;
151 
152   static constexpr const char* Name() { return "Poisson"; }
153 
154   static std::string FunctionName() {
155     return absl::StrCat(Name(), "<", ScalarTypeName<R>(), ">");
156   }
157   static std::string FormatArgs(const distribution_t& d) {
158     return absl::StrCat(d.mean());
159   }
160   static std::string FormatResults(absl::Span<const result_t> results) {
161     return absl::StrJoin(results, ", ");
162   }
163 };
164 
165 template <>
166 struct DistributionFormatTraits<absl::bernoulli_distribution> {
167   using distribution_t = absl::bernoulli_distribution;
168   using result_t = typename distribution_t::result_type;
169 
170   static constexpr const char* Name() { return "Bernoulli"; }
171 
172   static constexpr const char* FunctionName() { return Name(); }
173   static std::string FormatArgs(const distribution_t& d) {
174     return absl::StrCat(d.p());
175   }
176   static std::string FormatResults(absl::Span<const result_t> results) {
177     return absl::StrJoin(results, ", ");
178   }
179 };
180 
181 template <typename R>
182 struct DistributionFormatTraits<absl::beta_distribution<R>> {
183   using distribution_t = absl::beta_distribution<R>;
184   using result_t = typename distribution_t::result_type;
185 
186   static constexpr const char* Name() { return "Beta"; }
187 
188   static std::string FunctionName() {
189     return absl::StrCat(Name(), "<", ScalarTypeName<R>(), ">");
190   }
191   static std::string FormatArgs(const distribution_t& d) {
192     return absl::StrCat(d.alpha(), ", ", d.beta());
193   }
194   static std::string FormatResults(absl::Span<const result_t> results) {
195     return absl::StrJoin(results, ", ");
196   }
197 };
198 
199 template <typename R>
200 struct DistributionFormatTraits<absl::zipf_distribution<R>> {
201   using distribution_t = absl::zipf_distribution<R>;
202   using result_t = typename distribution_t::result_type;
203 
204   static constexpr const char* Name() { return "Zipf"; }
205 
206   static std::string FunctionName() {
207     return absl::StrCat(Name(), "<", ScalarTypeName<R>(), ">");
208   }
209   static std::string FormatArgs(const distribution_t& d) {
210     return absl::StrCat(d.k(), ", ", d.v(), ", ", d.q());
211   }
212   static std::string FormatResults(absl::Span<const result_t> results) {
213     return absl::StrJoin(results, ", ");
214   }
215 };
216 
217 template <typename R>
218 struct DistributionFormatTraits<absl::gaussian_distribution<R>> {
219   using distribution_t = absl::gaussian_distribution<R>;
220   using result_t = typename distribution_t::result_type;
221 
222   static constexpr const char* Name() { return "Gaussian"; }
223 
224   static std::string FunctionName() {
225     return absl::StrCat(Name(), "<", ScalarTypeName<R>(), ">");
226   }
227   static std::string FormatArgs(const distribution_t& d) {
228     return absl::StrJoin(std::make_tuple(d.mean(), d.stddev()), ", ");
229   }
230   static std::string FormatResults(absl::Span<const result_t> results) {
231     return absl::StrJoin(results, ", ");
232   }
233 };
234 
235 template <typename R>
236 struct DistributionFormatTraits<absl::log_uniform_int_distribution<R>> {
237   using distribution_t = absl::log_uniform_int_distribution<R>;
238   using result_t = typename distribution_t::result_type;
239 
240   static constexpr const char* Name() { return "LogUniform"; }
241 
242   static std::string FunctionName() {
243     return absl::StrCat(Name(), "<", ScalarTypeName<R>(), ">");
244   }
245   static std::string FormatArgs(const distribution_t& d) {
246     return absl::StrJoin(std::make_tuple((d.min)(), (d.max)(), d.base()), ", ");
247   }
248   static std::string FormatResults(absl::Span<const result_t> results) {
249     return absl::StrJoin(results, ", ");
250   }
251 };
252 
253 template <typename NumType>
254 struct UniformDistributionWrapper;
255 
256 template <typename NumType>
257 struct DistributionFormatTraits<UniformDistributionWrapper<NumType>> {
258   using distribution_t = UniformDistributionWrapper<NumType>;
259   using result_t = NumType;
260 
261   static constexpr const char* Name() { return "Uniform"; }
262 
263   static std::string FunctionName() {
264     return absl::StrCat(Name(), "<", ScalarTypeName<NumType>(), ">");
265   }
266   static std::string FormatArgs(const distribution_t& d) {
267     return absl::StrCat((d.min)(), ", ", (d.max)());
268   }
269   static std::string FormatResults(absl::Span<const result_t> results) {
270     return absl::StrJoin(results, ", ");
271   }
272 };
273 
274 }  // namespace random_internal
275 ABSL_NAMESPACE_END
276 }  // namespace absl
277 
278 #endif  // ABSL_RANDOM_DISTRIBUTION_FORMAT_TRAITS_H_
279