1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h"
17 
18 #include <atomic>
19 
20 #include "absl/strings/str_split.h"
21 #include "llvm/ADT/StringRef.h"
22 #include "llvm/Support/FormatVariadic.h"
23 #include "mlir/IR/Operation.h"  // from @llvm-project
24 #include "mlir/Pass/Pass.h"  // from @llvm-project
25 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
26 
27 namespace tensorflow {
28 
29 // Counter is used as a prefix for filenames.
30 static std::atomic<int> log_counter(0);
31 
BridgeLoggerConfig(bool print_module_scope,bool print_after_only_on_change)32 BridgeLoggerConfig::BridgeLoggerConfig(bool print_module_scope,
33                                        bool print_after_only_on_change)
34     : mlir::PassManager::IRPrinterConfig(print_module_scope,
35                                          print_after_only_on_change) {
36   const char* log_pass_patterns = getenv("MLIR_BRIDGE_LOG_PASS_PATTERNS");
37   if (log_pass_patterns) {
38     log_pass_patterns_ =
39         absl::StrSplit(log_pass_patterns, ',', absl::SkipWhitespace());
40   }
41 }
42 
43 // Logs op to file with name of format
44 // `<log_counter>_mlir_bridge_<pass_name>_<file_suffix>.mlir`.
Log(BridgeLoggerConfig::PrintCallbackFn print_callback,mlir::Pass * pass,mlir::Operation * op,llvm::StringRef file_suffix)45 inline static void Log(BridgeLoggerConfig::PrintCallbackFn print_callback,
46                        mlir::Pass* pass, mlir::Operation* op,
47                        llvm::StringRef file_suffix) {
48   std::string pass_name = pass->getName().str();
49 
50   // Add 4-digit counter as prefix so the order of the passes is obvious.
51   std::string name = llvm::formatv("{0,0+4}_mlir_bridge_{1}_{2}", log_counter++,
52                                    pass_name, file_suffix);
53 
54   std::unique_ptr<llvm::raw_ostream> os;
55   std::string filepath;
56   if (CreateFileForDumping(name, &os, &filepath).ok()) {
57     print_callback(*os);
58     LOG(INFO) << "Dumped MLIR module to " << filepath;
59   }
60 }
61 
printBeforeIfEnabled(mlir::Pass * pass,mlir::Operation * operation,PrintCallbackFn print_callback)62 void BridgeLoggerConfig::printBeforeIfEnabled(mlir::Pass* pass,
63                                               mlir::Operation* operation,
64                                               PrintCallbackFn print_callback) {
65   if (should_print(pass)) Log(print_callback, pass, operation, "before");
66 }
67 
printAfterIfEnabled(mlir::Pass * pass,mlir::Operation * operation,PrintCallbackFn print_callback)68 void BridgeLoggerConfig::printAfterIfEnabled(mlir::Pass* pass,
69                                              mlir::Operation* operation,
70                                              PrintCallbackFn print_callback) {
71   if (should_print(pass)) Log(print_callback, pass, operation, "after");
72 }
73 
should_print(mlir::Pass * pass)74 bool BridgeLoggerConfig::should_print(mlir::Pass* pass) {
75   if (log_pass_patterns_.empty()) return true;
76 
77   std::string pass_name = pass->getName().str();
78   for (const auto& pattern : log_pass_patterns_) {
79     if (pass_name.find(pattern) != std::string::npos) {
80       // pattern matches pass
81       return true;
82     }
83   }
84   // no pattern matches pass
85   VLOG(1) << "Not logging pass " << pass_name
86           << " because it does not match any pattern in "
87              "MLIR_BRIDGE_LOG_PASS_PATTERNS";
88   return false;
89 }
90 
printTiming(PrintCallbackFn printCallback)91 void BridgeTimingConfig::printTiming(PrintCallbackFn printCallback) {
92   std::string name = "mlir_bridge_pass_timing.txt";
93   std::unique_ptr<llvm::raw_ostream> os;
94   std::string filepath;
95   if (CreateFileForDumping(name, &os, &filepath).ok()) printCallback(*os);
96 }
97 
98 }  // namespace tensorflow
99