1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/literal_comparison.h"
17 
18 #include <unistd.h>
19 
20 #include <cmath>
21 #include <vector>
22 
23 #include "absl/base/casts.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_format.h"
26 #include "tensorflow/compiler/xla/literal_util.h"
27 #include "tensorflow/compiler/xla/util.h"
28 #include "tensorflow/core/platform/env.h"
29 
30 using absl::StrAppend;
31 using absl::StrAppendFormat;
32 using absl::StrCat;
33 
34 namespace xla {
35 namespace literal_comparison {
36 namespace {
37 
38 // Since Eigen::half doesn't satisfy the absl::bit_cast contract, we need to be
39 // able to transparently access the raw 16-bit value contained within.
40 template <typename T>
GetRawValue(T val)41 T GetRawValue(T val) {
42   return val;
43 }
GetRawValue(Eigen::half val)44 uint16 GetRawValue(Eigen::half val) { return val.x; }
45 
46 // Helper function for comparing a floating point type, FloatT, bitwise equal
47 // between the left-hand-side and right-hand-side, by bit-casting to UnsignedT
48 // -- on miscompare, a nice error message is given in the AssertionFailure.
49 template <typename FloatT, typename UnsignedT>
CompareFloatsBitwiseEqual(FloatT lhs,FloatT rhs,absl::Span<const int64> multi_index)50 bool CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs,
51                                absl::Span<const int64> multi_index) {
52   auto ulhs = absl::bit_cast<UnsignedT>(GetRawValue(lhs));
53   auto urhs = absl::bit_cast<UnsignedT>(GetRawValue(rhs));
54   return ulhs == urhs;
55 }
56 
57 // Templated comparator that specializes for float equality comparison with the
58 // bitwise helper above (this is the un-specialized fallback, to just use the
59 // default gunit implementation).
60 template <typename NativeT>
CompareEqual(NativeT lhs,NativeT rhs,absl::Span<const int64> multi_index)61 bool CompareEqual(NativeT lhs, NativeT rhs,
62                   absl::Span<const int64> multi_index) {
63   return lhs == rhs;
64 }
65 
66 // Specializations for floating types that do bitwise comparisons when equality
67 // comparison is requested.
68 template <>
CompareEqual(bfloat16 lhs,bfloat16 rhs,absl::Span<const int64> multi_index)69 bool CompareEqual<bfloat16>(bfloat16 lhs, bfloat16 rhs,
70                             absl::Span<const int64> multi_index) {
71   return CompareFloatsBitwiseEqual<bfloat16, uint16>(lhs, rhs, multi_index);
72 }
73 template <>
CompareEqual(Eigen::half lhs,Eigen::half rhs,absl::Span<const int64> multi_index)74 bool CompareEqual<Eigen::half>(Eigen::half lhs, Eigen::half rhs,
75                                absl::Span<const int64> multi_index) {
76   return CompareFloatsBitwiseEqual<Eigen::half, uint16>(lhs, rhs, multi_index);
77 }
78 template <>
CompareEqual(float lhs,float rhs,absl::Span<const int64> multi_index)79 bool CompareEqual<float>(float lhs, float rhs,
80                          absl::Span<const int64> multi_index) {
81   return CompareFloatsBitwiseEqual<float, uint32>(lhs, rhs, multi_index);
82 }
83 template <>
CompareEqual(double lhs,double rhs,absl::Span<const int64> multi_index)84 bool CompareEqual<double>(double lhs, double rhs,
85                           absl::Span<const int64> multi_index) {
86   return CompareFloatsBitwiseEqual<double, uint64>(lhs, rhs, multi_index);
87 }
88 template <>
CompareEqual(complex64 lhs,complex64 rhs,absl::Span<const int64> multi_index)89 bool CompareEqual<complex64>(complex64 lhs, complex64 rhs,
90                              absl::Span<const int64> multi_index) {
91   return CompareEqual<float>(lhs.real(), rhs.real(), multi_index) &&
92          CompareEqual<float>(lhs.imag(), rhs.imag(), multi_index);
93 }
94 template <>
CompareEqual(complex128 lhs,complex128 rhs,absl::Span<const int64> multi_index)95 bool CompareEqual<complex128>(complex128 lhs, complex128 rhs,
96                               absl::Span<const int64> multi_index) {
97   return CompareEqual<double>(lhs.real(), rhs.real(), multi_index) &&
98          CompareEqual<double>(lhs.imag(), rhs.imag(), multi_index);
99 }
100 
101 template <typename NativeT, typename UnsignedT>
MakeBitwiseErrorStatus(NativeT lhs,NativeT rhs,absl::Span<const int64> multi_index)102 Status MakeBitwiseErrorStatus(NativeT lhs, NativeT rhs,
103                               absl::Span<const int64> multi_index) {
104   auto ulhs = absl::bit_cast<UnsignedT>(GetRawValue(lhs));
105   auto urhs = absl::bit_cast<UnsignedT>(GetRawValue(rhs));
106   auto lhs_double = static_cast<double>(lhs);
107   auto rhs_double = static_cast<double>(rhs);
108   return InvalidArgument(
109       "floating values are not bitwise-equal; and equality testing "
110       "was requested: %s=%s=%a vs %s=%s=%a at array index %s",
111       StrCat(absl::Hex(ulhs)), RoundTripFpToString(lhs), lhs_double,
112       StrCat(absl::Hex(urhs)), RoundTripFpToString(rhs), rhs_double,
113       LiteralUtil::MultiIndexAsString(multi_index));
114 }
115 
116 template <typename NativeT>
MakeErrorStatus(NativeT lhs,NativeT rhs,absl::Span<const int64> multi_index)117 Status MakeErrorStatus(NativeT lhs, NativeT rhs,
118                        absl::Span<const int64> multi_index) {
119   return InvalidArgument(
120       "first mismatch at array index %s:\n  expected value: %s\n  actual "
121       "value:   %s",
122       LiteralUtil::MultiIndexAsString(multi_index), StrCat(lhs), StrCat(rhs));
123 }
124 
125 template <>
MakeErrorStatus(bfloat16 lhs,bfloat16 rhs,absl::Span<const int64> multi_index)126 Status MakeErrorStatus(bfloat16 lhs, bfloat16 rhs,
127                        absl::Span<const int64> multi_index) {
128   return MakeBitwiseErrorStatus<bfloat16, uint16>(lhs, rhs, multi_index);
129 }
130 template <>
MakeErrorStatus(Eigen::half lhs,Eigen::half rhs,absl::Span<const int64> multi_index)131 Status MakeErrorStatus(Eigen::half lhs, Eigen::half rhs,
132                        absl::Span<const int64> multi_index) {
133   return MakeBitwiseErrorStatus<Eigen::half, uint16>(lhs, rhs, multi_index);
134 }
135 template <>
MakeErrorStatus(float lhs,float rhs,absl::Span<const int64> multi_index)136 Status MakeErrorStatus(float lhs, float rhs,
137                        absl::Span<const int64> multi_index) {
138   return MakeBitwiseErrorStatus<float, uint32>(lhs, rhs, multi_index);
139 }
140 template <>
MakeErrorStatus(double lhs,double rhs,absl::Span<const int64> multi_index)141 Status MakeErrorStatus(double lhs, double rhs,
142                        absl::Span<const int64> multi_index) {
143   return MakeBitwiseErrorStatus<double, uint64>(lhs, rhs, multi_index);
144 }
145 template <>
MakeErrorStatus(complex64 lhs,complex64 rhs,absl::Span<const int64> multi_index)146 Status MakeErrorStatus(complex64 lhs, complex64 rhs,
147                        absl::Span<const int64> multi_index) {
148   if (!CompareEqual<float>(lhs.real(), rhs.real(), multi_index)) {
149     return MakeErrorStatus(lhs.real(), rhs.real(), multi_index);
150   }
151   return MakeErrorStatus(lhs.imag(), rhs.imag(), multi_index);
152 }
153 template <>
MakeErrorStatus(complex128 lhs,complex128 rhs,absl::Span<const int64> multi_index)154 Status MakeErrorStatus(complex128 lhs, complex128 rhs,
155                        absl::Span<const int64> multi_index) {
156   if (!CompareEqual<double>(lhs.real(), rhs.real(), multi_index)) {
157     return MakeErrorStatus(lhs.real(), rhs.real(), multi_index);
158   }
159   return MakeErrorStatus(lhs.imag(), rhs.imag(), multi_index);
160 }
161 
162 // A recursive function which iterates through every index of expected and
163 // actual literal and compares their values elementwise. Returns true if all
164 // elements are equal. Mismatched must either be:
165 //    - a literal of booleans that has the same shape as expected and actual. In
166 //      this case, each index in mismatched will be set to true if expected does
167 //      not equal actual at that index and false if there are equal.
168 //    - nullptr. In this case, the function will return once any mismatch is
169 //      found between expected and actual.
170 template <typename NativeT>
Equal(LiteralSlice expected,LiteralSlice actual,absl::Span<int64> multi_index,int64 dimension,Literal * mismatched=nullptr)171 Status Equal(LiteralSlice expected, LiteralSlice actual,
172              absl::Span<int64> multi_index, int64 dimension,
173              Literal* mismatched = nullptr) {
174   if (dimension == expected.shape().dimensions_size()) {
175     NativeT expected_value = expected.Get<NativeT>(multi_index);
176     NativeT actual_value = actual.Get<NativeT>(multi_index);
177     bool result =
178         CompareEqual<NativeT>(expected_value, actual_value, multi_index);
179     if (mismatched) {
180       mismatched->Set<bool>(multi_index, !result);
181     }
182     return result ? Status::OK()
183                   : MakeErrorStatus<NativeT>(expected_value, actual_value,
184                                              multi_index);
185   }
186 
187   Status result;
188   for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) {
189     multi_index[dimension] = i;
190     if (mismatched != nullptr) {
191       result.Update(Equal<NativeT>(expected, actual, multi_index, dimension + 1,
192                                    mismatched));
193     } else {
194       TF_RETURN_IF_ERROR(Equal<NativeT>(expected, actual, multi_index,
195                                         dimension + 1, mismatched));
196     }
197   }
198   return result;
199 }
200 
201 // Gets the total element count.  For tuples, this is not the count of tuple
202 // elements, but the sum of elements of each tuple element.
RecursiveElementCount(const Shape & shape)203 int64 RecursiveElementCount(const Shape& shape) {
204   if (shape.IsTuple()) {
205     const int64 tuple_elements = ShapeUtil::TupleElementCount(shape);
206     int64 total = 0;
207     for (int64 i = 0; i < tuple_elements; ++i) {
208       total += RecursiveElementCount(ShapeUtil::GetTupleElementShape(shape, i));
209     }
210     return total;
211   } else if (shape.IsArray()) {
212     return ShapeUtil::ElementsIn(shape);
213   } else {
214     return 0;
215   }
216 }
217 
218 // Returns whether the given value is infinity.
219 template <typename NativeT>
IsInf(NativeT val)220 bool IsInf(NativeT val) {
221   return Eigen::numext::isinf(val);
222 }
223 // Returns whether the given value is nan.
224 template <typename NativeT>
IsNan(NativeT value)225 bool IsNan(NativeT value) {
226   return Eigen::numext::isnan(value);
227 }
228 
229 // Converts the given floating-point value to a string.
FpValueToString(bfloat16 value)230 string FpValueToString(bfloat16 value) {
231   return absl::StrFormat("%10.4g", static_cast<double>(value));
232 }
233 
FpValueToString(half value)234 string FpValueToString(half value) {
235   return absl::StrFormat("%11.5g", static_cast<double>(value));
236 }
237 
FpValueToString(float value)238 string FpValueToString(float value) {
239   return absl::StrFormat("%15.9g", static_cast<double>(value));
240 }
241 
FpValueToString(double value)242 string FpValueToString(double value) {
243   return absl::StrFormat("%24.17g", value);
244 }
245 
FpValueToString(complex64 value)246 string FpValueToString(complex64 value) {
247   return absl::StrCat(FpValueToString(value.real()), " + ",
248                       FpValueToString(value.imag()));
249 }
250 
FpValueToString(complex128 value)251 string FpValueToString(complex128 value) {
252   return absl::StrCat(FpValueToString(value.real()), " + ",
253                       FpValueToString(value.imag()));
254 }
255 
256 // A wrapper of std::abs to include data types that are not supported by
257 // std::abs, in particular, bfloat16 and half.
258 template <typename NativeT>
FpAbsoluteValue(NativeT value)259 double FpAbsoluteValue(NativeT value) {
260   return std::abs(value);
261 }
262 
263 template <>
FpAbsoluteValue(bfloat16 value)264 double FpAbsoluteValue(bfloat16 value) {
265   return FpAbsoluteValue<float>(static_cast<float>(value));
266 }
267 
268 template <>
FpAbsoluteValue(half value)269 double FpAbsoluteValue(half value) {
270   return FpAbsoluteValue<float>(static_cast<float>(value));
271 }
272 
273 // Helper class for comparing floating-point literals within an error bound.
274 template <typename NativeT>
275 class NearComparator {
276  public:
277   // Compares the two array literals elementwise and returns a comparison
278   // result. The comparison is ok() if all actual and expected elements are
279   // within the given error bound. In case of error, the status contains a
280   // detailed message about the discrepancy.
Compare(const LiteralSlice & expected,const LiteralSlice & actual,const ShapeIndex & shape_index,ErrorSpec error,bool detailed_message,const MiscompareCallback & miscompare_callback)281   static Status Compare(const LiteralSlice& expected,
282                         const LiteralSlice& actual,
283                         const ShapeIndex& shape_index, ErrorSpec error,
284                         bool detailed_message,
285                         const MiscompareCallback& miscompare_callback) {
286     NearComparator<NativeT> comparator(expected, actual, shape_index, error,
287                                        detailed_message, miscompare_callback);
288     return comparator.Run();
289   }
290 
291  private:
292   // Data structure encapsulating metadata about a single element mismatch.
293   struct Mismatch {
294     NativeT actual;
295     NativeT expected;
296     double rel_error;
297     double abs_error;
298 
299     // The linear index of the failure within the shape. This linear index is
300     // from the 'actual' literal.
301     int64 linear_index;
302 
operator <xla::literal_comparison::__anon1e4119ae0111::NearComparator::Mismatch303     bool operator<(const Mismatch& other) const {
304       return rel_error < other.rel_error;
305     }
306 
ToStringxla::literal_comparison::__anon1e4119ae0111::NearComparator::Mismatch307     string ToString(const Shape& shape) const {
308       return absl::StrFormat(
309           "actual %s, expected %s, index %s, rel error %8.3g, abs error %8.3g",
310           FpValueToString(actual), FpValueToString(expected),
311           LiteralUtil::MultiIndexAsString(
312               IndexUtil::LinearIndexToMultidimensionalIndex(shape,
313                                                             linear_index)),
314           rel_error, abs_error);
315     }
316   };
317 
NearComparator(const LiteralSlice & expected,const LiteralSlice & actual,const ShapeIndex & shape_index,ErrorSpec error,bool detailed_message,const MiscompareCallback & miscompare_callback)318   NearComparator(const LiteralSlice& expected, const LiteralSlice& actual,
319                  const ShapeIndex& shape_index, ErrorSpec error,
320                  bool detailed_message,
321                  const MiscompareCallback& miscompare_callback)
322       : expected_(expected),
323         actual_(actual),
324         shape_index_(shape_index),
325         error_(error),
326         detailed_message_(detailed_message),
327         miscompare_callback_(miscompare_callback),
328         abs_value_buckets_(kAbsValueBucketBounds.size() - 1, {0, 0}),
329         abs_error_buckets_(kErrorBucketBounds.size(), 0),
330         rel_error_buckets_(kErrorBucketBounds.size(), 0) {}
331 
332   // Runs the comparison between expected and actual literals.
Run()333   Status Run() {
334     // If the shapes mismatch, we simply fail the expectation instead of
335     // printing out data, as it's a type error rather than a value error.
336     TF_RETURN_IF_ERROR(EqualShapes(expected_.shape(), actual_.shape()));
337     if (!expected_.shape().IsArray()) {
338       return InvalidArgument("Expected array shape; got %s.",
339                              ShapeUtil::HumanString(expected_.shape()));
340     }
341 
342     mismatches_ = Literal(ShapeUtil::ChangeElementType(actual_.shape(), PRED));
343     mismatches_.PopulateWithValue(false);
344 
345     CompareLiterals();
346 
347     if (num_mismatches_ == 0) {
348       return Status::OK();
349     } else if (!VLOG_IS_ON(1) && miscompare_callback_ != nullptr) {
350       miscompare_callback_(expected_, actual_, mismatches_, shape_index_);
351     }
352     return InvalidArgument("%s", ErrorMessage());
353   }
354 
355   // Insert the given absolute value into the absolute value bucket vector. The
356   // bounds of the buckets are given by kAbsValueBucketBounds.
UpdateAbsValueBucket(NativeT value,bool is_mismatch)357   void UpdateAbsValueBucket(NativeT value, bool is_mismatch) {
358     // Adjust the bucket containing the absolute values of the 'actual'
359     // elements.
360     const double abs_value = FpAbsoluteValue(value);
361     for (int i = 0; i < abs_value_buckets_.size(); ++i) {
362       if (i == abs_value_buckets_.size() - 1 ||
363           (abs_value >= kAbsValueBucketBounds[i] &&
364            abs_value < kAbsValueBucketBounds[i + 1])) {
365         // The first value of the pair is the count of elements in the bucket,
366         // the second is the count of mismatches in the bucket.
367         abs_value_buckets_[i].first++;
368         if (is_mismatch) {
369           abs_value_buckets_[i].second++;
370         }
371         return;
372       }
373     }
374   }
375 
376   // Insert the given error into the given error bucket vector.
UpdateErrorBucket(double error,absl::Span<int64> error_buckets)377   void UpdateErrorBucket(double error, absl::Span<int64> error_buckets) {
378     CHECK_EQ(error_buckets.size(), kErrorBucketBounds.size());
379     for (int i = 0; i < error_buckets.size(); ++i) {
380       if (error >= kErrorBucketBounds[i]) {
381         error_buckets[i]++;
382       }
383     }
384   }
385 
386   // Compares the two given elements from the expected and actual literals at
387   // the given literal_index and keeps track of various mismatch statistics.
388   template <typename T>
CompareValues(T expected,T actual,int64 linear_index)389   void CompareValues(T expected, T actual, int64 linear_index) {
390     double abs_error;
391     double rel_error;
392     if (CompareEqual<T>(expected, actual, {linear_index})) {
393       abs_error = 0;
394       rel_error = 0;
395     } else if (IsNan(expected) || IsNan(actual)) {
396       if ((!error_.relaxed_nans && IsNan(expected) != IsNan(actual)) ||
397           (error_.relaxed_nans && !IsNan(expected) && IsNan(actual))) {
398         num_nan_mismatches_++;
399         // A nan mismatch is considered to have infinite error. rel_error is
400         // used for sorting a std::set of the top mismatches, and a nan value
401         // here will result in undefined behavior because nan's do not satisfy
402         // the strict weak ordering requirement of std containers.
403         abs_error = std::numeric_limits<float>::infinity();
404         rel_error = std::numeric_limits<float>::infinity();
405       } else {
406         abs_error = 0;
407         rel_error = 0;
408       }
409     } else if (IsInf(actual) && !IsInf(expected) && error_.fewer_infs_ok) {
410       // `fewer_infs_ok` gives us the option of comparing as though `actual`
411       // were float_max/min rather than inf.
412       T actual_finite = actual > T{0} ? std::numeric_limits<T>::max()
413                                       : std::numeric_limits<T>::lowest();
414       abs_error = FpAbsoluteValue(actual_finite - expected);
415 
416       // Avoid division by 0 even though it's well-defined because ubsan can be
417       // configured to treat this as a fatal error.
418       if (expected != T{0}) {
419         rel_error = abs_error / FpAbsoluteValue(expected);
420       } else {
421         rel_error = std::numeric_limits<float>::infinity();
422       }
423     } else if (IsInf(expected) || IsInf(actual)) {
424       // If either the expected or actual value is infinity but not both,
425       // then both absolute and relative error are regarded as infinity.
426       CHECK(!CompareEqual(expected, actual, {linear_index}));
427       abs_error = std::numeric_limits<float>::infinity();
428       rel_error = std::numeric_limits<float>::infinity();
429     } else {
430       abs_error = FpAbsoluteValue(actual - expected);
431 
432       // Avoid division by 0 even though it's well-defined because ubsan can be
433       // configured to treat this as a fatal error.
434       if (expected != T{0}) {
435         rel_error = abs_error / FpAbsoluteValue(expected);
436       } else {
437         rel_error = std::numeric_limits<float>::infinity();
438       }
439     }
440     const bool is_abs_mismatch = abs_error > error_.abs;
441     const bool is_rel_mismatch = rel_error > error_.rel;
442     const bool is_mismatch = is_abs_mismatch && is_rel_mismatch;
443 
444     // Update the error of the relative bucket only if the *absolute* error
445     // bound is exceeded and vice versa.
446     if (is_abs_mismatch) {
447       num_abs_mismatches_++;
448       UpdateErrorBucket(rel_error, absl::MakeSpan(rel_error_buckets_));
449     }
450     if (is_rel_mismatch) {
451       num_rel_mismatches_++;
452       UpdateErrorBucket(abs_error, absl::MakeSpan(abs_error_buckets_));
453     }
454 
455     UpdateAbsValueBucket(actual, is_mismatch);
456 
457     if (!is_mismatch) {
458       return;
459     }
460 
461     num_mismatches_++;
462 
463     // Keep track of the kTopRelativeErrorCount relative error mismatches.
464     if (top_rel_mismatches_.size() < kTopRelativeErrorCount ||
465         rel_error > top_rel_mismatches_.begin()->rel_error) {
466       Mismatch mismatch = {actual, expected, rel_error, abs_error,
467                            linear_index};
468       top_rel_mismatches_.insert(mismatch);
469       if (top_rel_mismatches_.size() > kTopRelativeErrorCount) {
470         top_rel_mismatches_.erase(top_rel_mismatches_.begin());
471       }
472     }
473 
474     mismatches_.data<bool>()[linear_index] = true;
475   }
476 
477   // For complex types, we compare real and imaginary parts individually.
CompareValues(complex64 expected,complex64 actual,int64 linear_index)478   void CompareValues(complex64 expected, complex64 actual, int64 linear_index) {
479     const auto both_parts_mismatch = num_mismatches_ + 2;
480     CompareValues<float>(expected.real(), actual.real(), linear_index);
481     CompareValues<float>(expected.imag(), actual.imag(), linear_index);
482     if (num_mismatches_ == both_parts_mismatch) {
483       // The mismatch counter had been incremented by each CompareValues() call,
484       // which means that both real and imaginary parts of the passed-in complex
485       // values are different. However, the counter should reflect a single
486       // mismatch between these complex values.
487       num_mismatches_--;
488     }
489   }
490 
CompareValues(complex128 expected,complex128 actual,int64 linear_index)491   void CompareValues(complex128 expected, complex128 actual,
492                      int64 linear_index) {
493     const auto both_parts_mismatch = num_mismatches_ + 2;
494     CompareValues<double>(expected.real(), actual.real(), linear_index);
495     CompareValues<double>(expected.imag(), actual.imag(), linear_index);
496     if (num_mismatches_ == both_parts_mismatch) {
497       // The mismatch counter had been incremented by each CompareValues() call,
498       // which means that both real and imaginary parts of the passed-in complex
499       // values are different. However, the counter should reflect a single
500       // mismatch between these complex values.
501       num_mismatches_--;
502     }
503   }
504 
505   // Compares the two literals elementwise.
CompareLiterals()506   void CompareLiterals() {
507     // Fast path optimization for the case were layouts match.
508     if (LayoutUtil::Equal(actual_.shape().layout(),
509                           expected_.shape().layout())) {
510       absl::Span<const NativeT> expected_data = expected_.data<NativeT>();
511       absl::Span<const NativeT> actual_data = actual_.data<NativeT>();
512       const int64 len = expected_data.size();
513       for (int64 i = 0; i < len; ++i) {
514         CompareValues(expected_data[i], actual_data[i], i);
515       }
516       return;
517     }
518     std::vector<int64> multi_index(actual_.shape().rank(), 0);
519     CompareLiteralsSlow(0, &multi_index);
520   }
521 
522   // Slow path for CompareLiterals when 'actual' and 'expected' literals have
523   // different layouts. In this case, multidimensional indices are constructed
524   // and indexed for each element.
CompareLiteralsSlow(int64 dimension,std::vector<int64> * multi_index)525   void CompareLiteralsSlow(int64 dimension, std::vector<int64>* multi_index) {
526     if (dimension == multi_index->size()) {
527       CompareValues(expected_.Get<NativeT>(*multi_index),
528                     actual_.Get<NativeT>(*multi_index),
529                     IndexUtil::MultidimensionalIndexToLinearIndex(
530                         actual_.shape(), *multi_index));
531     } else {
532       for (int64 i = 0; i < expected_.shape().dimensions(dimension); ++i) {
533         (*multi_index)[dimension] = i;
534         CompareLiteralsSlow(dimension + 1, multi_index);
535       }
536     }
537   }
538 
539   // Returns an error message string with a detailed breakdown of the
540   // mismatches. Called after calling Run().
ErrorMessage()541   string ErrorMessage() {
542     string out;
543     int64 element_count = ShapeUtil::ElementsIn(actual_.shape());
544 
545     auto percent_string = [](float a, float b) {
546       float pct = b == 0.0 ? 0.0 : 100.0 * a / b;
547       return absl::StrFormat("%0.4f%%", pct);
548     };
549 
550     StrAppendFormat(
551         &out,
552         "\nMismatch count %d (%s) in shape %s (%d elements), abs bound "
553         "%g, rel bound %g\n",
554         num_mismatches_, percent_string(num_mismatches_, element_count),
555         ShapeUtil::HumanString(actual_.shape()),
556         ShapeUtil::ElementsIn(actual_.shape()), error_.abs, error_.rel);
557     if (num_nan_mismatches_ > 0) {
558       StrAppend(&out, "nan mismatches ", num_nan_mismatches_, "\n");
559     }
560     StrAppendFormat(&out, "Top relative error mismatches:\n");
561     for (auto it = top_rel_mismatches_.rbegin();
562          it != top_rel_mismatches_.rend(); ++it) {
563       StrAppend(&out, "  ", it->ToString(actual_.shape()), "\n");
564     }
565 
566     if (!detailed_message_) {
567       return out;
568     }
569 
570     StrAppend(&out, "Absolute magnitude breakdown of actual values:\n");
571     CHECK_EQ(abs_value_buckets_.size() + 1, kAbsValueBucketBounds.size());
572     for (int i = 0; i < abs_value_buckets_.size(); ++i) {
573       const int64 bucket_size = abs_value_buckets_[i].first;
574       const int64 bucket_mismatches = abs_value_buckets_[i].second;
575       string mismatch_str =
576           bucket_mismatches > 0
577               ? absl::StrFormat(", mismatches %d", bucket_mismatches)
578               : "";
579       StrAppendFormat(&out, "  %-6g <= x < %-6g : %7d (%9s)%s\n",
580                       kAbsValueBucketBounds[i], kAbsValueBucketBounds[i + 1],
581                       bucket_size, percent_string(bucket_size, element_count),
582                       mismatch_str);
583     }
584 
585     auto print_accum_buckets = [&](const string& header, int64 total,
586                                    absl::Span<const int64> buckets) {
587       StrAppend(&out, header, ":\n");
588       StrAppendFormat(&out, "  <  %-6g : %7d (%s)\n", kErrorBucketBounds[0],
589                       total - buckets[0],
590                       percent_string(total - buckets[0], total));
591       CHECK_EQ(buckets.size(), kErrorBucketBounds.size());
592       for (int i = 0; i < kErrorBucketBounds.size(); ++i) {
593         StrAppendFormat(&out, "  >= %-6g : %7d (%s)\n", kErrorBucketBounds[i],
594                         buckets[i], percent_string(buckets[i], total));
595       }
596     };
597     StrAppendFormat(&out, "Elements exceeding abs error bound %g: %d (%s)\n",
598                     error_.abs, num_abs_mismatches_,
599                     percent_string(num_abs_mismatches_, element_count));
600     print_accum_buckets(
601         "Relative error breakdown of elements exceeding abs error bound",
602         num_abs_mismatches_, rel_error_buckets_);
603     StrAppendFormat(&out, "Elements exceeding rel error bound %g: %d (%s)\n",
604                     error_.rel, num_rel_mismatches_,
605                     percent_string(num_rel_mismatches_, element_count));
606     print_accum_buckets(
607         "Absolute error breakdown of elements exceeding rel error bound",
608         num_rel_mismatches_, abs_error_buckets_);
609     return out;
610   }
611 
612   // 'actual' and 'expected' literals being compared.
613   LiteralSlice expected_;
614   LiteralSlice actual_;
615 
616   // The shape index of the LiteralSlice that is being compared.
617   ShapeIndex shape_index_;
618 
619   // The error bounds of the comparison.
620   ErrorSpec error_;
621 
622   // Whether to include detailed breakdown of mismatches in the error message.
623   bool detailed_message_;
624 
625   // Callback to invoke on miscompare.
626   MiscompareCallback miscompare_callback_;
627 
628   // Number of element mismatches encountered so far.
629   int64 num_mismatches_ = 0;
630 
631   // Number of elements with a nan mismatch.
632   int64 num_nan_mismatches_ = 0;
633 
634   // Number of elements which exceed the absolute/relative error bound.
635   int64 num_abs_mismatches_ = 0;
636   int64 num_rel_mismatches_ = 0;
637 
638   // A Literal containing which elements did not match in the expected and
639   // actual literals. mismatches_ contains PREDs and is of the same sizes as
640   // the comparison literals.
641   Literal mismatches_;
642 
643   // The number of mismatches to report in the output, sorted by relative error
644   // magnitude.
645   static constexpr int64 kTopRelativeErrorCount = 5;
646 
647   // The set of mismatches with the largest relative error. The size of this set
648   // is bounded by kTopRelativeErrorCount.
649   std::multiset<Mismatch> top_rel_mismatches_;
650 
651   // Actual values are bucketed by absolute value. kAbsValueBucketBounds is the
652   // bounds of these buckets. abs_value_buckets_ contains a pair for each
653   // bucket: the element count and failure count.
654   static constexpr std::array<float, 7> kAbsValueBucketBounds = {
655       0.0, 0.0001, 0.001, 0.01, 0.1, 1, std::numeric_limits<float>::infinity()};
656   std::vector<std::pair<int64, int64>> abs_value_buckets_;
657 
658   // Buckets for relative and absolute errors. The relative error buckets only
659   // contains those elements which exceed the *absolute* error bound, and vice
660   // versa. This makes it easy to see the effect of adjusting the relative (or
661   // absolute) error bound on the success of the comparison. kErrorBucketBounds
662   // are the lower bounds of the buckets in both vectors. The error buckets are
663   // a cumulative distribution so an error value may appear in more than one
664   // bucket. For example an error value of 0.003 may appear in the buckets
665   // bounded by 0.01, 0.1, and 1.0.
666   static constexpr std::array<float, 5> kErrorBucketBounds = {0.0001, 0.001,
667                                                               0.01, 0.1, 1};
668   std::vector<int64> abs_error_buckets_;
669   std::vector<int64> rel_error_buckets_;
670 };
671 
672 template <typename NativeT>
673 constexpr std::array<float, 7> NearComparator<NativeT>::kAbsValueBucketBounds;
674 template <typename NativeT>
675 constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
676 
EqualHelper(const LiteralSlice & expected,const LiteralSlice & actual,const ShapeIndex & shape_index,const MiscompareCallback & miscompare_callback)677 Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual,
678                    const ShapeIndex& shape_index,
679                    const MiscompareCallback& miscompare_callback) {
680   TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape()));
681 
682   Status result;
683   if (expected.shape().IsTuple()) {
684     ShapeIndex next_index = shape_index;
685     for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
686       next_index.push_back(i);
687       Status tuple_result =
688           EqualHelper(LiteralSlice(expected, {i}), LiteralSlice(actual, {i}),
689                       next_index, miscompare_callback);
690       if (miscompare_callback) {
691         result.Update(tuple_result);
692       } else {
693         TF_RETURN_IF_ERROR(tuple_result);
694       }
695       next_index.pop_back();
696     }
697   } else {
698     std::vector<int64> multi_index(expected.shape().dimensions_size(), 0);
699     auto index = absl::MakeSpan(multi_index);
700 
701     Shape unequal_shape = ShapeUtil::MakeShape(PrimitiveType::PRED,
702                                                expected.shape().dimensions());
703     Literal miscompared(unequal_shape);
704     Literal* miscompared_ptr =
705         (miscompare_callback == nullptr ? nullptr : &miscompared);
706 
707     switch (expected.shape().element_type()) {
708       case PRED:
709         result = Equal<bool>(expected, actual, index, 0, miscompared_ptr);
710         break;
711       case S8:
712         result = Equal<int8>(expected, actual, index, 0, miscompared_ptr);
713         break;
714       case S16:
715         result = Equal<int16>(expected, actual, index, 0, miscompared_ptr);
716         break;
717       case S32:
718         result = Equal<int32>(expected, actual, index, 0, miscompared_ptr);
719         break;
720       case S64:
721         result = Equal<int64>(expected, actual, index, 0, miscompared_ptr);
722         break;
723       case U8:
724         result = Equal<uint8>(expected, actual, index, 0, miscompared_ptr);
725         break;
726       case U16:
727         result = Equal<uint16>(expected, actual, index, 0, miscompared_ptr);
728         break;
729       case U32:
730         result = Equal<uint32>(expected, actual, index, 0, miscompared_ptr);
731         break;
732       case U64:
733         result = Equal<uint64>(expected, actual, index, 0, miscompared_ptr);
734         break;
735       case BF16:
736         result = Equal<bfloat16>(expected, actual, index, 0, miscompared_ptr);
737         break;
738       case F16:
739         result = Equal<half>(expected, actual, index, 0, miscompared_ptr);
740         break;
741       case F32:
742         result = Equal<float>(expected, actual, index, 0, miscompared_ptr);
743         break;
744       case F64:
745         result = Equal<double>(expected, actual, index, 0, miscompared_ptr);
746         break;
747       case C64:
748         result = Equal<complex64>(expected, actual, index, 0, miscompared_ptr);
749         break;
750       case C128:
751         result = Equal<complex128>(expected, actual, index, 0, miscompared_ptr);
752         break;
753       case TOKEN:
754         // Tokens have no on-device representation and are trivially equal.
755         return Status::OK();
756       default:
757         LOG(FATAL) << "Unsupported primitive type: "
758                    << PrimitiveType_Name(expected.shape().element_type());
759     }
760 
761     if (!result.ok() && miscompare_callback) {
762       miscompare_callback(expected, actual, LiteralSlice(miscompared),
763                           shape_index);
764     }
765   }
766 
767   return result;
768 }
769 
770 // Helper function for comparing two literals for nearness. Handles tuple-shapes
771 // via recursion. shape_index is the ShapeIndex of expected (or actual)
772 // currently being compared.
NearHelper(const LiteralSlice & expected,const LiteralSlice & actual,const ShapeIndex & shape_index,const ErrorSpec & error,absl::optional<bool> detailed_message,const MiscompareCallback & miscompare_callback)773 Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
774                   const ShapeIndex& shape_index, const ErrorSpec& error,
775                   absl::optional<bool> detailed_message,
776                   const MiscompareCallback& miscompare_callback) {
777   TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape()));
778 
779   if (expected.shape().IsTuple()) {
780     Status return_status;
781     for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
782       const auto expected_element = LiteralSlice(expected, {i});
783       const auto actual_element = LiteralSlice(actual, {i});
784       ShapeIndex element_index = shape_index;
785       element_index.push_back(i);
786       Status element_result =
787           NearHelper(expected_element, actual_element, element_index, error,
788                      detailed_message, miscompare_callback);
789       if (!element_result.ok()) {
790         element_result = InvalidArgument("Array at shape index %s, %s",
791                                          element_index.ToString(),
792                                          element_result.error_message());
793         if (return_status.ok()) {
794           return_status = element_result;
795         } else {
796           return_status =
797               AppendStatus(return_status, element_result.error_message());
798         }
799       }
800     }
801     if (!return_status.ok() && shape_index.empty()) {
802       // Emit a top-level error message containing the top-level shape in case
803       // of mismatch.
804       int64 total_elements = RecursiveElementCount(actual.shape());
805       return_status =
806           InvalidArgument("\nMismatches in shape %s (%d elements):\n%s",
807                           ShapeUtil::HumanString(actual.shape()),
808                           total_elements, return_status.error_message());
809     }
810     return return_status;
811   }
812 
813   if (ShapeUtil::ElementIsFloating(expected.shape()) ||
814       ShapeUtil::ElementIsComplex(expected.shape())) {
815     bool use_detailed_message = detailed_message.value_or(
816         ShapeUtil::ElementsIn(expected.shape()) >= 64);
817     switch (expected.shape().element_type()) {
818       case BF16:
819         return NearComparator<bfloat16>::Compare(expected, actual, shape_index,
820                                                  error, use_detailed_message,
821                                                  miscompare_callback);
822         break;
823       case F16:
824         return NearComparator<half>::Compare(expected, actual, shape_index,
825                                              error, use_detailed_message,
826                                              miscompare_callback);
827         break;
828       case F32:
829         return NearComparator<float>::Compare(expected, actual, shape_index,
830                                               error, use_detailed_message,
831                                               miscompare_callback);
832         break;
833       case F64:
834         return NearComparator<double>::Compare(expected, actual, shape_index,
835                                                error, use_detailed_message,
836                                                miscompare_callback);
837         break;
838       case C64:
839         return NearComparator<complex64>::Compare(expected, actual, shape_index,
840                                                   error, use_detailed_message,
841                                                   miscompare_callback);
842         break;
843       case C128:
844         return NearComparator<complex128>::Compare(
845             expected, actual, shape_index, error, use_detailed_message,
846             miscompare_callback);
847         break;
848       default:
849         LOG(FATAL) << "Unsupported primitive type in near comparator: "
850                    << PrimitiveType_Name(expected.shape().element_type())
851                    << ". Must be floating-point type.";
852     }
853   }
854 
855   // Non-floating point, non-tuple literal.
856   return EqualHelper(expected, actual, shape_index, miscompare_callback);
857 }
858 
859 }  // namespace
860 
EqualShapes(const Shape & expected,const Shape & actual)861 Status EqualShapes(const Shape& expected, const Shape& actual) {
862   if (expected.element_type() != actual.element_type()) {
863     return InvalidArgument("element type mismatch, want: %s got %s",
864                            ShapeUtil::HumanString(expected),
865                            ShapeUtil::HumanString(actual));
866   }
867   if (expected.IsTuple()) {
868     if (ShapeUtil::TupleElementCount(expected) !=
869         ShapeUtil::TupleElementCount(actual)) {
870       return InvalidArgument(
871           "want tuple element count: %d got tuple element count: %d",
872           ShapeUtil::TupleElementCount(expected),
873           ShapeUtil::TupleElementCount(actual));
874     }
875     for (int i = 0; i < expected.tuple_shapes_size(); ++i) {
876       Status result =
877           EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i));
878       if (!result.ok()) {
879         return AppendStatus(result, StrCat("mismatch in tuple index", i));
880       }
881     }
882   } else if (expected.IsArray()) {
883     if (expected.rank() != actual.rank()) {
884       return InvalidArgument("want rank of %s got rank of %s",
885                              ShapeUtil::HumanString(expected),
886                              ShapeUtil::HumanString(actual));
887     }
888     if (expected.element_type() != actual.element_type()) {
889       return InvalidArgument("mismatch in primitive type %s vs %s",
890                              PrimitiveType_Name(expected.element_type()),
891                              PrimitiveType_Name(actual.element_type()));
892     }
893     if (expected.dimensions_size() != actual.dimensions_size()) {
894       return InvalidArgument("want dimensions_size %d got dimensions_size %d",
895                              expected.dimensions_size(),
896                              actual.dimensions_size());
897     }
898     for (int i = 0; i < expected.dimensions_size(); ++i) {
899       if (expected.dimensions(i) != actual.dimensions(i)) {
900         return InvalidArgument(
901             "mismatch in dimension #%d expected: %s actual: %s", i,
902             ShapeUtil::HumanString(expected), ShapeUtil::HumanString(actual));
903       }
904     }
905   }
906   // Non-array, non-tuple shapes are trivially equivalent.
907   return Status::OK();
908 }
909 
910 namespace {
911 
912 // If result is an error, extend the error message with the expected and actual
913 // literals.
EmitLiteralsInErrorMessage(const Status & result,const LiteralSlice & expected,const LiteralSlice & actual)914 Status EmitLiteralsInErrorMessage(const Status& result,
915                                   const LiteralSlice& expected,
916                                   const LiteralSlice& actual) {
917   if (result.ok()) {
918     return result;
919   }
920   return InvalidArgument("%s\n\nExpected literal:\n%s\n\nActual literal:\n%s",
921                          result.error_message(), ToStringTruncated(expected),
922                          ToStringTruncated(actual));
923 }
924 
925 }  // namespace
926 
Equal(const LiteralSlice & expected,const LiteralSlice & actual)927 Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) {
928   VLOG(1) << "expected:";
929   XLA_VLOG_LINES(1, expected.ToString());
930   VLOG(1) << "actual:";
931   XLA_VLOG_LINES(1, actual.ToString());
932   Status result = EqualHelper(expected, actual, {}, nullptr);
933   return EmitLiteralsInErrorMessage(result, expected, actual);
934 }
935 
Near(const LiteralSlice & expected,const LiteralSlice & actual,const ErrorSpec & error,absl::optional<bool> detailed_message,const MiscompareCallback & miscompare_callback)936 Status Near(const LiteralSlice& expected, const LiteralSlice& actual,
937             const ErrorSpec& error, absl::optional<bool> detailed_message,
938             const MiscompareCallback& miscompare_callback) {
939   VLOG(1) << "Expected literal:";
940   XLA_VLOG_LINES(1, expected.ToString());
941   VLOG(1) << "Actual literal:";
942   XLA_VLOG_LINES(1, actual.ToString());
943   Status result = NearHelper(expected, actual, /*shape_index=*/{}, error,
944                              detailed_message, miscompare_callback);
945   return EmitLiteralsInErrorMessage(result, expected, actual);
946 }
947 
ToStringTruncated(const LiteralSlice & literal)948 string ToStringTruncated(const LiteralSlice& literal) {
949   return RecursiveElementCount(literal.shape()) < 1000
950              ? literal.ToString()
951              : "[TRUNCATED, Literal with more than 1000 values]";
952 }
953 
954 }  // namespace literal_comparison
955 }  // namespace xla
956