1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/loop_schedule_linearizer.h"
17 
18 #include "tensorflow/compiler/xla/service/dump.h"
19 #include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
20 
21 namespace xla {
22 
23 namespace {
24 
25 // Calculate ordering for HLO, for fast online checking of whether adding
26 // additional dependencies would create cycles.
27 struct ComputationInstructionOrdering {
ComputationInstructionOrderingxla::__anon7d309c810111::ComputationInstructionOrdering28   explicit ComputationInstructionOrdering(const HloComputation& computation) {
29     for (const HloInstruction* instr : computation.instructions()) {
30       for (const HloInstruction* control_pred : instr->control_predecessors()) {
31         CHECK(this->InsertEdge(*control_pred, *instr))
32             << "Graph already contained a cycle";
33       }
34 
35       for (int op_id = 0; op_id < instr->operand_count(); op_id++) {
36         const HloInstruction* op = instr->operand(op_id);
37         CHECK(this->InsertEdge(*op, *instr))
38             << "Graph already contained a cycle";
39       }
40     }
41   }
42 
NodeIdForInstructionxla::__anon7d309c810111::ComputationInstructionOrdering43   int32 NodeIdForInstruction(const HloInstruction& instr) {
44     int32 instruction_id = instr.unique_id();
45     auto it = node_id_to_graph_id.find(instruction_id);
46 
47     if (it != node_id_to_graph_id.end()) {
48       return it->second;
49     }
50     int32 node_id = graph_cycles.NewNode();
51     node_id_to_graph_id[instruction_id] = node_id;
52     return node_id;
53   }
54 
55   // Returns `false` if adding an edge would have introduced a cycle. Does not
56   // add an edge in that case. Returns `true` otherwise.
InsertEdgexla::__anon7d309c810111::ComputationInstructionOrdering57   bool InsertEdge(const HloInstruction& source, const HloInstruction& dest) {
58     int32 source_id = NodeIdForInstruction(source);
59     int32 dest_id = NodeIdForInstruction(dest);
60     return graph_cycles.InsertEdge(source_id, dest_id);
61   }
62 
63   absl::flat_hash_map<int32, int32> node_id_to_graph_id;
64 
65   tensorflow::GraphCycles graph_cycles;
66 };
67 
68 }  // namespace
69 
AddControlEdgesForLoopWrites(HloInstruction * xla_while,HloAliasAnalysis & alias_analysis)70 static StatusOr<bool> AddControlEdgesForLoopWrites(
71     HloInstruction* xla_while, HloAliasAnalysis& alias_analysis) {
72   HloDataflowAnalysis& dataflow = alias_analysis.dataflow_analysis();
73   HloComputation* body = xla_while->while_body();
74   HloInstruction* root = body->root_instruction();
75   HloInstruction* input = body->parameter_instruction(0);
76 
77   bool changed = false;
78 
79   // Compute dependency ordering ourselves. The reason we don't reuse other
80   // computations is because it is hard to extract the underlying graph from
81   // those abstractions.
82   ComputationInstructionOrdering ordering(*body);
83   ShapeTree<bool> indices_to_copy(xla_while->shape());
84 
85   for (auto& p : indices_to_copy) {
86     const ShapeIndex& index = p.first;
87 
88     if (index.empty()) {
89       continue;
90     }
91 
92     if (dataflow.GetValueSet(root, index).values().size() > 1 ||
93         dataflow.GetValueSet(input, index).values().size() > 1) {
94       VLOG(2) << "Index " << index.ToString() << " is associated with multiple "
95               << "values, not attempting to introduce stricter dependencies";
96     } else {
97       HloValue& value_at_root = dataflow.GetUniqueValueAt(root, index);
98       HloValue& value_at_input = dataflow.GetUniqueValueAt(input, index);
99 
100       if (value_at_root.shape().IsTuple()) {
101         // TODO(cheshire): For simplicity we currently do not handle nested
102         // tuples, as we haven't seen them in the examples we care about.
103         continue;
104       }
105 
106       // TODO(cheshire): This is too conservative and does not take aliasing
107       // into account.
108       HloInstruction* write = value_at_root.defining_instruction();
109 
110       for (const HloUse& use : value_at_input.uses()) {
111         HloInstruction* read = use.instruction;
112 
113         if (read != write &&
114             value_at_root != value_at_input
115 
116             // TODO(cheshire): Parents sometimes differ in case of e.g. nested
117             // loops, where the value is read/written into in the inner loop.
118             // For now we skip this case for simplicity (as the inner loop
119             // performance is more important in any case)
120             && read->parent() == write->parent()) {
121           VLOG(2) << "Inside " << body->name() << ", index "
122                   << index.ToString();
123           if (!ordering.InsertEdge(*read, *write)) {
124             VLOG(2) << "Not adding a control dependency from "
125                     << read->ToShortString() << " to " << write->ToShortString()
126                     << " as it would introduce a cycle";
127             continue;
128           }
129 
130           changed |= absl::c_linear_search(read->control_successors(), write);
131 
132           // Unless we want a copy, read should happen before write.
133           TF_RETURN_IF_ERROR(read->AddControlDependencyTo(write));
134           VLOG(2) << "Adding dependency: " << read->ToShortString()
135                   << " before " << write->ToShortString();
136         }
137       }
138     }
139   }
140   return changed;
141 }
142 
Run(HloModule * module)143 StatusOr<bool> LoopScheduleLinearizer::Run(HloModule* module) {
144   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
145                       HloAliasAnalysis::Run(module, can_share_buffer_));
146 
147   bool changed = false;
148   for (HloComputation* computation : module->MakeNonfusionComputations()) {
149     for (HloInstruction* instruction :
150          computation->MakeInstructionPostOrder()) {
151       if (instruction->opcode() == HloOpcode::kWhile) {
152         StatusOr<bool> updated_loop =
153             AddControlEdgesForLoopWrites(instruction, *alias_analysis);
154         TF_RETURN_IF_ERROR(updated_loop.status());
155         changed |= *updated_loop;
156       }
157     }
158   }
159   DumpHloModuleDuringPassIfEnabled(
160       name(), "after inserting control edges inside while loop bodies",
161       *module);
162 
163   return changed;
164 }
165 
166 }  // end namespace xla
167