1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/literal_comparison.h"
17
18 #include <unistd.h>
19 #include <cmath>
20 #include <vector>
21
22 #include "absl/base/casts.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_format.h"
25 #include "tensorflow/compiler/xla/literal_util.h"
26 #include "tensorflow/compiler/xla/util.h"
27 #include "tensorflow/core/platform/env.h"
28
29 using absl::StrAppend;
30 using absl::StrAppendFormat;
31 using absl::StrCat;
32
33 namespace xla {
34 namespace literal_comparison {
35 namespace {
36
37 // Since Eigen::half doesn't satisfy the absl::bit_cast contract, we need to be
38 // able to transparently access the raw 16-bit value contained within.
39 template <typename T>
GetRawValue(T val)40 T GetRawValue(T val) {
41 return val;
42 }
GetRawValue(Eigen::half val)43 uint16 GetRawValue(Eigen::half val) { return val.x; }
44
45 // Helper function for comparing a floating point type, FloatT, bitwise equal
46 // between the left-hand-side and right-hand-side, by bit-casting to UnsignedT
47 // -- on miscompare, a nice error message is given in the AssertionFailure.
48 template <typename FloatT, typename UnsignedT>
CompareFloatsBitwiseEqual(FloatT lhs,FloatT rhs,absl::Span<const int64> multi_index)49 bool CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs,
50 absl::Span<const int64> multi_index) {
51 auto ulhs = absl::bit_cast<UnsignedT>(GetRawValue(lhs));
52 auto urhs = absl::bit_cast<UnsignedT>(GetRawValue(rhs));
53 return ulhs == urhs;
54 }
55
56 // Templated comparator that specializes for float equality comparison with the
57 // bitwise helper above (this is the un-specialized fallback, to just use the
58 // default gunit implementation).
59 template <typename NativeT>
CompareEqual(NativeT lhs,NativeT rhs,absl::Span<const int64> multi_index)60 bool CompareEqual(NativeT lhs, NativeT rhs,
61 absl::Span<const int64> multi_index) {
62 return lhs == rhs;
63 }
64
65 // Specializations for floating types that do bitwise comparisons when equality
66 // comparison is requested.
67 template <>
CompareEqual(bfloat16 lhs,bfloat16 rhs,absl::Span<const int64> multi_index)68 bool CompareEqual<bfloat16>(bfloat16 lhs, bfloat16 rhs,
69 absl::Span<const int64> multi_index) {
70 return CompareFloatsBitwiseEqual<bfloat16, uint16>(lhs, rhs, multi_index);
71 }
72 template <>
CompareEqual(Eigen::half lhs,Eigen::half rhs,absl::Span<const int64> multi_index)73 bool CompareEqual<Eigen::half>(Eigen::half lhs, Eigen::half rhs,
74 absl::Span<const int64> multi_index) {
75 return CompareFloatsBitwiseEqual<Eigen::half, uint16>(lhs, rhs, multi_index);
76 }
77 template <>
CompareEqual(float lhs,float rhs,absl::Span<const int64> multi_index)78 bool CompareEqual<float>(float lhs, float rhs,
79 absl::Span<const int64> multi_index) {
80 return CompareFloatsBitwiseEqual<float, uint32>(lhs, rhs, multi_index);
81 }
82 template <>
CompareEqual(double lhs,double rhs,absl::Span<const int64> multi_index)83 bool CompareEqual<double>(double lhs, double rhs,
84 absl::Span<const int64> multi_index) {
85 return CompareFloatsBitwiseEqual<double, uint64>(lhs, rhs, multi_index);
86 }
87 template <>
CompareEqual(complex64 lhs,complex64 rhs,absl::Span<const int64> multi_index)88 bool CompareEqual<complex64>(complex64 lhs, complex64 rhs,
89 absl::Span<const int64> multi_index) {
90 return CompareEqual<float>(lhs.real(), rhs.real(), multi_index) &&
91 CompareEqual<float>(lhs.imag(), rhs.imag(), multi_index);
92 }
93 template <>
CompareEqual(complex128 lhs,complex128 rhs,absl::Span<const int64> multi_index)94 bool CompareEqual<complex128>(complex128 lhs, complex128 rhs,
95 absl::Span<const int64> multi_index) {
96 return CompareEqual<double>(lhs.real(), rhs.real(), multi_index) &&
97 CompareEqual<double>(lhs.imag(), rhs.imag(), multi_index);
98 }
99
100 template <typename NativeT, typename UnsignedT>
MakeBitwiseErrorStatus(NativeT lhs,NativeT rhs,absl::Span<const int64> multi_index)101 Status MakeBitwiseErrorStatus(NativeT lhs, NativeT rhs,
102 absl::Span<const int64> multi_index) {
103 auto ulhs = absl::bit_cast<UnsignedT>(GetRawValue(lhs));
104 auto urhs = absl::bit_cast<UnsignedT>(GetRawValue(rhs));
105 auto lhs_double = static_cast<double>(lhs);
106 auto rhs_double = static_cast<double>(rhs);
107 return InvalidArgument(
108 "floating values are not bitwise-equal; and equality testing "
109 "was requested: %s=%g=%a vs %s=%g=%a at array index %s",
110 StrCat(absl::Hex(ulhs)), lhs_double, lhs_double,
111 StrCat(absl::Hex(urhs)), rhs_double, rhs_double,
112 LiteralUtil::MultiIndexAsString(multi_index));
113 }
114
115 template <typename NativeT>
MakeErrorStatus(NativeT lhs,NativeT rhs,absl::Span<const int64> multi_index)116 Status MakeErrorStatus(NativeT lhs, NativeT rhs,
117 absl::Span<const int64> multi_index) {
118 return InvalidArgument(
119 "first mismatch at array index %s:\n expected value: %s\n actual "
120 "value: %s",
121 LiteralUtil::MultiIndexAsString(multi_index), StrCat(lhs), StrCat(rhs));
122 }
123
124 template <>
MakeErrorStatus(bfloat16 lhs,bfloat16 rhs,absl::Span<const int64> multi_index)125 Status MakeErrorStatus(bfloat16 lhs, bfloat16 rhs,
126 absl::Span<const int64> multi_index) {
127 return MakeBitwiseErrorStatus<bfloat16, uint16>(lhs, rhs, multi_index);
128 }
129 template <>
MakeErrorStatus(Eigen::half lhs,Eigen::half rhs,absl::Span<const int64> multi_index)130 Status MakeErrorStatus(Eigen::half lhs, Eigen::half rhs,
131 absl::Span<const int64> multi_index) {
132 return MakeBitwiseErrorStatus<Eigen::half, uint16>(lhs, rhs, multi_index);
133 }
134 template <>
MakeErrorStatus(float lhs,float rhs,absl::Span<const int64> multi_index)135 Status MakeErrorStatus(float lhs, float rhs,
136 absl::Span<const int64> multi_index) {
137 return MakeBitwiseErrorStatus<float, uint32>(lhs, rhs, multi_index);
138 }
139 template <>
MakeErrorStatus(double lhs,double rhs,absl::Span<const int64> multi_index)140 Status MakeErrorStatus(double lhs, double rhs,
141 absl::Span<const int64> multi_index) {
142 return MakeBitwiseErrorStatus<double, uint64>(lhs, rhs, multi_index);
143 }
144 template <>
MakeErrorStatus(complex64 lhs,complex64 rhs,absl::Span<const int64> multi_index)145 Status MakeErrorStatus(complex64 lhs, complex64 rhs,
146 absl::Span<const int64> multi_index) {
147 if (!CompareEqual<float>(lhs.real(), rhs.real(), multi_index)) {
148 return MakeErrorStatus(lhs.real(), rhs.real(), multi_index);
149 }
150 return MakeErrorStatus(lhs.imag(), rhs.imag(), multi_index);
151 }
152 template <>
MakeErrorStatus(complex128 lhs,complex128 rhs,absl::Span<const int64> multi_index)153 Status MakeErrorStatus(complex128 lhs, complex128 rhs,
154 absl::Span<const int64> multi_index) {
155 if (!CompareEqual<double>(lhs.real(), rhs.real(), multi_index)) {
156 return MakeErrorStatus(lhs.real(), rhs.real(), multi_index);
157 }
158 return MakeErrorStatus(lhs.imag(), rhs.imag(), multi_index);
159 }
160
161 // A recursive function which iterates through every index of expected and
162 // actual literal and compares their values elementwise. Returns true if all
163 // elements are equal.
164 template <typename NativeT>
Equal(LiteralSlice expected,LiteralSlice actual,absl::Span<int64> multi_index,int64 dimension)165 Status Equal(LiteralSlice expected, LiteralSlice actual,
166 absl::Span<int64> multi_index, int64 dimension) {
167 if (dimension == expected.shape().dimensions_size()) {
168 NativeT expected_value = expected.Get<NativeT>(multi_index);
169 NativeT actual_value = actual.Get<NativeT>(multi_index);
170 bool result =
171 CompareEqual<NativeT>(expected_value, actual_value, multi_index);
172 return result ? Status::OK()
173 : MakeErrorStatus<NativeT>(expected_value, actual_value,
174 multi_index);
175 }
176
177 Status result;
178 for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) {
179 multi_index[dimension] = i;
180 TF_RETURN_IF_ERROR(
181 Equal<NativeT>(expected, actual, multi_index, dimension + 1));
182 }
183 return result;
184 }
185
186 // Gets the total element count. For tuples, this is not the count of tuple
187 // elements, but the sum of elements of each tuple element.
RecursiveElementCount(const Shape & shape)188 int64 RecursiveElementCount(const Shape& shape) {
189 if (shape.IsTuple()) {
190 const int64 tuple_elements = ShapeUtil::TupleElementCount(shape);
191 int64 total = 0;
192 for (int64 i = 0; i < tuple_elements; ++i) {
193 total += RecursiveElementCount(ShapeUtil::GetTupleElementShape(shape, i));
194 }
195 return total;
196 } else if (shape.IsArray()) {
197 return ShapeUtil::ElementsIn(shape);
198 } else {
199 return 0;
200 }
201 }
202
203 // Returns whether the given value is infinity.
204 template <typename NativeT>
IsInf(NativeT val)205 bool IsInf(NativeT val) {
206 return std::isinf(val);
207 }
208
209 template <>
IsInf(half val)210 bool IsInf<half>(half val) {
211 return std::isinf(static_cast<float>(val));
212 }
213
214 // Returns whether the given value is nan.
215 template <typename NativeT>
IsNan(NativeT value)216 float IsNan(NativeT value) {
217 return std::isnan(value);
218 }
219
220 template <>
IsNan(half value)221 float IsNan(half value) {
222 return IsNan<float>(static_cast<float>(value));
223 }
224
225 // Converts the given floating-point value to a string.
226 template <typename NativeT>
FpValueToString(NativeT value)227 string FpValueToString(NativeT value) {
228 return absl::StrFormat("%8.4g", static_cast<double>(value));
229 }
230
231 template <>
FpValueToString(complex64 value)232 string FpValueToString<complex64>(complex64 value) {
233 return absl::StrFormat("%8.4g + %8.4fi", value.real(), value.imag());
234 }
235
236 template <>
FpValueToString(complex128 value)237 string FpValueToString<complex128>(complex128 value) {
238 return absl::StrFormat("%8.4g + %8.4fi", value.real(), value.imag());
239 }
240
241 // Returns the absolute value of the given floating point value. This function
242 // is used instead of std::abs directly in order to allow type-dependent
243 // implementations for NearComparator.
244 template <typename NativeT>
FpAbsoluteValue(NativeT value)245 float FpAbsoluteValue(NativeT value) {
246 return std::abs(value);
247 }
248
249 template <>
FpAbsoluteValue(bfloat16 value)250 float FpAbsoluteValue(bfloat16 value) {
251 return FpAbsoluteValue<float>(static_cast<float>(value));
252 }
253
254 template <>
FpAbsoluteValue(half value)255 float FpAbsoluteValue(half value) {
256 return FpAbsoluteValue<float>(static_cast<float>(value));
257 }
258
259 // Helper class for comparing floating-point literals within an error bound.
260 template <typename NativeT>
261 class NearComparator {
262 public:
263 // Compares the two array literals elementwise and returns a comparison
264 // result. The comparison is ok() if all actual and expected elements are
265 // within the given error bound. In case of error, the status contains a
266 // detailed message about the discrepancy.
Compare(const LiteralSlice & expected,const LiteralSlice & actual,ErrorSpec error,bool detailed_message,const MiscompareCallback & miscompare_callback)267 static Status Compare(const LiteralSlice& expected,
268 const LiteralSlice& actual, ErrorSpec error,
269 bool detailed_message,
270 const MiscompareCallback& miscompare_callback) {
271 NearComparator<NativeT> comparator(expected, actual, error,
272 detailed_message, miscompare_callback);
273 return comparator.Run();
274 }
275
276 private:
277 // Data structure encapsulating metadata about a single element mismatch.
278 struct Mismatch {
279 NativeT actual;
280 NativeT expected;
281 float rel_error;
282 float abs_error;
283
284 // The linear index of the failure within the shape. This linear index is
285 // from the 'actual' literal.
286 int64 linear_index;
287
operator <xla::literal_comparison::__anon1e4119ae0111::NearComparator::Mismatch288 bool operator<(const Mismatch& other) const {
289 return rel_error < other.rel_error;
290 }
291
ToStringxla::literal_comparison::__anon1e4119ae0111::NearComparator::Mismatch292 string ToString(const Shape& shape) const {
293 return absl::StrFormat(
294 "actual %s, expected %s, index %s, rel error %8.3g, abs error %8.3g",
295 FpValueToString(actual), FpValueToString(expected),
296 LiteralUtil::MultiIndexAsString(
297 IndexUtil::LinearIndexToMultidimensionalIndex(shape,
298 linear_index)),
299 rel_error, abs_error);
300 }
301 };
302
NearComparator(const LiteralSlice & expected,const LiteralSlice & actual,ErrorSpec error,bool detailed_message,const MiscompareCallback & miscompare_callback)303 NearComparator(const LiteralSlice& expected, const LiteralSlice& actual,
304 ErrorSpec error, bool detailed_message,
305 const MiscompareCallback& miscompare_callback)
306 : expected_(expected),
307 actual_(actual),
308 error_(error),
309 detailed_message_(detailed_message),
310 miscompare_callback_(miscompare_callback),
311 abs_value_buckets_(kAbsValueBucketBounds.size() - 1, {0, 0}),
312 abs_error_buckets_(kErrorBucketBounds.size(), 0),
313 rel_error_buckets_(kErrorBucketBounds.size(), 0) {}
314
315 // Runs the comparison between expected and actual literals.
Run()316 Status Run() {
317 // If the shapes mismatch, we simply fail the expectation instead of
318 // printing out data, as it's a type error rather than a value error.
319 TF_RETURN_IF_ERROR(EqualShapes(expected_.shape(), actual_.shape()));
320 if (!expected_.shape().IsArray()) {
321 return InvalidArgument("Expected array shape; got %s.",
322 ShapeUtil::HumanString(expected_.shape()));
323 }
324
325 mismatches_ = Literal(ShapeUtil::ChangeElementType(actual_.shape(), PRED));
326 mismatches_.PopulateWithValue(false);
327
328 CompareLiterals();
329
330 if (num_mismatches_ == 0) {
331 return Status::OK();
332 } else if (!VLOG_IS_ON(1) && miscompare_callback_ != nullptr) {
333 miscompare_callback_(expected_, actual_, mismatches_);
334 }
335 return InvalidArgument("%s", ErrorMessage());
336 }
337
338 // Insert the given absolute value into the absolute value bucket vector. The
339 // bounds of the buckets are given by kAbsValueBucketBounds.
UpdateAbsValueBucket(NativeT value,bool is_mismatch)340 void UpdateAbsValueBucket(NativeT value, bool is_mismatch) {
341 // Adjust the bucket containing the absolute values of the 'actual'
342 // elements.
343 const float abs_value = FpAbsoluteValue(value);
344 for (int i = 0; i < abs_value_buckets_.size(); ++i) {
345 if (i == abs_value_buckets_.size() - 1 ||
346 (abs_value >= kAbsValueBucketBounds[i] &&
347 abs_value < kAbsValueBucketBounds[i + 1])) {
348 // The first value of the pair is the count of elements in the bucket,
349 // the second is the count of mismatches in the bucket.
350 abs_value_buckets_[i].first++;
351 if (is_mismatch) {
352 abs_value_buckets_[i].second++;
353 }
354 return;
355 }
356 }
357 }
358
359 // Insert the given error into the given error bucket vector.
UpdateErrorBucket(float error,absl::Span<int64> error_buckets)360 void UpdateErrorBucket(float error, absl::Span<int64> error_buckets) {
361 CHECK_EQ(error_buckets.size(), kErrorBucketBounds.size());
362 for (int i = 0; i < error_buckets.size(); ++i) {
363 if (error >= kErrorBucketBounds[i]) {
364 error_buckets[i]++;
365 }
366 }
367 }
368
369 // Compares the two given elements from the expected and actual literals at
370 // the given literal_index and keeps track of various mismatch statistics.
371 template <typename T>
CompareValues(T expected,T actual,int64 linear_index)372 void CompareValues(T expected, T actual, int64 linear_index) {
373 float abs_error;
374 float rel_error;
375 if (CompareEqual<T>(expected, actual, {linear_index})) {
376 abs_error = 0;
377 rel_error = 0;
378 } else if (IsNan(expected) || IsNan(actual)) {
379 if ((!error_.relaxed_nans && IsNan(expected) != IsNan(actual)) ||
380 (error_.relaxed_nans && !IsNan(expected) && IsNan(actual))) {
381 num_nan_mismatches_++;
382 // A nan mismatch is considered to have infinite error. rel_error is
383 // used for sorting a std::set of the top mismatchs, and a nan value
384 // here will result in undefined behavior because nan's do not satisfy
385 // the strict weak ordering requirement of std containers.
386 abs_error = std::numeric_limits<float>::infinity();
387 rel_error = std::numeric_limits<float>::infinity();
388 } else {
389 abs_error = 0;
390 rel_error = 0;
391 }
392 } else if (IsInf(actual) && !IsInf(expected) && error_.fewer_infs_ok) {
393 // `fewer_infs_ok` gives us the option of comparing as though `actual`
394 // were float_max/min rather than inf.
395 T actual_finite = actual > T{0} ? std::numeric_limits<T>::max()
396 : std::numeric_limits<T>::lowest();
397 abs_error = FpAbsoluteValue(actual_finite - expected);
398
399 // Avoid division by 0 even though it's well-defined because ubsan can be
400 // configured to treat this as a fatal error.
401 if (expected != T{0}) {
402 rel_error = abs_error / FpAbsoluteValue(expected);
403 } else {
404 rel_error = std::numeric_limits<float>::infinity();
405 }
406 } else if (IsInf(expected) || IsInf(actual)) {
407 // If either the expected or actual value is infinity but not both,
408 // then both absolute and relative error are regarded as inifity.
409 CHECK(!CompareEqual(expected, actual, {linear_index}));
410 abs_error = std::numeric_limits<float>::infinity();
411 rel_error = std::numeric_limits<float>::infinity();
412 } else {
413 abs_error = FpAbsoluteValue(actual - expected);
414
415 // Avoid division by 0 even though it's well-defined because ubsan can be
416 // configured to treat this as a fatal error.
417 if (expected != T{0}) {
418 rel_error = abs_error / FpAbsoluteValue(expected);
419 } else {
420 rel_error = std::numeric_limits<float>::infinity();
421 }
422 }
423 const bool is_abs_mismatch = abs_error > error_.abs;
424 const bool is_rel_mismatch = rel_error > error_.rel;
425 const bool is_mismatch = is_abs_mismatch && is_rel_mismatch;
426
427 // Update the error of the relative bucket only if the *absolute* error
428 // bound is exceeded and vice versa.
429 if (is_abs_mismatch) {
430 num_abs_mismatches_++;
431 UpdateErrorBucket(rel_error, absl::MakeSpan(rel_error_buckets_));
432 }
433 if (is_rel_mismatch) {
434 num_rel_mismatches_++;
435 UpdateErrorBucket(abs_error, absl::MakeSpan(abs_error_buckets_));
436 }
437
438 UpdateAbsValueBucket(actual, is_mismatch);
439
440 if (!is_mismatch) {
441 return;
442 }
443
444 num_mismatches_++;
445
446 // Keep track of the kTopRelativeErrorCount relative error mismatches.
447 if (top_rel_mismatches_.size() < kTopRelativeErrorCount ||
448 rel_error > top_rel_mismatches_.begin()->rel_error) {
449 Mismatch mismatch = {actual, expected, rel_error, abs_error,
450 linear_index};
451 top_rel_mismatches_.insert(mismatch);
452 if (top_rel_mismatches_.size() > kTopRelativeErrorCount) {
453 top_rel_mismatches_.erase(top_rel_mismatches_.begin());
454 }
455 }
456
457 mismatches_.data<bool>()[linear_index] = true;
458 }
459
460 // For complex types, we compare real and imaginary parts individually.
CompareValues(complex64 expected,complex64 actual,int64 linear_index)461 void CompareValues(complex64 expected, complex64 actual, int64 linear_index) {
462 bool mismatch = false;
463 CompareValues<float>(expected.real(), actual.real(), linear_index);
464 if (mismatches_.data<bool>()[linear_index] == true) {
465 mismatch = true;
466 // Delay the mismatch count increase for real part, instead increase
467 // mismatch by 1 for the entire complex number.
468 num_mismatches_--;
469 }
470 CompareValues<float>(expected.imag(), actual.imag(), linear_index);
471 if (mismatches_.data<bool>()[linear_index] == true) {
472 mismatch = true;
473 // Delay the mismatch count increase for imag part, instead increase
474 // mismatch by 1 for the entire complex number.
475 num_mismatches_--;
476 }
477 if (mismatch == true) {
478 num_mismatches_++;
479 }
480 mismatches_.data<bool>()[linear_index] = mismatch;
481 }
482
CompareValues(complex128 expected,complex128 actual,int64 linear_index)483 void CompareValues(complex128 expected, complex128 actual,
484 int64 linear_index) {
485 bool mismatch = false;
486 CompareValues<double>(expected.real(), actual.real(), linear_index);
487 if (mismatches_.data<bool>()[linear_index] == true) {
488 mismatch = true;
489 // Delay the mismatch count increase for real part, instead increase
490 // mismatch by 1 for the entire complex number.
491 num_mismatches_--;
492 }
493 CompareValues<double>(expected.imag(), actual.imag(), linear_index);
494 if (mismatches_.data<bool>()[linear_index] == true) {
495 mismatch = true;
496 // Delay the mismatch count increase for imag part, instead increase
497 // mismatch by 1 for the entire complex number.
498 num_mismatches_--;
499 }
500 if (mismatch == true) {
501 num_mismatches_++;
502 }
503 mismatches_.data<bool>()[linear_index] = mismatch;
504 }
505
506 // Compares the two literals elementwise.
CompareLiterals()507 void CompareLiterals() {
508 // Fast path optimization for the case were layouts match.
509 if (LayoutUtil::Equal(actual_.shape().layout(),
510 expected_.shape().layout())) {
511 absl::Span<const NativeT> expected_data = expected_.data<NativeT>();
512 absl::Span<const NativeT> actual_data = actual_.data<NativeT>();
513 const int64 len = expected_data.size();
514 for (int64 i = 0; i < len; ++i) {
515 CompareValues(expected_data[i], actual_data[i], i);
516 }
517 return;
518 }
519 std::vector<int64> multi_index(actual_.shape().rank(), 0);
520 CompareLiteralsSlow(0, &multi_index);
521 }
522
523 // Slow path for CompareLiterals when 'actual' and 'expected' literals have
524 // different layouts. In this case, multidimensional indices are constructed
525 // and indexed for each element.
CompareLiteralsSlow(int64 dimension,std::vector<int64> * multi_index)526 void CompareLiteralsSlow(int64 dimension, std::vector<int64>* multi_index) {
527 if (dimension == multi_index->size()) {
528 CompareValues(expected_.Get<NativeT>(*multi_index),
529 actual_.Get<NativeT>(*multi_index),
530 IndexUtil::MultidimensionalIndexToLinearIndex(
531 actual_.shape(), *multi_index));
532 } else {
533 for (int64 i = 0; i < expected_.shape().dimensions(dimension); ++i) {
534 (*multi_index)[dimension] = i;
535 CompareLiteralsSlow(dimension + 1, multi_index);
536 }
537 }
538 }
539
540 // Returns an error message string with a detailed breakdown of the
541 // mismatches. Called after calling Run().
ErrorMessage()542 string ErrorMessage() {
543 string out;
544 int64 element_count = ShapeUtil::ElementsIn(actual_.shape());
545
546 auto percent_string = [](float a, float b) {
547 float pct = b == 0.0 ? 0.0 : 100.0 * a / b;
548 return absl::StrFormat("%0.4f%%", pct);
549 };
550
551 StrAppendFormat(
552 &out,
553 "\nMismatch count %d (%s) in shape %s (%d elements), abs bound "
554 "%g, rel bound %g\n",
555 num_mismatches_, percent_string(num_mismatches_, element_count),
556 ShapeUtil::HumanString(actual_.shape()),
557 ShapeUtil::ElementsIn(actual_.shape()), error_.abs, error_.rel);
558 if (num_nan_mismatches_ > 0) {
559 StrAppend(&out, "nan mismatches ", num_nan_mismatches_, "\n");
560 }
561 StrAppendFormat(&out, "Top relative error mismatches:\n");
562 for (auto it = top_rel_mismatches_.rbegin();
563 it != top_rel_mismatches_.rend(); ++it) {
564 StrAppend(&out, " ", it->ToString(actual_.shape()), "\n");
565 }
566
567 if (!detailed_message_) {
568 return out;
569 }
570
571 StrAppend(&out, "Absolute magnitude breakdown of actual values:\n");
572 CHECK_EQ(abs_value_buckets_.size() + 1, kAbsValueBucketBounds.size());
573 for (int i = 0; i < abs_value_buckets_.size(); ++i) {
574 const int64 bucket_size = abs_value_buckets_[i].first;
575 const int64 bucket_mismatches = abs_value_buckets_[i].second;
576 string mismatch_str =
577 bucket_mismatches > 0
578 ? absl::StrFormat(", mismatches %d", bucket_mismatches)
579 : "";
580 StrAppendFormat(&out, " %-6g <= x < %-6g : %7d (%9s)%s\n",
581 kAbsValueBucketBounds[i], kAbsValueBucketBounds[i + 1],
582 bucket_size, percent_string(bucket_size, element_count),
583 mismatch_str);
584 }
585
586 auto print_accum_buckets = [&](const string& header, int64 total,
587 absl::Span<const int64> buckets) {
588 StrAppend(&out, header, ":\n");
589 StrAppendFormat(&out, " < %-6g : %7d (%s)\n", kErrorBucketBounds[0],
590 total - buckets[0],
591 percent_string(total - buckets[0], total));
592 CHECK_EQ(buckets.size(), kErrorBucketBounds.size());
593 for (int i = 0; i < kErrorBucketBounds.size(); ++i) {
594 StrAppendFormat(&out, " >= %-6g : %7d (%s)\n", kErrorBucketBounds[i],
595 buckets[i], percent_string(buckets[i], total));
596 }
597 };
598 StrAppendFormat(&out, "Elements exceeding abs error bound %g: %d (%s)\n",
599 error_.abs, num_abs_mismatches_,
600 percent_string(num_abs_mismatches_, element_count));
601 print_accum_buckets(
602 "Relative error breakdown of elements exceeding abs error bound",
603 num_abs_mismatches_, rel_error_buckets_);
604 StrAppendFormat(&out, "Elements exceeding rel error bound %g: %d (%s)\n",
605 error_.rel, num_rel_mismatches_,
606 percent_string(num_rel_mismatches_, element_count));
607 print_accum_buckets(
608 "Absolute error breakdown of elements exceeding rel error bound",
609 num_rel_mismatches_, abs_error_buckets_);
610 return out;
611 }
612
613 // 'actual' and 'expected' literals being compared.
614 LiteralSlice expected_;
615 LiteralSlice actual_;
616
617 // The error bounds of the comparison.
618 ErrorSpec error_;
619
620 // Whether to include detailed breakdown of mismatches in the error message.
621 bool detailed_message_;
622
623 // Callback to invoke on miscompare.
624 MiscompareCallback miscompare_callback_;
625
626 // Number of element element mismatches encountered so far.
627 int64 num_mismatches_ = 0;
628
629 // Number of elements with a nan mismatch.
630 int64 num_nan_mismatches_ = 0;
631
632 // Number of elements which exceed the absolute/relative error bound.
633 int64 num_abs_mismatches_ = 0;
634 int64 num_rel_mismatches_ = 0;
635
636 // A Literal containing which elements did not match in the expected and
637 // actual literals. mismatches_ contains PREDs and is of the same sizes as
638 // the comparison literals.
639 Literal mismatches_;
640
641 // The number of mismatches to report in the output, sorted by relative error
642 // magnitude.
643 static constexpr int64 kTopRelativeErrorCount = 5;
644
645 // The set of mismatches with the largest relative error. The size of this set
646 // is bounded by kTopRelativeErrorCount.
647 std::multiset<Mismatch> top_rel_mismatches_;
648
649 // Actual values are bucketed by absolute value. kAbsValueBucketBounds is the
650 // bounds of these buckets. abs_value_buckets_ contains a pair for each
651 // bucket: the element count and failure count.
652 static constexpr std::array<float, 7> kAbsValueBucketBounds = {
653 0.0, 0.0001, 0.001, 0.01, 0.1, 1, std::numeric_limits<float>::infinity()};
654 std::vector<std::pair<int64, int64>> abs_value_buckets_;
655
656 // Buckets for relative and absolute errors. The relative error buckets only
657 // contains those elements which exceed the *absolute* error bound, and vice
658 // versa. This makes it easy to see the effect of adjusting the relative (or
659 // absolute) error bound on the success of the comparison. kErrorBucketBounds
660 // are the lower bounds of the buckets in both vectors. The error buckets are
661 // a cumulative distribution so an error value may appear in more than one
662 // bucket. For example an error value of 0.003 may appear in the buckets
663 // bounded by 0.01, 0.1, and 1.0.
664 static constexpr std::array<float, 5> kErrorBucketBounds = {0.0001, 0.001,
665 0.01, 0.1, 1};
666 std::vector<int64> abs_error_buckets_;
667 std::vector<int64> rel_error_buckets_;
668 };
669
670 template <typename NativeT>
671 constexpr std::array<float, 7> NearComparator<NativeT>::kAbsValueBucketBounds;
672 template <typename NativeT>
673 constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
674
EqualHelper(const LiteralSlice & expected,const LiteralSlice & actual)675 Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual) {
676 TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape()));
677 std::vector<int64> multi_index(expected.shape().dimensions_size(), 0);
678 auto index = absl::MakeSpan(multi_index);
679 Status result;
680 switch (expected.shape().element_type()) {
681 case PRED:
682 result = Equal<bool>(expected, actual, index, 0);
683 break;
684 case U8:
685 result = Equal<uint8>(expected, actual, index, 0);
686 break;
687 case S32:
688 result = Equal<int32>(expected, actual, index, 0);
689 break;
690 case S64:
691 result = Equal<int64>(expected, actual, index, 0);
692 break;
693 case U32:
694 result = Equal<uint32>(expected, actual, index, 0);
695 break;
696 case U64:
697 result = Equal<uint64>(expected, actual, index, 0);
698 break;
699 case BF16:
700 result = Equal<bfloat16>(expected, actual, index, 0);
701 break;
702 case F16:
703 result = Equal<half>(expected, actual, index, 0);
704 break;
705 case F32:
706 result = Equal<float>(expected, actual, index, 0);
707 break;
708 case F64:
709 result = Equal<double>(expected, actual, index, 0);
710 break;
711 case C64:
712 result = Equal<complex64>(expected, actual, index, 0);
713 break;
714 case C128:
715 result = Equal<complex128>(expected, actual, index, 0);
716 break;
717 case TUPLE: {
718 for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
719 result.Update(EqualHelper(LiteralSlice(expected, {i}),
720 LiteralSlice(actual, {i})));
721 }
722 break;
723 }
724 case TOKEN:
725 // Tokens have no on-device representation and are trivially equal.
726 return Status::OK();
727 default:
728 LOG(FATAL) << "Unsupported primitive type: "
729 << PrimitiveType_Name(expected.shape().element_type());
730 }
731
732 return result;
733 }
734
735 // Helper function for comparing two literals for nearness. Handles tuple-shapes
736 // via recursion. shape_index is the ShapeIndex of expected (or actual)
737 // currently being compared.
NearHelper(const LiteralSlice & expected,const LiteralSlice & actual,const ErrorSpec & error,absl::optional<bool> detailed_message,const MiscompareCallback & miscompare_callback,const ShapeIndex & shape_index)738 Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
739 const ErrorSpec& error, absl::optional<bool> detailed_message,
740 const MiscompareCallback& miscompare_callback,
741 const ShapeIndex& shape_index) {
742 TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape()));
743
744 if (expected.shape().IsTuple()) {
745 Status return_status;
746 for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
747 const auto expected_element = LiteralSlice(expected, {i});
748 const auto actual_element = LiteralSlice(actual, {i});
749 ShapeIndex element_index = shape_index;
750 element_index.push_back(i);
751 Status element_result =
752 NearHelper(expected_element, actual_element, error, detailed_message,
753 miscompare_callback, element_index);
754 if (!element_result.ok()) {
755 element_result = InvalidArgument("Array at shape index %s, %s",
756 element_index.ToString(),
757 element_result.error_message());
758 if (return_status.ok()) {
759 return_status = element_result;
760 } else {
761 return_status =
762 AppendStatus(return_status, element_result.error_message());
763 }
764 }
765 }
766 if (!return_status.ok() && shape_index.empty()) {
767 // Emit a top-level error message containing the top-level shape in case
768 // of mismatch.
769 int64 total_elements = RecursiveElementCount(actual.shape());
770 return_status =
771 InvalidArgument("\nMismatches in shape %s (%d elements):\n%s",
772 ShapeUtil::HumanString(actual.shape()),
773 total_elements, return_status.error_message());
774 }
775 return return_status;
776 }
777
778 if (ShapeUtil::ElementIsFloating(expected.shape()) ||
779 ShapeUtil::ElementIsComplex(expected.shape())) {
780 bool use_detailed_message = detailed_message.value_or(
781 ShapeUtil::ElementsIn(expected.shape()) >= 64);
782 switch (expected.shape().element_type()) {
783 case BF16:
784 return NearComparator<bfloat16>::Compare(
785 expected, actual, error, use_detailed_message, miscompare_callback);
786 break;
787 case F16:
788 return NearComparator<half>::Compare(
789 expected, actual, error, use_detailed_message, miscompare_callback);
790 break;
791 case F32:
792 return NearComparator<float>::Compare(
793 expected, actual, error, use_detailed_message, miscompare_callback);
794 break;
795 case F64:
796 return NearComparator<double>::Compare(
797 expected, actual, error, use_detailed_message, miscompare_callback);
798 break;
799 case C64:
800 return NearComparator<complex64>::Compare(
801 expected, actual, error, use_detailed_message, miscompare_callback);
802 break;
803 case C128:
804 return NearComparator<complex128>::Compare(
805 expected, actual, error, use_detailed_message, miscompare_callback);
806 break;
807 default:
808 LOG(FATAL) << "Unsupported primitive type in near comparator: "
809 << PrimitiveType_Name(expected.shape().element_type())
810 << ". Must be floating-point type.";
811 }
812 }
813
814 // Non-floating point, non-tuple literal.
815 return EqualHelper(expected, actual);
816 }
817
818 } // namespace
819
EqualShapes(const Shape & expected,const Shape & actual)820 Status EqualShapes(const Shape& expected, const Shape& actual) {
821 if (expected.element_type() != actual.element_type()) {
822 return InvalidArgument("element type mismatch, want: %s got %s",
823 ShapeUtil::HumanString(expected),
824 ShapeUtil::HumanString(actual));
825 }
826 if (expected.IsTuple()) {
827 if (ShapeUtil::TupleElementCount(expected) !=
828 ShapeUtil::TupleElementCount(actual)) {
829 return InvalidArgument(
830 "want tuple element count: %d got tuple element count: %d",
831 ShapeUtil::TupleElementCount(expected),
832 ShapeUtil::TupleElementCount(actual));
833 }
834 for (int i = 0; i < expected.tuple_shapes_size(); ++i) {
835 Status result =
836 EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i));
837 if (!result.ok()) {
838 return AppendStatus(result, StrCat("mismatch in tuple index", i));
839 }
840 }
841 } else if (expected.IsArray()) {
842 if (expected.rank() != actual.rank()) {
843 return InvalidArgument("want rank of %s got rank of %s",
844 ShapeUtil::HumanString(expected),
845 ShapeUtil::HumanString(actual));
846 }
847 if (expected.element_type() != actual.element_type()) {
848 return InvalidArgument("mismatch in primitive type %s vs %s",
849 PrimitiveType_Name(expected.element_type()),
850 PrimitiveType_Name(actual.element_type()));
851 }
852 if (expected.dimensions_size() != actual.dimensions_size()) {
853 return InvalidArgument("want dimensions_size %d got dimensions_size %d",
854 expected.dimensions_size(),
855 actual.dimensions_size());
856 }
857 for (int i = 0; i < expected.dimensions_size(); ++i) {
858 if (expected.dimensions(i) != actual.dimensions(i)) {
859 return InvalidArgument(
860 "mismatch in dimension #%d expected: %s actual: %s", i,
861 ShapeUtil::HumanString(expected), ShapeUtil::HumanString(actual));
862 }
863 }
864 }
865 // Non-array, non-tuple shapes are trivially equivalent.
866 return Status::OK();
867 }
868
869 namespace {
870
871 // If result is an error, extend the error message with the expected and actual
872 // literals.
EmitLiteralsInErrorMessage(const Status & result,const LiteralSlice & expected,const LiteralSlice & actual)873 Status EmitLiteralsInErrorMessage(const Status& result,
874 const LiteralSlice& expected,
875 const LiteralSlice& actual) {
876 if (result.ok()) {
877 return result;
878 }
879 return InvalidArgument("%s\n\nExpected literal:\n%s\n\nActual literal:\n%s",
880 result.error_message(), ToStringTruncated(expected),
881 ToStringTruncated(actual));
882 }
883
884 } // namespace
885
Equal(const LiteralSlice & expected,const LiteralSlice & actual)886 Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) {
887 VLOG(1) << "expected:";
888 XLA_VLOG_LINES(1, expected.ToString());
889 VLOG(1) << "actual:";
890 XLA_VLOG_LINES(1, actual.ToString());
891 Status result = EqualHelper(expected, actual);
892 return EmitLiteralsInErrorMessage(result, expected, actual);
893 }
894
Near(const LiteralSlice & expected,const LiteralSlice & actual,const ErrorSpec & error,absl::optional<bool> detailed_message,const MiscompareCallback & miscompare_callback)895 Status Near(const LiteralSlice& expected, const LiteralSlice& actual,
896 const ErrorSpec& error, absl::optional<bool> detailed_message,
897 const MiscompareCallback& miscompare_callback) {
898 VLOG(1) << "Expected literal:";
899 XLA_VLOG_LINES(1, expected.ToString());
900 VLOG(1) << "Actual literal:";
901 XLA_VLOG_LINES(1, actual.ToString());
902 Status result =
903 NearHelper(expected, actual, error, detailed_message, miscompare_callback,
904 /*shape_index=*/{});
905 return EmitLiteralsInErrorMessage(result, expected, actual);
906 }
907
ToStringTruncated(const LiteralSlice & literal)908 string ToStringTruncated(const LiteralSlice& literal) {
909 return RecursiveElementCount(literal.shape()) < 1000
910 ? literal.ToString()
911 : "[TRUNCATED, Literal with more than 1000 values]";
912 }
913
914 } // namespace literal_comparison
915 } // namespace xla
916