1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
17
18 #include <algorithm>
19 #include <memory>
20 #include <utility>
21 #include <vector>
22
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/strings/str_cat.h"
26 #include "absl/strings/str_join.h"
27 #include "tensorflow/compiler/xla/map_util.h"
28 #include "tensorflow/compiler/xla/service/hlo_buffer.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_value.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/platform/logging.h"
36
37 namespace xla {
38
39 using absl::StrAppend;
40
41 // Data structure used to construct the alias analysis. Thrown away after alias
42 // analysis is complete. This data structure keeps track of which sets of
43 // HloValues must be in the same HloBuffer. This is maintained as a map from a
44 // buffer identifier (BufferNumber) to set of HLoValues.
45 //
46 // Initially each value is its own buffer. In MergeAliasedBuffers, sets of
47 // values which must share the same buffer are merged together. The end result
48 // is a partitioning of all HloValues into sets where each set needs its own
49 // HloBuffer. By performing this analysis without constructing HloBuffers on the
50 // fly, we can after-the-fact construct a vector of contiguously numbered
51 // HloBuffers after the buffer requirement has been determined.
52 class BufferValueMap {
53 public:
54 // A unique identifier for a set of colocated values which must share the same
55 // buffer. This is not necessarily the same as the HloBuffer::Id which will
56 // ultimately contain the values. The reason is that HloBuffer::Id's are
57 // contiguous, while BufferNumbers may not be. BufferNumbers may not be
58 // dense because buffers may be created and destroyed during the analysis
59 // construction process.
60 using BufferNumber = int64;
61
BufferValueMap(HloModule * module,const HloDataflowAnalysis & dataflow)62 explicit BufferValueMap(HloModule* module,
63 const HloDataflowAnalysis& dataflow)
64 : module_(module), dataflow_(dataflow) {
65 buffers_.reserve(dataflow_.values().size());
66 value_to_buffer_number_.reserve(dataflow_.values().size());
67 for (const HloValue* value : dataflow_.values()) {
68 BufferNumber buffer_number = next_buffer_number_++;
69 buffers_[buffer_number].insert(value);
70 value_to_buffer_number_[value] = buffer_number;
71 }
72 }
73
74 // Merge together sets of HloValues which must be in the same HloBuffer
75 // because of aliasing rules (eg, in-place kWhile instruction).
MergeAliasedBuffers()76 void MergeAliasedBuffers() {
77 for (const HloValue* value : dataflow_.values()) {
78 VLOG(3) << "Merging colocated values, value: " << value->ToShortString();
79
80 // Gather the set of buffers with aliasing rules (eg, kWhile) which this
81 // value must be contained in.
82 std::vector<BufferNumber> aliased_buffers = ComputeAliasedBuffers(*value);
83
84 BufferNumber current_buffer = value_to_buffer_number_.at(value);
85 if (aliased_buffers.empty()) {
86 // The buffer containing 'value' aliases no other buffers. If the buffer
87 // containing 'value' already only contains 'value', then no change is
88 // necessary. If the buffer containing 'value' does contain other
89 // values, then remove 'value' from the buffer and create a new buffer
90 // containing only 'value'
91 if (buffers_.at(current_buffer).size() == 1) {
92 CHECK_EQ(*buffers_.at(current_buffer).begin(), value);
93 } else {
94 MoveValueToNewBuffer(*value);
95 }
96 } else {
97 // If multiple buffers are aliased merge these buffers together into a
98 // single buffer (arbitrarily chosen as the first buffer in the vector).
99 if (aliased_buffers.size() > 1) {
100 for (int64 i = 1; i < aliased_buffers.size(); ++i) {
101 MergeBuffers(/*from=*/aliased_buffers[i],
102 /*to=*/aliased_buffers[0]);
103 }
104 }
105 BufferNumber new_buffer = aliased_buffers[0];
106 if (current_buffer != new_buffer) {
107 MoveValueToBuffer(*value, new_buffer);
108 }
109 }
110 }
111 }
112
113 // Compute and return a sorted vector of all BufferNumbers. Can be used to
114 // iterate through all buffers stabily.
ComputeSortedBufferNumbers() const115 std::vector<BufferNumber> ComputeSortedBufferNumbers() const {
116 std::vector<BufferNumber> buffer_numbers;
117 for (const auto& pair : buffers_) {
118 buffer_numbers.push_back(pair.first);
119 }
120 absl::c_sort(buffer_numbers);
121 return buffer_numbers;
122 }
123
124 // Return a set of all the values in the given buffer.
GetValuesInBuffer(BufferNumber buffer_number) const125 const absl::flat_hash_set<const HloValue*>& GetValuesInBuffer(
126 BufferNumber buffer_number) const {
127 return buffers_.at(buffer_number);
128 }
129
130 private:
131 // Create a new buffer.
NewBuffer(const HloValue & value)132 void NewBuffer(const HloValue& value) {
133 BufferNumber buffer_number = next_buffer_number_++;
134 buffers_[buffer_number].insert(&value);
135 value_to_buffer_number_[&value] = buffer_number;
136 }
137
138 // Move the given value into a new buffer containing only the value.
MoveValueToNewBuffer(const HloValue & value)139 void MoveValueToNewBuffer(const HloValue& value) {
140 BufferNumber new_buffer_number = next_buffer_number_++;
141 buffers_[new_buffer_number];
142 MoveValueToBuffer(value, new_buffer_number);
143 }
144
145 // Move the given value into the given buffer.
MoveValueToBuffer(const HloValue & value,BufferNumber buffer_number)146 void MoveValueToBuffer(const HloValue& value, BufferNumber buffer_number) {
147 BufferNumber old_buffer_number = value_to_buffer_number_.at(&value);
148 absl::flat_hash_set<const HloValue*>& old_value_set =
149 buffers_.at(old_buffer_number);
150 old_value_set.erase(&value);
151 if (old_value_set.empty()) {
152 buffers_.erase(old_buffer_number);
153 }
154
155 buffers_.at(buffer_number).insert(&value);
156 value_to_buffer_number_.at(&value) = buffer_number;
157 }
158
159 // Merge the buffer 'from' into the buffer 'to'.
MergeBuffers(BufferNumber from,BufferNumber to)160 void MergeBuffers(BufferNumber from, BufferNumber to) {
161 auto& from_value_set = buffers_.at(from);
162 buffers_.at(to).insert(from_value_set.begin(), from_value_set.end());
163 // NOTE: using a union-find algorithm to hold the colocated values might be
164 // faster.
165 for (const HloValue* value : from_value_set) {
166 value_to_buffer_number_.at(value) = to;
167 }
168 buffers_.erase(from);
169 }
170
GetBufferForValue(const HloValue & value)171 BufferNumber GetBufferForValue(const HloValue& value) {
172 return value_to_buffer_number_.at(&value);
173 }
174
ComputeInputOutputAliasedBuffers(const HloValue & value,std::vector<BufferNumber> * aliased_buffers)175 void ComputeInputOutputAliasedBuffers(
176 const HloValue& value, std::vector<BufferNumber>* aliased_buffers) {
177 // Get parameter value from an aliased_input object.
178 const auto get_parameter_value =
179 [this](const HloInputOutputAliasConfig::Alias& aliased_input)
180 -> const HloValue& {
181 return dataflow_.GetUniqueValueAt(
182 module_->entry_computation()->parameter_instruction(
183 aliased_input.parameter_number),
184 aliased_input.parameter_index);
185 };
186
187 // If the value shows up in a root instruction, alias it with parameter
188 // intruction.
189 for (const HloPosition& pos : value.positions()) {
190 if (pos.instruction == module_->entry_computation()->root_instruction()) {
191 ShapeIndex output_index = pos.index;
192
193 auto aliased_input =
194 module_->input_output_alias_config().GetAliasedParameter(
195 output_index);
196 if (aliased_input) {
197 aliased_buffers->push_back(
198 GetBufferForValue(get_parameter_value(*aliased_input)));
199 }
200 }
201 }
202
203 // If the value is parameter instruction itself, alias it with itself.
204 if (value.instruction()->opcode() == HloOpcode::kParameter &&
205 value.instruction()->parent() == module_->entry_computation()) {
206 aliased_buffers->push_back(GetBufferForValue(value));
207 }
208 }
209
ComputeWhileAliasedBuffers(const HloValue & value,std::vector<BufferNumber> * aliased_buffers)210 void ComputeWhileAliasedBuffers(const HloValue& value,
211 std::vector<BufferNumber>* aliased_buffers) {
212 VLOG(3) << "Compute kWhile aliases";
213 // Value is init of a while (use is while).
214 for (const HloUse& use : value.uses()) {
215 if (use.instruction->opcode() == HloOpcode::kWhile) {
216 // Determine the while value that this shares a buffer with.
217 const HloValue& while_value =
218 dataflow_.GetUniqueValueAt(use.instruction, use.operand_index);
219 aliased_buffers->push_back(GetBufferForValue(while_value));
220 VLOG(3) << " value is init value to a while; must share buffer with "
221 "while value "
222 << while_value.ToShortString();
223 }
224 }
225 // Value is a parameter of a while body/condition.
226 if (value.defining_instruction()->opcode() == HloOpcode::kParameter) {
227 const HloComputation* computation =
228 value.defining_instruction()->parent();
229 const CallGraphNode& call_graph_node =
230 dataflow_.call_graph().GetNode(computation);
231 for (const CallSite& callsite : call_graph_node.caller_callsites()) {
232 if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
233 // Call graph must have been flattened.
234 CHECK_EQ(call_graph_node.caller_callsites().size(), 1);
235
236 const HloValue& while_value = dataflow_.GetUniqueValueAt(
237 callsite.instruction(), value.defining_index());
238 VLOG(3) << " value is parameter value of the body or condition of a "
239 "while; must share buffer with while value "
240 << while_value.ToShortString();
241 aliased_buffers->push_back(GetBufferForValue(while_value));
242 }
243 }
244 }
245 // Value is the root of a while body.
246 for (const HloPosition& position : value.positions()) {
247 const HloComputation* computation = position.instruction->parent();
248 const CallGraphNode& call_graph_node =
249 dataflow_.call_graph().GetNode(computation);
250 if (position.instruction == computation->root_instruction()) {
251 for (const CallSite& callsite : call_graph_node.caller_callsites()) {
252 if (callsite.instruction()->opcode() == HloOpcode::kWhile &&
253 callsite.instruction()->while_body() == computation) {
254 // Call graph must have been flattened.
255 CHECK_EQ(call_graph_node.caller_callsites().size(), 1);
256
257 const HloValue& while_value = dataflow_.GetUniqueValueAt(
258 callsite.instruction(), position.index);
259 VLOG(3) << " value @ " << position << " is root of "
260 << callsite.instruction()->name()
261 << "; body root and while value root must share buffer "
262 "among them : "
263 << while_value.ToShortString();
264 aliased_buffers->push_back(GetBufferForValue(while_value));
265 }
266 }
267 }
268 }
269 // Value is the output of the while instruction itself.
270 if (value.defining_instruction()->opcode() == HloOpcode::kWhile) {
271 VLOG(3) << " value is output of a while instruction";
272 aliased_buffers->push_back(GetBufferForValue(value));
273 }
274 }
275
ComputeConditionalAliasedBuffers(const HloValue & value,std::vector<BufferNumber> * aliased_buffers)276 void ComputeConditionalAliasedBuffers(
277 const HloValue& value, std::vector<BufferNumber>* aliased_buffers) {
278 VLOG(3) << "Compute kConditional aliases";
279 // Aliases the buffers of the true/false computations roots, with the one of
280 // the conditional.
281 for (const HloPosition& position : value.positions()) {
282 const HloComputation* computation = position.instruction->parent();
283 const CallGraphNode& call_graph_node =
284 dataflow_.call_graph().GetNode(computation);
285 if (position.instruction == computation->root_instruction()) {
286 for (const CallSite& callsite : call_graph_node.caller_callsites()) {
287 if (callsite.instruction()->opcode() == HloOpcode::kConditional) {
288 // Call graph must have been flattened.
289 CHECK_EQ(call_graph_node.caller_callsites().size(), 1);
290
291 const HloValue& cond_value = dataflow_.GetUniqueValueAt(
292 callsite.instruction(), position.index);
293 VLOG(3)
294 << " value @ " << position << " is root of "
295 << callsite.instruction()->name()
296 << "; branch computation roots must share buffer among them : "
297 << cond_value.ToShortString();
298 aliased_buffers->push_back(GetBufferForValue(cond_value));
299 }
300 }
301 }
302 }
303 // Value is the output of the conditional instruction itself.
304 if (value.defining_instruction()->opcode() == HloOpcode::kConditional) {
305 VLOG(3) << " value is output of a conditional instruction";
306 aliased_buffers->push_back(GetBufferForValue(value));
307 }
308 }
309
310 // Compute and return a vector of buffers that the given value must be
311 // contained in due to HLO aliasing rules.
ComputeAliasedBuffers(const HloValue & value)312 std::vector<BufferNumber> ComputeAliasedBuffers(const HloValue& value) {
313 for (const HloUse& use : value.uses()) {
314 VLOG(2) << "Use of value " << value.ToShortString() << ": " << use;
315 }
316 std::vector<BufferNumber> aliased_buffers;
317 ComputeInputOutputAliasedBuffers(value, &aliased_buffers);
318 ComputeWhileAliasedBuffers(value, &aliased_buffers);
319 ComputeConditionalAliasedBuffers(value, &aliased_buffers);
320 // Uniquify aliased buffers.
321 absl::c_sort(aliased_buffers);
322 aliased_buffers.erase(
323 std::unique(aliased_buffers.begin(), aliased_buffers.end()),
324 aliased_buffers.end());
325 return aliased_buffers;
326 }
327
328 HloModule* module_;
329
330 // Dataflow analysis used to construct the buffer map.
331 const HloDataflowAnalysis& dataflow_;
332
333 // A map containing the set of values contained in each buffer.
334 absl::flat_hash_map<BufferNumber, absl::flat_hash_set<const HloValue*>>
335 buffers_;
336
337 // A map indicating which buffer each value is contained in.
338 absl::flat_hash_map<const HloValue*, BufferNumber> value_to_buffer_number_;
339
340 // The buffer number of the next buffer to be created.
341 BufferNumber next_buffer_number_ = 0;
342 };
343
HloAliasAnalysis(HloModule * module)344 HloAliasAnalysis::HloAliasAnalysis(HloModule* module) : module_(module) {}
345
GetUniqueBufferAt(const HloInstruction * instruction,const ShapeIndex & index) const346 const HloBuffer& HloAliasAnalysis::GetUniqueBufferAt(
347 const HloInstruction* instruction, const ShapeIndex& index) const {
348 std::vector<const HloBuffer*> buffers = ComputeBuffersAt(instruction, index);
349 CHECK_EQ(buffers.size(), 1);
350 return *buffers[0];
351 }
352
GetUniqueBufferAt(const HloInstruction * instruction,const ShapeIndex & index)353 HloBuffer& HloAliasAnalysis::GetUniqueBufferAt(
354 const HloInstruction* instruction, const ShapeIndex& index) {
355 return GetBuffer(static_cast<const HloAliasAnalysis*>(this)
356 ->GetUniqueBufferAt(instruction, index)
357 .id());
358 }
359
ComputeBuffersAt(const HloInstruction * instruction,const ShapeIndex & index) const360 std::vector<const HloBuffer*> HloAliasAnalysis::ComputeBuffersAt(
361 const HloInstruction* instruction, const ShapeIndex& index) const {
362 std::vector<const HloBuffer*> buffers;
363 for (const HloValue* value :
364 dataflow_analysis_->GetValueSet(instruction, index).values()) {
365 buffers.push_back(&GetBufferContainingValue(*value));
366 }
367
368 // Sort and uniquify vector before returning.
369 absl::c_sort(buffers, HloBuffer::IdLessThan);
370 buffers.erase(std::unique(buffers.begin(), buffers.end()), buffers.end());
371
372 return buffers;
373 }
374
InstructionBuffersAreAmbiguous(const HloInstruction * instruction) const375 bool HloAliasAnalysis::InstructionBuffersAreAmbiguous(
376 const HloInstruction* instruction) const {
377 for (const auto& pair :
378 dataflow_analysis_->GetInstructionValueSet(instruction)) {
379 const HloValueSet& value_set = pair.second;
380 const HloBuffer* buffer = nullptr;
381 for (const HloValue* value : value_set.values()) {
382 if (buffer == nullptr) {
383 buffer = &GetBufferContainingValue(*value);
384 } else if (buffer != &GetBufferContainingValue(*value)) {
385 return true;
386 }
387 }
388 }
389 return false;
390 }
391
InstructionBuffersAreDistinct(const HloInstruction * instruction) const392 bool HloAliasAnalysis::InstructionBuffersAreDistinct(
393 const HloInstruction* instruction) const {
394 absl::flat_hash_set<const HloBuffer*> buffers_seen;
395 for (const auto& pair :
396 dataflow_analysis_->GetInstructionValueSet(instruction)) {
397 const HloValueSet& value_set = pair.second;
398 if (value_set.values().size() == 1) {
399 if (!buffers_seen
400 .insert(&GetBufferContainingValue(value_set.GetUniqueValue()))
401 .second) {
402 return false;
403 }
404 } else {
405 // It's possible for multiple values at this index to have the same
406 // HloBuffer. This does not result in non-distictness. To account for
407 // this case, add all of the buffers at this index after checking
408 // whether each buffer exists at an earlier index. This is a corner
409 // case, however, as the number of values at an index is almost always
410 // one.
411 std::vector<const HloBuffer*> buffers_at_this_index;
412 for (const HloValue* value : value_set.values()) {
413 const HloBuffer* buffer = &GetBufferContainingValue(*value);
414 if (ContainsKey(buffers_seen, buffer)) {
415 return false;
416 }
417 buffers_at_this_index.push_back(buffer);
418 }
419 buffers_seen.insert(buffers_at_this_index.begin(),
420 buffers_at_this_index.end());
421 }
422 }
423 return true;
424 }
425
Verify() const426 Status HloAliasAnalysis::Verify() const {
427 // Verify consistency between the value_to_buffer_ map and
428 // HloBuffer::values().
429 for (const auto& pair : value_to_buffer_) {
430 const HloValue* value = pair.first;
431 const HloBuffer& buffer = *pair.second;
432 TF_RET_CHECK(absl::c_linear_search(buffer.values(), value));
433 }
434
435 for (HloBuffer::Id id = 0; id < buffers_.size(); ++id) {
436 const HloBuffer& buffer = buffers_[id];
437 TF_RET_CHECK(buffer.id() == id);
438
439 HloValue::Id last_value_id = -1;
440 for (const HloValue* value : buffer.values()) {
441 TF_RET_CHECK(GetBufferContainingValue(*value) == buffer);
442
443 // Also verify the values in HloBuffer are unique and sorted by id.
444 TF_RET_CHECK(value->id() > last_value_id);
445 last_value_id = value->id();
446 }
447 }
448
449 return Status::OK();
450 }
451
ToString() const452 string HloAliasAnalysis::ToString() const {
453 string out = absl::StrCat("HloAliasAnalysis, module ", module_->name(), "\n");
454 StrAppend(&out, " Buffers at each position:\n");
455 for (const HloComputation* computation : module_->computations()) {
456 for (const HloInstruction* instruction : computation->instructions()) {
457 StrAppend(&out, " ", instruction->name(), ":\n");
458 if (instruction->shape().IsTuple()) {
459 ShapeUtil::ForEachSubshape(
460 instruction->shape(),
461 [&out, &instruction, this](const Shape&, const ShapeIndex& index) {
462 StrAppend(&out, " tuple index ", index.ToString(), ":\n");
463 for (const HloBuffer* buffer :
464 ComputeBuffersAt(instruction, index)) {
465 StrAppend(&out, " ", buffer->ToString(), "\n");
466 }
467 });
468 } else {
469 for (const HloBuffer* buffer :
470 ComputeBuffersAt(instruction, /*index=*/{})) {
471 StrAppend(&out, " ", buffer->ToString(), "\n");
472 }
473 }
474 }
475 }
476
477 StrAppend(&out, " Buffers:\n");
478 for (const HloBuffer& buffer : buffers()) {
479 StrAppend(&out, " ", buffer.ToString(), "\n");
480 StrAppend(&out, " positions:\n");
481 for (const HloPosition& position : buffer.ComputePositions()) {
482 StrAppend(&out, " ", position.ToString(), "\n");
483 }
484 }
485
486 return out;
487 }
488
489 /* static */
Run(HloModule * module,const HloDataflowAnalysis::FusionCanShareBufferFunction & fusion_can_share_buffer)490 StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
491 HloModule* module, const HloDataflowAnalysis::FusionCanShareBufferFunction&
492 fusion_can_share_buffer) {
493 VLOG(2) << "HloAliasAnalysis::Run on module " << module->name();
494 XLA_VLOG_LINES(2, module->ToString());
495
496 auto alias_analysis = absl::WrapUnique(new HloAliasAnalysis(module));
497 TF_ASSIGN_OR_RETURN(alias_analysis->dataflow_analysis_,
498 HloDataflowAnalysis::Run(*module, /*ssa_form=*/true,
499 /*bitcast_defines_value=*/false,
500 fusion_can_share_buffer));
501
502 BufferValueMap buffer_map(module, alias_analysis->dataflow_analysis());
503 buffer_map.MergeAliasedBuffers();
504
505 // Create a vector of HloBuffers, one for each set of values in the
506 // BufferValueMap. Create the HloBuffers as a vector of contiguously numbered
507 // buffers.
508 std::vector<BufferValueMap::BufferNumber> sorted_buffer_numbers =
509 buffer_map.ComputeSortedBufferNumbers();
510 alias_analysis->buffers_.reserve(sorted_buffer_numbers.size());
511 HloBuffer::Id next_id = 0;
512 for (BufferValueMap::BufferNumber buffer_number : sorted_buffer_numbers) {
513 auto& value_set = buffer_map.GetValuesInBuffer(buffer_number);
514 std::vector<const HloValue*> sorted_values(value_set.begin(),
515 value_set.end());
516 absl::c_sort(sorted_values, HloValue::IdLessThan);
517 alias_analysis->buffers_.emplace_back(next_id++, sorted_values);
518 for (const HloValue* value : sorted_values) {
519 alias_analysis->value_to_buffer_[value] =
520 &alias_analysis->buffers_.back();
521 }
522 }
523
524 TF_DCHECK_OK(alias_analysis->Verify());
525
526 XLA_VLOG_LINES(2, alias_analysis->ToString());
527 return std::move(alias_analysis);
528 }
529
HasLiveRangeInterference(const HloOrdering & ordering) const530 bool HloAliasAnalysis::HasLiveRangeInterference(
531 const HloOrdering& ordering) const {
532 for (const HloBuffer& buffer : buffers()) {
533 CHECK(!buffer.values().empty());
534 if (buffer.values().front()->shape().IsToken()) {
535 // Tokens have no on-device representation and cannot interfere.
536 for (const HloValue* value : buffer.values()) {
537 // If one of the values is a token, all values must be a token.
538 DCHECK(value->shape().IsToken());
539 }
540 continue;
541 }
542
543 // Check that the values in the buffer are totally ordered with respect to
544 // 'ordering'. Begin by sorting the values with respect to 'ordering' with a
545 // tie-break using value ID. The tie-break is necessary because we need a
546 // strict weak order for std::sort.
547 std::vector<const HloValue*> values = buffer.values();
548 absl::c_sort(values, [&ordering](const HloValue* a, const HloValue* b) {
549 if (ordering.IsDefinedBefore(*a, *b)) {
550 return true;
551 } else if (ordering.IsDefinedBefore(*b, *a)) {
552 return false;
553 } else {
554 return a->id() < b->id();
555 }
556 });
557
558 // Walk through the ordered vector of values. First verify that the values
559 // are totally ordered with respect to 'ordering', then check that no
560 // adjacent values have overlapping live ranges. Only adjacent values must
561 // be checked because of the property of live range interference. For
562 // example, if you have values A, B, and C (in program order) contained in
563 // a buffer and A interferes with C, then necessarily A also interferes
564 // with B. So to check interference you only need to check interference
565 // between A and B, and between B and C.
566 for (int i = 1; i < values.size(); ++i) {
567 if (!ordering.IsDefinedBefore(*values[i - 1], *values[i])) {
568 VLOG(1) << values[i - 1]->ToShortString() << " and "
569 << values[i]->ToShortString() << " are not ordered";
570 return true;
571 }
572 if (ordering.MayInterfere(*values[i - 1], *values[i],
573 dataflow_analysis())) {
574 VLOG(1) << "In buffer " << buffer.id() << " containing values:\n "
575 << absl::StrJoin(values, ", ",
576 [](string* out, const HloValue* value) {
577 StrAppend(out, value->ToShortString());
578 })
579
580 << "\nValue " << values[i - 1]->ToShortString()
581 << " may interfere with value " << values[i]->ToShortString();
582 return true;
583 }
584 }
585 }
586
587 return false;
588 }
589
590 } // namespace xla
591