1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
17 
18 #include <utility>
19 #include <vector>
20 
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/str_format.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/platform/logging.h"
32 
33 namespace xla {
34 
ExecutesBefore(const HloInstruction * a,const HloInstruction * b) const35 bool HloOrdering::ExecutesBefore(const HloInstruction* a,
36                                  const HloInstruction* b) const {
37   switch (GetExecutionConstraint(a, b)) {
38     case ExecutionConstraint::kIsSame:  // a and b are the same instruction;
39       return false;
40     case ExecutionConstraint::kRunBefore:
41     case ExecutionConstraint::kRunExclusiveBefore:
42       return true;
43     case ExecutionConstraint::kRunExclusiveAfter:
44     case ExecutionConstraint::kRunAfter:
45     case ExecutionConstraint::kUnordered:
46       return false;
47   }
48 }
49 
GetExecutionConstraint(const HloInstruction * a,const HloInstruction * b) const50 HloOrdering::ExecutionConstraint HloOrdering::GetExecutionConstraint(
51     const HloInstruction* a, const HloInstruction* b) const {
52   // 'a' and 'b' may be in different computations. In this case, find the
53   // callgraph ancestor instructions which call (potentially transitively) the
54   // computations containing 'a' and 'b' and use these ancestor instructions to
55   // compare order.
56   if (a == b) {
57     return ExecutionConstraint::kIsSame;
58   }
59   const HloInstruction* a_ancestor;
60   const HloInstruction* b_ancestor;
61   std::tie(a_ancestor, b_ancestor) =
62       call_graph_->NearestAncestorsInSameComputation(
63           const_cast<HloInstruction*>(a), const_cast<HloInstruction*>(b));
64 
65   if (a_ancestor == nullptr) {
66     VLOG(4) << "Ancestors in a common computation could not be found between"
67             << a->ToString() << "\n and \n"
68             << b->ToString() << "\n so consider them to be unordered.\n";
69     return ExecutionConstraint::kUnordered;
70   }
71   // a_ancestor and b_ancestor must be either both null or both non-null.
72   CHECK_NE(b_ancestor, nullptr);
73   CHECK_EQ(a_ancestor->parent(), b_ancestor->parent());
74 
75   // If the common ancestor is a while instruction there is an additional
76   // ordering criteria which may apply. The condition computation is considered
77   // to execute before the body computation so if 'a' is in the condition and
78   // 'b' is in the body, then 'a' executes before 'b'.
79   if (a_ancestor == b_ancestor && a_ancestor->opcode() == HloOpcode::kWhile) {
80     const HloComputation* body = a_ancestor->while_body();
81     const HloComputation* condition = a_ancestor->while_condition();
82     if (call_graph_->InstructionIsNestedIn(a, condition) &&
83         call_graph_->InstructionIsNestedIn(b, body)) {
84       return ExecutionConstraint::kRunBefore;
85     }
86   }
87 
88   // If the common ancestor is a conditional instruction, even though the branch
89   // computations are not really ordered per-se, we define the 0th branch
90   // computation to be ordered before the 1st one, before the 2nd and so forth.
91   // This ensures that buffers can still be shared among branch computations
92   // as they will forcibly have disjoint liveness.
93   if (a_ancestor == b_ancestor &&
94       (a_ancestor->opcode() == HloOpcode::kConditional)) {
95     int a_branch = -1;
96     int b_branch = -1;
97     for (int j = 0; j < a_ancestor->branch_count(); ++j) {
98       if (call_graph_->InstructionIsNestedIn(
99               a, a_ancestor->branch_computation(j))) {
100         a_branch = j;
101       }
102       if (call_graph_->InstructionIsNestedIn(
103               b, a_ancestor->branch_computation(j))) {
104         b_branch = j;
105       }
106     }
107     // If neither a nor b is inside the branches they both are the ancestor.
108     if (a_branch == -1 && b_branch == -1) {
109       CHECK_EQ(a, a_ancestor);
110       CHECK_EQ(b, b_ancestor);
111       CHECK_EQ(a, b);
112       return ExecutionConstraint::kIsSame;
113     }
114     // If 'b' is the conditional ancestor, and 'a' is within a branch
115     // computation, 'a' executes before 'b'.
116     if (b_branch == -1) {
117       CHECK_EQ(b, a_ancestor);
118       return ExecutionConstraint::kRunBefore;
119     }
120     if (a_branch == -1) {
121       CHECK_EQ(a, a_ancestor);
122       return ExecutionConstraint::kRunAfter;
123     }
124     if (a_branch < b_branch) {
125       return ExecutionConstraint::kRunExclusiveBefore;
126     }
127     if (b_branch < a_branch) {
128       return ExecutionConstraint::kRunExclusiveAfter;
129     }
130   }
131 
132   if (ExecutesBeforeInSameComputation(a_ancestor, b_ancestor)) {
133     return ExecutionConstraint::kRunBefore;
134   }
135   if (ExecutesBeforeInSameComputation(b_ancestor, a_ancestor)) {
136     return ExecutionConstraint::kRunAfter;
137   }
138   VLOG(1) << "Cannot determine order between:" << a->ToString() << "\n"
139           << "and " << b->ToString() << " which are in the same computation\n";
140   return ExecutionConstraint::kUnordered;
141 }
142 
IsDefinedBefore(const HloValue & a,const HloValue & b) const143 bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const {
144   // Entry parameter should always be defined before other instructions.
145   const HloModule* module = b.defining_instruction()->parent()->parent();
146   if (b.defining_instruction()->parent() == module->entry_computation() &&
147       b.defining_instruction()->opcode() == HloOpcode::kParameter) {
148     return false;
149   }
150 
151   if (a.defining_instruction()->parent() == module->entry_computation() &&
152       a.defining_instruction()->opcode() == HloOpcode::kParameter) {
153     return true;
154   }
155 
156   // Phi values require special handling. Because XLA does not have a phi
157   // instruction, the definition instruction of the phis values are
158   // placeholders: either the subcomputation parameter (body or condition) or
159   // the while instruction. However, the program point where these values are
160   // logically defined does not necessarily coincide exactly with program point
161   // of these place-holder instructions. So we explicitly define the following
162   // order for phi values:
163   //
164   //   body/condition parameter phi:
165   //     Defined before all values defined in its computation excepting other
166   //     phis.
167   //
168   //   while phi:
169   //     defined after all values defined in the condition or body.
170   //
171   auto is_body_or_condition_phi = [](const HloValue& v) {
172     return v.is_phi() &&
173            v.defining_instruction()->opcode() == HloOpcode::kParameter;
174   };
175   if (is_body_or_condition_phi(a) && !is_body_or_condition_phi(b) &&
176       call_graph_->InstructionIsNestedIn(b.defining_instruction(),
177                                          a.defining_instruction()->parent())) {
178     return true;
179   }
180   if (is_body_or_condition_phi(b) &&
181       call_graph_->InstructionIsNestedIn(a.defining_instruction(),
182                                          b.defining_instruction()->parent())) {
183     return false;
184   }
185 
186   // If 'b' is a while phi and 'a' is in the body or condition, then 'a'
187   // executes before 'b'.
188   if (b.is_phi() && b.defining_instruction()->opcode() == HloOpcode::kWhile &&
189       (call_graph_->InstructionIsNestedIn(
190            a.defining_instruction(), b.defining_instruction()->while_body()) ||
191        call_graph_->InstructionIsNestedIn(
192            a.defining_instruction(),
193            b.defining_instruction()->while_condition()))) {
194     return true;
195   }
196   // If 'b' is a conditional phi and 'a' is in some branch computation, then 'a'
197   // executes before 'b'.
198   if (b.is_phi() &&
199       b.defining_instruction()->opcode() == HloOpcode::kConditional) {
200     for (int j = 0; j < b.defining_instruction()->branch_count(); ++j) {
201       if (call_graph_->InstructionIsNestedIn(
202               a.defining_instruction(),
203               b.defining_instruction()->branch_computation(j))) {
204         return true;
205       }
206     }
207   }
208   return ExecutesBefore(a.defining_instruction(), b.defining_instruction());
209 }
210 
211 /* static */
UsesBeforeValueDefinition(absl::Span<const HloUse * const> uses,const HloValue & value,const HloDataflowAnalysis & dataflow) const212 bool HloOrdering::UsesBeforeValueDefinition(
213     absl::Span<const HloUse* const> uses, const HloValue& value,
214     const HloDataflowAnalysis& dataflow) const {
215   bool has_use_in_exclusive_branches = false;
216   bool has_escaped_use_in_conditional = false;
217   auto UseIsBeforeValueDefinition = [&](const HloUse& use) {
218     VLOG(4) << "UseIsBeforeValueDefinition(use=" << use
219             << ", value=" << value.ToShortString() << ")";
220     switch (
221         GetExecutionConstraint(use.instruction, value.defining_instruction())) {
222       case HloOrdering::ExecutionConstraint::kIsSame:
223         // If the use is at the instruction where the value is defined, then the
224         // use is before the def if the instruction allows buffer sharing (in
225         // place computation).
226         if (dataflow.CanShareOperandBufferWithUser(
227                 use.instruction->mutable_operand(use.operand_number),
228                 use.operand_index, value.defining_instruction(),
229                 value.defining_index())) {
230           VLOG(4)
231               << "  use is value def, and instruction can share use buffer.";
232           return true;
233         }
234         break;
235       case HloOrdering::ExecutionConstraint::kRunExclusiveAfter:
236         // If the use is located in a branch that is exclusive to the branch
237         // where value is located, in order for them to interfere, there must be
238         // an execution path where the value's definition can reach the use, so
239         // that the wrong value would reach use if their live ranges are merged.
240         // If there is such a path, it would have to pass through the point
241         // where the two exclusive branches are joined --- specifically the end
242         // of the conditional operation. For the join point to reach back to the
243         // use at the other exclusive branch, there has to be a be a surrounding
244         // loop, where the result of the conditional is passed back inside the
245         // conditional through one of its parameters. This use-def conflict
246         // between the parameter of a conditional and one of its branches is
247         // caught in the has_escaped_use_in_conditinoal variable.
248         VLOG(4) << " use and value def are in exclusive branches.";
249         if (!has_escaped_use_in_conditional) {
250           has_use_in_exclusive_branches = true;
251           VLOG(4) << "Allowing them to share buffer.\n";
252           return true;
253         }
254         VLOG(4) << "value def has escaped use in conditional. \n";
255         break;
256       case HloOrdering::ExecutionConstraint::kRunExclusiveBefore:
257       case HloOrdering::ExecutionConstraint::kRunBefore:
258         VLOG(4)
259             << "  use instruction executes before value-defining instruction";
260         return true;
261       case HloOrdering::ExecutionConstraint::kRunAfter:
262       case HloOrdering::ExecutionConstraint::kUnordered:
263         break;
264     }
265 
266     // The use at a while is an input to a phi, and logically occurs before
267     // values are defined in the body. Note that the use is *not* before the
268     // value if the value is defined in the condition and is not the condition
269     // parameter, since the input of a while's live range is only ended at the
270     // start the body.
271     if (use.instruction->opcode() == HloOpcode::kWhile) {
272       const HloInstruction* xla_while = use.instruction;
273       if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
274                                              xla_while->while_body())) {
275         VLOG(4) << "  use is while " << use.instruction->name()
276                 << " and def is in body";
277         return true;
278       }
279       if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
280                                              xla_while->while_condition())) {
281         if (value.defining_instruction() !=
282             xla_while->while_condition()->parameter_instruction(0)) {
283           VLOG(4) << "  use is while " << use.instruction->name()
284                   << " and def is in condition and is not the parameter";
285           return false;
286         } else {
287           VLOG(4) << "  use is while " << use.instruction->name()
288                   << " and def is in condition and is the parameter";
289           return true;
290         }
291       }
292     }
293     // Similarly if the value is defined at a while, it logically occurs after
294     // any uses in the body or condition computations.
295     if (value.defining_instruction()->opcode() == HloOpcode::kWhile) {
296       CHECK(value.is_phi());
297       const HloInstruction* xla_while = value.defining_instruction();
298       if (call_graph_->InstructionIsNestedIn(use.instruction,
299                                              xla_while->while_body()) ||
300           call_graph_->InstructionIsNestedIn(use.instruction,
301                                              xla_while->while_condition())) {
302         VLOG(4) << "  value is while " << value.defining_instruction()->name()
303                 << " and use is in condition or body";
304         return true;
305       }
306     }
307     // The use at a call occurs before values that are defined in the called
308     // computation.
309     if (use.instruction->opcode() == HloOpcode::kCall) {
310       const HloInstruction* call = use.instruction;
311       if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
312                                              call->to_apply())) {
313         VLOG(4) << "  use is call " << use.instruction->name()
314                 << " and def is in called computation";
315         return true;
316       }
317     }
318     if (use.instruction->opcode() == HloOpcode::kConditional) {
319       const HloInstruction* conditional = use.instruction;
320       // In general the use of a value in the conditional parameter should be
321       // considered to be before a definition in one of its branches, and
322       // therefore allowed in live range merging, if there is no
323       // surrounding loop that creates a backward control flow path that
324       // allows the definition in the branch to have its value flow backward
325       // into the conditional and then flow into another branch in the
326       // conditional that uses the value. This is reflected by checking that
327       // the use-def in exclusive branches has not been already allowed.
328       // Further, if the def value escapes its branch, we conservatively
329       // assume a backward control flow path could exist, and set
330       // has_escaped_use_in_conditinoal to disallow any later uses in
331       // exclusive branches.
332       for (int j = 0; j < conditional->branch_count(); ++j) {
333         if (call_graph_->InstructionIsNestedIn(
334                 value.defining_instruction(),
335                 conditional->branch_computation(j))) {
336           // If the use operand does not create a new value, and the value def
337           // is returned by as part of the result of the conditional, it
338           // is possible for the branch definition to flow backward through a
339           // surrounding loop and then back into the conditional parameter.
340           if (!dataflow.ValueIsDefinedAt(
341                   use.instruction->operand(use.operand_number), {})) {
342             for (auto value_use : value.uses()) {
343               VLOG(4) << "def have use:" << value_use << "\n";
344               if (value_use.instruction ==
345                   value_use.instruction->parent()->root_instruction()) {
346                 VLOG(4) << "def use is conditional root \n";
347                 has_escaped_use_in_conditional = true;
348                 break;
349               }
350             }
351           }
352           if (!has_use_in_exclusive_branches) {
353             VLOG(4) << "  use is conditional " << use.instruction->name()
354                     << " and def is in " << j << "th branch computation";
355             return true;
356           }
357         }
358       }
359       if (value.defining_instruction() == use.instruction) {
360         VLOG(4) << "  use is conditional " << use << " and def is "
361                 << value.ToShortString();
362         return true;
363       }
364     }
365 
366     VLOG(4) << "  use is not before value definition";
367     return false;
368   };
369   for (auto* use : uses) {
370     if (!UseIsBeforeValueDefinition(*use)) {
371       return false;
372     }
373   }
374   return true;
375 }
376 
LiveRangeStrictlyBefore(const HloValue & a,const HloValue & b,const HloDataflowAnalysis & dataflow) const377 bool HloOrdering::LiveRangeStrictlyBefore(
378     const HloValue& a, const HloValue& b,
379     const HloDataflowAnalysis& dataflow) const {
380   VLOG(4) << "LiveRangeStrictlyBefore(a = " << a.ToShortString()
381           << ", b = " << b.ToShortString() << ")";
382   VLOG(4) << "Parent:" << a.instruction()->parent()->ToString() << "\n";
383   if (!IsDefinedBefore(a, b)) {
384     VLOG(4) << a << " not defined before " << b;
385     return false;
386   }
387 
388   if (a.live_out_of_module()) {
389     VLOG(4) << a << " is live out of module and not defined before " << b;
390     return false;
391   }
392 
393   // If the root instruction aliases the buffer 'a', the live range of 'a' is
394   // until the end of the computation and can never be strictly before another
395   // buffer nested in the same computation. This is needed to prevent the root
396   // instruction's buffers from being reused by later instructions even when
397   // the root is not the last instruction in the schedule.
398   for (const HloPosition& pos : a.positions()) {
399     if (pos.instruction->parent()->root_instruction() == pos.instruction &&
400         call_graph().InstructionIsNestedIn(b.instruction(),
401                                            pos.instruction->parent())) {
402       return false;
403     }
404   }
405 
406   // All uses of 'a' must be before 'b' is defined.
407   std::vector<const HloUse*> uses;
408   for (const HloUse& use : a.uses()) {
409     if (dataflow.DoesNotUseOperandBuffer(a.instruction(), a.index(),
410                                          use.instruction)) {
411       continue;
412     }
413     uses.push_back(&use);
414   }
415   if (!UsesBeforeValueDefinition(uses, b, dataflow)) {
416     VLOG(4) << "uses of " << a << "not before " << b << " is defined";
417     return false;
418   }
419 
420   if (a.instruction()->parent() == b.instruction()->parent()) {
421     for (const HloPosition& position : a.positions()) {
422       if (position.instruction ==
423           a.instruction()->parent()->root_instruction()) {
424         VLOG(4) << a << " is live out of computation and defined before " << b
425                 << " which is in same computation";
426         return false;
427       }
428     }
429   }
430 
431   return true;
432 }
433 
MayInterfere(const HloValue & a,const HloValue & b,const HloDataflowAnalysis & dataflow) const434 bool HloOrdering::MayInterfere(const HloValue& a, const HloValue& b,
435                                const HloDataflowAnalysis& dataflow) const {
436   // Buffers without disjoint liveness may interfere.
437   return !LiveRangeStrictlyBefore(a, b, dataflow) &&
438          !LiveRangeStrictlyBefore(b, a, dataflow);
439 }
440 
PredecessorHloOrdering(const HloModule * module)441 PredecessorHloOrdering::PredecessorHloOrdering(const HloModule* module)
442     : HloOrdering(module) {}
443 
ExecutesBeforeInSameComputation(const HloInstruction * a,const HloInstruction * b) const444 bool PredecessorHloOrdering::ExecutesBeforeInSameComputation(
445     const HloInstruction* a, const HloInstruction* b) const {
446   CHECK_EQ(a->parent(), b->parent());
447 
448   // 'a' executes before 'b' if 'a' is in the strict predecessor set of 'b'.
449   return a != b && predecessors_.at(a->parent())->IsReachable(a, b);
450 }
451 
ToStringHelper(const string & name) const452 string PredecessorHloOrdering::ToStringHelper(const string& name) const {
453   std::vector<string> pieces;
454   pieces.push_back(name);
455   for (auto* computation : module_->MakeNonfusionComputations()) {
456     pieces.push_back(absl::StrFormat("computation %s:", computation->name()));
457     const auto all = computation->MakeInstructionPostOrder();
458     for (auto instruction : all) {
459       pieces.push_back(
460           absl::StrFormat("  %s predecessors:", instruction->name()));
461       for (auto predecessor : all) {
462         if (predecessors_.at(computation)
463                 ->IsReachable(predecessor, instruction)) {
464           pieces.push_back(absl::StrFormat("    %s", predecessor->name()));
465         }
466       }
467     }
468   }
469   return absl::StrJoin(pieces, "\n");
470 }
471 
DependencyHloOrdering(const HloModule * module)472 DependencyHloOrdering::DependencyHloOrdering(const HloModule* module)
473     : PredecessorHloOrdering(module) {
474   // Compute predecessor relationships between all instructions to determine
475   // ordering based on dependencies. ExecutesBefore will return true iff there
476   // exists a path in the HLO computation graph from 'a' to 'b'.
477   for (auto* computation : module->MakeNonfusionComputations()) {
478     predecessors_.emplace(computation, HloReachabilityMap::Build(computation));
479   }
480 }
481 
ToString() const482 string DependencyHloOrdering::ToString() const {
483   return ToStringHelper("DependencyHloOrdering");
484 }
485 
SequentialHloOrdering(const HloSchedule & schedule)486 SequentialHloOrdering::SequentialHloOrdering(const HloSchedule& schedule)
487     : HloOrdering(schedule.module()), schedule_(schedule) {
488   Initialize();
489 }
490 
SequentialHloOrdering(HloSchedule && schedule)491 SequentialHloOrdering::SequentialHloOrdering(HloSchedule&& schedule)
492     : HloOrdering(schedule.module()), schedule_(std::move(schedule)) {
493   Initialize();
494 }
495 
Initialize()496 void SequentialHloOrdering::Initialize() {
497   // Create a map from instruction to its order position.
498   TF_DCHECK_OK(schedule_.Verify());
499   for (const auto& computation_sequence : schedule_.sequences()) {
500     const auto& order = computation_sequence.second.instructions();
501     for (int i = 0; i < order.size(); ++i) {
502       InsertOrDie(&order_position_, order[i], i);
503     }
504   }
505 }
506 
ExecutesBeforeInSameComputation(const HloInstruction * a,const HloInstruction * b) const507 bool SequentialHloOrdering::ExecutesBeforeInSameComputation(
508     const HloInstruction* a, const HloInstruction* b) const {
509   CHECK_EQ(a->parent(), b->parent());
510   // If either instruction is not in the order, then 'a' and 'b' are unordered.
511   if (!order_position_.contains(a) || !order_position_.contains(b)) {
512     return false;
513   }
514   return order_position_.at(a) < order_position_.at(b);
515 }
516 
SequentialOrder(const HloComputation & computation) const517 const HloInstructionSequence* SequentialHloOrdering::SequentialOrder(
518     const HloComputation& computation) const {
519   return schedule_.is_computation_scheduled(&computation)
520              ? &schedule_.sequence(&computation)
521              : nullptr;
522 }
523 
ToString() const524 string SequentialHloOrdering::ToString() const {
525   return absl::StrCat("SequentialHloOrdering\n", schedule_.ToString());
526 }
527 
528 }  // namespace xla
529