1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
17 
18 #include <utility>
19 #include <vector>
20 
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/str_format.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/platform/logging.h"
32 
33 namespace xla {
34 
ExecutesBefore(const HloInstruction * a,const HloInstruction * b) const35 bool HloOrdering::ExecutesBefore(const HloInstruction* a,
36                                  const HloInstruction* b) const {
37   // 'a' and 'b' may be in different computations. In this case, find the
38   // callgraph ancestor instructions which call (potentially transitively) the
39   // computations containing 'a' and 'b' and use these ancestor instructions to
40   // compare order.
41   const HloInstruction* a_ancestor;
42   const HloInstruction* b_ancestor;
43   std::tie(a_ancestor, b_ancestor) =
44       call_graph_->NearestAncestorsInSameComputation(
45           const_cast<HloInstruction*>(a), const_cast<HloInstruction*>(b));
46 
47   if (a_ancestor == nullptr) {
48     // Ancestors in a common computation could not be found so consider the
49     // instructions 'a' and 'b' to be unordered.
50     return false;
51   }
52   // a_ancestor and b_ancestor must be either both null or both non-null.
53   CHECK_NE(b_ancestor, nullptr);
54   CHECK_EQ(a_ancestor->parent(), b_ancestor->parent());
55 
56   // If the common ancestor is a while instruction there is an additional
57   // ordering criteria which may apply. The condition computation is considered
58   // to execute before the body computation so if 'a' is in the condition and
59   // 'b' is in the body, then 'a' executes before 'b'.
60   if (a_ancestor == b_ancestor && a_ancestor->opcode() == HloOpcode::kWhile) {
61     const HloComputation* body = a_ancestor->while_body();
62     const HloComputation* condition = a_ancestor->while_condition();
63     if (call_graph_->InstructionIsNestedIn(a, condition) &&
64         call_graph_->InstructionIsNestedIn(b, body)) {
65       return true;
66     }
67   }
68 
69   // If the common ancestor is a conditional instruction, even though the branch
70   // computations are not really ordered per-se, we define the 0th branch
71   // computation to be ordered before the 1st one, before the 2nd and so forth.
72   // This ensures that buffers can still be shared among branch computations
73   // as they will forcibly have disjoint liveness.
74   if (a_ancestor == b_ancestor &&
75       (a_ancestor->opcode() == HloOpcode::kConditional)) {
76     int a_branch = -1;
77     int b_branch = -1;
78     for (int j = 0; j < a_ancestor->branch_count(); ++j) {
79       if (call_graph_->InstructionIsNestedIn(
80               a, a_ancestor->branch_computation(j))) {
81         a_branch = j;
82       }
83       if (call_graph_->InstructionIsNestedIn(
84               b, a_ancestor->branch_computation(j))) {
85         b_branch = j;
86       }
87     }
88     if (a_branch != -1 && a_branch < b_branch) {
89       return true;
90     }
91     // If 'b' is the conditional ancestor, and 'a' is within a branch
92     // computation, 'a' executes before 'b'.
93     if (b == a_ancestor && a_branch != -1) {
94       return true;
95     }
96   }
97 
98   return ExecutesBeforeInSameComputation(a_ancestor, b_ancestor);
99 }
100 
IsDefinedBefore(const HloValue & a,const HloValue & b) const101 bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const {
102   // Entry parameter should always be defined before other instructions.
103   const HloModule* module = b.defining_instruction()->parent()->parent();
104   if (b.defining_instruction()->parent() == module->entry_computation() &&
105       b.defining_instruction()->opcode() == HloOpcode::kParameter) {
106     return false;
107   }
108 
109   if (a.defining_instruction()->parent() == module->entry_computation() &&
110       a.defining_instruction()->opcode() == HloOpcode::kParameter) {
111     return true;
112   }
113 
114   // Phi values require special handling. Because XLA does not have a phi
115   // instruction, the definition instruction of the phis values are
116   // placeholders: either the subcomputation parameter (body or condition) or
117   // the while instruction. However, the program point where these values are
118   // logically defined does not necessarily coincide exactly with program point
119   // of these place-holder instructions. So we explicitly define the following
120   // order for phi values:
121   //
122   //   body/condition parameter phi:
123   //     Defined before all values defined in its computation excepting other
124   //     phis.
125   //
126   //   while phi:
127   //     defined after all values defined in the condition or body.
128   //
129   auto is_body_or_condition_phi = [](const HloValue& v) {
130     return v.is_phi() &&
131            v.defining_instruction()->opcode() == HloOpcode::kParameter;
132   };
133   if (is_body_or_condition_phi(a) && !is_body_or_condition_phi(b) &&
134       call_graph_->InstructionIsNestedIn(b.defining_instruction(),
135                                          a.defining_instruction()->parent())) {
136     return true;
137   }
138   if (is_body_or_condition_phi(b) &&
139       call_graph_->InstructionIsNestedIn(a.defining_instruction(),
140                                          b.defining_instruction()->parent())) {
141     return false;
142   }
143 
144   // If 'b' is a while phi and 'a' is in the body or condition, then 'a'
145   // executes before 'b'.
146   if (b.is_phi() && b.defining_instruction()->opcode() == HloOpcode::kWhile &&
147       (call_graph_->InstructionIsNestedIn(
148            a.defining_instruction(), b.defining_instruction()->while_body()) ||
149        call_graph_->InstructionIsNestedIn(
150            a.defining_instruction(),
151            b.defining_instruction()->while_condition()))) {
152     return true;
153   }
154   // If 'b' is a conditional phi and 'a' is in some branch computation, then 'a'
155   // executes before 'b'.
156   if (b.is_phi() &&
157       b.defining_instruction()->opcode() == HloOpcode::kConditional) {
158     for (int j = 0; j < b.defining_instruction()->branch_count(); ++j) {
159       if (call_graph_->InstructionIsNestedIn(
160               a.defining_instruction(),
161               b.defining_instruction()->branch_computation(j))) {
162         return true;
163       }
164     }
165   }
166   return ExecutesBefore(a.defining_instruction(), b.defining_instruction());
167 }
168 
169 /* static */
UseIsBeforeValueDefinition(const HloUse & use,const HloValue & value,const HloDataflowAnalysis & dataflow) const170 bool HloOrdering::UseIsBeforeValueDefinition(
171     const HloUse& use, const HloValue& value,
172     const HloDataflowAnalysis& dataflow) const {
173   VLOG(4) << "UseIsBeforeValueDefinition(use=" << use
174           << ", value=" << value.ToShortString() << ")";
175   if (ExecutesBefore(use.instruction, value.defining_instruction())) {
176     VLOG(4) << "  use instruction executes before value-defining instruction";
177     return true;
178   }
179 
180   // If the use is at the instruction where the value is defined, then the use
181   // is before the def if the instruction allows buffer sharing (in place
182   // computation).
183   if (use.instruction == value.defining_instruction() &&
184       dataflow.CanShareOperandBufferWithUser(
185           use.instruction->mutable_operand(use.operand_number),
186           use.operand_index, value.defining_instruction(),
187           value.defining_index())) {
188     VLOG(4) << "  use is value def, and instruction can share use buffer";
189     return true;
190   }
191 
192   // The use at a while is an input to a phi, and logically occurs before values
193   // are defined in the body or condition computations.
194   if (use.instruction->opcode() == HloOpcode::kWhile) {
195     const HloInstruction* xla_while = use.instruction;
196     if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
197                                            xla_while->while_body()) ||
198         call_graph_->InstructionIsNestedIn(value.defining_instruction(),
199                                            xla_while->while_condition())) {
200       VLOG(4) << "  use is while " << use.instruction->name()
201               << " and def is in condition or body";
202       return true;
203     }
204   }
205 
206   // Similarly if the value is defined at a while, it logically occurs after any
207   // uses in the body or condition computations.
208   if (value.defining_instruction()->opcode() == HloOpcode::kWhile) {
209     CHECK(value.is_phi());
210     const HloInstruction* xla_while = value.defining_instruction();
211     if (call_graph_->InstructionIsNestedIn(use.instruction,
212                                            xla_while->while_body()) ||
213         call_graph_->InstructionIsNestedIn(use.instruction,
214                                            xla_while->while_condition())) {
215       VLOG(4) << "  value is while " << value.defining_instruction()->name()
216               << " and use is in condition or body";
217       return true;
218     }
219   }
220 
221   // The use at a call occurs before values that are defined in the called
222   // computation.
223   if (use.instruction->opcode() == HloOpcode::kCall) {
224     const HloInstruction* call = use.instruction;
225     if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
226                                            call->to_apply())) {
227       VLOG(4) << "  use is call " << use.instruction->name()
228               << " and def is in called computation";
229       return true;
230     }
231   }
232 
233   if (use.instruction->opcode() == HloOpcode::kConditional) {
234     const HloInstruction* conditional = use.instruction;
235     for (int j = 0; j < conditional->branch_count(); ++j) {
236       if (call_graph_->InstructionIsNestedIn(
237               value.defining_instruction(),
238               conditional->branch_computation(j))) {
239         VLOG(4) << "  use is conditional " << use.instruction->name()
240                 << " and def is in " << j << "th branch computation";
241         return true;
242       }
243     }
244     if (value.defining_instruction() == use.instruction) {
245       VLOG(4) << "  use is conditional " << use << " and def is "
246               << value.ToShortString();
247       return true;
248     }
249   }
250 
251   VLOG(4) << "  use is not before value";
252   return false;
253 }
254 
LiveRangeStrictlyBefore(const HloValue & a,const HloValue & b,const HloDataflowAnalysis & dataflow) const255 bool HloOrdering::LiveRangeStrictlyBefore(
256     const HloValue& a, const HloValue& b,
257     const HloDataflowAnalysis& dataflow) const {
258   VLOG(4) << "LiveRangeStrictlyBefore(a = " << a.ToShortString()
259           << ", b = " << b.ToShortString() << ")";
260   if (!IsDefinedBefore(a, b)) {
261     VLOG(4) << a << " not defined before " << b;
262     return false;
263   }
264 
265   if (a.live_out_of_module()) {
266     VLOG(4) << a << " is live out of module and defined before " << b;
267     return false;
268   }
269 
270   // All uses of 'a' must be before 'b' is defined.
271   for (const HloUse& use : a.uses()) {
272     if (dataflow.DoesNotUseOperandBuffer(a.instruction(), a.index(),
273                                          use.instruction)) {
274       continue;
275     }
276     if (!UseIsBeforeValueDefinition(use, b, dataflow)) {
277       VLOG(4) << "use of " << a << " (" << use << ") not before " << b
278               << " is defined";
279       return false;
280     }
281   }
282 
283   if (a.instruction()->parent() == b.instruction()->parent()) {
284     for (const HloPosition& position : a.positions()) {
285       if (position.instruction ==
286           a.instruction()->parent()->root_instruction()) {
287         VLOG(4) << a << " is live out of computation and defined before " << b
288                 << " which is in same computation";
289         return false;
290       }
291     }
292   }
293 
294   return true;
295 }
296 
MayInterfere(const HloValue & a,const HloValue & b,const HloDataflowAnalysis & dataflow) const297 bool HloOrdering::MayInterfere(const HloValue& a, const HloValue& b,
298                                const HloDataflowAnalysis& dataflow) const {
299   // Buffers without disjoint liveness may interfere.
300   return !LiveRangeStrictlyBefore(a, b, dataflow) &&
301          !LiveRangeStrictlyBefore(b, a, dataflow);
302 }
303 
PredecessorHloOrdering(const HloModule * module)304 PredecessorHloOrdering::PredecessorHloOrdering(const HloModule* module)
305     : HloOrdering(module) {}
306 
ExecutesBeforeInSameComputation(const HloInstruction * a,const HloInstruction * b) const307 bool PredecessorHloOrdering::ExecutesBeforeInSameComputation(
308     const HloInstruction* a, const HloInstruction* b) const {
309   CHECK_EQ(a->parent(), b->parent());
310 
311   // 'a' executes before 'b' if 'a' is in the strict predecessor set of 'b'.
312   return a != b && predecessors_.at(a->parent())->IsReachable(a, b);
313 }
314 
ToStringHelper(const string & name) const315 string PredecessorHloOrdering::ToStringHelper(const string& name) const {
316   std::vector<string> pieces;
317   pieces.push_back(name);
318   for (auto* computation : module_->MakeNonfusionComputations()) {
319     pieces.push_back(absl::StrFormat("computation %s:", computation->name()));
320     const auto all = computation->MakeInstructionPostOrder();
321     for (auto instruction : all) {
322       pieces.push_back(
323           absl::StrFormat("  %s predecessors:", instruction->name()));
324       for (auto predecessor : all) {
325         if (predecessors_.at(computation)
326                 ->IsReachable(predecessor, instruction)) {
327           pieces.push_back(absl::StrFormat("    %s", predecessor->name()));
328         }
329       }
330     }
331   }
332   return absl::StrJoin(pieces, "\n");
333 }
334 
DependencyHloOrdering(const HloModule * module)335 DependencyHloOrdering::DependencyHloOrdering(const HloModule* module)
336     : PredecessorHloOrdering(module) {
337   // Compute predecessor relationships between all instructions to determine
338   // ordering based on dependencies. ExecutesBefore will return true iff there
339   // exists a path in the HLO computation graph from 'a' to 'b'.
340   for (auto* computation : module->MakeNonfusionComputations()) {
341     predecessors_.emplace(computation, HloReachabilityMap::Build(computation));
342   }
343 }
344 
ToString() const345 string DependencyHloOrdering::ToString() const {
346   return ToStringHelper("DependencyHloOrdering");
347 }
348 
SequentialHloOrdering(const HloSchedule & schedule)349 SequentialHloOrdering::SequentialHloOrdering(const HloSchedule& schedule)
350     : HloOrdering(schedule.module()), schedule_(schedule) {
351   Initialize();
352 }
353 
SequentialHloOrdering(HloSchedule && schedule)354 SequentialHloOrdering::SequentialHloOrdering(HloSchedule&& schedule)
355     : HloOrdering(schedule.module()), schedule_(std::move(schedule)) {
356   Initialize();
357 }
358 
Initialize()359 void SequentialHloOrdering::Initialize() {
360   // Create a map from instruction to its order position.
361   TF_DCHECK_OK(schedule_.Verify());
362   for (const auto& computation_sequence : schedule_.sequences()) {
363     const auto& order = computation_sequence.second.instructions();
364     for (int i = 0; i < order.size(); ++i) {
365       InsertOrDie(&order_position_, order[i], i);
366     }
367   }
368 }
369 
ExecutesBeforeInSameComputation(const HloInstruction * a,const HloInstruction * b) const370 bool SequentialHloOrdering::ExecutesBeforeInSameComputation(
371     const HloInstruction* a, const HloInstruction* b) const {
372   CHECK_EQ(a->parent(), b->parent());
373   // If either instruction is not in the order, then 'a' and 'b' are unordered.
374   if (!order_position_.contains(a) || !order_position_.contains(b)) {
375     return false;
376   }
377   return order_position_.at(a) < order_position_.at(b);
378 }
379 
SequentialOrder(const HloComputation & computation) const380 const HloInstructionSequence* SequentialHloOrdering::SequentialOrder(
381     const HloComputation& computation) const {
382   return schedule_.is_computation_scheduled(&computation)
383              ? &schedule_.sequence(&computation)
384              : nullptr;
385 }
386 
ToString() const387 string SequentialHloOrdering::ToString() const {
388   return absl::StrCat("SequentialHloOrdering\n", schedule_.ToString());
389 }
390 
391 }  // namespace xla
392