1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/strings/str_cat.h"
26 #include "absl/strings/str_join.h"
27 #include "tensorflow/compiler/xla/map_util.h"
28 #include "tensorflow/compiler/xla/service/hlo_buffer.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_value.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/platform/logging.h"
36 
37 namespace xla {
38 
39 using absl::StrAppend;
40 
41 // Data structure used to construct the alias analysis. Thrown away after alias
42 // analysis is complete. This data structure keeps track of which sets of
43 // HloValues must be in the same HloBuffer. This is maintained as a map from a
44 // buffer identifier (BufferNumber) to set of HLoValues.
45 //
46 // Initially each value is its own buffer. In MergeAliasedBuffers, sets of
47 // values which must share the same buffer are merged together. The end result
48 // is a partitioning of all HloValues into sets where each set needs its own
49 // HloBuffer. By performing this analysis without constructing HloBuffers on the
50 // fly, we can after-the-fact construct a vector of contiguously numbered
51 // HloBuffers after the buffer requirement has been determined.
52 class BufferValueMap {
53  public:
54   // A unique identifier for a set of colocated values which must share the same
55   // buffer. This is not necessarily the same as the HloBuffer::Id which will
56   // ultimately contain the values. The reason is that HloBuffer::Id's are
57   // contiguous, while BufferNumbers may not be. BufferNumbers may not be
58   // dense because buffers may be created and destroyed during the analysis
59   // construction process.
60   using BufferNumber = int64;
61 
BufferValueMap(HloModule * module,const HloDataflowAnalysis & dataflow)62   explicit BufferValueMap(HloModule* module,
63                           const HloDataflowAnalysis& dataflow)
64       : module_(module), dataflow_(dataflow) {
65     buffers_.reserve(dataflow_.values().size());
66     value_to_buffer_number_.reserve(dataflow_.values().size());
67     for (const HloValue* value : dataflow_.values()) {
68       BufferNumber buffer_number = next_buffer_number_++;
69       buffers_[buffer_number].insert(value);
70       value_to_buffer_number_[value] = buffer_number;
71     }
72   }
73 
74   // Merge together sets of HloValues which must be in the same HloBuffer
75   // because of aliasing rules (eg, in-place kWhile instruction).
MergeAliasedBuffers()76   void MergeAliasedBuffers() {
77     for (const HloValue* value : dataflow_.values()) {
78       VLOG(3) << "Merging colocated values, value: " << value->ToShortString();
79 
80       // Gather the set of buffers with aliasing rules (eg, kWhile) which this
81       // value must be contained in.
82       std::vector<BufferNumber> aliased_buffers = ComputeAliasedBuffers(*value);
83 
84       BufferNumber current_buffer = value_to_buffer_number_.at(value);
85       if (aliased_buffers.empty()) {
86         // The buffer containing 'value' aliases no other buffers. If the buffer
87         // containing 'value' already only contains 'value', then no change is
88         // necessary. If the buffer containing 'value' does contain other
89         // values, then remove 'value' from the buffer and create a new buffer
90         // containing only 'value'
91         if (buffers_.at(current_buffer).size() == 1) {
92           CHECK_EQ(*buffers_.at(current_buffer).begin(), value);
93         } else {
94           MoveValueToNewBuffer(*value);
95         }
96       } else {
97         // If multiple buffers are aliased merge these buffers together into a
98         // single buffer (arbitrarily chosen as the first buffer in the vector).
99         if (aliased_buffers.size() > 1) {
100           for (int64 i = 1; i < aliased_buffers.size(); ++i) {
101             MergeBuffers(/*from=*/aliased_buffers[i],
102                          /*to=*/aliased_buffers[0]);
103           }
104         }
105         BufferNumber new_buffer = aliased_buffers[0];
106         if (current_buffer != new_buffer) {
107           MoveValueToBuffer(*value, new_buffer);
108         }
109       }
110     }
111   }
112 
113   // Compute and return a sorted vector of all BufferNumbers. Can be used to
114   // iterate through all buffers stabily.
ComputeSortedBufferNumbers() const115   std::vector<BufferNumber> ComputeSortedBufferNumbers() const {
116     std::vector<BufferNumber> buffer_numbers;
117     for (const auto& pair : buffers_) {
118       buffer_numbers.push_back(pair.first);
119     }
120     absl::c_sort(buffer_numbers);
121     return buffer_numbers;
122   }
123 
124   // Return a set of all the values in the given buffer.
GetValuesInBuffer(BufferNumber buffer_number) const125   const absl::flat_hash_set<const HloValue*>& GetValuesInBuffer(
126       BufferNumber buffer_number) const {
127     return buffers_.at(buffer_number);
128   }
129 
130  private:
131   // Create a new buffer.
NewBuffer(const HloValue & value)132   void NewBuffer(const HloValue& value) {
133     BufferNumber buffer_number = next_buffer_number_++;
134     buffers_[buffer_number].insert(&value);
135     value_to_buffer_number_[&value] = buffer_number;
136   }
137 
138   // Move the given value into a new buffer containing only the value.
MoveValueToNewBuffer(const HloValue & value)139   void MoveValueToNewBuffer(const HloValue& value) {
140     BufferNumber new_buffer_number = next_buffer_number_++;
141     buffers_[new_buffer_number];
142     MoveValueToBuffer(value, new_buffer_number);
143   }
144 
145   // Move the given value into the given buffer.
MoveValueToBuffer(const HloValue & value,BufferNumber buffer_number)146   void MoveValueToBuffer(const HloValue& value, BufferNumber buffer_number) {
147     BufferNumber old_buffer_number = value_to_buffer_number_.at(&value);
148     absl::flat_hash_set<const HloValue*>& old_value_set =
149         buffers_.at(old_buffer_number);
150     old_value_set.erase(&value);
151     if (old_value_set.empty()) {
152       buffers_.erase(old_buffer_number);
153     }
154 
155     buffers_.at(buffer_number).insert(&value);
156     value_to_buffer_number_.at(&value) = buffer_number;
157   }
158 
159   // Merge the buffer 'from' into the buffer 'to'.
MergeBuffers(BufferNumber from,BufferNumber to)160   void MergeBuffers(BufferNumber from, BufferNumber to) {
161     auto& from_value_set = buffers_.at(from);
162     buffers_.at(to).insert(from_value_set.begin(), from_value_set.end());
163     // NOTE: using a union-find algorithm to hold the colocated values might be
164     // faster.
165     for (const HloValue* value : from_value_set) {
166       value_to_buffer_number_.at(value) = to;
167     }
168     buffers_.erase(from);
169   }
170 
GetBufferForValue(const HloValue & value)171   BufferNumber GetBufferForValue(const HloValue& value) {
172     return value_to_buffer_number_.at(&value);
173   }
174 
ComputeInputOutputAliasedBuffers(const HloValue & value,std::vector<BufferNumber> * aliased_buffers)175   void ComputeInputOutputAliasedBuffers(
176       const HloValue& value, std::vector<BufferNumber>* aliased_buffers) {
177     // Get parameter value from an aliased_input object.
178     const auto get_parameter_value =
179         [this](const HloInputOutputAliasConfig::Alias& aliased_input)
180         -> const HloValue& {
181       return dataflow_.GetUniqueValueAt(
182           module_->entry_computation()->parameter_instruction(
183               aliased_input.parameter_number),
184           aliased_input.parameter_index);
185     };
186 
187     // If the value shows up in a root instruction, alias it with parameter
188     // intruction.
189     for (const HloPosition& pos : value.positions()) {
190       if (pos.instruction == module_->entry_computation()->root_instruction()) {
191         ShapeIndex output_index = pos.index;
192 
193         auto aliased_input =
194             module_->input_output_alias_config().GetAliasedParameter(
195                 output_index);
196         if (aliased_input) {
197           aliased_buffers->push_back(
198               GetBufferForValue(get_parameter_value(*aliased_input)));
199         }
200       }
201     }
202 
203     // If the value is parameter instruction itself, alias it with itself.
204     if (value.instruction()->opcode() == HloOpcode::kParameter &&
205         value.instruction()->parent() == module_->entry_computation()) {
206       aliased_buffers->push_back(GetBufferForValue(value));
207     }
208   }
209 
ComputeWhileAliasedBuffers(const HloValue & value,std::vector<BufferNumber> * aliased_buffers)210   void ComputeWhileAliasedBuffers(const HloValue& value,
211                                   std::vector<BufferNumber>* aliased_buffers) {
212     VLOG(3) << "Compute kWhile aliases";
213     // Value is init of a while (use is while).
214     for (const HloUse& use : value.uses()) {
215       if (use.instruction->opcode() == HloOpcode::kWhile) {
216         // Determine the while value that this shares a buffer with.
217         const HloValue& while_value =
218             dataflow_.GetUniqueValueAt(use.instruction, use.operand_index);
219         aliased_buffers->push_back(GetBufferForValue(while_value));
220         VLOG(3) << "  value is init value to a while; must share buffer with "
221                    "while value "
222                 << while_value.ToShortString();
223       }
224     }
225     // Value is a parameter of a while body/condition.
226     if (value.defining_instruction()->opcode() == HloOpcode::kParameter) {
227       const HloComputation* computation =
228           value.defining_instruction()->parent();
229       const CallGraphNode& call_graph_node =
230           dataflow_.call_graph().GetNode(computation);
231       for (const CallSite& callsite : call_graph_node.caller_callsites()) {
232         if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
233           // Call graph must have been flattened.
234           CHECK_EQ(call_graph_node.caller_callsites().size(), 1);
235 
236           const HloValue& while_value = dataflow_.GetUniqueValueAt(
237               callsite.instruction(), value.defining_index());
238           VLOG(3) << "  value is parameter value of the body or condition of a "
239                      "while; must share buffer with while value "
240                   << while_value.ToShortString();
241           aliased_buffers->push_back(GetBufferForValue(while_value));
242         }
243       }
244     }
245     // Value is the root of a while body.
246     for (const HloPosition& position : value.positions()) {
247       const HloComputation* computation = position.instruction->parent();
248       const CallGraphNode& call_graph_node =
249           dataflow_.call_graph().GetNode(computation);
250       if (position.instruction == computation->root_instruction()) {
251         for (const CallSite& callsite : call_graph_node.caller_callsites()) {
252           if (callsite.instruction()->opcode() == HloOpcode::kWhile &&
253               callsite.instruction()->while_body() == computation) {
254             // Call graph must have been flattened.
255             CHECK_EQ(call_graph_node.caller_callsites().size(), 1);
256 
257             const HloValue& while_value = dataflow_.GetUniqueValueAt(
258                 callsite.instruction(), position.index);
259             VLOG(3) << "  value @ " << position << " is root of "
260                     << callsite.instruction()->name()
261                     << "; body root and while value root must share buffer "
262                        "among them : "
263                     << while_value.ToShortString();
264             aliased_buffers->push_back(GetBufferForValue(while_value));
265           }
266         }
267       }
268     }
269     // Value is the output of the while instruction itself.
270     if (value.defining_instruction()->opcode() == HloOpcode::kWhile) {
271       VLOG(3) << "  value is output of a while instruction";
272       aliased_buffers->push_back(GetBufferForValue(value));
273     }
274   }
275 
ComputeConditionalAliasedBuffers(const HloValue & value,std::vector<BufferNumber> * aliased_buffers)276   void ComputeConditionalAliasedBuffers(
277       const HloValue& value, std::vector<BufferNumber>* aliased_buffers) {
278     VLOG(3) << "Compute kConditional aliases";
279     // Aliases the buffers of the true/false computations roots, with the one of
280     // the conditional.
281     for (const HloPosition& position : value.positions()) {
282       const HloComputation* computation = position.instruction->parent();
283       const CallGraphNode& call_graph_node =
284           dataflow_.call_graph().GetNode(computation);
285       if (position.instruction == computation->root_instruction()) {
286         for (const CallSite& callsite : call_graph_node.caller_callsites()) {
287           if (callsite.instruction()->opcode() == HloOpcode::kConditional) {
288             // Call graph must have been flattened.
289             CHECK_EQ(call_graph_node.caller_callsites().size(), 1);
290 
291             const HloValue& cond_value = dataflow_.GetUniqueValueAt(
292                 callsite.instruction(), position.index);
293             VLOG(3)
294                 << "  value @ " << position << " is root of "
295                 << callsite.instruction()->name()
296                 << "; branch computation roots must share buffer among them : "
297                 << cond_value.ToShortString();
298             aliased_buffers->push_back(GetBufferForValue(cond_value));
299           }
300         }
301       }
302     }
303     // Value is the output of the conditional instruction itself.
304     if (value.defining_instruction()->opcode() == HloOpcode::kConditional) {
305       VLOG(3) << "  value is output of a conditional instruction";
306       aliased_buffers->push_back(GetBufferForValue(value));
307     }
308   }
309 
310   // Compute and return a vector of buffers that the given value must be
311   // contained in due to HLO aliasing rules.
ComputeAliasedBuffers(const HloValue & value)312   std::vector<BufferNumber> ComputeAliasedBuffers(const HloValue& value) {
313     for (const HloUse& use : value.uses()) {
314       VLOG(2) << "Use of value " << value.ToShortString() << ": " << use;
315     }
316     std::vector<BufferNumber> aliased_buffers;
317     ComputeInputOutputAliasedBuffers(value, &aliased_buffers);
318     ComputeWhileAliasedBuffers(value, &aliased_buffers);
319     ComputeConditionalAliasedBuffers(value, &aliased_buffers);
320     // Uniquify aliased buffers.
321     absl::c_sort(aliased_buffers);
322     aliased_buffers.erase(
323         std::unique(aliased_buffers.begin(), aliased_buffers.end()),
324         aliased_buffers.end());
325     return aliased_buffers;
326   }
327 
328   HloModule* module_;
329 
330   // Dataflow analysis used to construct the buffer map.
331   const HloDataflowAnalysis& dataflow_;
332 
333   // A map containing the set of values contained in each buffer.
334   absl::flat_hash_map<BufferNumber, absl::flat_hash_set<const HloValue*>>
335       buffers_;
336 
337   // A map indicating which buffer each value is contained in.
338   absl::flat_hash_map<const HloValue*, BufferNumber> value_to_buffer_number_;
339 
340   // The buffer number of the next buffer to be created.
341   BufferNumber next_buffer_number_ = 0;
342 };
343 
HloAliasAnalysis(HloModule * module)344 HloAliasAnalysis::HloAliasAnalysis(HloModule* module) : module_(module) {}
345 
GetUniqueBufferAt(const HloInstruction * instruction,const ShapeIndex & index) const346 const HloBuffer& HloAliasAnalysis::GetUniqueBufferAt(
347     const HloInstruction* instruction, const ShapeIndex& index) const {
348   std::vector<const HloBuffer*> buffers = ComputeBuffersAt(instruction, index);
349   CHECK_EQ(buffers.size(), 1);
350   return *buffers[0];
351 }
352 
GetUniqueBufferAt(const HloInstruction * instruction,const ShapeIndex & index)353 HloBuffer& HloAliasAnalysis::GetUniqueBufferAt(
354     const HloInstruction* instruction, const ShapeIndex& index) {
355   return GetBuffer(static_cast<const HloAliasAnalysis*>(this)
356                        ->GetUniqueBufferAt(instruction, index)
357                        .id());
358 }
359 
ComputeBuffersAt(const HloInstruction * instruction,const ShapeIndex & index) const360 std::vector<const HloBuffer*> HloAliasAnalysis::ComputeBuffersAt(
361     const HloInstruction* instruction, const ShapeIndex& index) const {
362   std::vector<const HloBuffer*> buffers;
363   for (const HloValue* value :
364        dataflow_analysis_->GetValueSet(instruction, index).values()) {
365     buffers.push_back(&GetBufferContainingValue(*value));
366   }
367 
368   // Sort and uniquify vector before returning.
369   absl::c_sort(buffers, HloBuffer::IdLessThan);
370   buffers.erase(std::unique(buffers.begin(), buffers.end()), buffers.end());
371 
372   return buffers;
373 }
374 
InstructionBuffersAreAmbiguous(const HloInstruction * instruction) const375 bool HloAliasAnalysis::InstructionBuffersAreAmbiguous(
376     const HloInstruction* instruction) const {
377   for (const auto& pair :
378        dataflow_analysis_->GetInstructionValueSet(instruction)) {
379     const HloValueSet& value_set = pair.second;
380     const HloBuffer* buffer = nullptr;
381     for (const HloValue* value : value_set.values()) {
382       if (buffer == nullptr) {
383         buffer = &GetBufferContainingValue(*value);
384       } else if (buffer != &GetBufferContainingValue(*value)) {
385         return true;
386       }
387     }
388   }
389   return false;
390 }
391 
InstructionBuffersAreDistinct(const HloInstruction * instruction) const392 bool HloAliasAnalysis::InstructionBuffersAreDistinct(
393     const HloInstruction* instruction) const {
394   absl::flat_hash_set<const HloBuffer*> buffers_seen;
395   for (const auto& pair :
396        dataflow_analysis_->GetInstructionValueSet(instruction)) {
397     const HloValueSet& value_set = pair.second;
398     if (value_set.values().size() == 1) {
399       if (!buffers_seen
400                .insert(&GetBufferContainingValue(value_set.GetUniqueValue()))
401                .second) {
402         return false;
403       }
404     } else {
405       // It's possible for multiple values at this index to have the same
406       // HloBuffer. This does not result in non-distictness. To account for
407       // this case, add all of the buffers at this index after checking
408       // whether each buffer exists at an earlier index. This is a corner
409       // case, however, as the number of values at an index is almost always
410       // one.
411       std::vector<const HloBuffer*> buffers_at_this_index;
412       for (const HloValue* value : value_set.values()) {
413         const HloBuffer* buffer = &GetBufferContainingValue(*value);
414         if (ContainsKey(buffers_seen, buffer)) {
415           return false;
416         }
417         buffers_at_this_index.push_back(buffer);
418       }
419       buffers_seen.insert(buffers_at_this_index.begin(),
420                           buffers_at_this_index.end());
421     }
422   }
423   return true;
424 }
425 
Verify() const426 Status HloAliasAnalysis::Verify() const {
427   // Verify consistency between the value_to_buffer_ map and
428   // HloBuffer::values().
429   for (const auto& pair : value_to_buffer_) {
430     const HloValue* value = pair.first;
431     const HloBuffer& buffer = *pair.second;
432     TF_RET_CHECK(absl::c_linear_search(buffer.values(), value));
433   }
434 
435   for (HloBuffer::Id id = 0; id < buffers_.size(); ++id) {
436     const HloBuffer& buffer = buffers_[id];
437     TF_RET_CHECK(buffer.id() == id);
438 
439     HloValue::Id last_value_id = -1;
440     for (const HloValue* value : buffer.values()) {
441       TF_RET_CHECK(GetBufferContainingValue(*value) == buffer);
442 
443       // Also verify the values in HloBuffer are unique and sorted by id.
444       TF_RET_CHECK(value->id() > last_value_id);
445       last_value_id = value->id();
446     }
447   }
448 
449   return Status::OK();
450 }
451 
ToString() const452 string HloAliasAnalysis::ToString() const {
453   string out = absl::StrCat("HloAliasAnalysis, module ", module_->name(), "\n");
454   StrAppend(&out, "  Buffers at each position:\n");
455   for (const HloComputation* computation : module_->computations()) {
456     for (const HloInstruction* instruction : computation->instructions()) {
457       StrAppend(&out, "    ", instruction->name(), ":\n");
458       if (instruction->shape().IsTuple()) {
459         ShapeUtil::ForEachSubshape(
460             instruction->shape(),
461             [&out, &instruction, this](const Shape&, const ShapeIndex& index) {
462               StrAppend(&out, "      tuple index ", index.ToString(), ":\n");
463               for (const HloBuffer* buffer :
464                    ComputeBuffersAt(instruction, index)) {
465                 StrAppend(&out, "        ", buffer->ToString(), "\n");
466               }
467             });
468       } else {
469         for (const HloBuffer* buffer :
470              ComputeBuffersAt(instruction, /*index=*/{})) {
471           StrAppend(&out, "      ", buffer->ToString(), "\n");
472         }
473       }
474     }
475   }
476 
477   StrAppend(&out, "  Buffers:\n");
478   for (const HloBuffer& buffer : buffers()) {
479     StrAppend(&out, "    ", buffer.ToString(), "\n");
480     StrAppend(&out, "      positions:\n");
481     for (const HloPosition& position : buffer.ComputePositions()) {
482       StrAppend(&out, "        ", position.ToString(), "\n");
483     }
484   }
485 
486   return out;
487 }
488 
489 /* static */
Run(HloModule * module,const HloDataflowAnalysis::FusionCanShareBufferFunction & fusion_can_share_buffer)490 StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
491     HloModule* module, const HloDataflowAnalysis::FusionCanShareBufferFunction&
492                            fusion_can_share_buffer) {
493   VLOG(2) << "HloAliasAnalysis::Run on module " << module->name();
494   XLA_VLOG_LINES(2, module->ToString());
495 
496   auto alias_analysis = absl::WrapUnique(new HloAliasAnalysis(module));
497   TF_ASSIGN_OR_RETURN(alias_analysis->dataflow_analysis_,
498                       HloDataflowAnalysis::Run(*module, /*ssa_form=*/true,
499                                                /*bitcast_defines_value=*/false,
500                                                fusion_can_share_buffer));
501 
502   BufferValueMap buffer_map(module, alias_analysis->dataflow_analysis());
503   buffer_map.MergeAliasedBuffers();
504 
505   // Create a vector of HloBuffers, one for each set of values in the
506   // BufferValueMap. Create the HloBuffers as a vector of contiguously numbered
507   // buffers.
508   std::vector<BufferValueMap::BufferNumber> sorted_buffer_numbers =
509       buffer_map.ComputeSortedBufferNumbers();
510   alias_analysis->buffers_.reserve(sorted_buffer_numbers.size());
511   HloBuffer::Id next_id = 0;
512   for (BufferValueMap::BufferNumber buffer_number : sorted_buffer_numbers) {
513     auto& value_set = buffer_map.GetValuesInBuffer(buffer_number);
514     std::vector<const HloValue*> sorted_values(value_set.begin(),
515                                                value_set.end());
516     absl::c_sort(sorted_values, HloValue::IdLessThan);
517     alias_analysis->buffers_.emplace_back(next_id++, sorted_values);
518     for (const HloValue* value : sorted_values) {
519       alias_analysis->value_to_buffer_[value] =
520           &alias_analysis->buffers_.back();
521     }
522   }
523 
524   TF_DCHECK_OK(alias_analysis->Verify());
525 
526   XLA_VLOG_LINES(2, alias_analysis->ToString());
527   return std::move(alias_analysis);
528 }
529 
HasLiveRangeInterference(const HloOrdering & ordering) const530 bool HloAliasAnalysis::HasLiveRangeInterference(
531     const HloOrdering& ordering) const {
532   for (const HloBuffer& buffer : buffers()) {
533     CHECK(!buffer.values().empty());
534     if (buffer.values().front()->shape().IsToken()) {
535       // Tokens have no on-device representation and cannot interfere.
536       for (const HloValue* value : buffer.values()) {
537         // If one of the values is a token, all values must be a token.
538         DCHECK(value->shape().IsToken());
539       }
540       continue;
541     }
542 
543     // Check that the values in the buffer are totally ordered with respect to
544     // 'ordering'. Begin by sorting the values with respect to 'ordering' with a
545     // tie-break using value ID. The tie-break is necessary because we need a
546     // strict weak order for std::sort.
547     std::vector<const HloValue*> values = buffer.values();
548     absl::c_sort(values, [&ordering](const HloValue* a, const HloValue* b) {
549       if (ordering.IsDefinedBefore(*a, *b)) {
550         return true;
551       } else if (ordering.IsDefinedBefore(*b, *a)) {
552         return false;
553       } else {
554         return a->id() < b->id();
555       }
556     });
557 
558     // Walk through the ordered vector of values. First verify that the values
559     // are totally ordered with respect to 'ordering', then check that no
560     // adjacent values have overlapping live ranges. Only adjacent values must
561     // be checked because of the property of live range interference. For
562     // example, if you have values A, B, and C (in program order) contained in
563     // a buffer and A interferes with C, then necessarily A also interferes
564     // with B. So to check interference you only need to check interference
565     // between A and B, and between B and C.
566     for (int i = 1; i < values.size(); ++i) {
567       if (!ordering.IsDefinedBefore(*values[i - 1], *values[i])) {
568         VLOG(1) << values[i - 1]->ToShortString() << " and "
569                 << values[i]->ToShortString() << " are not ordered";
570         return true;
571       }
572       if (ordering.MayInterfere(*values[i - 1], *values[i],
573                                 dataflow_analysis())) {
574         VLOG(1) << "In buffer " << buffer.id() << " containing values:\n  "
575                 << absl::StrJoin(values, ", ",
576                                  [](string* out, const HloValue* value) {
577                                    StrAppend(out, value->ToShortString());
578                                  })
579 
580                 << "\nValue " << values[i - 1]->ToShortString()
581                 << " may interfere with value " << values[i]->ToShortString();
582         return true;
583       }
584     }
585   }
586 
587   return false;
588 }
589 
590 }  // namespace xla
591