1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_ERROR_UTIL_H_
17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_ERROR_UTIL_H_
18 
19 #include "llvm/Support/SourceMgr.h"
20 #include "llvm/Support/raw_ostream.h"
21 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
22 #include "mlir/IR/Location.h"  // from @llvm-project
23 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
24 #include "tensorflow/core/platform/status.h"
25 
26 // Error utilities for MLIR when interacting with code using Status returns.
27 namespace mlir {
28 
29 // TensorFlow's Status is used for error reporting back to callers.
30 using ::tensorflow::Status;
31 
32 // Diagnostic handler that collects all the diagnostics reported and can produce
33 // a Status to return to callers. This is for the case where MLIR functions are
34 // called from a function that will return a Status: MLIR code still uses the
35 // default error reporting, and the final return function can return the Status
36 // constructed from the diagnostics collected.
37 class StatusScopedDiagnosticHandler : public SourceMgrDiagnosticHandler {
38  public:
39   // Constructs a diagnostic handler in a context. If propagate is true, then
40   // diagnostics reported are also propagated back to the original diagnostic
41   // handler.
42   explicit StatusScopedDiagnosticHandler(MLIRContext* context,
43                                          bool propagate = false);
44   // On destruction error consumption is verified.
45   ~StatusScopedDiagnosticHandler();
46 
47   // Returns whether any errors were reported.
48   bool ok() const;
49 
50   // Returns Status corresponding to the diagnostics reported. This consumes the
51   // diagnostics reported and returns a Status of type Unknown. It is required
52   // to consume the error status, if there is one, before destroying the object.
53   Status ConsumeStatus();
54 
55   // Returns the combination of the passed in status and consumed diagnostics.
56   // This consumes the diagnostics reported and either appends the diagnostics
57   // to the error message of 'status' (if 'status' is already an error state),
58   // or returns an Unknown status (if diagnostics reported), otherwise OK.
59   Status Combine(Status status);
60 
61  private:
62   LogicalResult handler(Diagnostic* diag);
63 
64   // String stream to assemble the final error message.
65   std::string diag_str_;
66   llvm::raw_string_ostream diag_stream_;
67 
68   // A SourceMgr to use for the base handler class.
69   llvm::SourceMgr source_mgr_;
70 
71   // Whether to propagate diagnostics to the old diagnostic handler.
72   bool propagate_;
73 };
74 }  // namespace mlir
75 
76 #endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_ERROR_UTIL_H_
77