1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/literal_comparison.h"
17
18 #include <unistd.h>
19
20 #include <cmath>
21 #include <vector>
22
23 #include "absl/base/casts.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_format.h"
26 #include "tensorflow/compiler/xla/literal_util.h"
27 #include "tensorflow/compiler/xla/util.h"
28 #include "tensorflow/core/platform/env.h"
29
30 using absl::StrAppend;
31 using absl::StrAppendFormat;
32 using absl::StrCat;
33
34 namespace xla {
35 namespace literal_comparison {
36 namespace {
37
38 // Since Eigen::half doesn't satisfy the absl::bit_cast contract, we need to be
39 // able to transparently access the raw 16-bit value contained within.
40 template <typename T>
GetRawValue(T val)41 T GetRawValue(T val) {
42 return val;
43 }
GetRawValue(Eigen::half val)44 uint16 GetRawValue(Eigen::half val) { return val.x; }
45
46 // Helper function for comparing a floating point type, FloatT, bitwise equal
47 // between the left-hand-side and right-hand-side, by bit-casting to UnsignedT
48 // -- on miscompare, a nice error message is given in the AssertionFailure.
49 template <typename FloatT, typename UnsignedT>
CompareFloatsBitwiseEqual(FloatT lhs,FloatT rhs,absl::Span<const int64> multi_index)50 bool CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs,
51 absl::Span<const int64> multi_index) {
52 auto ulhs = absl::bit_cast<UnsignedT>(GetRawValue(lhs));
53 auto urhs = absl::bit_cast<UnsignedT>(GetRawValue(rhs));
54 return ulhs == urhs;
55 }
56
57 // Templated comparator that specializes for float equality comparison with the
58 // bitwise helper above (this is the un-specialized fallback, to just use the
59 // default gunit implementation).
60 template <typename NativeT>
CompareEqual(NativeT lhs,NativeT rhs,absl::Span<const int64> multi_index)61 bool CompareEqual(NativeT lhs, NativeT rhs,
62 absl::Span<const int64> multi_index) {
63 return lhs == rhs;
64 }
65
66 // Specializations for floating types that do bitwise comparisons when equality
67 // comparison is requested.
68 template <>
CompareEqual(bfloat16 lhs,bfloat16 rhs,absl::Span<const int64> multi_index)69 bool CompareEqual<bfloat16>(bfloat16 lhs, bfloat16 rhs,
70 absl::Span<const int64> multi_index) {
71 return CompareFloatsBitwiseEqual<bfloat16, uint16>(lhs, rhs, multi_index);
72 }
73 template <>
CompareEqual(Eigen::half lhs,Eigen::half rhs,absl::Span<const int64> multi_index)74 bool CompareEqual<Eigen::half>(Eigen::half lhs, Eigen::half rhs,
75 absl::Span<const int64> multi_index) {
76 return CompareFloatsBitwiseEqual<Eigen::half, uint16>(lhs, rhs, multi_index);
77 }
78 template <>
CompareEqual(float lhs,float rhs,absl::Span<const int64> multi_index)79 bool CompareEqual<float>(float lhs, float rhs,
80 absl::Span<const int64> multi_index) {
81 return CompareFloatsBitwiseEqual<float, uint32>(lhs, rhs, multi_index);
82 }
83 template <>
CompareEqual(double lhs,double rhs,absl::Span<const int64> multi_index)84 bool CompareEqual<double>(double lhs, double rhs,
85 absl::Span<const int64> multi_index) {
86 return CompareFloatsBitwiseEqual<double, uint64>(lhs, rhs, multi_index);
87 }
88 template <>
CompareEqual(complex64 lhs,complex64 rhs,absl::Span<const int64> multi_index)89 bool CompareEqual<complex64>(complex64 lhs, complex64 rhs,
90 absl::Span<const int64> multi_index) {
91 return CompareEqual<float>(lhs.real(), rhs.real(), multi_index) &&
92 CompareEqual<float>(lhs.imag(), rhs.imag(), multi_index);
93 }
94 template <>
CompareEqual(complex128 lhs,complex128 rhs,absl::Span<const int64> multi_index)95 bool CompareEqual<complex128>(complex128 lhs, complex128 rhs,
96 absl::Span<const int64> multi_index) {
97 return CompareEqual<double>(lhs.real(), rhs.real(), multi_index) &&
98 CompareEqual<double>(lhs.imag(), rhs.imag(), multi_index);
99 }
100
101 template <typename NativeT, typename UnsignedT>
MakeBitwiseErrorStatus(NativeT lhs,NativeT rhs,absl::Span<const int64> multi_index)102 Status MakeBitwiseErrorStatus(NativeT lhs, NativeT rhs,
103 absl::Span<const int64> multi_index) {
104 auto ulhs = absl::bit_cast<UnsignedT>(GetRawValue(lhs));
105 auto urhs = absl::bit_cast<UnsignedT>(GetRawValue(rhs));
106 auto lhs_double = static_cast<double>(lhs);
107 auto rhs_double = static_cast<double>(rhs);
108 return InvalidArgument(
109 "floating values are not bitwise-equal; and equality testing "
110 "was requested: %s=%s=%a vs %s=%s=%a at array index %s",
111 StrCat(absl::Hex(ulhs)), RoundTripFpToString(lhs), lhs_double,
112 StrCat(absl::Hex(urhs)), RoundTripFpToString(rhs), rhs_double,
113 LiteralUtil::MultiIndexAsString(multi_index));
114 }
115
116 template <typename NativeT>
MakeErrorStatus(NativeT lhs,NativeT rhs,absl::Span<const int64> multi_index)117 Status MakeErrorStatus(NativeT lhs, NativeT rhs,
118 absl::Span<const int64> multi_index) {
119 return InvalidArgument(
120 "first mismatch at array index %s:\n expected value: %s\n actual "
121 "value: %s",
122 LiteralUtil::MultiIndexAsString(multi_index), StrCat(lhs), StrCat(rhs));
123 }
124
125 template <>
MakeErrorStatus(bfloat16 lhs,bfloat16 rhs,absl::Span<const int64> multi_index)126 Status MakeErrorStatus(bfloat16 lhs, bfloat16 rhs,
127 absl::Span<const int64> multi_index) {
128 return MakeBitwiseErrorStatus<bfloat16, uint16>(lhs, rhs, multi_index);
129 }
130 template <>
MakeErrorStatus(Eigen::half lhs,Eigen::half rhs,absl::Span<const int64> multi_index)131 Status MakeErrorStatus(Eigen::half lhs, Eigen::half rhs,
132 absl::Span<const int64> multi_index) {
133 return MakeBitwiseErrorStatus<Eigen::half, uint16>(lhs, rhs, multi_index);
134 }
135 template <>
MakeErrorStatus(float lhs,float rhs,absl::Span<const int64> multi_index)136 Status MakeErrorStatus(float lhs, float rhs,
137 absl::Span<const int64> multi_index) {
138 return MakeBitwiseErrorStatus<float, uint32>(lhs, rhs, multi_index);
139 }
140 template <>
MakeErrorStatus(double lhs,double rhs,absl::Span<const int64> multi_index)141 Status MakeErrorStatus(double lhs, double rhs,
142 absl::Span<const int64> multi_index) {
143 return MakeBitwiseErrorStatus<double, uint64>(lhs, rhs, multi_index);
144 }
145 template <>
MakeErrorStatus(complex64 lhs,complex64 rhs,absl::Span<const int64> multi_index)146 Status MakeErrorStatus(complex64 lhs, complex64 rhs,
147 absl::Span<const int64> multi_index) {
148 if (!CompareEqual<float>(lhs.real(), rhs.real(), multi_index)) {
149 return MakeErrorStatus(lhs.real(), rhs.real(), multi_index);
150 }
151 return MakeErrorStatus(lhs.imag(), rhs.imag(), multi_index);
152 }
153 template <>
MakeErrorStatus(complex128 lhs,complex128 rhs,absl::Span<const int64> multi_index)154 Status MakeErrorStatus(complex128 lhs, complex128 rhs,
155 absl::Span<const int64> multi_index) {
156 if (!CompareEqual<double>(lhs.real(), rhs.real(), multi_index)) {
157 return MakeErrorStatus(lhs.real(), rhs.real(), multi_index);
158 }
159 return MakeErrorStatus(lhs.imag(), rhs.imag(), multi_index);
160 }
161
162 // A recursive function which iterates through every index of expected and
163 // actual literal and compares their values elementwise. Returns true if all
164 // elements are equal. Mismatched must either be:
165 // - a literal of booleans that has the same shape as expected and actual. In
166 // this case, each index in mismatched will be set to true if expected does
167 // not equal actual at that index and false if there are equal.
168 // - nullptr. In this case, the function will return once any mismatch is
169 // found between expected and actual.
170 template <typename NativeT>
Equal(LiteralSlice expected,LiteralSlice actual,absl::Span<int64> multi_index,int64 dimension,Literal * mismatched=nullptr)171 Status Equal(LiteralSlice expected, LiteralSlice actual,
172 absl::Span<int64> multi_index, int64 dimension,
173 Literal* mismatched = nullptr) {
174 if (dimension == expected.shape().dimensions_size()) {
175 NativeT expected_value = expected.Get<NativeT>(multi_index);
176 NativeT actual_value = actual.Get<NativeT>(multi_index);
177 bool result =
178 CompareEqual<NativeT>(expected_value, actual_value, multi_index);
179 if (mismatched) {
180 mismatched->Set<bool>(multi_index, !result);
181 }
182 return result ? Status::OK()
183 : MakeErrorStatus<NativeT>(expected_value, actual_value,
184 multi_index);
185 }
186
187 Status result;
188 for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) {
189 multi_index[dimension] = i;
190 if (mismatched != nullptr) {
191 result.Update(Equal<NativeT>(expected, actual, multi_index, dimension + 1,
192 mismatched));
193 } else {
194 TF_RETURN_IF_ERROR(Equal<NativeT>(expected, actual, multi_index,
195 dimension + 1, mismatched));
196 }
197 }
198 return result;
199 }
200
201 // Gets the total element count. For tuples, this is not the count of tuple
202 // elements, but the sum of elements of each tuple element.
RecursiveElementCount(const Shape & shape)203 int64 RecursiveElementCount(const Shape& shape) {
204 if (shape.IsTuple()) {
205 const int64 tuple_elements = ShapeUtil::TupleElementCount(shape);
206 int64 total = 0;
207 for (int64 i = 0; i < tuple_elements; ++i) {
208 total += RecursiveElementCount(ShapeUtil::GetTupleElementShape(shape, i));
209 }
210 return total;
211 } else if (shape.IsArray()) {
212 return ShapeUtil::ElementsIn(shape);
213 } else {
214 return 0;
215 }
216 }
217
218 // Returns whether the given value is infinity.
219 template <typename NativeT>
IsInf(NativeT val)220 bool IsInf(NativeT val) {
221 return Eigen::numext::isinf(val);
222 }
223 // Returns whether the given value is nan.
224 template <typename NativeT>
IsNan(NativeT value)225 bool IsNan(NativeT value) {
226 return Eigen::numext::isnan(value);
227 }
228
229 // Converts the given floating-point value to a string.
FpValueToString(bfloat16 value)230 string FpValueToString(bfloat16 value) {
231 return absl::StrFormat("%10.4g", static_cast<double>(value));
232 }
233
FpValueToString(half value)234 string FpValueToString(half value) {
235 return absl::StrFormat("%11.5g", static_cast<double>(value));
236 }
237
FpValueToString(float value)238 string FpValueToString(float value) {
239 return absl::StrFormat("%15.9g", static_cast<double>(value));
240 }
241
FpValueToString(double value)242 string FpValueToString(double value) {
243 return absl::StrFormat("%24.17g", value);
244 }
245
FpValueToString(complex64 value)246 string FpValueToString(complex64 value) {
247 return absl::StrCat(FpValueToString(value.real()), " + ",
248 FpValueToString(value.imag()));
249 }
250
FpValueToString(complex128 value)251 string FpValueToString(complex128 value) {
252 return absl::StrCat(FpValueToString(value.real()), " + ",
253 FpValueToString(value.imag()));
254 }
255
256 // A wrapper of std::abs to include data types that are not supported by
257 // std::abs, in particular, bfloat16 and half.
258 template <typename NativeT>
FpAbsoluteValue(NativeT value)259 double FpAbsoluteValue(NativeT value) {
260 return std::abs(value);
261 }
262
263 template <>
FpAbsoluteValue(bfloat16 value)264 double FpAbsoluteValue(bfloat16 value) {
265 return FpAbsoluteValue<float>(static_cast<float>(value));
266 }
267
268 template <>
FpAbsoluteValue(half value)269 double FpAbsoluteValue(half value) {
270 return FpAbsoluteValue<float>(static_cast<float>(value));
271 }
272
273 // Helper class for comparing floating-point literals within an error bound.
274 template <typename NativeT>
275 class NearComparator {
276 public:
277 // Compares the two array literals elementwise and returns a comparison
278 // result. The comparison is ok() if all actual and expected elements are
279 // within the given error bound. In case of error, the status contains a
280 // detailed message about the discrepancy.
Compare(const LiteralSlice & expected,const LiteralSlice & actual,const ShapeIndex & shape_index,ErrorSpec error,bool detailed_message,const MiscompareCallback & miscompare_callback)281 static Status Compare(const LiteralSlice& expected,
282 const LiteralSlice& actual,
283 const ShapeIndex& shape_index, ErrorSpec error,
284 bool detailed_message,
285 const MiscompareCallback& miscompare_callback) {
286 NearComparator<NativeT> comparator(expected, actual, shape_index, error,
287 detailed_message, miscompare_callback);
288 return comparator.Run();
289 }
290
291 private:
292 // Data structure encapsulating metadata about a single element mismatch.
293 struct Mismatch {
294 NativeT actual;
295 NativeT expected;
296 double rel_error;
297 double abs_error;
298
299 // The linear index of the failure within the shape. This linear index is
300 // from the 'actual' literal.
301 int64 linear_index;
302
operator <xla::literal_comparison::__anon1e4119ae0111::NearComparator::Mismatch303 bool operator<(const Mismatch& other) const {
304 return rel_error < other.rel_error;
305 }
306
ToStringxla::literal_comparison::__anon1e4119ae0111::NearComparator::Mismatch307 string ToString(const Shape& shape) const {
308 return absl::StrFormat(
309 "actual %s, expected %s, index %s, rel error %8.3g, abs error %8.3g",
310 FpValueToString(actual), FpValueToString(expected),
311 LiteralUtil::MultiIndexAsString(
312 IndexUtil::LinearIndexToMultidimensionalIndex(shape,
313 linear_index)),
314 rel_error, abs_error);
315 }
316 };
317
NearComparator(const LiteralSlice & expected,const LiteralSlice & actual,const ShapeIndex & shape_index,ErrorSpec error,bool detailed_message,const MiscompareCallback & miscompare_callback)318 NearComparator(const LiteralSlice& expected, const LiteralSlice& actual,
319 const ShapeIndex& shape_index, ErrorSpec error,
320 bool detailed_message,
321 const MiscompareCallback& miscompare_callback)
322 : expected_(expected),
323 actual_(actual),
324 shape_index_(shape_index),
325 error_(error),
326 detailed_message_(detailed_message),
327 miscompare_callback_(miscompare_callback),
328 abs_value_buckets_(kAbsValueBucketBounds.size() - 1, {0, 0}),
329 abs_error_buckets_(kErrorBucketBounds.size(), 0),
330 rel_error_buckets_(kErrorBucketBounds.size(), 0) {}
331
332 // Runs the comparison between expected and actual literals.
Run()333 Status Run() {
334 // If the shapes mismatch, we simply fail the expectation instead of
335 // printing out data, as it's a type error rather than a value error.
336 TF_RETURN_IF_ERROR(EqualShapes(expected_.shape(), actual_.shape()));
337 if (!expected_.shape().IsArray()) {
338 return InvalidArgument("Expected array shape; got %s.",
339 ShapeUtil::HumanString(expected_.shape()));
340 }
341
342 mismatches_ = Literal(ShapeUtil::ChangeElementType(actual_.shape(), PRED));
343 mismatches_.PopulateWithValue(false);
344
345 CompareLiterals();
346
347 if (num_mismatches_ == 0) {
348 return Status::OK();
349 } else if (!VLOG_IS_ON(1) && miscompare_callback_ != nullptr) {
350 miscompare_callback_(expected_, actual_, mismatches_, shape_index_);
351 }
352 return InvalidArgument("%s", ErrorMessage());
353 }
354
355 // Insert the given absolute value into the absolute value bucket vector. The
356 // bounds of the buckets are given by kAbsValueBucketBounds.
UpdateAbsValueBucket(NativeT value,bool is_mismatch)357 void UpdateAbsValueBucket(NativeT value, bool is_mismatch) {
358 // Adjust the bucket containing the absolute values of the 'actual'
359 // elements.
360 const double abs_value = FpAbsoluteValue(value);
361 for (int i = 0; i < abs_value_buckets_.size(); ++i) {
362 if (i == abs_value_buckets_.size() - 1 ||
363 (abs_value >= kAbsValueBucketBounds[i] &&
364 abs_value < kAbsValueBucketBounds[i + 1])) {
365 // The first value of the pair is the count of elements in the bucket,
366 // the second is the count of mismatches in the bucket.
367 abs_value_buckets_[i].first++;
368 if (is_mismatch) {
369 abs_value_buckets_[i].second++;
370 }
371 return;
372 }
373 }
374 }
375
376 // Insert the given error into the given error bucket vector.
UpdateErrorBucket(double error,absl::Span<int64> error_buckets)377 void UpdateErrorBucket(double error, absl::Span<int64> error_buckets) {
378 CHECK_EQ(error_buckets.size(), kErrorBucketBounds.size());
379 for (int i = 0; i < error_buckets.size(); ++i) {
380 if (error >= kErrorBucketBounds[i]) {
381 error_buckets[i]++;
382 }
383 }
384 }
385
386 // Compares the two given elements from the expected and actual literals at
387 // the given literal_index and keeps track of various mismatch statistics.
388 template <typename T>
CompareValues(T expected,T actual,int64 linear_index)389 void CompareValues(T expected, T actual, int64 linear_index) {
390 double abs_error;
391 double rel_error;
392 if (CompareEqual<T>(expected, actual, {linear_index})) {
393 abs_error = 0;
394 rel_error = 0;
395 } else if (IsNan(expected) || IsNan(actual)) {
396 if ((!error_.relaxed_nans && IsNan(expected) != IsNan(actual)) ||
397 (error_.relaxed_nans && !IsNan(expected) && IsNan(actual))) {
398 num_nan_mismatches_++;
399 // A nan mismatch is considered to have infinite error. rel_error is
400 // used for sorting a std::set of the top mismatches, and a nan value
401 // here will result in undefined behavior because nan's do not satisfy
402 // the strict weak ordering requirement of std containers.
403 abs_error = std::numeric_limits<float>::infinity();
404 rel_error = std::numeric_limits<float>::infinity();
405 } else {
406 abs_error = 0;
407 rel_error = 0;
408 }
409 } else if (IsInf(actual) && !IsInf(expected) && error_.fewer_infs_ok) {
410 // `fewer_infs_ok` gives us the option of comparing as though `actual`
411 // were float_max/min rather than inf.
412 T actual_finite = actual > T{0} ? std::numeric_limits<T>::max()
413 : std::numeric_limits<T>::lowest();
414 abs_error = FpAbsoluteValue(actual_finite - expected);
415
416 // Avoid division by 0 even though it's well-defined because ubsan can be
417 // configured to treat this as a fatal error.
418 if (expected != T{0}) {
419 rel_error = abs_error / FpAbsoluteValue(expected);
420 } else {
421 rel_error = std::numeric_limits<float>::infinity();
422 }
423 } else if (IsInf(expected) || IsInf(actual)) {
424 // If either the expected or actual value is infinity but not both,
425 // then both absolute and relative error are regarded as infinity.
426 CHECK(!CompareEqual(expected, actual, {linear_index}));
427 abs_error = std::numeric_limits<float>::infinity();
428 rel_error = std::numeric_limits<float>::infinity();
429 } else {
430 abs_error = FpAbsoluteValue(actual - expected);
431
432 // Avoid division by 0 even though it's well-defined because ubsan can be
433 // configured to treat this as a fatal error.
434 if (expected != T{0}) {
435 rel_error = abs_error / FpAbsoluteValue(expected);
436 } else {
437 rel_error = std::numeric_limits<float>::infinity();
438 }
439 }
440 const bool is_abs_mismatch = abs_error > error_.abs;
441 const bool is_rel_mismatch = rel_error > error_.rel;
442 const bool is_mismatch = is_abs_mismatch && is_rel_mismatch;
443
444 // Update the error of the relative bucket only if the *absolute* error
445 // bound is exceeded and vice versa.
446 if (is_abs_mismatch) {
447 num_abs_mismatches_++;
448 UpdateErrorBucket(rel_error, absl::MakeSpan(rel_error_buckets_));
449 }
450 if (is_rel_mismatch) {
451 num_rel_mismatches_++;
452 UpdateErrorBucket(abs_error, absl::MakeSpan(abs_error_buckets_));
453 }
454
455 UpdateAbsValueBucket(actual, is_mismatch);
456
457 if (!is_mismatch) {
458 return;
459 }
460
461 num_mismatches_++;
462
463 // Keep track of the kTopRelativeErrorCount relative error mismatches.
464 if (top_rel_mismatches_.size() < kTopRelativeErrorCount ||
465 rel_error > top_rel_mismatches_.begin()->rel_error) {
466 Mismatch mismatch = {actual, expected, rel_error, abs_error,
467 linear_index};
468 top_rel_mismatches_.insert(mismatch);
469 if (top_rel_mismatches_.size() > kTopRelativeErrorCount) {
470 top_rel_mismatches_.erase(top_rel_mismatches_.begin());
471 }
472 }
473
474 mismatches_.data<bool>()[linear_index] = true;
475 }
476
477 // For complex types, we compare real and imaginary parts individually.
CompareValues(complex64 expected,complex64 actual,int64 linear_index)478 void CompareValues(complex64 expected, complex64 actual, int64 linear_index) {
479 const auto both_parts_mismatch = num_mismatches_ + 2;
480 CompareValues<float>(expected.real(), actual.real(), linear_index);
481 CompareValues<float>(expected.imag(), actual.imag(), linear_index);
482 if (num_mismatches_ == both_parts_mismatch) {
483 // The mismatch counter had been incremented by each CompareValues() call,
484 // which means that both real and imaginary parts of the passed-in complex
485 // values are different. However, the counter should reflect a single
486 // mismatch between these complex values.
487 num_mismatches_--;
488 }
489 }
490
CompareValues(complex128 expected,complex128 actual,int64 linear_index)491 void CompareValues(complex128 expected, complex128 actual,
492 int64 linear_index) {
493 const auto both_parts_mismatch = num_mismatches_ + 2;
494 CompareValues<double>(expected.real(), actual.real(), linear_index);
495 CompareValues<double>(expected.imag(), actual.imag(), linear_index);
496 if (num_mismatches_ == both_parts_mismatch) {
497 // The mismatch counter had been incremented by each CompareValues() call,
498 // which means that both real and imaginary parts of the passed-in complex
499 // values are different. However, the counter should reflect a single
500 // mismatch between these complex values.
501 num_mismatches_--;
502 }
503 }
504
505 // Compares the two literals elementwise.
CompareLiterals()506 void CompareLiterals() {
507 // Fast path optimization for the case were layouts match.
508 if (LayoutUtil::Equal(actual_.shape().layout(),
509 expected_.shape().layout())) {
510 absl::Span<const NativeT> expected_data = expected_.data<NativeT>();
511 absl::Span<const NativeT> actual_data = actual_.data<NativeT>();
512 const int64 len = expected_data.size();
513 for (int64 i = 0; i < len; ++i) {
514 CompareValues(expected_data[i], actual_data[i], i);
515 }
516 return;
517 }
518 std::vector<int64> multi_index(actual_.shape().rank(), 0);
519 CompareLiteralsSlow(0, &multi_index);
520 }
521
522 // Slow path for CompareLiterals when 'actual' and 'expected' literals have
523 // different layouts. In this case, multidimensional indices are constructed
524 // and indexed for each element.
CompareLiteralsSlow(int64 dimension,std::vector<int64> * multi_index)525 void CompareLiteralsSlow(int64 dimension, std::vector<int64>* multi_index) {
526 if (dimension == multi_index->size()) {
527 CompareValues(expected_.Get<NativeT>(*multi_index),
528 actual_.Get<NativeT>(*multi_index),
529 IndexUtil::MultidimensionalIndexToLinearIndex(
530 actual_.shape(), *multi_index));
531 } else {
532 for (int64 i = 0; i < expected_.shape().dimensions(dimension); ++i) {
533 (*multi_index)[dimension] = i;
534 CompareLiteralsSlow(dimension + 1, multi_index);
535 }
536 }
537 }
538
539 // Returns an error message string with a detailed breakdown of the
540 // mismatches. Called after calling Run().
ErrorMessage()541 string ErrorMessage() {
542 string out;
543 int64 element_count = ShapeUtil::ElementsIn(actual_.shape());
544
545 auto percent_string = [](float a, float b) {
546 float pct = b == 0.0 ? 0.0 : 100.0 * a / b;
547 return absl::StrFormat("%0.4f%%", pct);
548 };
549
550 StrAppendFormat(
551 &out,
552 "\nMismatch count %d (%s) in shape %s (%d elements), abs bound "
553 "%g, rel bound %g\n",
554 num_mismatches_, percent_string(num_mismatches_, element_count),
555 ShapeUtil::HumanString(actual_.shape()),
556 ShapeUtil::ElementsIn(actual_.shape()), error_.abs, error_.rel);
557 if (num_nan_mismatches_ > 0) {
558 StrAppend(&out, "nan mismatches ", num_nan_mismatches_, "\n");
559 }
560 StrAppendFormat(&out, "Top relative error mismatches:\n");
561 for (auto it = top_rel_mismatches_.rbegin();
562 it != top_rel_mismatches_.rend(); ++it) {
563 StrAppend(&out, " ", it->ToString(actual_.shape()), "\n");
564 }
565
566 if (!detailed_message_) {
567 return out;
568 }
569
570 StrAppend(&out, "Absolute magnitude breakdown of actual values:\n");
571 CHECK_EQ(abs_value_buckets_.size() + 1, kAbsValueBucketBounds.size());
572 for (int i = 0; i < abs_value_buckets_.size(); ++i) {
573 const int64 bucket_size = abs_value_buckets_[i].first;
574 const int64 bucket_mismatches = abs_value_buckets_[i].second;
575 string mismatch_str =
576 bucket_mismatches > 0
577 ? absl::StrFormat(", mismatches %d", bucket_mismatches)
578 : "";
579 StrAppendFormat(&out, " %-6g <= x < %-6g : %7d (%9s)%s\n",
580 kAbsValueBucketBounds[i], kAbsValueBucketBounds[i + 1],
581 bucket_size, percent_string(bucket_size, element_count),
582 mismatch_str);
583 }
584
585 auto print_accum_buckets = [&](const string& header, int64 total,
586 absl::Span<const int64> buckets) {
587 StrAppend(&out, header, ":\n");
588 StrAppendFormat(&out, " < %-6g : %7d (%s)\n", kErrorBucketBounds[0],
589 total - buckets[0],
590 percent_string(total - buckets[0], total));
591 CHECK_EQ(buckets.size(), kErrorBucketBounds.size());
592 for (int i = 0; i < kErrorBucketBounds.size(); ++i) {
593 StrAppendFormat(&out, " >= %-6g : %7d (%s)\n", kErrorBucketBounds[i],
594 buckets[i], percent_string(buckets[i], total));
595 }
596 };
597 StrAppendFormat(&out, "Elements exceeding abs error bound %g: %d (%s)\n",
598 error_.abs, num_abs_mismatches_,
599 percent_string(num_abs_mismatches_, element_count));
600 print_accum_buckets(
601 "Relative error breakdown of elements exceeding abs error bound",
602 num_abs_mismatches_, rel_error_buckets_);
603 StrAppendFormat(&out, "Elements exceeding rel error bound %g: %d (%s)\n",
604 error_.rel, num_rel_mismatches_,
605 percent_string(num_rel_mismatches_, element_count));
606 print_accum_buckets(
607 "Absolute error breakdown of elements exceeding rel error bound",
608 num_rel_mismatches_, abs_error_buckets_);
609 return out;
610 }
611
612 // 'actual' and 'expected' literals being compared.
613 LiteralSlice expected_;
614 LiteralSlice actual_;
615
616 // The shape index of the LiteralSlice that is being compared.
617 ShapeIndex shape_index_;
618
619 // The error bounds of the comparison.
620 ErrorSpec error_;
621
622 // Whether to include detailed breakdown of mismatches in the error message.
623 bool detailed_message_;
624
625 // Callback to invoke on miscompare.
626 MiscompareCallback miscompare_callback_;
627
628 // Number of element mismatches encountered so far.
629 int64 num_mismatches_ = 0;
630
631 // Number of elements with a nan mismatch.
632 int64 num_nan_mismatches_ = 0;
633
634 // Number of elements which exceed the absolute/relative error bound.
635 int64 num_abs_mismatches_ = 0;
636 int64 num_rel_mismatches_ = 0;
637
638 // A Literal containing which elements did not match in the expected and
639 // actual literals. mismatches_ contains PREDs and is of the same sizes as
640 // the comparison literals.
641 Literal mismatches_;
642
643 // The number of mismatches to report in the output, sorted by relative error
644 // magnitude.
645 static constexpr int64 kTopRelativeErrorCount = 5;
646
647 // The set of mismatches with the largest relative error. The size of this set
648 // is bounded by kTopRelativeErrorCount.
649 std::multiset<Mismatch> top_rel_mismatches_;
650
651 // Actual values are bucketed by absolute value. kAbsValueBucketBounds is the
652 // bounds of these buckets. abs_value_buckets_ contains a pair for each
653 // bucket: the element count and failure count.
654 static constexpr std::array<float, 7> kAbsValueBucketBounds = {
655 0.0, 0.0001, 0.001, 0.01, 0.1, 1, std::numeric_limits<float>::infinity()};
656 std::vector<std::pair<int64, int64>> abs_value_buckets_;
657
658 // Buckets for relative and absolute errors. The relative error buckets only
659 // contains those elements which exceed the *absolute* error bound, and vice
660 // versa. This makes it easy to see the effect of adjusting the relative (or
661 // absolute) error bound on the success of the comparison. kErrorBucketBounds
662 // are the lower bounds of the buckets in both vectors. The error buckets are
663 // a cumulative distribution so an error value may appear in more than one
664 // bucket. For example an error value of 0.003 may appear in the buckets
665 // bounded by 0.01, 0.1, and 1.0.
666 static constexpr std::array<float, 5> kErrorBucketBounds = {0.0001, 0.001,
667 0.01, 0.1, 1};
668 std::vector<int64> abs_error_buckets_;
669 std::vector<int64> rel_error_buckets_;
670 };
671
672 template <typename NativeT>
673 constexpr std::array<float, 7> NearComparator<NativeT>::kAbsValueBucketBounds;
674 template <typename NativeT>
675 constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
676
EqualHelper(const LiteralSlice & expected,const LiteralSlice & actual,const ShapeIndex & shape_index,const MiscompareCallback & miscompare_callback)677 Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual,
678 const ShapeIndex& shape_index,
679 const MiscompareCallback& miscompare_callback) {
680 TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape()));
681
682 Status result;
683 if (expected.shape().IsTuple()) {
684 ShapeIndex next_index = shape_index;
685 for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
686 next_index.push_back(i);
687 Status tuple_result =
688 EqualHelper(LiteralSlice(expected, {i}), LiteralSlice(actual, {i}),
689 next_index, miscompare_callback);
690 if (miscompare_callback) {
691 result.Update(tuple_result);
692 } else {
693 TF_RETURN_IF_ERROR(tuple_result);
694 }
695 next_index.pop_back();
696 }
697 } else {
698 std::vector<int64> multi_index(expected.shape().dimensions_size(), 0);
699 auto index = absl::MakeSpan(multi_index);
700
701 Shape unequal_shape = ShapeUtil::MakeShape(PrimitiveType::PRED,
702 expected.shape().dimensions());
703 Literal miscompared(unequal_shape);
704 Literal* miscompared_ptr =
705 (miscompare_callback == nullptr ? nullptr : &miscompared);
706
707 switch (expected.shape().element_type()) {
708 case PRED:
709 result = Equal<bool>(expected, actual, index, 0, miscompared_ptr);
710 break;
711 case S8:
712 result = Equal<int8>(expected, actual, index, 0, miscompared_ptr);
713 break;
714 case S16:
715 result = Equal<int16>(expected, actual, index, 0, miscompared_ptr);
716 break;
717 case S32:
718 result = Equal<int32>(expected, actual, index, 0, miscompared_ptr);
719 break;
720 case S64:
721 result = Equal<int64>(expected, actual, index, 0, miscompared_ptr);
722 break;
723 case U8:
724 result = Equal<uint8>(expected, actual, index, 0, miscompared_ptr);
725 break;
726 case U16:
727 result = Equal<uint16>(expected, actual, index, 0, miscompared_ptr);
728 break;
729 case U32:
730 result = Equal<uint32>(expected, actual, index, 0, miscompared_ptr);
731 break;
732 case U64:
733 result = Equal<uint64>(expected, actual, index, 0, miscompared_ptr);
734 break;
735 case BF16:
736 result = Equal<bfloat16>(expected, actual, index, 0, miscompared_ptr);
737 break;
738 case F16:
739 result = Equal<half>(expected, actual, index, 0, miscompared_ptr);
740 break;
741 case F32:
742 result = Equal<float>(expected, actual, index, 0, miscompared_ptr);
743 break;
744 case F64:
745 result = Equal<double>(expected, actual, index, 0, miscompared_ptr);
746 break;
747 case C64:
748 result = Equal<complex64>(expected, actual, index, 0, miscompared_ptr);
749 break;
750 case C128:
751 result = Equal<complex128>(expected, actual, index, 0, miscompared_ptr);
752 break;
753 case TOKEN:
754 // Tokens have no on-device representation and are trivially equal.
755 return Status::OK();
756 default:
757 LOG(FATAL) << "Unsupported primitive type: "
758 << PrimitiveType_Name(expected.shape().element_type());
759 }
760
761 if (!result.ok() && miscompare_callback) {
762 miscompare_callback(expected, actual, LiteralSlice(miscompared),
763 shape_index);
764 }
765 }
766
767 return result;
768 }
769
770 // Helper function for comparing two literals for nearness. Handles tuple-shapes
771 // via recursion. shape_index is the ShapeIndex of expected (or actual)
772 // currently being compared.
NearHelper(const LiteralSlice & expected,const LiteralSlice & actual,const ShapeIndex & shape_index,const ErrorSpec & error,absl::optional<bool> detailed_message,const MiscompareCallback & miscompare_callback)773 Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
774 const ShapeIndex& shape_index, const ErrorSpec& error,
775 absl::optional<bool> detailed_message,
776 const MiscompareCallback& miscompare_callback) {
777 TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape()));
778
779 if (expected.shape().IsTuple()) {
780 Status return_status;
781 for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
782 const auto expected_element = LiteralSlice(expected, {i});
783 const auto actual_element = LiteralSlice(actual, {i});
784 ShapeIndex element_index = shape_index;
785 element_index.push_back(i);
786 Status element_result =
787 NearHelper(expected_element, actual_element, element_index, error,
788 detailed_message, miscompare_callback);
789 if (!element_result.ok()) {
790 element_result = InvalidArgument("Array at shape index %s, %s",
791 element_index.ToString(),
792 element_result.error_message());
793 if (return_status.ok()) {
794 return_status = element_result;
795 } else {
796 return_status =
797 AppendStatus(return_status, element_result.error_message());
798 }
799 }
800 }
801 if (!return_status.ok() && shape_index.empty()) {
802 // Emit a top-level error message containing the top-level shape in case
803 // of mismatch.
804 int64 total_elements = RecursiveElementCount(actual.shape());
805 return_status =
806 InvalidArgument("\nMismatches in shape %s (%d elements):\n%s",
807 ShapeUtil::HumanString(actual.shape()),
808 total_elements, return_status.error_message());
809 }
810 return return_status;
811 }
812
813 if (ShapeUtil::ElementIsFloating(expected.shape()) ||
814 ShapeUtil::ElementIsComplex(expected.shape())) {
815 bool use_detailed_message = detailed_message.value_or(
816 ShapeUtil::ElementsIn(expected.shape()) >= 64);
817 switch (expected.shape().element_type()) {
818 case BF16:
819 return NearComparator<bfloat16>::Compare(expected, actual, shape_index,
820 error, use_detailed_message,
821 miscompare_callback);
822 break;
823 case F16:
824 return NearComparator<half>::Compare(expected, actual, shape_index,
825 error, use_detailed_message,
826 miscompare_callback);
827 break;
828 case F32:
829 return NearComparator<float>::Compare(expected, actual, shape_index,
830 error, use_detailed_message,
831 miscompare_callback);
832 break;
833 case F64:
834 return NearComparator<double>::Compare(expected, actual, shape_index,
835 error, use_detailed_message,
836 miscompare_callback);
837 break;
838 case C64:
839 return NearComparator<complex64>::Compare(expected, actual, shape_index,
840 error, use_detailed_message,
841 miscompare_callback);
842 break;
843 case C128:
844 return NearComparator<complex128>::Compare(
845 expected, actual, shape_index, error, use_detailed_message,
846 miscompare_callback);
847 break;
848 default:
849 LOG(FATAL) << "Unsupported primitive type in near comparator: "
850 << PrimitiveType_Name(expected.shape().element_type())
851 << ". Must be floating-point type.";
852 }
853 }
854
855 // Non-floating point, non-tuple literal.
856 return EqualHelper(expected, actual, shape_index, miscompare_callback);
857 }
858
859 } // namespace
860
EqualShapes(const Shape & expected,const Shape & actual)861 Status EqualShapes(const Shape& expected, const Shape& actual) {
862 if (expected.element_type() != actual.element_type()) {
863 return InvalidArgument("element type mismatch, want: %s got %s",
864 ShapeUtil::HumanString(expected),
865 ShapeUtil::HumanString(actual));
866 }
867 if (expected.IsTuple()) {
868 if (ShapeUtil::TupleElementCount(expected) !=
869 ShapeUtil::TupleElementCount(actual)) {
870 return InvalidArgument(
871 "want tuple element count: %d got tuple element count: %d",
872 ShapeUtil::TupleElementCount(expected),
873 ShapeUtil::TupleElementCount(actual));
874 }
875 for (int i = 0; i < expected.tuple_shapes_size(); ++i) {
876 Status result =
877 EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i));
878 if (!result.ok()) {
879 return AppendStatus(result, StrCat("mismatch in tuple index", i));
880 }
881 }
882 } else if (expected.IsArray()) {
883 if (expected.rank() != actual.rank()) {
884 return InvalidArgument("want rank of %s got rank of %s",
885 ShapeUtil::HumanString(expected),
886 ShapeUtil::HumanString(actual));
887 }
888 if (expected.element_type() != actual.element_type()) {
889 return InvalidArgument("mismatch in primitive type %s vs %s",
890 PrimitiveType_Name(expected.element_type()),
891 PrimitiveType_Name(actual.element_type()));
892 }
893 if (expected.dimensions_size() != actual.dimensions_size()) {
894 return InvalidArgument("want dimensions_size %d got dimensions_size %d",
895 expected.dimensions_size(),
896 actual.dimensions_size());
897 }
898 for (int i = 0; i < expected.dimensions_size(); ++i) {
899 if (expected.dimensions(i) != actual.dimensions(i)) {
900 return InvalidArgument(
901 "mismatch in dimension #%d expected: %s actual: %s", i,
902 ShapeUtil::HumanString(expected), ShapeUtil::HumanString(actual));
903 }
904 }
905 }
906 // Non-array, non-tuple shapes are trivially equivalent.
907 return Status::OK();
908 }
909
910 namespace {
911
912 // If result is an error, extend the error message with the expected and actual
913 // literals.
EmitLiteralsInErrorMessage(const Status & result,const LiteralSlice & expected,const LiteralSlice & actual)914 Status EmitLiteralsInErrorMessage(const Status& result,
915 const LiteralSlice& expected,
916 const LiteralSlice& actual) {
917 if (result.ok()) {
918 return result;
919 }
920 return InvalidArgument("%s\n\nExpected literal:\n%s\n\nActual literal:\n%s",
921 result.error_message(), ToStringTruncated(expected),
922 ToStringTruncated(actual));
923 }
924
925 } // namespace
926
Equal(const LiteralSlice & expected,const LiteralSlice & actual)927 Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) {
928 VLOG(1) << "expected:";
929 XLA_VLOG_LINES(1, expected.ToString());
930 VLOG(1) << "actual:";
931 XLA_VLOG_LINES(1, actual.ToString());
932 Status result = EqualHelper(expected, actual, {}, nullptr);
933 return EmitLiteralsInErrorMessage(result, expected, actual);
934 }
935
Near(const LiteralSlice & expected,const LiteralSlice & actual,const ErrorSpec & error,absl::optional<bool> detailed_message,const MiscompareCallback & miscompare_callback)936 Status Near(const LiteralSlice& expected, const LiteralSlice& actual,
937 const ErrorSpec& error, absl::optional<bool> detailed_message,
938 const MiscompareCallback& miscompare_callback) {
939 VLOG(1) << "Expected literal:";
940 XLA_VLOG_LINES(1, expected.ToString());
941 VLOG(1) << "Actual literal:";
942 XLA_VLOG_LINES(1, actual.ToString());
943 Status result = NearHelper(expected, actual, /*shape_index=*/{}, error,
944 detailed_message, miscompare_callback);
945 return EmitLiteralsInErrorMessage(result, expected, actual);
946 }
947
ToStringTruncated(const LiteralSlice & literal)948 string ToStringTruncated(const LiteralSlice& literal) {
949 return RecursiveElementCount(literal.shape()) < 1000
950 ? literal.ToString()
951 : "[TRUNCATED, Literal with more than 1000 values]";
952 }
953
954 } // namespace literal_comparison
955 } // namespace xla
956