1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/strings/str_cat.h"
26 #include "absl/strings/str_join.h"
27 #include "tensorflow/compiler/xla/map_util.h"
28 #include "tensorflow/compiler/xla/service/hlo_buffer.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_value.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/platform/logging.h"
36 
37 namespace xla {
38 
39 using absl::StrAppend;
40 
41 // Data structure used to construct the alias analysis. Thrown away after alias
42 // analysis is complete. This data structure keeps track of which sets of
43 // HloValues must be in the same HloBuffer. This is maintained as a map from a
44 // buffer identifier (BufferNumber) to set of HLoValues.
45 //
46 // Initially each value is its own buffer. In MergeAliasedBuffers, sets of
47 // values which must share the same buffer are merged together. The end result
48 // is a partitioning of all HloValues into sets where each set needs its own
49 // HloBuffer. By performing this analysis without constructing HloBuffers on the
50 // fly, we can after-the-fact construct a vector of contiguously numbered
51 // HloBuffers after the buffer requirement has been determined.
52 class BufferValueMap {
53  public:
54   // A unique identifier for a set of colocated values which must share the same
55   // buffer. This is not necessarily the same as the HloBuffer::Id which will
56   // ultimately contain the values. The reason is that HloBuffer::Id's are
57   // contiguous, while BufferNumbers may not be. BufferNumbers may not be
58   // dense because buffers may be created and destroyed during the analysis
59   // construction process.
60   using BufferNumber = int64;
61 
BufferValueMap(const HloModule * module,const HloDataflowAnalysis & dataflow)62   explicit BufferValueMap(const HloModule* module,
63                           const HloDataflowAnalysis& dataflow)
64       : module_(module), dataflow_(dataflow) {
65     buffers_.reserve(dataflow_.values().size());
66     value_to_buffer_number_.reserve(dataflow_.values().size());
67     for (const HloValue* value : dataflow_.values()) {
68       BufferNumber buffer_number = next_buffer_number_++;
69       buffers_[buffer_number].insert(value);
70       value_to_buffer_number_[value] = buffer_number;
71     }
72   }
73 
74   // Merge together sets of HloValues which must be in the same HloBuffer
75   // because of aliasing rules (eg, in-place kWhile instruction).
MergeAliasedBuffers()76   void MergeAliasedBuffers() {
77     for (const HloValue* value : dataflow_.values()) {
78       VLOG(3) << "Merging colocated values, value: " << value->ToShortString();
79 
80       // Gather the set of buffers with aliasing rules (eg, kWhile) which this
81       // value must be contained in.
82       std::vector<BufferNumber> aliased_buffers = ComputeAliasedBuffers(*value);
83 
84       BufferNumber current_buffer = value_to_buffer_number_.at(value);
85       if (aliased_buffers.empty()) {
86         // The buffer containing 'value' aliases no other buffers. If the buffer
87         // containing 'value' already only contains 'value', then no change is
88         // necessary. If the buffer containing 'value' does contain other
89         // values, then remove 'value' from the buffer and create a new buffer
90         // containing only 'value'
91         if (buffers_.at(current_buffer).size() == 1) {
92           CHECK_EQ(*buffers_.at(current_buffer).begin(), value);
93         } else {
94           MoveValueToNewBuffer(*value);
95         }
96       } else {
97         // If multiple buffers are aliased merge these buffers together into a
98         // single buffer (arbitrarily chosen as the first buffer in the vector).
99         if (aliased_buffers.size() > 1) {
100           for (int64 i = 1; i < aliased_buffers.size(); ++i) {
101             MergeBuffers(/*from=*/aliased_buffers[i],
102                          /*to=*/aliased_buffers[0]);
103           }
104         }
105         BufferNumber new_buffer = aliased_buffers[0];
106         if (current_buffer != new_buffer) {
107           MoveValueToBuffer(*value, new_buffer);
108         }
109       }
110     }
111   }
112 
113   // Compute and return a sorted vector of all BufferNumbers. Can be used to
114   // iterate through all buffers stabily.
ComputeSortedBufferNumbers() const115   std::vector<BufferNumber> ComputeSortedBufferNumbers() const {
116     std::vector<BufferNumber> buffer_numbers;
117     for (const auto& pair : buffers_) {
118       buffer_numbers.push_back(pair.first);
119     }
120     absl::c_sort(buffer_numbers);
121     return buffer_numbers;
122   }
123 
124   // Return a set of all the values in the given buffer.
GetValuesInBuffer(BufferNumber buffer_number) const125   const absl::flat_hash_set<const HloValue*>& GetValuesInBuffer(
126       BufferNumber buffer_number) const {
127     return buffers_.at(buffer_number);
128   }
129 
130  private:
131   // Create a new buffer.
NewBuffer(const HloValue & value)132   void NewBuffer(const HloValue& value) {
133     BufferNumber buffer_number = next_buffer_number_++;
134     buffers_[buffer_number].insert(&value);
135     value_to_buffer_number_[&value] = buffer_number;
136   }
137 
138   // Move the given value into a new buffer containing only the value.
MoveValueToNewBuffer(const HloValue & value)139   void MoveValueToNewBuffer(const HloValue& value) {
140     BufferNumber new_buffer_number = next_buffer_number_++;
141     buffers_[new_buffer_number];
142     MoveValueToBuffer(value, new_buffer_number);
143   }
144 
145   // Move the given value into the given buffer.
MoveValueToBuffer(const HloValue & value,BufferNumber buffer_number)146   void MoveValueToBuffer(const HloValue& value, BufferNumber buffer_number) {
147     BufferNumber old_buffer_number = value_to_buffer_number_.at(&value);
148     absl::flat_hash_set<const HloValue*>& old_value_set =
149         buffers_.at(old_buffer_number);
150     old_value_set.erase(&value);
151     if (old_value_set.empty()) {
152       buffers_.erase(old_buffer_number);
153     }
154 
155     buffers_.at(buffer_number).insert(&value);
156     value_to_buffer_number_.at(&value) = buffer_number;
157   }
158 
159   // Merge the buffer 'from' into the buffer 'to'.
MergeBuffers(BufferNumber from,BufferNumber to)160   void MergeBuffers(BufferNumber from, BufferNumber to) {
161     auto& from_value_set = buffers_.at(from);
162     buffers_.at(to).insert(from_value_set.begin(), from_value_set.end());
163     // NOTE: using a union-find algorithm to hold the colocated values might be
164     // faster.
165     for (const HloValue* value : from_value_set) {
166       value_to_buffer_number_.at(value) = to;
167     }
168     buffers_.erase(from);
169   }
170 
GetBufferForValue(const HloValue & value)171   BufferNumber GetBufferForValue(const HloValue& value) {
172     return value_to_buffer_number_.at(&value);
173   }
174 
ComputeInputOutputAliasedBuffers(const HloValue & value,std::vector<BufferNumber> * aliased_buffers)175   void ComputeInputOutputAliasedBuffers(
176       const HloValue& value, std::vector<BufferNumber>* aliased_buffers) {
177     // Get parameter value from an aliased_input object.
178     const auto get_parameter_value =
179         [this](const HloInputOutputAliasConfig::Alias& aliased_input)
180         -> const HloValue& {
181       return dataflow_.GetUniqueValueAt(
182           module_->entry_computation()->parameter_instruction(
183               aliased_input.parameter_number),
184           aliased_input.parameter_index);
185     };
186 
187     // If the value shows up in a root instruction, alias it with parameter
188     // instruction.
189     for (const HloPosition& pos : value.positions()) {
190       if (pos.instruction == module_->entry_computation()->root_instruction()) {
191         ShapeIndex output_index = pos.index;
192 
193         auto aliased_input =
194             module_->input_output_alias_config().GetAliasedParameter(
195                 output_index);
196         if (aliased_input) {
197           aliased_buffers->push_back(
198               GetBufferForValue(get_parameter_value(*aliased_input)));
199         }
200       }
201     }
202 
203     // If the value is parameter instruction itself, alias it with itself.
204     if (value.instruction()->opcode() == HloOpcode::kParameter &&
205         value.instruction()->parent() == module_->entry_computation()) {
206       aliased_buffers->push_back(GetBufferForValue(value));
207     }
208   }
209 
ComputeWhileAliasedBuffers(const HloValue & value,std::vector<BufferNumber> * aliased_buffers)210   void ComputeWhileAliasedBuffers(const HloValue& value,
211                                   std::vector<BufferNumber>* aliased_buffers) {
212     VLOG(3) << "Compute kWhile aliases";
213     // Value is init of a while (use is while).
214     for (const HloUse& use : value.uses()) {
215       if (use.instruction->opcode() == HloOpcode::kWhile) {
216         // Determine the while value that this shares a buffer with.
217         const HloValue& while_value =
218             dataflow_.GetUniqueValueAt(use.instruction, use.operand_index);
219         aliased_buffers->push_back(GetBufferForValue(while_value));
220         VLOG(3) << "  value is init value to a while; must share buffer with "
221                    "while value "
222                 << while_value.ToShortString();
223       }
224     }
225     // Value is a parameter of a while body/condition.
226     if (value.defining_instruction()->opcode() == HloOpcode::kParameter) {
227       const HloComputation* computation =
228           value.defining_instruction()->parent();
229       const CallGraphNode& call_graph_node =
230           dataflow_.call_graph().GetNode(computation);
231       for (const CallSite& callsite : call_graph_node.caller_callsites()) {
232         if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
233           // Call graph must have been flattened.
234           CHECK_EQ(call_graph_node.caller_callsites().size(), 1);
235 
236           const HloValue& while_value = dataflow_.GetUniqueValueAt(
237               callsite.instruction(), value.defining_index());
238           VLOG(3) << "  value is parameter value of the body or condition of a "
239                      "while; must share buffer with while value "
240                   << while_value.ToShortString();
241           aliased_buffers->push_back(GetBufferForValue(while_value));
242         }
243       }
244     }
245     // Value is the root of a while body.
246     for (const HloPosition& position : value.positions()) {
247       const HloComputation* computation = position.instruction->parent();
248       const CallGraphNode& call_graph_node =
249           dataflow_.call_graph().GetNode(computation);
250       if (position.instruction == computation->root_instruction()) {
251         for (const CallSite& callsite : call_graph_node.caller_callsites()) {
252           if (callsite.instruction()->opcode() == HloOpcode::kWhile &&
253               callsite.instruction()->while_body() == computation) {
254             // Call graph must have been flattened.
255             CHECK_EQ(call_graph_node.caller_callsites().size(), 1)
256                 << "Call graph must have been flattened.";
257 
258             const HloValue& while_value = dataflow_.GetUniqueValueAt(
259                 callsite.instruction(), position.index);
260             VLOG(3) << "  value @ " << position << " is root of "
261                     << callsite.instruction()->name()
262                     << "; body root and while value root must share buffer "
263                        "among them : "
264                     << while_value.ToShortString();
265             aliased_buffers->push_back(GetBufferForValue(while_value));
266           }
267         }
268       }
269     }
270     // Value is the output of the while instruction itself.
271     if (value.defining_instruction()->opcode() == HloOpcode::kWhile) {
272       VLOG(3) << "  value is output of a while instruction";
273       aliased_buffers->push_back(GetBufferForValue(value));
274     }
275   }
276 
ComputeConditionalAliasedBuffers(const HloValue & value,std::vector<BufferNumber> * aliased_buffers)277   void ComputeConditionalAliasedBuffers(
278       const HloValue& value, std::vector<BufferNumber>* aliased_buffers) {
279     VLOG(3) << "Compute kConditional aliases";
280     // Aliases the buffers of the true/false computations roots, with the one of
281     // the conditional.
282     for (const HloPosition& position : value.positions()) {
283       const HloComputation* computation = position.instruction->parent();
284       const CallGraphNode& call_graph_node =
285           dataflow_.call_graph().GetNode(computation);
286       if (position.instruction == computation->root_instruction()) {
287         for (const CallSite& callsite : call_graph_node.caller_callsites()) {
288           if (callsite.instruction()->opcode() == HloOpcode::kConditional) {
289             // Call graph must have been flattened.
290             CHECK_EQ(call_graph_node.caller_callsites().size(), 1);
291 
292             const HloValue& cond_value = dataflow_.GetUniqueValueAt(
293                 callsite.instruction(), position.index);
294             VLOG(3)
295                 << "  value @ " << position << " is root of "
296                 << callsite.instruction()->name()
297                 << "; branch computation roots must share buffer among them : "
298                 << cond_value.ToShortString();
299             aliased_buffers->push_back(GetBufferForValue(cond_value));
300           }
301         }
302       }
303     }
304     // Value is the output of the conditional instruction itself.
305     if (value.defining_instruction()->opcode() == HloOpcode::kConditional) {
306       VLOG(3) << "  value is output of a conditional instruction";
307       aliased_buffers->push_back(GetBufferForValue(value));
308     }
309   }
310 
ComputeInPlaceOperationAliasedBuffers(const HloValue & value,std::vector<BufferNumber> * aliased_buffers)311   void ComputeInPlaceOperationAliasedBuffers(
312       const HloValue& value, std::vector<BufferNumber>* aliased_buffers) {
313     VLOG(3) << "Compute aliases for in-place operations (e.g. "
314                "kDynamicUpdateSlice and kScatter)";
315     for (const HloPosition& position : value.positions()) {
316       HloInstruction* instruction = position.instruction;
317       for (const auto& operand_and_output_index :
318            HloDataflowAnalysis::GetInPlaceInputOutputPairs(instruction)) {
319         if (position.index == operand_and_output_index.second) {
320           const HloUse& operand = operand_and_output_index.first;
321           const HloValue& operand_value = dataflow_.GetUniqueValueAt(
322               instruction->operand(operand.operand_number),
323               operand.operand_index);
324           VLOG(3) << " operand value " << operand_value.ToShortString()
325                   << " aliases.";
326           aliased_buffers->push_back(GetBufferForValue(operand_value));
327         }
328       }
329     }
330 
331     for (const HloUse& use : value.uses()) {
332       for (const auto& operand_and_output_index :
333            HloDataflowAnalysis::GetInPlaceInputOutputPairs(use.instruction)) {
334         if (use == operand_and_output_index.first) {
335           const HloValue& use_value = dataflow_.GetUniqueValueAt(
336               use.instruction, operand_and_output_index.second);
337           VLOG(3) << " use value " << use_value.ToShortString() << " aliases.";
338           aliased_buffers->push_back(GetBufferForValue(use_value));
339         }
340       }
341     }
342   }
343 
344   // Compute and return a vector of buffers that the given value must be
345   // contained in due to HLO aliasing rules.
ComputeAliasedBuffers(const HloValue & value)346   std::vector<BufferNumber> ComputeAliasedBuffers(const HloValue& value) {
347     for (const HloUse& use : value.uses()) {
348       VLOG(2) << "Use of value " << value.ToShortString() << ": " << use;
349     }
350     std::vector<BufferNumber> aliased_buffers;
351     ComputeInputOutputAliasedBuffers(value, &aliased_buffers);
352     ComputeWhileAliasedBuffers(value, &aliased_buffers);
353     ComputeConditionalAliasedBuffers(value, &aliased_buffers);
354     ComputeInPlaceOperationAliasedBuffers(value, &aliased_buffers);
355     // Uniquify aliased buffers.
356     absl::c_sort(aliased_buffers);
357     aliased_buffers.erase(
358         std::unique(aliased_buffers.begin(), aliased_buffers.end()),
359         aliased_buffers.end());
360     return aliased_buffers;
361   }
362 
363   const HloModule* module_ = nullptr;
364 
365   // Dataflow analysis used to construct the buffer map.
366   const HloDataflowAnalysis& dataflow_;
367 
368   // A map containing the set of values contained in each buffer.
369   absl::flat_hash_map<BufferNumber, absl::flat_hash_set<const HloValue*>>
370       buffers_;
371 
372   // A map indicating which buffer each value is contained in.
373   absl::flat_hash_map<const HloValue*, BufferNumber> value_to_buffer_number_;
374 
375   // The buffer number of the next buffer to be created.
376   BufferNumber next_buffer_number_ = 0;
377 };
378 
HloAliasAnalysis(const HloModule * module)379 HloAliasAnalysis::HloAliasAnalysis(const HloModule* module) : module_(module) {}
380 
GetUniqueBufferAt(const HloInstruction * instruction,const ShapeIndex & index) const381 const HloBuffer& HloAliasAnalysis::GetUniqueBufferAt(
382     const HloInstruction* instruction, const ShapeIndex& index) const {
383   std::vector<const HloBuffer*> buffers = ComputeBuffersAt(instruction, index);
384   CHECK_EQ(buffers.size(), 1);
385   return *buffers[0];
386 }
387 
GetUniqueBufferAt(const HloInstruction * instruction,const ShapeIndex & index)388 HloBuffer& HloAliasAnalysis::GetUniqueBufferAt(
389     const HloInstruction* instruction, const ShapeIndex& index) {
390   return GetBuffer(static_cast<const HloAliasAnalysis*>(this)
391                        ->GetUniqueBufferAt(instruction, index)
392                        .id());
393 }
394 
ComputeBuffersAt(const HloInstruction * instruction,const ShapeIndex & index) const395 std::vector<const HloBuffer*> HloAliasAnalysis::ComputeBuffersAt(
396     const HloInstruction* instruction, const ShapeIndex& index) const {
397   std::vector<const HloBuffer*> buffers;
398   for (const HloValue* value :
399        dataflow_analysis_->GetValueSet(instruction, index).values()) {
400     buffers.push_back(&GetBufferContainingValue(*value));
401   }
402 
403   // Sort and uniquify vector before returning.
404   absl::c_sort(buffers, HloBuffer::IdLessThan);
405   buffers.erase(std::unique(buffers.begin(), buffers.end()), buffers.end());
406 
407   return buffers;
408 }
409 
InstructionBuffersAreAmbiguous(const HloInstruction * instruction) const410 bool HloAliasAnalysis::InstructionBuffersAreAmbiguous(
411     const HloInstruction* instruction) const {
412   for (const auto& pair :
413        dataflow_analysis_->GetInstructionValueSet(instruction)) {
414     const HloValueSet& value_set = pair.second;
415     const HloBuffer* buffer = nullptr;
416     for (const HloValue* value : value_set.values()) {
417       if (buffer == nullptr) {
418         buffer = &GetBufferContainingValue(*value);
419       } else if (buffer != &GetBufferContainingValue(*value)) {
420         return true;
421       }
422     }
423   }
424   return false;
425 }
426 
InstructionBuffersAreDistinct(const HloInstruction * instruction) const427 bool HloAliasAnalysis::InstructionBuffersAreDistinct(
428     const HloInstruction* instruction) const {
429   absl::flat_hash_set<const HloBuffer*> buffers_seen;
430   for (const auto& pair :
431        dataflow_analysis_->GetInstructionValueSet(instruction)) {
432     const HloValueSet& value_set = pair.second;
433     if (value_set.values().size() == 1) {
434       if (!buffers_seen
435                .insert(&GetBufferContainingValue(value_set.GetUniqueValue()))
436                .second) {
437         return false;
438       }
439     } else {
440       // It's possible for multiple values at this index to have the same
441       // HloBuffer. This does not result in non-distinctness. To account for
442       // this case, add all of the buffers at this index after checking
443       // whether each buffer exists at an earlier index. This is a corner
444       // case, however, as the number of values at an index is almost always
445       // one.
446       std::vector<const HloBuffer*> buffers_at_this_index;
447       for (const HloValue* value : value_set.values()) {
448         const HloBuffer* buffer = &GetBufferContainingValue(*value);
449         if (ContainsKey(buffers_seen, buffer)) {
450           return false;
451         }
452         buffers_at_this_index.push_back(buffer);
453       }
454       buffers_seen.insert(buffers_at_this_index.begin(),
455                           buffers_at_this_index.end());
456     }
457   }
458   return true;
459 }
460 
Verify() const461 Status HloAliasAnalysis::Verify() const {
462   // Verify consistency between the value_to_buffer_ map and
463   // HloBuffer::values().
464   for (const auto& pair : value_to_buffer_) {
465     const HloValue* value = pair.first;
466     const HloBuffer& buffer = *pair.second;
467     TF_RET_CHECK(absl::c_linear_search(buffer.values(), value));
468   }
469 
470   for (HloBuffer::Id id = 0; id < buffers_.size(); ++id) {
471     const HloBuffer& buffer = buffers_[id];
472     TF_RET_CHECK(buffer.id() == id);
473 
474     HloValue::Id last_value_id = -1;
475     for (const HloValue* value : buffer.values()) {
476       TF_RET_CHECK(GetBufferContainingValue(*value) == buffer);
477 
478       // Also verify the values in HloBuffer are unique and sorted by id.
479       TF_RET_CHECK(value->id() > last_value_id);
480       last_value_id = value->id();
481     }
482   }
483 
484   return Status::OK();
485 }
486 
ToString() const487 string HloAliasAnalysis::ToString() const {
488   string out = absl::StrCat("HloAliasAnalysis, module ", module_->name(), "\n");
489   StrAppend(&out, "  Buffers at each position:\n");
490   for (const HloComputation* computation : module_->computations()) {
491     for (const HloInstruction* instruction : computation->instructions()) {
492       StrAppend(&out, "    ", instruction->name(), ":\n");
493       if (instruction->shape().IsTuple()) {
494         ShapeUtil::ForEachSubshape(
495             instruction->shape(),
496             [&out, &instruction, this](const Shape&, const ShapeIndex& index) {
497               StrAppend(&out, "      tuple index ", index.ToString(), ":\n");
498               for (const HloBuffer* buffer :
499                    ComputeBuffersAt(instruction, index)) {
500                 StrAppend(&out, "        ", buffer->ToString(), "\n");
501               }
502             });
503       } else {
504         for (const HloBuffer* buffer :
505              ComputeBuffersAt(instruction, /*index=*/{})) {
506           StrAppend(&out, "      ", buffer->ToString(), "\n");
507         }
508       }
509     }
510   }
511 
512   StrAppend(&out, "  Buffers:\n");
513   for (const HloBuffer& buffer : buffers()) {
514     StrAppend(&out, "    ", buffer.ToString(), "\n");
515     StrAppend(&out, "      positions:\n");
516     for (const HloPosition& position : buffer.ComputePositions()) {
517       StrAppend(&out, "        ", position.ToString(), "\n");
518     }
519   }
520 
521   return out;
522 }
523 
524 /* static */
Run(const HloModule * module,const HloDataflowAnalysis::CanShareBuffer & can_share_buffer)525 StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
526     const HloModule* module,
527     const HloDataflowAnalysis::CanShareBuffer& can_share_buffer) {
528   VLOG(2) << "HloAliasAnalysis::Run on module " << module->name();
529   XLA_VLOG_LINES(2, module->ToString());
530 
531   auto alias_analysis = absl::WrapUnique(new HloAliasAnalysis(module));
532   TF_ASSIGN_OR_RETURN(alias_analysis->dataflow_analysis_,
533                       HloDataflowAnalysis::Run(*module, /*ssa_form=*/true,
534                                                /*bitcast_defines_value=*/false,
535                                                can_share_buffer));
536 
537   BufferValueMap buffer_map(module, alias_analysis->dataflow_analysis());
538   buffer_map.MergeAliasedBuffers();
539 
540   // Create a vector of HloBuffers, one for each set of values in the
541   // BufferValueMap. Create the HloBuffers as a vector of contiguously numbered
542   // buffers.
543   std::vector<BufferValueMap::BufferNumber> sorted_buffer_numbers =
544       buffer_map.ComputeSortedBufferNumbers();
545   alias_analysis->buffers_.reserve(sorted_buffer_numbers.size());
546   HloBuffer::Id next_id = 0;
547   for (BufferValueMap::BufferNumber buffer_number : sorted_buffer_numbers) {
548     auto& value_set = buffer_map.GetValuesInBuffer(buffer_number);
549     std::vector<const HloValue*> sorted_values(value_set.begin(),
550                                                value_set.end());
551     absl::c_sort(sorted_values, HloValue::IdLessThan);
552     alias_analysis->buffers_.emplace_back(next_id++, sorted_values);
553     for (const HloValue* value : sorted_values) {
554       alias_analysis->value_to_buffer_[value] =
555           &alias_analysis->buffers_.back();
556     }
557   }
558 
559   TF_DCHECK_OK(alias_analysis->Verify());
560 
561   HloInstruction* root = module->entry_computation()->root_instruction();
562   ShapeUtil::ForEachSubshape(
563       root->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) {
564         for (const HloBuffer* buffer :
565              alias_analysis->ComputeBuffersAt(root, index)) {
566           alias_analysis->live_out_buffers_.insert(buffer);
567         }
568       });
569 
570   XLA_VLOG_LINES(2, alias_analysis->ToString());
571   return std::move(alias_analysis);
572 }
573 
MergeBuffers(const HloBuffer & to,const HloBuffer & from)574 void HloAliasAnalysis::MergeBuffers(const HloBuffer& to,
575                                     const HloBuffer& from) {
576   CHECK(to.id() != from.id());
577   VLOG(2) << "Merge buffer: " << from.ToString() << " into :" << to.ToString();
578 
579   CHECK(from.id() < buffers_.size());
580   CHECK(to.id() < buffers_.size());
581 
582   // Merge the values of `to` and `from`, creates a new buffer with the
583   // merged values.
584   std::vector<const HloValue*> merged_values(to.values().begin(),
585                                              to.values().end());
586 
587   merged_values.insert(merged_values.end(), from.values().begin(),
588                        from.values().end());
589   absl::c_sort(merged_values, [](const HloValue* a, const HloValue* b) {
590     return a->id() < b->id();
591   });
592 
593   buffers_[to.id()] = HloBuffer(to.id(), merged_values);
594   for (const HloValue* value : merged_values) {
595     // Update references of values.
596     value_to_buffer_[value] = &buffers_[to.id()];
597   }
598 
599   if (live_out_buffers_.count(&from) > 0) {
600     // Update live out set to erase `from` and add `to`.
601     live_out_buffers_.erase(&from);
602     live_out_buffers_.insert(&buffers_[to.id()]);
603   }
604 
605   int64 from_id = from.id();
606   if (from_id != buffers_.size() - 1) {
607     // Now `from` is invalid, move the last element of buffers to replace `from`
608     // and update references to the last element.
609     const HloBuffer& last_elem = buffers_.back();
610     buffers_[from.id()] = HloBuffer(from_id, last_elem.values());
611 
612     if (live_out_buffers_.count(&last_elem) > 0) {
613       // Update live out set to redirect the last element to its new position.
614       live_out_buffers_.erase(&last_elem);
615       live_out_buffers_.insert(&buffers_[from_id]);
616     }
617 
618     // Update references of values.
619     for (const HloValue* value : buffers_[from_id].values()) {
620       value_to_buffer_[value] = &buffers_[from_id];
621     }
622   }
623 
624   // Remove the last element.
625   buffers_.pop_back();
626 
627   CHECK(Verify().ok());
628 }
629 
HasLiveRangeInterference(const HloOrdering & ordering) const630 bool HloAliasAnalysis::HasLiveRangeInterference(
631     const HloOrdering& ordering) const {
632   for (const HloBuffer& buffer : buffers()) {
633     CHECK(!buffer.values().empty());
634     if (buffer.values().front()->shape().IsToken()) {
635       // Tokens have no on-device representation and cannot interfere.
636       for (const HloValue* value : buffer.values()) {
637         // If one of the values is a token, all values must be a token.
638         DCHECK(value->shape().IsToken());
639       }
640       continue;
641     }
642 
643     // Check that the values in the buffer are totally ordered with respect to
644     // 'ordering'. Begin by sorting the values with respect to 'ordering' with a
645     // tie-break using value ID. The tie-break is necessary because we need a
646     // strict weak order for std::sort.
647     std::vector<const HloValue*> values = buffer.values();
648     absl::c_sort(values, [&ordering](const HloValue* a, const HloValue* b) {
649       if (ordering.IsDefinedBefore(*a, *b)) {
650         return true;
651       } else if (ordering.IsDefinedBefore(*b, *a)) {
652         return false;
653       } else {
654         return a->id() < b->id();
655       }
656     });
657 
658     // Walk through the ordered vector of values. First verify that the values
659     // are totally ordered with respect to 'ordering', then check that no
660     // adjacent values have overlapping live ranges. Only adjacent values must
661     // be checked because of the property of live range interference. For
662     // example, if you have values A, B, and C (in program order) contained in
663     // a buffer and A interferes with C, then necessarily A also interferes
664     // with B. So to check interference you only need to check interference
665     // between A and B, and between B and C.
666     for (int i = 1; i < values.size(); ++i) {
667       if (!ordering.IsDefinedBefore(*values[i - 1], *values[i])) {
668         VLOG(1) << values[i - 1]->ToShortString() << " and "
669                 << values[i]->ToShortString() << " are not ordered";
670         return true;
671       }
672       if (ordering.MayInterfere(*values[i - 1], *values[i],
673                                 dataflow_analysis())) {
674         VLOG(1) << "In buffer " << buffer.id() << " containing values:\n  "
675                 << absl::StrJoin(values, ", ",
676                                  [](string* out, const HloValue* value) {
677                                    StrAppend(out, value->ToShortString());
678                                  })
679 
680                 << "\nValue " << values[i - 1]->ToShortString()
681                 << " may interfere with value " << values[i]->ToShortString();
682         return true;
683       }
684     }
685   }
686 
687   return false;
688 }
689 
690 }  // namespace xla
691