1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_BRIDGE_LOGGER_H_
17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_BRIDGE_LOGGER_H_
18 
19 #include "mlir/IR/Operation.h"  // from @llvm-project
20 #include "mlir/Pass/Pass.h"  // from @llvm-project
21 #include "mlir/Pass/PassManager.h"  // from @llvm-project
22 
23 namespace tensorflow {
24 
25 // Logger for logging/dumping MLIR modules before and after passes in bridge
26 // targeting TPUs. The passes being logged can be restricted via environment
27 // variable `MLIR_BRIDGE_LOG_PASS_PATTERNS` which is interpreted as a comma-
28 // separated list of strings, and only passes whose name contains any of those
29 // strings as a substring are logged (no regex support). If
30 // `MLIR_BRIDGE_LOG_PASS_PATTERNS` is not defined, then all passes are logged.
31 class BridgeLoggerConfig : public mlir::PassManager::IRPrinterConfig {
32  public:
33   explicit BridgeLoggerConfig(bool print_module_scope = false,
34                               bool print_after_only_on_change = true);
35 
36   // A hook that may be overridden by a derived config that checks if the IR
37   // of 'operation' should be dumped *before* the pass 'pass' has been
38   // executed. If the IR should be dumped, 'print_callback' should be invoked
39   // with the stream to dump into.
40   void printBeforeIfEnabled(mlir::Pass *pass, mlir::Operation *operation,
41                             PrintCallbackFn print_callback) override;
42 
43   // A hook that may be overridden by a derived config that checks if the IR
44   // of 'operation' should be dumped *after* the pass 'pass' has been
45   // executed. If the IR should be dumped, 'print_callback' should be invoked
46   // with the stream to dump into.
47   void printAfterIfEnabled(mlir::Pass *pass, mlir::Operation *operation,
48                            PrintCallbackFn print_callback) override;
49 
50  private:
51   bool should_print(mlir::Pass *pass);
52 
53   // Only print passes that match any of these patterns. A pass matches a
54   // pattern if its name contains the pattern as a substring. If
55   // `log_pass_patterns_` is empty, print all passes.
56   std::vector<std::string> log_pass_patterns_;
57 };
58 
59 // Logger for logging/dumping pass pipeline timings after completion.
60 class BridgeTimingConfig : public mlir::PassManager::PassTimingConfig {
61  public:
62   // Hook that control how/where is the output produced
63   void printTiming(PrintCallbackFn printCallback) override;
64 };
65 
66 }  // namespace tensorflow
67 
68 #endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_BRIDGE_LOGGER_H_
69