1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
17 
18 #include <ostream>
19 #include <utility>
20 #include <vector>
21 
22 #include "absl/algorithm/container.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/memory/memory.h"
25 #include "absl/strings/str_cat.h"
26 #include "absl/strings/str_format.h"
27 #include "absl/strings/str_join.h"
28 #include "tensorflow/compiler/xla/map_util.h"
29 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
30 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
31 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
32 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/types.h"
35 #include "tensorflow/compiler/xla/util.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/platform/logging.h"
38 
39 namespace xla {
40 
ToString() const41 string BufferAlias::ToString() const {
42   return absl::StrCat("BufferAlias(", instruction_->name(), "[",
43                       absl::StrJoin(index_, ","), "])");
44 }
45 
operator <<(std::ostream & out,const BufferAlias & buffer_alias)46 std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias) {
47   out << buffer_alias.ToString();
48   return out;
49 }
50 
IsAmbiguous() const51 bool PointsToSet::IsAmbiguous() const {
52   bool ambiguous = false;
53   ForEachElement(
54       [&ambiguous](const ShapeIndex& /*index*/, const BufferList& points_to) {
55         ambiguous |= points_to.size() > 1;
56       });
57   return ambiguous;
58 }
59 
IsDistinct() const60 bool PointsToSet::IsDistinct() const {
61   bool distinct = true;
62   absl::flat_hash_set<const LogicalBuffer*> all_points_to;
63   ForEachElement([&](const ShapeIndex& /*index*/, const BufferList& points_to) {
64     for (auto& buffer : points_to) {
65       if (all_points_to.contains(buffer)) {
66         distinct = false;
67       }
68       all_points_to.insert(buffer);
69     }
70   });
71   return distinct;
72 }
73 
size() const74 size_t PointsToSet::size() const {
75   // Because pointed-to elements may be duplicated we have to create a flattened
76   // set and return the size.
77   return CreateFlattenedSet().size();
78 }
79 
CreateFlattenedSet() const80 PointsToSet::BufferSet PointsToSet::CreateFlattenedSet() const {
81   BufferSet flat_set;
82   ForEachElement(
83       [&flat_set](const ShapeIndex& /*index*/, const BufferList& buffers) {
84         flat_set.insert(buffers.begin(), buffers.end());
85       });
86   return flat_set;
87 }
88 
ContainsBuffer(const LogicalBuffer & buffer) const89 bool PointsToSet::ContainsBuffer(const LogicalBuffer& buffer) const {
90   bool found = false;
91   ForEachElement([&found, &buffer](const ShapeIndex& /*index*/,
92                                    const BufferList& pointed_to_buffers) {
93     if (!found && absl::c_linear_search(pointed_to_buffers, &buffer)) {
94       found = true;
95     }
96   });
97   return found;
98 }
99 
ContainsBufferAtIndex(const LogicalBuffer & buffer,const ShapeIndex & index) const100 bool PointsToSet::ContainsBufferAtIndex(const LogicalBuffer& buffer,
101                                         const ShapeIndex& index) const {
102   const auto& pointed_to_buffers = element(index);
103   return absl::c_linear_search(pointed_to_buffers, &buffer);
104 }
105 
AddPointedToBuffer(const LogicalBuffer & buffer,const ShapeIndex & index)106 void PointsToSet::AddPointedToBuffer(const LogicalBuffer& buffer,
107                                      const ShapeIndex& index) {
108   if (ContainsBufferAtIndex(buffer, index)) {
109     return;
110   }
111   mutable_element(index)->push_back(&buffer);
112 }
113 
tuple_sources(const ShapeIndex & index) const114 const PointsToSet::SourceSet& PointsToSet::tuple_sources(
115     const ShapeIndex& index) const {
116   return tree_.element(index).tuple_sources;
117 }
118 
add_tuple_source(const ShapeIndex & index,HloInstruction * tuple)119 void PointsToSet::add_tuple_source(const ShapeIndex& index,
120                                    HloInstruction* tuple) {
121   tree_.mutable_element(index)->tuple_sources.insert(tuple);
122 }
123 
124 namespace {
125 // Gather fusion instructions from 'instruction' into 'fusion_instructions'.
GatherFusionInstructions(HloInstruction * instruction,std::vector<HloInstruction * > * fusion_instructions)126 void GatherFusionInstructions(
127     HloInstruction* instruction,
128     std::vector<HloInstruction*>* fusion_instructions) {
129   CHECK_EQ(HloOpcode::kFusion, instruction->opcode());
130   for (auto* fused : instruction->fused_instructions()) {
131     if (fused->opcode() == HloOpcode::kFusion) {
132       GatherFusionInstructions(fused, fusion_instructions);
133     }
134   }
135   fusion_instructions->push_back(instruction);
136 }
137 
138 }  // namespace
139 
140 /* static */ StatusOr<std::unique_ptr<TuplePointsToAnalysis>>
Run(const HloModule * module)141 TuplePointsToAnalysis::Run(const HloModule* module) {
142   auto logical_buffer_analysis = LogicalBufferAnalysis::Run(module);
143   std::unique_ptr<TuplePointsToAnalysis> analysis(new TuplePointsToAnalysis(
144       module, logical_buffer_analysis.ConsumeValueOrDie()));
145   TF_RETURN_IF_ERROR(analysis->Analyze());
146   return std::move(analysis);
147 }
148 
Analyze()149 Status TuplePointsToAnalysis::Analyze() {
150   per_instruction_.clear();
151   per_instruction_.reserve(module_->instruction_count());
152 
153   logical_buffer_aliases_.clear();
154   logical_buffer_aliases_.resize(
155       logical_buffer_analysis_->num_logical_buffers());
156 
157   std::vector<HloInstruction*> fusion_instructions;
158   for (auto* computation : module_->MakeNonfusionComputations()) {
159     TF_RETURN_IF_ERROR(computation->Accept(this));
160     TF_RETURN_IF_ERROR(
161         PopulateDefinedBuffersAndAliases(computation->instructions()));
162     for (auto* instruction : computation->instructions()) {
163       if (instruction->opcode() == HloOpcode::kFusion) {
164         GatherFusionInstructions(instruction, &fusion_instructions);
165       }
166     }
167   }
168   // Run points-to analysis on fusion instructions in 'computation'.
169   for (auto* instruction : fusion_instructions) {
170     TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this));
171     TF_RETURN_IF_ERROR(
172         PopulateDefinedBuffersAndAliases(instruction->fused_instructions()));
173   }
174 
175   XLA_VLOG_LINES(3, ToString());
176 
177   return Status::OK();
178 }
179 
180 Status TuplePointsToAnalysis::PopulateDefinedBuffersAndAliases(const decltype(
181     std::declval<HloComputation>().instructions())& instructions) {
182   for (auto* instruction : instructions) {
183     PerInstruction* pi = PerInst(instruction);
184     TF_RETURN_IF_ERROR(GatherBuffersDefinedByInstruction(
185         instruction, &pi->instruction_defined_buffers));
186 
187     const PointsToSet& points_to_set = GetPointsToSet(instruction);
188     points_to_set.ForEachElement(
189         [this, &instruction](
190             const ShapeIndex& index,
__anon878f1c720602( const ShapeIndex& index, const PointsToSet::BufferList& pointed_to_buffers) 191             const PointsToSet::BufferList& pointed_to_buffers) {
192           for (const LogicalBuffer* buffer : pointed_to_buffers) {
193             logical_buffer_aliases_[buffer->id()].emplace_back(instruction,
194                                                                index);
195           }
196         });
197   }
198   return Status::OK();
199 }
200 
DefaultAction(HloInstruction * hlo_instruction)201 Status TuplePointsToAnalysis::DefaultAction(HloInstruction* hlo_instruction) {
202   // Create trivial points-to set for instruction. Each points-to set at index i
203   // contains a single element LogicalBuffer(hlo_instruction, i). This indicates
204   // that this instruction is the source of all buffers in its own output.
205   PointsToSet& points_to_set = CreateEmptyPointsToSet(hlo_instruction);
206   points_to_set.ForEachMutableElement(
207       [this, hlo_instruction](const ShapeIndex& index,
208                               PointsToSet::BufferList* buffers) {
209         buffers->push_back(
210             &logical_buffer_analysis_->GetBuffer(hlo_instruction, index));
211       });
212 
213   if (hlo_instruction->shape().IsTuple()) {
214     // If the hlo instruction is a tuple-shaped, then trivially the instruction
215     // itself is the source of the tuple.
216     points_to_set.add_tuple_source({}, hlo_instruction);
217   }
218 
219   return Status::OK();
220 }
221 
HandleGetTupleElement(HloInstruction * get_tuple_element)222 Status TuplePointsToAnalysis::HandleGetTupleElement(
223     HloInstruction* get_tuple_element) {
224   // GetTupleElement forwards a pointer to a particular element of the tuple
225   // operand.
226   int64 element_index = get_tuple_element->tuple_index();
227 
228   PointsToSet& points_to_set = CreateEmptyPointsToSet(get_tuple_element);
229   const PointsToSet& operand_points_to_set =
230       *PerInst(get_tuple_element->operand(0))->points_to_set;
231 
232   // Copy the points-to set (and tuple sources) at index {element_index} of the
233   // operand to the points-to set for this GetTupleElement instruction.
234   points_to_set.ForEachMutableElement(
235       [&](const ShapeIndex& target_index, PointsToSet::BufferList* points_to) {
236         // Construct an index into the operand by prepending element_index to
237         // the index for the GetTupleElement instruction's points-to set.
238         ShapeIndex src_index;
239         src_index.push_back(element_index);
240         for (auto element : target_index) {
241           src_index.push_back(element);
242         }
243 
244         *points_to = operand_points_to_set.element(src_index);
245         for (HloInstruction* tuple :
246              operand_points_to_set.tuple_sources(src_index)) {
247           points_to_set.add_tuple_source(target_index, tuple);
248         }
249       });
250 
251   return Status::OK();
252 }
253 
HandleCopy(HloInstruction * copy)254 Status TuplePointsToAnalysis::HandleCopy(HloInstruction* copy) {
255   // A kCopy instruction performs a shallow copy of the operand. The top-level
256   // buffer (index={}) is newly created, but all other buffers (in the case of a
257   // tuple shape) come from the operand
258   PointsToSet& points_to_set = CreateCopiedPointsToSet(copy, copy->operand(0));
259   points_to_set.mutable_element(/*index=*/{})->clear();
260   points_to_set.AddPointedToBuffer(
261       logical_buffer_analysis_->GetBuffer(copy, /*index=*/{}),
262       /*index=*/{});
263 
264   return Status::OK();
265 }
266 
HandleBitcast(HloInstruction * bitcast)267 Status TuplePointsToAnalysis::HandleBitcast(HloInstruction* bitcast) {
268   // A kBitcast instruction aliases its operand. That is, the buffer of its
269   // result *is* the buffer of its operand, so just copy the operands points-to
270   // set.
271   CreateCopiedPointsToSet(bitcast, bitcast->operand(0));
272   return Status::OK();
273 }
274 
HandleDomain(HloInstruction * domain)275 Status TuplePointsToAnalysis::HandleDomain(HloInstruction* domain) {
276   // A kDomain instruction aliases its operand. That is, the buffer of its
277   // result *is* the buffer of its operand, so just copy the operands points-to
278   // set.
279   CreateCopiedPointsToSet(domain, domain->operand(0));
280   return Status::OK();
281 }
282 
HandleAddDependency(HloInstruction * add_dependency)283 Status TuplePointsToAnalysis::HandleAddDependency(
284     HloInstruction* add_dependency) {
285   // AddDependency just forwards the value of its zero-th operand.
286   CreateCopiedPointsToSet(add_dependency, add_dependency->operand(0));
287   return Status::OK();
288 }
289 
HandleRecvDone(HloInstruction * recv_done)290 Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) {
291   // RecvDone aliases its input (Recv) tuple element {0} to element {0} of its
292   // output. The other indices ({} and {1}) define their own buffers.
293   PointsToSet& points_to_set = CreateEmptyPointsToSet(recv_done);
294   points_to_set.AddPointedToBuffer(
295       logical_buffer_analysis_->GetBuffer(recv_done, /*index=*/{}),
296       /*index=*/{});
297   points_to_set.AddPointedToBuffer(
298       logical_buffer_analysis_->GetBuffer(recv_done, /*index=*/{1}),
299       /*index=*/{1});
300 
301   const PointsToSet& operand_points_to_set =
302       GetPointsToSet(recv_done->operand(0));
303 
304   // Recursively copy the points to set of the operand tuple {0} to the output
305   // element {0}.
306   points_to_set.ForEachMutableElement(
307       [&points_to_set, &operand_points_to_set](
308           const ShapeIndex& index, PointsToSet::BufferList* buffers) {
309         if (index.empty() || index[0] != 0) {
310           return;
311         }
312         *buffers = operand_points_to_set.element(index);
313         for (auto& tuple_source : operand_points_to_set.tuple_sources(index)) {
314           points_to_set.add_tuple_source(index, tuple_source);
315         }
316       });
317   return Status::OK();
318 }
319 
HandleCopyStart(HloInstruction * copy_start)320 Status TuplePointsToAnalysis::HandleCopyStart(HloInstruction* copy_start) {
321   // CopyStart forwards its aliased operand to {1}.
322   PointsToSet& points_to_set = CreateEmptyPointsToSet(copy_start);
323   const PointsToSet& operand_points_to_set =
324       GetPointsToSet(copy_start->operand(0));
325 
326   points_to_set.ForEachMutableElement(
327       [&](const ShapeIndex& target_index, PointsToSet::BufferList* buffers) {
328         if (target_index == ShapeIndex({1})) {
329           *buffers = operand_points_to_set.element(/*index=*/{});
330         } else {
331           buffers->push_back(
332               &logical_buffer_analysis_->GetBuffer(copy_start, target_index));
333         }
334       });
335 
336   for (HloInstruction* tuple :
337        operand_points_to_set.tuple_sources(/*index=*/{})) {
338     points_to_set.add_tuple_source(/*index=*/{1}, tuple);
339   }
340 
341   return Status::OK();
342 }
343 
HandleCopyDone(HloInstruction * copy_done)344 Status TuplePointsToAnalysis::HandleCopyDone(HloInstruction* copy_done) {
345   // CopyDone forwards its aliased operand.
346   PointsToSet& points_to_set = CreateEmptyPointsToSet(copy_done);
347   const PointsToSet& operand_points_to_set =
348       GetPointsToSet(copy_done->operand(0));
349   operand_points_to_set.ForEachElement(
350       [&points_to_set, &operand_points_to_set](
351           const ShapeIndex& src_index,
352           const PointsToSet::BufferList& points_to) {
353         if (src_index == ShapeIndex({0})) {
354           const ShapeIndex target_index = {};
355           *points_to_set.mutable_element(target_index) = points_to;
356 
357           for (HloInstruction* tuple :
358                operand_points_to_set.tuple_sources(src_index)) {
359             points_to_set.add_tuple_source(target_index, tuple);
360           }
361         }
362       });
363 
364   return Status::OK();
365 }
366 
HandleSend(HloInstruction * send)367 Status TuplePointsToAnalysis::HandleSend(HloInstruction* send) {
368   // Send creates a tuple of {aliased operand, U32 context, token}.
369   PointsToSet& points_to_set = CreateEmptyPointsToSet(send);
370 
371   // Creates the points to set for the tuple and its element at {1}.
372   auto top_buffer = points_to_set.mutable_element(ShapeIndex({}));
373   top_buffer->push_back(
374       &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({})));
375   points_to_set.add_tuple_source({}, send);
376 
377   auto context_buffer = points_to_set.mutable_element(ShapeIndex({1}));
378   context_buffer->push_back(
379       &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({1})));
380 
381   auto token_buffer = points_to_set.mutable_element(ShapeIndex({2}));
382   token_buffer->push_back(
383       &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({2})));
384 
385   // Recursively copy the points to set of the operand to output tuple {0}.
386   const PointsToSet& operand_points_to_set = GetPointsToSet(send->operand(0));
387   operand_points_to_set.ForEachElement(
388       [&points_to_set, &operand_points_to_set](
389           const ShapeIndex& src_index,
390           const PointsToSet::BufferList& points_to) {
391         ShapeIndex target_index({0});
392         for (auto element : src_index) {
393           target_index.push_back(element);
394         }
395         *points_to_set.mutable_element(target_index) = points_to;
396 
397         for (HloInstruction* tuple :
398              operand_points_to_set.tuple_sources(src_index)) {
399           points_to_set.add_tuple_source(target_index, tuple);
400         }
401       });
402 
403   return Status::OK();
404 }
405 
HandleTuple(HloInstruction * tuple)406 Status TuplePointsToAnalysis::HandleTuple(HloInstruction* tuple) {
407   absl::Span<HloInstruction* const> operands(tuple->operands());
408   PointsToSet& points_to_set = CreateEmptyPointsToSet(tuple);
409   points_to_set.AddPointedToBuffer(
410       logical_buffer_analysis_->GetBuffer(tuple, /*index=*/{}),
411       /*index=*/{});
412 
413   // A tuple contains references to all input operands and transitively any
414   // references in those operands.
415   for (int64 i = 0; i < operands.size(); ++i) {
416     const PointsToSet& operand_points_to_set =
417         *PerInst(operands[i])->points_to_set;
418 
419     // Copy the points-to set (and tuple sources) of the operand into the
420     // respective subtree of the tuple instructions points-to set.
421     operand_points_to_set.ForEachElement(
422         [&points_to_set, &operand_points_to_set, i](
423             const ShapeIndex& src_index,
424             const PointsToSet::BufferList& points_to) {
425           ShapeIndex target_index;
426           target_index.push_back(i);
427           for (auto element : src_index) {
428             target_index.push_back(element);
429           }
430 
431           *points_to_set.mutable_element(target_index) = points_to;
432 
433           for (HloInstruction* tuple :
434                operand_points_to_set.tuple_sources(src_index)) {
435             points_to_set.add_tuple_source(target_index, tuple);
436           }
437         });
438   }
439 
440   points_to_set.add_tuple_source({}, tuple);
441 
442   return Status::OK();
443 }
444 
HandleTupleSelect(HloInstruction * tuple_select)445 Status TuplePointsToAnalysis::HandleTupleSelect(HloInstruction* tuple_select) {
446   // Select allocates a new buffer and then shallow copies the on_true or
447   // on_false buffer into this new buffer. Which side is chosen cannot be
448   // determined statically so conservatively set the points-to set to the union
449   // of these on_true and on_false operands.
450   //
451   // First create a copy of the on_true points-to set (and tuple sources), then
452   // add in elements of the on_false points-to set (tuple sources).
453   auto on_true = tuple_select->operand(1);
454   auto on_false = tuple_select->operand(2);
455   PointsToSet& points_to_set = CreateCopiedPointsToSet(tuple_select, on_true);
456   const PointsToSet& false_points_to_set = *PerInst(on_false)->points_to_set;
457   points_to_set.ForEachMutableElement(
458       [&](const ShapeIndex& index, PointsToSet::BufferList* buffers) {
459         for (const LogicalBuffer* false_buffer :
460              false_points_to_set.element(index)) {
461           points_to_set.AddPointedToBuffer(*false_buffer, index);
462         }
463 
464         for (HloInstruction* tuple : false_points_to_set.tuple_sources(index)) {
465           points_to_set.add_tuple_source(index, tuple);
466         }
467       });
468 
469   // Select creates a new (top-level) buffer to store its result, so its
470   // respective element in the points-to set should contain only itself.
471   points_to_set.mutable_element({})->clear();
472   points_to_set.AddPointedToBuffer(
473       logical_buffer_analysis_->GetBuffer(tuple_select, /*index=*/{}),
474       /*index=*/{});
475   return Status::OK();
476 }
477 
HandleCustomCall(HloInstruction * custom_call)478 Status TuplePointsToAnalysis::HandleCustomCall(HloInstruction* custom_call) {
479   auto ccall = Cast<HloCustomCallInstruction>(custom_call);
480   PointsToSet& points_to_set = CreateEmptyPointsToSet(custom_call);
481   absl::flat_hash_map<ShapeIndex, std::pair<int64, ShapeIndex>> aliased_outputs;
482   for (const auto& pair : ccall->output_to_operand_aliasing()) {
483     aliased_outputs.emplace(pair.first, pair.second);
484   }
485   points_to_set.ForEachMutableElement([&](const ShapeIndex& index,
486                                           PointsToSet::BufferList* buffers) {
487     auto it = aliased_outputs.find(index);
488     if (it == aliased_outputs.end()) {
489       points_to_set.AddPointedToBuffer(
490           logical_buffer_analysis_->GetBuffer(custom_call, index), index);
491     } else {
492       const PointsToSet& input_set =
493           *PerInst(ccall->operand(it->second.first))->points_to_set;
494       for (const LogicalBuffer* input_buffer :
495            input_set.element(it->second.second)) {
496         points_to_set.AddPointedToBuffer(*input_buffer, index);
497       }
498 
499       for (HloInstruction* tuple : input_set.tuple_sources(it->second.second)) {
500         points_to_set.add_tuple_source(index, tuple);
501       }
502     }
503   });
504   points_to_set.add_tuple_source({}, custom_call);
505   return Status::OK();
506 }
507 
GetPointsToSet(const HloInstruction * hlo_instruction) const508 const PointsToSet& TuplePointsToAnalysis::GetPointsToSet(
509     const HloInstruction* hlo_instruction) const {
510   return *PerInst(hlo_instruction)->points_to_set;
511 }
512 
CreateEmptyPointsToSet(const HloInstruction * instruction)513 PointsToSet& TuplePointsToAnalysis::CreateEmptyPointsToSet(
514     const HloInstruction* instruction) {
515   PerInstruction* pi = PerInst(instruction);
516   CHECK(pi->points_to_set == nullptr)
517       << "instruction should not have been present in the map.";
518   auto set = absl::make_unique<PointsToSet>(&instruction->shape());
519   pi->points_to_set = std::move(set);
520   // Return *set using the iterator returned by emplace.
521   return *pi->points_to_set;
522 }
523 
InstructionDefinesBufferAtIndex(const HloInstruction * instruction,const ShapeIndex & index) const524 bool TuplePointsToAnalysis::InstructionDefinesBufferAtIndex(
525     const HloInstruction* instruction, const ShapeIndex& index) const {
526   const auto& buffers = GetPointsToSet(instruction).element(index);
527   return (buffers.size() == 1 && buffers[0]->instruction() == instruction);
528 }
529 
VerifyBuffer(const LogicalBuffer & buffer) const530 Status TuplePointsToAnalysis::VerifyBuffer(const LogicalBuffer& buffer) const {
531   if (!InstructionDefinesBufferAtIndex(buffer.instruction(), buffer.index())) {
532     return FailedPrecondition(
533         "LogicalBuffer %s is ill-defined: instruction %s does not define a "
534         "buffer at that index",
535         buffer.ToString(), buffer.instruction()->name());
536   }
537 
538   if (buffer.id() < 0 ||
539       buffer.id() >= logical_buffer_analysis_->num_logical_buffers()) {
540     return FailedPrecondition("LogicalBuffer %s is ill-defined: invalid id %d",
541                               buffer.ToString(), buffer.id());
542   }
543   if (GetBuffer(buffer.id()).instruction() != buffer.instruction() ||
544       GetBuffer(buffer.id()).index() != buffer.index()) {
545     return FailedPrecondition(
546         "LogicalBuffer %s is ill-defined: buffer with same id differs: %s",
547         buffer.ToString(), GetBuffer(buffer.id()).ToString());
548   }
549 
550   return Status::OK();
551 }
552 
GetBuffer(LogicalBuffer::Id id) const553 const LogicalBuffer& TuplePointsToAnalysis::GetBuffer(
554     LogicalBuffer::Id id) const {
555   CHECK_GE(id, 0);
556   CHECK_LT(id, logical_buffer_analysis_->num_logical_buffers());
557   return logical_buffer_analysis_->GetBuffer(id);
558 }
559 
GetBufferDefinedAt(const HloInstruction * instruction,const ShapeIndex & index) const560 StatusOr<const LogicalBuffer*> TuplePointsToAnalysis::GetBufferDefinedAt(
561     const HloInstruction* instruction, const ShapeIndex& index) const {
562   const auto& buffers = GetPointsToSet(instruction).element(index);
563   if (buffers.size() != 1 || buffers[0]->instruction() != instruction) {
564     return FailedPrecondition(
565         "instruction %s does not define buffer at index {%s}",
566         instruction->name(), absl::StrJoin(index, ","));
567   }
568   return buffers[0];
569 }
570 
571 const TuplePointsToAnalysis::BufferAliasVector&
GetBufferAliases(const LogicalBuffer & buffer) const572 TuplePointsToAnalysis::GetBufferAliases(const LogicalBuffer& buffer) const {
573   return logical_buffer_aliases_.at(buffer.id());
574 }
575 
576 const TuplePointsToAnalysis::BufferDefinitionVector&
GetBuffersDefinedByInstruction(const HloInstruction * instruction) const577 TuplePointsToAnalysis::GetBuffersDefinedByInstruction(
578     const HloInstruction* instruction) const {
579   return PerInst(instruction)->instruction_defined_buffers;
580 }
581 
GatherBuffersDefinedByInstruction(const HloInstruction * instruction,TuplePointsToAnalysis::BufferDefinitionVector * buffers)582 Status TuplePointsToAnalysis::GatherBuffersDefinedByInstruction(
583     const HloInstruction* instruction,
584     TuplePointsToAnalysis::BufferDefinitionVector* buffers) {
585   GetPointsToSet(instruction)
586       .ForEachElement([buffers, instruction](
587                           const ShapeIndex& index,
588                           const PointsToSet::BufferList& source_buffers) {
589         // Add buffers which 'instruction' is the source of.
590         CHECK(!source_buffers.empty());
591         if (source_buffers.size() == 1 &&
592             source_buffers[0]->instruction() == instruction) {
593           // If this instruction is the source of this buffer the
594           // indices must match.
595           DCHECK(source_buffers[0]->index() == index);
596           buffers->push_back(source_buffers[0]);
597         } else {
598           // If the points-to set includes more than one buffer then
599           // necessarily this instruction did not produce the
600           // buffer.
601           for (const LogicalBuffer* source_buffer : source_buffers) {
602             DCHECK(source_buffer->instruction() != instruction);
603           }
604         }
605       });
606   return Status::OK();
607 }
608 
CreateCopiedPointsToSet(const HloInstruction * instruction,const HloInstruction * src)609 PointsToSet& TuplePointsToAnalysis::CreateCopiedPointsToSet(
610     const HloInstruction* instruction, const HloInstruction* src) {
611   // PointsToSet doesn't have a copy constructor so copy over element-by-element
612   // from src PointsToSet.
613   PointsToSet& dst_points_to_set = CreateEmptyPointsToSet(instruction);
614   const PointsToSet& src_points_to_set = GetPointsToSet(src);
615   dst_points_to_set.ForEachMutableElement(
616       [&dst_points_to_set, &src_points_to_set](
617           const ShapeIndex& index, PointsToSet::BufferList* buffers) {
618         *buffers = src_points_to_set.element(index);
619         for (auto& tuple_source : src_points_to_set.tuple_sources(index)) {
620           dst_points_to_set.add_tuple_source(index, tuple_source);
621         }
622       });
623   return *PerInst(instruction)->points_to_set;
624 }
625 
ToString() const626 string TuplePointsToAnalysis::ToString() const {
627   string output =
628       absl::StrFormat("TuplePointsToSet for module %s:\n", module_->name());
629   for (const auto* computation : module_->MakeNonfusionComputations()) {
630     const char* entry =
631         computation == module_->entry_computation() ? "entry " : "";
632     absl::StrAppend(&output, entry, "computation ", computation->name(), ":\n");
633     for (const HloInstruction* instruction :
634          computation->MakeInstructionPostOrder()) {
635       InstructionToString(instruction, &output);
636       if (instruction->opcode() == HloOpcode::kFusion) {
637         for (auto* fused : instruction->fused_instructions()) {
638           InstructionToString(fused, &output);
639         }
640       }
641     }
642   }
643 
644   absl::StrAppend(&output, "LogicalBuffers:\n");
645   for (const auto& b : logical_buffer_analysis_->logical_buffers()) {
646     absl::StrAppend(&output, "  buffer ", b->ToString(), ":\n");
647     for (const BufferAlias& alias : logical_buffer_aliases_.at(b->id())) {
648       absl::StrAppend(&output, "    alias ", alias.ToString(), "\n");
649     }
650   }
651   return output;
652 }
653 
InstructionToString(const HloInstruction * instruction,string * output) const654 void TuplePointsToAnalysis::InstructionToString(
655     const HloInstruction* instruction, string* output) const {
656   const string prefix = instruction->IsFused() ? "    " : "";
657   absl::StrAppend(output, prefix, "  instruction ",
658                   instruction->ToShortString(), ":\n");
659   const PointsToSet& points_to_set = GetPointsToSet(instruction);
660   points_to_set.ForEachElement([&prefix, &output](
661                                    const ShapeIndex& index,
662                                    const PointsToSet::BufferList& points_to) {
663     absl::StrAppend(output, prefix, "    {", absl::StrJoin(index, ","), "}: ",
664                     absl::StrJoin(points_to, ", ",
665                                   [](string* out, const LogicalBuffer* source) {
666                                     out->append(source->ToString());
667                                   }),
668                     "\n");
669   });
670 }
671 
DoesNotUseOperandBuffer(const HloInstruction * operand,const ShapeIndex & index,const HloInstruction * user) const672 bool TuplePointsToAnalysis::DoesNotUseOperandBuffer(
673     const HloInstruction* operand, const ShapeIndex& index,
674     const HloInstruction* user) const {
675   CHECK(user->IsUserOf(operand))
676       << "user: " << user->ToString() << " operand: " << operand->ToString();
677   if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) {
678     // GetTupleElement instructions only access the top-level buffer of their
679     // operand.
680     return true;
681   } else if (user->IsLoopFusion()) {
682     // Find fusion parameter associated with 'operand'.
683     auto it = absl::c_find_if(
684         user->fused_parameters(), [&](HloInstruction* fused_param) {
685           return user->operand(fused_param->parameter_number()) == operand;
686         });
687     CHECK(it != user->fused_parameters().end());
688     // Iterate through all users of all buffer aliases of the buffer in the
689     // points-to set of fusion parameter at 'index'.
690     // Return false if any uses are detected at 'index', returns true otherwise.
691     const LogicalBuffer* buffer = GetBufferDefinedAt(*it, index).ValueOrDie();
692     for (const BufferAlias& alias : GetBufferAliases(*buffer)) {
693       for (HloInstruction* alias_user : alias.instruction()->users()) {
694         if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(),
695                                     alias_user)) {
696           continue;
697         }
698         // Return false: use detected at 'buffer' -> 'alias' -> 'alias_user'.
699         return false;
700       }
701     }
702     // Return true: found no uses of 'operand' at 'index' in 'user'.
703     return true;
704   }
705   return false;
706 }
707 
708 // Returns all uses of all aliases of 'instruction' at 'index' in 'uses'.
709 // Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index)
710 // where 'user' is a user of an alias of 'instruction' at 'index', and
711 // 'operand_index' is the operand index at which the alias appears in the
712 // operand list of 'user'.
713 std::vector<std::pair<HloInstruction*, int64>>
GetAllUsesOfInstructionAtIndex(HloInstruction * instruction,const ShapeIndex & index) const714 TuplePointsToAnalysis::GetAllUsesOfInstructionAtIndex(
715     HloInstruction* instruction, const ShapeIndex& index) const {
716   std::vector<std::pair<HloInstruction*, int64>> uses;
717   const PointsToSet::BufferList& points_to =
718       GetPointsToSet(instruction).element(index);
719   for (const LogicalBuffer* buffer : points_to) {
720     for (const BufferAlias& alias : GetBufferAliases(*buffer)) {
721       for (HloInstruction* alias_user : alias.instruction()->users()) {
722         if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(),
723                                     alias_user)) {
724           continue;
725         }
726         for (int64 op_idx : alias_user->OperandIndices(alias.instruction())) {
727           uses.emplace_back(alias_user, op_idx);
728         }
729       }
730     }
731   }
732   return uses;
733 }
734 
735 // Returns true if there is exactly one use of 'operand' at 'operand_index'
736 // in 'fusion.fused_instructions', where the singleton use is the fused
737 // root at operand index 'use_operand_index'. Returns false otherwise.
738 //
739 // REQUIRES: 'fusion' opcode is a kFusion instruction.
HasUniqueFusedUseOfOperandAt(HloInstruction * operand,const ShapeIndex & operand_index,HloInstruction * fusion,const int64 use_operand_index) const740 bool TuplePointsToAnalysis::HasUniqueFusedUseOfOperandAt(
741     HloInstruction* operand, const ShapeIndex& operand_index,
742     HloInstruction* fusion, const int64 use_operand_index) const {
743   CHECK_EQ(HloOpcode::kFusion, fusion->opcode());
744   // Check that 'operand' is unique in the operand list of 'fusion'.
745   if (fusion->OperandIndices(operand).size() > 1) {
746     return false;
747   }
748   // Find fusion parameter associated with 'operand'.
749   const auto& fused_params = fusion->fused_parameters();
750   auto fused_param_it =
751       absl::c_find_if(fused_params, [&](HloInstruction* fused_param) {
752         return fusion->operand(fused_param->parameter_number()) == operand;
753       });
754   if (fused_param_it == fused_params.end()) {
755     return false;
756   }
757   auto* fused_param = *fused_param_it;
758   // Get all uses of 'operand' at 'index' from 'fusion.fused_instructions'.
759   auto fused_param_uses =
760       GetAllUsesOfInstructionAtIndex(fused_param, operand_index);
761   // Return true iff there is exactly one use of 'operand' at 'index', and
762   // this singleton use is the fused root (at index in 'use_operand_indices').
763   return fused_param_uses.size() == 1 &&
764          fused_param_uses[0].first == fusion->fused_expression_root() &&
765          fused_param_uses[0].second == use_operand_index;
766 }
767 }  // namespace xla
768