1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_module_dce.h"
17 
18 #include <deque>
19 #include <unordered_set>
20 
21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
22 #include "tensorflow/compiler/xla/service/hlo_dce.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_liveness_analysis.h"
25 #include "tensorflow/compiler/xla/service/hlo_module.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/status.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/compiler/xla/types.h"
31 #include "tensorflow/compiler/xla/util.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/platform/logging.h"
34 
35 namespace xla {
36 
37 namespace {
38 
RunWhileDCE(HloModule * module,HloLivenessAnalysis * liveness)39 StatusOr<bool> RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) {
40   bool changed = false;
41   for (auto* computation : module->computations()) {
42     for (auto* instruction : computation->instructions()) {
43       if (instruction->opcode() != HloOpcode::kWhile) {
44         continue;
45       }
46 
47       const auto* xla_while = instruction;
48       auto* while_body_comp = xla_while->while_body();
49       auto* while_body_param = while_body_comp->parameter_instruction(0);
50       auto* while_body_root = while_body_comp->root_instruction();
51 
52       if (!xla_while->shape().IsTuple() ||
53           while_body_root->opcode() != HloOpcode::kTuple) {
54         // Only run DCE on tuple-shaped while loops where body root is Tuple,
55         // with no I/O instructions.
56         VLOG(1) << "WhileDCE SKIP while: " << xla_while->ToString();
57         continue;
58       }
59 
60       // Remove dead tuple elements.
61       const int64 tuple_element_count =
62           ShapeUtil::TupleElementCount(xla_while->shape());
63       for (int64 i = 0; i < tuple_element_count; ++i) {
64         if (liveness->IsLive(xla_while, {i})) {
65           continue;
66         }
67         VLOG(1) << "WhileDCE Dead while tuple element."
68                 << " while: " << xla_while->name() << " tuple_index: " << i;
69         // Transform while.body computation to make tuple element at
70         // 'shape_index' as simple pass-through parameter (which candidate
71         // be removed later by simplification pass).
72         HloInstruction* pass_thru_gte = while_body_comp->AddInstruction(
73             HloInstruction::CreateGetTupleElement(
74                 while_body_param->shape().tuple_shapes(i), while_body_param,
75                 i));
76         // Replace while.body.root Tuple operand at 'tuple_index' with
77         // 'pass_thru_gte', making prior operand a dead root (to be cleaned
78         // up with a subsequent DCE pass).
79         TF_RETURN_IF_ERROR(
80             while_body_root->ReplaceOperandWith(i, pass_thru_gte));
81         changed = true;
82       }
83     }
84   }
85   return changed;
86 }
87 
88 }  // namespace
89 
Run(HloModule * module)90 StatusOr<bool> HloModuleDCE::Run(HloModule* module) {
91   VLOG(2) << "Before HloModuleDCE:";
92   XLA_VLOG_LINES(3, module->ToString());
93 
94   std::unique_ptr<HloLivenessAnalysis> liveness;
95   TF_ASSIGN_OR_RETURN(liveness, HloLivenessAnalysis::Run(*module));
96 
97   // Sweep through while instructions, transforming dead while tuple element
98   // computations to pass through tuple values (creating dead roots in while
99   // body computation in the process).
100   TF_ASSIGN_OR_RETURN(bool hlo_module_dce_changed,
101                       RunWhileDCE(module, liveness.get()));
102 
103   // Run HloDCE to clean up any dead code created during HloModuleDCE.
104   HloDCE hlo_dce;
105   TF_ASSIGN_OR_RETURN(bool hlo_dce_changed, hlo_dce.Run(module));
106 
107   VLOG(2) << "After HloModuleDCE:";
108   XLA_VLOG_LINES(3, module->ToString());
109 
110   return hlo_module_dce_changed | hlo_dce_changed;
111 }
112 
113 }  // namespace xla
114