1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_TESTS_LITERAL_TEST_UTIL_H_
17 #define TENSORFLOW_COMPILER_XLA_TESTS_LITERAL_TEST_UTIL_H_
18 
19 #include <initializer_list>
20 #include <memory>
21 #include <random>
22 #include <string>
23 
24 #include "absl/types/optional.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/array2d.h"
27 #include "tensorflow/compiler/xla/array3d.h"
28 #include "tensorflow/compiler/xla/array4d.h"
29 #include "tensorflow/compiler/xla/error_spec.h"
30 #include "tensorflow/compiler/xla/literal.h"
31 #include "tensorflow/compiler/xla/literal_util.h"
32 #include "tensorflow/compiler/xla/test.h"
33 #include "tensorflow/compiler/xla/test_helpers.h"
34 #include "tensorflow/compiler/xla/types.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/platform/macros.h"
38 #include "tensorflow/core/platform/test.h"
39 #include "tensorflow/core/platform/types.h"
40 
41 namespace xla {
42 
43 // Utility class for making expectations/assertions related to XLA literals.
44 class LiteralTestUtil {
45  public:
46   // Asserts that the given shapes have the same rank, dimension sizes, and
47   // primitive types.
48   static ::testing::AssertionResult EqualShapes(
49       const Shape& expected, const Shape& actual) TF_MUST_USE_RESULT;
50 
51   // Asserts that the provided shapes are equal as defined in AssertEqualShapes
52   // and that they have the same layout.
53   static ::testing::AssertionResult EqualShapesAndLayouts(
54       const Shape& expected, const Shape& actual) TF_MUST_USE_RESULT;
55 
56   static ::testing::AssertionResult Equal(const LiteralSlice& expected,
57                                           const LiteralSlice& actual)
58       TF_MUST_USE_RESULT;
59 
60   // Asserts the given literal are (bitwise) equal to given expected values.
61   template <typename NativeT>
62   static void ExpectR0Equal(NativeT expected, const LiteralSlice& actual);
63 
64   template <typename NativeT>
65   static void ExpectR1Equal(absl::Span<const NativeT> expected,
66                             const LiteralSlice& actual);
67   template <typename NativeT>
68   static void ExpectR2Equal(
69       std::initializer_list<std::initializer_list<NativeT>> expected,
70       const LiteralSlice& actual);
71 
72   template <typename NativeT>
73   static void ExpectR3Equal(
74       std::initializer_list<
75           std::initializer_list<std::initializer_list<NativeT>>>
76           expected,
77       const LiteralSlice& actual);
78 
79   // Asserts the given literal are (bitwise) equal to given array.
80   template <typename NativeT>
81   static void ExpectR2EqualArray2D(const Array2D<NativeT>& expected,
82                                    const LiteralSlice& actual);
83   template <typename NativeT>
84   static void ExpectR3EqualArray3D(const Array3D<NativeT>& expected,
85                                    const LiteralSlice& actual);
86   template <typename NativeT>
87   static void ExpectR4EqualArray4D(const Array4D<NativeT>& expected,
88                                    const LiteralSlice& actual);
89 
90   // Decorates literal_comparison::Near() with an AssertionResult return type.
91   //
92   // See comment on literal_comparison::Near().
93   static ::testing::AssertionResult Near(
94       const LiteralSlice& expected, const LiteralSlice& actual,
95       const ErrorSpec& error_spec,
96       absl::optional<bool> detailed_message = absl::nullopt) TF_MUST_USE_RESULT;
97 
98   // Asserts the given literal are within the given error bound of the given
99   // expected values. Only supported for floating point values.
100   template <typename NativeT>
101   static void ExpectR0Near(NativeT expected, const LiteralSlice& actual,
102                            const ErrorSpec& error);
103 
104   template <typename NativeT>
105   static void ExpectR1Near(absl::Span<const NativeT> expected,
106                            const LiteralSlice& actual, const ErrorSpec& error);
107 
108   template <typename NativeT>
109   static void ExpectR2Near(
110       std::initializer_list<std::initializer_list<NativeT>> expected,
111       const LiteralSlice& actual, const ErrorSpec& error);
112 
113   template <typename NativeT>
114   static void ExpectR3Near(
115       std::initializer_list<
116           std::initializer_list<std::initializer_list<NativeT>>>
117           expected,
118       const LiteralSlice& actual, const ErrorSpec& error);
119 
120   template <typename NativeT>
121   static void ExpectR4Near(
122       std::initializer_list<std::initializer_list<
123           std::initializer_list<std::initializer_list<NativeT>>>>
124           expected,
125       const LiteralSlice& actual, const ErrorSpec& error);
126 
127   // Asserts the given literal are within the given error bound to the given
128   // array. Only supported for floating point values.
129   template <typename NativeT>
130   static void ExpectR2NearArray2D(const Array2D<NativeT>& expected,
131                                   const LiteralSlice& actual,
132                                   const ErrorSpec& error);
133 
134   template <typename NativeT>
135   static void ExpectR3NearArray3D(const Array3D<NativeT>& expected,
136                                   const LiteralSlice& actual,
137                                   const ErrorSpec& error);
138 
139   template <typename NativeT>
140   static void ExpectR4NearArray4D(const Array4D<NativeT>& expected,
141                                   const LiteralSlice& actual,
142                                   const ErrorSpec& error);
143 
144   // If the error spec is given, returns whether the expected and the actual are
145   // within the error bound; otherwise, returns whether they are equal. Tuples
146   // will be compared recursively.
147   static ::testing::AssertionResult NearOrEqual(
148       const LiteralSlice& expected, const LiteralSlice& actual,
149       const absl::optional<ErrorSpec>& error) TF_MUST_USE_RESULT;
150 
151  private:
152   TF_DISALLOW_COPY_AND_ASSIGN(LiteralTestUtil);
153 };
154 
155 template <typename NativeT>
ExpectR0Equal(NativeT expected,const LiteralSlice & actual)156 /* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected,
157                                                  const LiteralSlice& actual) {
158   EXPECT_TRUE(Equal(LiteralUtil::CreateR0<NativeT>(expected), actual));
159 }
160 
161 template <typename NativeT>
ExpectR1Equal(absl::Span<const NativeT> expected,const LiteralSlice & actual)162 /* static */ void LiteralTestUtil::ExpectR1Equal(
163     absl::Span<const NativeT> expected, const LiteralSlice& actual) {
164   EXPECT_TRUE(Equal(LiteralUtil::CreateR1<NativeT>(expected), actual));
165 }
166 
167 template <typename NativeT>
ExpectR2Equal(std::initializer_list<std::initializer_list<NativeT>> expected,const LiteralSlice & actual)168 /* static */ void LiteralTestUtil::ExpectR2Equal(
169     std::initializer_list<std::initializer_list<NativeT>> expected,
170     const LiteralSlice& actual) {
171   EXPECT_TRUE(Equal(LiteralUtil::CreateR2<NativeT>(expected), actual));
172 }
173 
174 template <typename NativeT>
ExpectR3Equal(std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> expected,const LiteralSlice & actual)175 /* static */ void LiteralTestUtil::ExpectR3Equal(
176     std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
177         expected,
178     const LiteralSlice& actual) {
179   EXPECT_TRUE(Equal(LiteralUtil::CreateR3<NativeT>(expected), actual));
180 }
181 
182 template <typename NativeT>
ExpectR2EqualArray2D(const Array2D<NativeT> & expected,const LiteralSlice & actual)183 /* static */ void LiteralTestUtil::ExpectR2EqualArray2D(
184     const Array2D<NativeT>& expected, const LiteralSlice& actual) {
185   EXPECT_TRUE(Equal(LiteralUtil::CreateR2FromArray2D(expected), actual));
186 }
187 
188 template <typename NativeT>
ExpectR3EqualArray3D(const Array3D<NativeT> & expected,const LiteralSlice & actual)189 /* static */ void LiteralTestUtil::ExpectR3EqualArray3D(
190     const Array3D<NativeT>& expected, const LiteralSlice& actual) {
191   EXPECT_TRUE(Equal(LiteralUtil::CreateR3FromArray3D(expected), actual));
192 }
193 
194 template <typename NativeT>
ExpectR4EqualArray4D(const Array4D<NativeT> & expected,const LiteralSlice & actual)195 /* static */ void LiteralTestUtil::ExpectR4EqualArray4D(
196     const Array4D<NativeT>& expected, const LiteralSlice& actual) {
197   EXPECT_TRUE(Equal(LiteralUtil::CreateR4FromArray4D(expected), actual));
198 }
199 
200 template <typename NativeT>
ExpectR0Near(NativeT expected,const LiteralSlice & actual,const ErrorSpec & error)201 /* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected,
202                                                 const LiteralSlice& actual,
203                                                 const ErrorSpec& error) {
204   EXPECT_TRUE(Near(LiteralUtil::CreateR0<NativeT>(expected), actual, error));
205 }
206 
207 template <typename NativeT>
ExpectR1Near(absl::Span<const NativeT> expected,const LiteralSlice & actual,const ErrorSpec & error)208 /* static */ void LiteralTestUtil::ExpectR1Near(
209     absl::Span<const NativeT> expected, const LiteralSlice& actual,
210     const ErrorSpec& error) {
211   EXPECT_TRUE(Near(LiteralUtil::CreateR1<NativeT>(expected), actual, error));
212 }
213 
214 template <typename NativeT>
ExpectR2Near(std::initializer_list<std::initializer_list<NativeT>> expected,const LiteralSlice & actual,const ErrorSpec & error)215 /* static */ void LiteralTestUtil::ExpectR2Near(
216     std::initializer_list<std::initializer_list<NativeT>> expected,
217     const LiteralSlice& actual, const ErrorSpec& error) {
218   EXPECT_TRUE(Near(LiteralUtil::CreateR2<NativeT>(expected), actual, error));
219 }
220 
221 template <typename NativeT>
ExpectR3Near(std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> expected,const LiteralSlice & actual,const ErrorSpec & error)222 /* static */ void LiteralTestUtil::ExpectR3Near(
223     std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
224         expected,
225     const LiteralSlice& actual, const ErrorSpec& error) {
226   EXPECT_TRUE(Near(LiteralUtil::CreateR3<NativeT>(expected), actual, error));
227 }
228 
229 template <typename NativeT>
ExpectR4Near(std::initializer_list<std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>> expected,const LiteralSlice & actual,const ErrorSpec & error)230 /* static */ void LiteralTestUtil::ExpectR4Near(
231     std::initializer_list<std::initializer_list<
232         std::initializer_list<std::initializer_list<NativeT>>>>
233         expected,
234     const LiteralSlice& actual, const ErrorSpec& error) {
235   EXPECT_TRUE(Near(LiteralUtil::CreateR4<NativeT>(expected), actual, error));
236 }
237 
238 template <typename NativeT>
ExpectR2NearArray2D(const Array2D<NativeT> & expected,const LiteralSlice & actual,const ErrorSpec & error)239 /* static */ void LiteralTestUtil::ExpectR2NearArray2D(
240     const Array2D<NativeT>& expected, const LiteralSlice& actual,
241     const ErrorSpec& error) {
242   EXPECT_TRUE(Near(LiteralUtil::CreateR2FromArray2D(expected), actual, error));
243 }
244 
245 template <typename NativeT>
ExpectR3NearArray3D(const Array3D<NativeT> & expected,const LiteralSlice & actual,const ErrorSpec & error)246 /* static */ void LiteralTestUtil::ExpectR3NearArray3D(
247     const Array3D<NativeT>& expected, const LiteralSlice& actual,
248     const ErrorSpec& error) {
249   EXPECT_TRUE(Near(LiteralUtil::CreateR3FromArray3D(expected), actual, error));
250 }
251 
252 template <typename NativeT>
ExpectR4NearArray4D(const Array4D<NativeT> & expected,const LiteralSlice & actual,const ErrorSpec & error)253 /* static */ void LiteralTestUtil::ExpectR4NearArray4D(
254     const Array4D<NativeT>& expected, const LiteralSlice& actual,
255     const ErrorSpec& error) {
256   EXPECT_TRUE(Near(LiteralUtil::CreateR4FromArray4D(expected), actual, error));
257 }
258 
259 }  // namespace xla
260 
261 #endif  // TENSORFLOW_COMPILER_XLA_TESTS_LITERAL_TEST_UTIL_H_
262