1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_BRIDGE_LOGGER_H_ 17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_BRIDGE_LOGGER_H_ 18 19 #include "mlir/IR/Operation.h" // from @llvm-project 20 #include "mlir/Pass/Pass.h" // from @llvm-project 21 #include "mlir/Pass/PassManager.h" // from @llvm-project 22 23 namespace tensorflow { 24 25 // Logger for logging/dumping MLIR modules before and after passes in bridge 26 // targeting TPUs. The passes being logged can be restricted via environment 27 // variable `MLIR_BRIDGE_LOG_PASS_PATTERNS` which is interpreted as a comma- 28 // separated list of strings, and only passes whose name contains any of those 29 // strings as a substring are logged (no regex support). If 30 // `MLIR_BRIDGE_LOG_PASS_PATTERNS` is not defined, then all passes are logged. 31 class BridgeLoggerConfig : public mlir::PassManager::IRPrinterConfig { 32 public: 33 explicit BridgeLoggerConfig(bool print_module_scope = false, 34 bool print_after_only_on_change = true); 35 36 // A hook that may be overridden by a derived config that checks if the IR 37 // of 'operation' should be dumped *before* the pass 'pass' has been 38 // executed. If the IR should be dumped, 'print_callback' should be invoked 39 // with the stream to dump into. 40 void printBeforeIfEnabled(mlir::Pass *pass, mlir::Operation *operation, 41 PrintCallbackFn print_callback) override; 42 43 // A hook that may be overridden by a derived config that checks if the IR 44 // of 'operation' should be dumped *after* the pass 'pass' has been 45 // executed. If the IR should be dumped, 'print_callback' should be invoked 46 // with the stream to dump into. 47 void printAfterIfEnabled(mlir::Pass *pass, mlir::Operation *operation, 48 PrintCallbackFn print_callback) override; 49 50 private: 51 bool should_print(mlir::Pass *pass); 52 53 // Only print passes that match any of these patterns. A pass matches a 54 // pattern if its name contains the pattern as a substring. If 55 // `log_pass_patterns_` is empty, print all passes. 56 std::vector<std::string> log_pass_patterns_; 57 }; 58 59 // Logger for logging/dumping pass pipeline timings after completion. 60 class BridgeTimingConfig : public mlir::PassManager::PassTimingConfig { 61 public: 62 // Hook that control how/where is the output produced 63 void printTiming(PrintCallbackFn printCallback) override; 64 }; 65 66 } // namespace tensorflow 67 68 #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_BRIDGE_LOGGER_H_ 69