1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Library for comparing literals without taking a dependency on testing 17 // libraries. 18 19 #ifndef TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_ 20 #define TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_ 21 22 #include "tensorflow/compiler/xla/error_spec.h" 23 #include "tensorflow/compiler/xla/literal.h" 24 #include "tensorflow/core/lib/core/status.h" 25 26 namespace xla { 27 namespace literal_comparison { 28 29 // Returns ok if the given shapes have the same rank, dimension sizes, and 30 // primitive types. 31 Status EqualShapes(const Shape& expected, const Shape& actual); 32 33 // Returns ok if the expected and actual literals are (bitwise) equal for all 34 // elements in the literal. Also, asserts that the rank, dimensions sizes, and 35 // primitive type are equal. 36 Status Equal(const LiteralSlice& expected, const LiteralSlice& actual); 37 38 using MiscompareCallback = 39 std::function<void(const LiteralSlice& expected, const LiteralSlice& actual, 40 const LiteralSlice& mismatches)>; 41 42 // Inspects whether the expected and actual literals are within the given error 43 // bound for all elements. Also, inspects whether the rank, dimensions sizes, 44 // and dimension bounds are equivalent. 45 // 46 // Tuples are matched recursively. 47 // 48 // When comparing tensors of non-floating-point type, this inspects for exact 49 // equality, ignoring the ErrorSpec. 50 // 51 // If the shape of the literals is neither a complex/floating-point tensor nor a 52 // tuple which contains a complex/floating-point tensor, Near() is equivalent to 53 // Equal(). We don't raise an error in this case, because we want to allow 54 // callers to call Near() even if they have no preconceptions about the shapes 55 // being compared. 56 // 57 // If detailed_message is true, then the error message in the assertion result 58 // will contain a more detailed breakdown of mismatches. By default, we display 59 // a detailed message only for "large" inputs. 60 Status Near(const LiteralSlice& expected, const LiteralSlice& actual, 61 const ErrorSpec& error, absl::optional<bool> detailed_message, 62 const MiscompareCallback& miscompare_callback); 63 64 // Calling ToString on a literal with over 100 million elements takes around 65 // 3 minutes. The utility of printing a literal with >1000 elements is 66 // questionable, especially when writing the Literal proto to disk is orders 67 // of magnitude faster. 68 string ToStringTruncated(const LiteralSlice& literal); 69 70 } // namespace literal_comparison 71 } // namespace xla 72 73 #endif // TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_ 74