1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/loop_schedule_linearizer.h"
17
18 #include "tensorflow/compiler/xla/service/dump.h"
19 #include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
20
21 namespace xla {
22
23 namespace {
24
25 // Calculate ordering for HLO, for fast online checking of whether adding
26 // additional dependencies would create cycles.
27 struct ComputationInstructionOrdering {
ComputationInstructionOrderingxla::__anon7d309c810111::ComputationInstructionOrdering28 explicit ComputationInstructionOrdering(const HloComputation& computation) {
29 for (const HloInstruction* instr : computation.instructions()) {
30 for (const HloInstruction* control_pred : instr->control_predecessors()) {
31 CHECK(this->InsertEdge(*control_pred, *instr))
32 << "Graph already contained a cycle";
33 }
34
35 for (int op_id = 0; op_id < instr->operand_count(); op_id++) {
36 const HloInstruction* op = instr->operand(op_id);
37 CHECK(this->InsertEdge(*op, *instr))
38 << "Graph already contained a cycle";
39 }
40 }
41 }
42
NodeIdForInstructionxla::__anon7d309c810111::ComputationInstructionOrdering43 int32 NodeIdForInstruction(const HloInstruction& instr) {
44 int32 instruction_id = instr.unique_id();
45 auto it = node_id_to_graph_id.find(instruction_id);
46
47 if (it != node_id_to_graph_id.end()) {
48 return it->second;
49 }
50 int32 node_id = graph_cycles.NewNode();
51 node_id_to_graph_id[instruction_id] = node_id;
52 return node_id;
53 }
54
55 // Returns `false` if adding an edge would have introduced a cycle. Does not
56 // add an edge in that case. Returns `true` otherwise.
InsertEdgexla::__anon7d309c810111::ComputationInstructionOrdering57 bool InsertEdge(const HloInstruction& source, const HloInstruction& dest) {
58 int32 source_id = NodeIdForInstruction(source);
59 int32 dest_id = NodeIdForInstruction(dest);
60 return graph_cycles.InsertEdge(source_id, dest_id);
61 }
62
63 absl::flat_hash_map<int32, int32> node_id_to_graph_id;
64
65 tensorflow::GraphCycles graph_cycles;
66 };
67
68 } // namespace
69
AddControlEdgesForLoopWrites(HloInstruction * xla_while,HloAliasAnalysis & alias_analysis)70 static StatusOr<bool> AddControlEdgesForLoopWrites(
71 HloInstruction* xla_while, HloAliasAnalysis& alias_analysis) {
72 HloDataflowAnalysis& dataflow = alias_analysis.dataflow_analysis();
73 HloComputation* body = xla_while->while_body();
74 HloInstruction* root = body->root_instruction();
75 HloInstruction* input = body->parameter_instruction(0);
76
77 bool changed = false;
78
79 // Compute dependency ordering ourselves. The reason we don't reuse other
80 // computations is because it is hard to extract the underlying graph from
81 // those abstractions.
82 ComputationInstructionOrdering ordering(*body);
83 ShapeTree<bool> indices_to_copy(xla_while->shape());
84
85 for (auto& p : indices_to_copy) {
86 const ShapeIndex& index = p.first;
87
88 if (index.empty()) {
89 continue;
90 }
91
92 if (dataflow.GetValueSet(root, index).values().size() > 1 ||
93 dataflow.GetValueSet(input, index).values().size() > 1) {
94 VLOG(2) << "Index " << index.ToString() << " is associated with multiple "
95 << "values, not attempting to introduce stricter dependencies";
96 } else {
97 HloValue& value_at_root = dataflow.GetUniqueValueAt(root, index);
98 HloValue& value_at_input = dataflow.GetUniqueValueAt(input, index);
99
100 if (value_at_root.shape().IsTuple()) {
101 // TODO(cheshire): For simplicity we currently do not handle nested
102 // tuples, as we haven't seen them in the examples we care about.
103 continue;
104 }
105
106 // TODO(cheshire): This is too conservative and does not take aliasing
107 // into account.
108 HloInstruction* write = value_at_root.defining_instruction();
109
110 for (const HloUse& use : value_at_input.uses()) {
111 HloInstruction* read = use.instruction;
112
113 if (read != write &&
114 value_at_root != value_at_input
115
116 // TODO(cheshire): Parents sometimes differ in case of e.g. nested
117 // loops, where the value is read/written into in the inner loop.
118 // For now we skip this case for simplicity (as the inner loop
119 // performance is more important in any case)
120 && read->parent() == write->parent()) {
121 VLOG(2) << "Inside " << body->name() << ", index "
122 << index.ToString();
123 if (!ordering.InsertEdge(*read, *write)) {
124 VLOG(2) << "Not adding a control dependency from "
125 << read->ToShortString() << " to " << write->ToShortString()
126 << " as it would introduce a cycle";
127 continue;
128 }
129
130 changed |= absl::c_linear_search(read->control_successors(), write);
131
132 // Unless we want a copy, read should happen before write.
133 TF_RETURN_IF_ERROR(read->AddControlDependencyTo(write));
134 VLOG(2) << "Adding dependency: " << read->ToShortString()
135 << " before " << write->ToShortString();
136 }
137 }
138 }
139 }
140 return changed;
141 }
142
Run(HloModule * module)143 StatusOr<bool> LoopScheduleLinearizer::Run(HloModule* module) {
144 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
145 HloAliasAnalysis::Run(module, can_share_buffer_));
146
147 bool changed = false;
148 for (HloComputation* computation : module->MakeNonfusionComputations()) {
149 for (HloInstruction* instruction :
150 computation->MakeInstructionPostOrder()) {
151 if (instruction->opcode() == HloOpcode::kWhile) {
152 StatusOr<bool> updated_loop =
153 AddControlEdgesForLoopWrites(instruction, *alias_analysis);
154 TF_RETURN_IF_ERROR(updated_loop.status());
155 changed |= *updated_loop;
156 }
157 }
158 }
159 DumpHloModuleDuringPassIfEnabled(
160 name(), "after inserting control edges inside while loop bodies",
161 *module);
162
163 return changed;
164 }
165
166 } // end namespace xla
167