1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
17 
18 #include <functional>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/strings/str_format.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/compiler/xla/service/dump.h"
25 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
26 #include "tensorflow/compiler/xla/service/hlo_proto_util.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/core/platform/logging.h"
31 
32 namespace xla {
33 
34 namespace {
35 
RecordPassStartMetadata(HloModule & module,const std::string & pass_name,const std::string & pipeline_name)36 void RecordPassStartMetadata(HloModule& module, const std::string& pass_name,
37                              const std::string& pipeline_name) {
38   module.metadata()->RecordPassStart();
39   // An HloPassMetadata was just created so Status should always be OK.
40   TF_CHECK_OK(module.metadata()->set_current_pass_name(pass_name));
41   TF_CHECK_OK(module.metadata()->set_current_pass_pipeline_name(pipeline_name));
42 }
43 
RecordPassStartMetadata(HloModuleGroup & module_group,const std::string & pass_name,const std::string & pipeline_name)44 void RecordPassStartMetadata(HloModuleGroup& module_group,
45                              const std::string& pass_name,
46                              const std::string& pipeline_name) {
47   for (HloModule* module : module_group.modules()) {
48     RecordPassStartMetadata(*module, pass_name, pipeline_name);
49   }
50 }
51 
AttemptRecordPassEndMetadata(HloModule & module,const std::string & pass_name,bool module_changed)52 Status AttemptRecordPassEndMetadata(HloModule& module,
53                                     const std::string& pass_name,
54                                     bool module_changed) {
55   // Module id is set here instead of RecordPassStartMetadata because it may
56   // change in the middle of the pass, and we want the final id.
57   TF_RETURN_IF_ERROR(
58       module.metadata()->set_current_pass_module_id(module.unique_id()));
59   TF_RETURN_IF_ERROR(
60       module.metadata()->set_current_pass_module_changed(module_changed));
61   TF_RETURN_IF_ERROR(module.metadata()->RecordPassEnd());
62   return Status::OK();
63 }
64 
RecordPassEndMetadata(HloModule & module,const std::string & pass_name,bool module_changed)65 void RecordPassEndMetadata(HloModule& module, const std::string& pass_name,
66                            bool module_changed) {
67   Status status =
68       AttemptRecordPassEndMetadata(module, pass_name, module_changed);
69   if (!status.ok()) {
70     LOG(FATAL) << status;
71   }
72 }
73 
AttemptRecordPassEndMetadata(HloModuleGroup & module_group,const std::string & pass_name,bool module_changed)74 Status AttemptRecordPassEndMetadata(HloModuleGroup& module_group,
75                                     const std::string& pass_name,
76                                     bool module_changed) {
77   for (HloModule* module : module_group.modules()) {
78     for (HloModule* other_module : module_group.modules()) {
79       TF_RETURN_IF_ERROR(
80           module->metadata()->add_current_pass_module_group_module_id(
81               other_module->unique_id()));
82     }
83     TF_RETURN_IF_ERROR(
84         AttemptRecordPassEndMetadata(*module, pass_name, module_changed));
85   }
86   return Status::OK();
87 }
88 
RecordPassEndMetadata(HloModuleGroup & module_group,const std::string & pass_name,bool module_changed)89 void RecordPassEndMetadata(HloModuleGroup& module_group,
90                            const std::string& pass_name, bool module_changed) {
91   Status status =
92       AttemptRecordPassEndMetadata(module_group, pass_name, module_changed);
93   if (!status.ok()) {
94     LOG(FATAL) << status;
95   }
96 }
97 
SetInstructionMetadata(HloModule & module)98 void SetInstructionMetadata(HloModule& module) {
99   StatusOr<int64> pass_id = module.metadata()->current_pass_id();
100   if (!pass_id.ok()) {
101     LOG(FATAL) << pass_id.status();
102   }
103   for (xla::HloComputation* computation : module.computations()) {
104     for (xla::HloInstruction* instruction : computation->instructions()) {
105       if (instruction->metadata().creation_pass_id() == 0) {
106         instruction->set_creation_pass_id(*pass_id);
107       }
108       if (instruction->metadata().logical_creation_pass_id() == 0) {
109         instruction->set_logical_creation_pass_id(*pass_id);
110       }
111     }
112   }
113 }
114 
SetInstructionMetadata(HloModuleGroup & module_group)115 void SetInstructionMetadata(HloModuleGroup& module_group) {
116   for (HloModule* module : module_group.modules()) {
117     SetInstructionMetadata(*module);
118   }
119 }
120 
121 }  // namespace
122 
123 template <typename HloT>
RunInvariantCheckers(HloT * hlo,absl::string_view after_pass_name)124 Status HloPassPipeline::RunInvariantCheckers(
125     HloT* hlo, absl::string_view after_pass_name) {
126   for (auto& invariant_checker : invariant_checkers_) {
127     VLOG(1) << "    Invariant checker " << invariant_checker->name();
128     StatusOr<bool> changed_status = RunHelper(invariant_checker.get(), hlo);
129     VLOG(1) << "    Invariant checker done " << invariant_checker->name();
130     if (!changed_status.ok()) {
131       VLOG(2) << "Failed invariant check:";
132       XLA_VLOG_LINES(2, hlo->ToString());
133       return Status(changed_status.status().code(),
134                     absl::StrCat(changed_status.status().error_message(),
135                                  "\n\nFailed after ", after_pass_name));
136     }
137     TF_RET_CHECK(!changed_status.ValueOrDie())
138         << "invariant checkers must not change the graph";
139   }
140   return Status::OK();
141 }
142 
143 template <typename HloT>
RunPassesInternal(HloT * hlo,absl::Span<HloPassInterface * const> passes)144 StatusOr<bool> HloPassPipeline::RunPassesInternal(
145     HloT* hlo, absl::Span<HloPassInterface* const> passes) {
146   static constexpr absl::string_view kPipelineStart = "pipeline-start";
147   static constexpr absl::string_view kPipelineEnd = "pipeline-end";
148   std::string pipeline_name = std::string(name());
149 
150   TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, kPipelineStart));
151 
152   RecordPassStartMetadata(*hlo, std::string(kPipelineStart), pipeline_name);
153   SetInstructionMetadata(*hlo);
154   MaybeDumpHloAndSaveFilenames(*hlo,
155                                /*after_pass_name=*/kPipelineStart,
156                                /*before_pass_name=*/passes.empty()
157                                    ? kPipelineEnd
158                                    : passes.front()->name());
159   RecordPassEndMetadata(*hlo, std::string(kPipelineStart),
160                         /*module_changed=*/false);
161 
162   bool changed = false;
163   for (int i = 0; i < passes.size(); i++) {
164     HloPassInterface* pass = passes[i];
165     XLA_SCOPED_LOGGING_TIMER(absl::StrCat("HLO pass: ", pass->name()));
166     std::string pass_name = std::string(pass->name());
167     VLOG(1) << "  HLO pass " << pass_name;
168     VLOG(2) << "  Module hash " << hlo->Hash();
169     if (!pass->IsPassPipeline()) {
170       compilation_stats_->StartPass(pass_name);
171     }
172     RecordPassStartMetadata(*hlo, pass_name, pipeline_name);
173     TF_ASSIGN_OR_RETURN(bool pass_changed, RunHelper(pass, hlo));
174     SetInstructionMetadata(*hlo);
175     MaybeDumpHloAndSaveFilenames(*hlo,
176                                  /*after_pass_name=*/pass_name,
177                                  /*before_pass_name=*/i + 1 >= passes.size()
178                                      ? kPipelineEnd
179                                      : passes[i + 1]->name());
180     RecordPassEndMetadata(*hlo, pass_name, pass_changed);
181     changed |= pass_changed;
182     if (pass_changed) {
183       VLOG(3) << "  Pass caused changes " << pass->name();
184     }
185     TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, pass_name));
186     if (!pass->IsPassPipeline()) {
187       compilation_stats_->EndPass(pass_name);
188     }
189   }
190   return changed;
191 }
192 
GetEnabledPasses(const DebugOptions & debug_options)193 std::vector<HloPassInterface*> HloPassPipeline::GetEnabledPasses(
194     const DebugOptions& debug_options) {
195   if (debug_options.xla_disable_all_hlo_passes()) {
196     VLOG(1) << "*All* passes disabled by --xla_disable_all_hlo_passes.";
197     return {};
198   }
199 
200   absl::flat_hash_set<string> disabled_pass_names(
201       debug_options.xla_disable_hlo_passes().begin(),
202       debug_options.xla_disable_hlo_passes().end());
203 
204   absl::flat_hash_set<string> enabled_pass_names(
205       debug_options.xla_enable_hlo_passes_only().begin(),
206       debug_options.xla_enable_hlo_passes_only().end());
207 
208   if (!disabled_pass_names.empty()) {
209     VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: "
210             << absl::StrJoin(disabled_pass_names, ", ");
211   }
212 
213   if (!enabled_pass_names.empty()) {
214     VLOG(1) << "Passes enabled by --xla_enable_hlo_passes_only: "
215             << absl::StrJoin(enabled_pass_names, ", ");
216   }
217 
218   CHECK(disabled_pass_names.empty() || enabled_pass_names.empty());
219 
220   std::vector<HloPassInterface*> enabled_passes;
221   if (!enabled_pass_names.empty()) {
222     for (auto& pass : passes_) {
223       if (enabled_pass_names.contains(pass->name())) {
224         enabled_passes.push_back(pass.get());
225       }
226     }
227   } else {
228     for (auto& pass : passes_) {
229       if (!disabled_pass_names.contains(pass->name())) {
230         enabled_passes.push_back(pass.get());
231       }
232     }
233   }
234   return enabled_passes;
235 }
236 
MaybeDumpHloAndSaveFilenames(HloModule & module,absl::string_view after_pass_name,absl::string_view before_pass_name)237 void HloPassPipeline::MaybeDumpHloAndSaveFilenames(
238     HloModule& module, absl::string_view after_pass_name,
239     absl::string_view before_pass_name) {
240   for (const std::string& filename : DumpHloModuleBetweenPassesIfEnabled(
241            name(), before_pass_name, after_pass_name, module)) {
242     Status status = module.metadata()->add_current_pass_dump_filename(filename);
243     if (!status.ok()) {
244       LOG(FATAL) << status;
245     }
246   }
247 }
248 
MaybeDumpHloAndSaveFilenames(HloModuleGroup & module_group,absl::string_view after_pass_name,absl::string_view before_pass_name)249 void HloPassPipeline::MaybeDumpHloAndSaveFilenames(
250     HloModuleGroup& module_group, absl::string_view after_pass_name,
251     absl::string_view before_pass_name) {
252   for (HloModule* module : module_group.modules()) {
253     MaybeDumpHloAndSaveFilenames(*module, after_pass_name, before_pass_name);
254   }
255 }
256 
Run(HloModule * module)257 StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
258   run_called_ = true;
259 
260   VLOG(1) << "Running HLO pass pipeline on module " << module->name() << ": "
261           << name();
262 
263   return RunPassesInternal(module,
264                            GetEnabledPasses(module->config().debug_options()));
265 }
266 
RunOnModuleGroup(HloModuleGroup * module_group)267 StatusOr<bool> HloPassPipeline::RunOnModuleGroup(HloModuleGroup* module_group) {
268   run_called_ = true;
269 
270   VLOG(1) << "Running HLO pass pipeline on module group "
271           << module_group->name() << ": " << name();
272 
273   if (module_group->modules().empty()) {
274     VLOG(1) << "Module group is empty. Nothing to do.";
275     return false;
276   }
277 
278   return RunPassesInternal(
279       module_group,
280       GetEnabledPasses(module_group->module(0).config().debug_options()));
281 }
282 
283 }  // namespace xla
284