1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
18 
19 #include <functional>
20 #include <list>
21 #include <memory>
22 #include <string>
23 #include <unordered_set>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/container/flat_hash_set.h"
29 #include "absl/types/span.h"
30 #include "tensorflow/compiler/xla/iterator_util.h"
31 #include "tensorflow/compiler/xla/map_util.h"
32 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
33 #include "tensorflow/compiler/xla/service/hlo.pb.h"
34 #include "tensorflow/compiler/xla/service/hlo_clone_context.h"
35 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
36 #include "tensorflow/compiler/xla/service/name_uniquer.h"
37 #include "tensorflow/compiler/xla/shape_tree.h"
38 #include "tensorflow/compiler/xla/statusor.h"
39 #include "tensorflow/compiler/xla/types.h"
40 #include "tensorflow/compiler/xla/xla_data.pb.h"
41 #include "tensorflow/core/lib/core/status.h"
42 #include "tensorflow/core/platform/macros.h"
43 #include "tensorflow/core/platform/types.h"
44 
45 namespace xla {
46 
47 class HloModule;
48 
49 // Describes a computation at the HLO level.
50 //
51 // You can think of an HloComputation like a function.  It has some inputs
52 // (parameters) and returns exactly one value (the value of its root node).  If
53 // you want to return multiple values, you can return a tuple.
54 //
55 // The instructions inside of a computation do not have an explicit total order.
56 // Instead, they have a partial order determined by their data and control
57 // dependencies.
58 //
59 // An HloModule contains one "entry computation" -- this is like main() in a C
60 // program.  Every other computation inside of a module is attached to one or
61 // more HloInstructions, as a "nested computation".  For example, the kMap
62 // instruction has a nested computation and "applies" it to every element of its
63 // input, elementwise.  (That is, the input [x, y, z] is transformed to [f(x),
64 // f(y), f(z)].)
65 class HloComputation {
66  public:
67   // Builder class for HloComputation.
68   class Builder {
69    public:
70     explicit Builder(const string& name,
71                      HloInstruction* fusion_instruction = nullptr)
name_(name)72         : name_(name),
73           last_added_instruction_(nullptr),
74           fusion_instruction_(fusion_instruction) {}
75 
76     // Build and return an HloComputation. The parameter root_instruction
77     // specifies the already-added instruction to use as the root. If
78     // root_instruction is nullptr then use the last added instruction as the
79     // root.
80     std::unique_ptr<HloComputation> Build(
81         HloInstruction* root_instruction = nullptr);
82 
AddInstruction(std::unique_ptr<HloInstruction> instruction)83     HloInstruction* AddInstruction(
84         std::unique_ptr<HloInstruction> instruction) {
85       instructions_.push_back(std::move(instruction));
86       last_added_instruction_ = instructions_.back().get();
87       return last_added_instruction_;
88     }
89 
ForEachInstruction(const std::function<Status (const HloInstruction *)> & func)90     Status ForEachInstruction(
91         const std::function<Status(const HloInstruction*)>& func) const {
92       for (const auto& instruction : instructions_) {
93         TF_RETURN_IF_ERROR(func(instruction.get()));
94       }
95       return Status::OK();
96     }
97 
98    private:
99     const string name_;
100     HloInstruction* last_added_instruction_;
101     HloInstruction* fusion_instruction_;
102     std::vector<std::unique_ptr<HloInstruction>> instructions_;
103   };
104 
105   // Helper class to automatically set the OpMetadata for every instruction
106   // added to a computation.
107   class MetadataBuilder {
108    public:
MetadataBuilder(HloComputation * computation,const OpMetadata & metadata)109     MetadataBuilder(HloComputation* computation, const OpMetadata& metadata)
110         : computation_(computation), metadata_(metadata) {}
111 
AddInstruction(std::unique_ptr<HloInstruction> instruction)112     HloInstruction* AddInstruction(
113         std::unique_ptr<HloInstruction> instruction) {
114       instruction->set_metadata(metadata_);
115       return computation_->AddInstruction(std::move(instruction));
116     }
117 
118    private:
119     HloComputation* computation_;
120     OpMetadata metadata_;
121   };
122 
123   // Add an instruction to the computation. The computation takes ownership of
124   // the instruction.
125   HloInstruction* AddInstruction(std::unique_ptr<HloInstruction> instruction,
126                                  const std::string& new_name = "");
127 
128   // Remove the param_no'th parameter from the computation.
129   // Note this is only applicatable to the computation for the fusion
130   // instruction.
131   Status RemoveParameter(int64 param_no);
132 
133   // Remove unused parameters from the computation.
134   // Note this is only applicatable to the computation for the fusion
135   // instruction.
136   Status RemoveUnusedParametersFromFusedComputation();
137 
138   // Remove unused parameters from the computation. Unlike
139   // RemoveUnusedParametersFromFusedComputation, this function can be used
140   // to remove parameters from non-fusion computations.
141   Status RemoveUnusedParametersFromAnyComputation();
142 
143   // Adds a new parameter instruction to a fusion computation.
144   //
145   // This should be a new parameter. Instruction will be appended to parameters
146   // and inserted to the instruction list.
147   HloInstruction* AddParameter(std::unique_ptr<HloInstruction> instruction);
148 
149   // Adds a new parameter instruction to the entry computation and update
150   // the parent module config to reflect the change.
151   //
152   // This should be a new parameter. Instruction will be appended to parameters
153   // and inserted to the instruction list.
154   HloInstruction* AddEntryComputationParameter(
155       std::unique_ptr<HloInstruction> instruction);
156 
157   // Replaces an old parameter with a new parameter. Adds the new parameter
158   // instruction to the entry computation.
159   Status ReplaceEntryComputationParameter(
160       int64 param_no, HloInstruction* old_instruction,
161       std::unique_ptr<HloInstruction> instruction);
162 
163   // Remove an instruction from the computation. The instruction must have no
164   // users. Instruction is deallocated with this call.
165   Status RemoveInstruction(HloInstruction* instruction);
166 
167   // Removes an instruction from the computation. The instruction must have no
168   // users. Instruction is deallocated with this call. The instruction will be
169   // removed even if it is marked as not removable.
170   Status ForceRemoveInstruction(HloInstruction* instruction);
171 
172   // Remove an instruction (including side effecting ones) from the computation
173   // and also transitively any operand that has no side effect and no users post
174   // removing an instruction. The instruction must have no users. Instruction is
175   // deallocated with this call. If given, the cleanup routine is executed on a
176   // removed instruction before its deallocation.
177   Status RemoveInstructionAndUnusedOperands(
178       HloInstruction* instruction,
179       std::function<void(HloInstruction*)> cleanup = nullptr);
180 
181   // Set the root of the computation to the given instruction. The instruction
182   // must have already been added to the computation. In addition it must have
183   // the same shape as the result of the computation for non fusion
184   // computations, except if accept_different_shape is set to true.
185   void set_root_instruction(HloInstruction* new_root_instruction,
186                             bool accept_different_shape = false);
187 
188   // Return the root instruction of the computation. The root instruction is the
189   // instruction which produces the output of the computation.
root_instruction()190   HloInstruction* root_instruction() const { return root_instruction_; }
191 
192   // Returns the number of parameters for this computation.
num_parameters()193   int64 num_parameters() const { return param_instructions_.size(); }
194 
195   // Returns the parameter instruction for the given parameter number.
parameter_instruction(int64 param_no)196   HloInstruction* parameter_instruction(int64 param_no) const {
197     CHECK_GE(param_no, 0);
198     CHECK_LT(param_no, static_cast<int64>(param_instructions_.size()))
199         << "Computation " << name() << " has no parameter number " << param_no;
200     return param_instructions_[param_no];
201   }
202 
parameter_instructions()203   const std::vector<HloInstruction*>& parameter_instructions() const {
204     return param_instructions_;
205   }
206 
name()207   const string& name() const { return name_; }
208 
209   // Use the given NameUniquer to select a unique name for the computation based
210   // on the computation's existing name.
211   void UniquifyName(NameUniquer* name_uniquer);
212 
213   // Return a string representation of the computation.
214   //
215   // (We express the default options using an overload rather than a default
216   // param because gdb ignores default params, but does resolve overloads.)
ToString()217   string ToString() const { return ToString(HloPrintOptions()); }
218   string ToString(const HloPrintOptions& options) const;
219 
220   // Overload which accepts an order to emit the instructions in.
221   string ToString(
222       const HloPrintOptions& options,
223       absl::Span<const HloInstruction* const> instruction_order) const;
224 
225   // Returns a serialized representation of this computation.
226   HloComputationProto ToProto() const;
227 
228   // Creates a computation from the given proto. Arguments:
229   //
230   //   proto: the proto to convert from.
231   //   computation_map: a map from computation id to HloComputation*. This map
232   //     must contain all computations which the newly constructed computation
233   //     calls.
234   static StatusOr<std::unique_ptr<HloComputation>> CreateFromProto(
235       const HloComputationProto& proto,
236       const absl::flat_hash_map<int64, HloComputation*>& computation_map,
237       bool prohibit_empty_literal = true);
238 
239   using InstructionSequence = tensorflow::gtl::iterator_range<
240       UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>;
241 
242   using ConstInstructionSequence =
243       tensorflow::gtl::iterator_range<UnwrappingIterator<
244           std::list<std::unique_ptr<HloInstruction>>::const_iterator>>;
245 
246   // Gets the instructions in this computation.
247   //
248   // The returned type is a range of HloInstruction*s, so you can iterate over
249   // it using a range-based for loop in the natural way:
250   //
251   //   for (HloInstruction* instr : computation->instructions()) { ... }
252   //
instructions()253   ConstInstructionSequence instructions() const {
254     return {MakeUnwrappingIterator(instructions_.begin()),
255             MakeUnwrappingIterator(instructions_.end())};
256   }
instructions()257   InstructionSequence instructions() {
258     return {MakeUnwrappingIterator(instructions_.begin()),
259             MakeUnwrappingIterator(instructions_.end())};
260   }
261 
262   // Compute and return a post-order of the instructions in the computation. In
263   // this order, definitions of values always appear before their uses.
264   std::vector<HloInstruction*> MakeInstructionPostOrder() const;
265 
instruction_count()266   int64 instruction_count() const { return instruction_iterators_.size(); }
267 
268   // Creates and returns a list of the embedded computations called by this
269   // computation. This includes all embedded computations called directly or
270   // transitively. The embedded computations are sorted such that if computation
271   // A calls computation B (eg, via a map instruction) then A will appear after
272   // B in the list.
273   std::vector<HloComputation*> MakeEmbeddedComputationsList() const;
274 
275   // Creates a fusion instruction containing the given instructions.
276   // `fusion_kind` indicates the type of the fusion, e.g., loop fusion or fusion
277   // into a library call. Instructions must be in reverse topological order
278   // (root of the fused expression first). Replaces all uses of the original
279   // root instruction with the fusion instruction. The original instructions are
280   // removed if they have no uses after fusion (this is necessarily true for at
281   // least the root).
282   HloInstruction* CreateFusionInstruction(
283       absl::Span<HloInstruction* const> instructions_to_fuse,
284       HloInstruction::FusionKind fusion_kind);
285 
286   // Create a deep copy of the given instruction and return the instruction
287   // producing the copied result. All instructions performing the copy are added
288   // to the computation. For array-shaped values, this method trivially returns
289   // a kCopy instruction. For tuple-shaped instructions, the copy is performed
290   // with a series of kGetTupleElement and kTuple instructions. If
291   // indices_to_copy is non-null then this ShapeTree indicates which elements
292   // (arrays) of the shape to copy. Non-copied elements are passed through
293   // transparently. If copies_added is non-null, then the added kCopy
294   // instructions will be inserted in the respective index in the given
295   // ShapeTree.
296   StatusOr<HloInstruction*> DeepCopyInstruction(
297       HloInstruction* instruction,
298       const ShapeTree<bool>* indices_to_copy = nullptr,
299       ShapeTree<HloInstruction*>* copies_added = nullptr);
300 
301   // As above, but uses a custom function to copy the leaf nodes, which could
302   // create alternative HLOs other than kCopy, or even pass-throughs.
303   StatusOr<HloInstruction*> DeepCopyInstructionWithCustomCopier(
304       HloInstruction* instruction,
305       const std::function<
306           HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
307                           HloComputation* computation)>& copy_leaf);
308 
309   // Computes and returns the ProgramShape of this computation (shape of
310   // parameters and result with layout).
311   ProgramShape ComputeProgramShape(bool include_ids = true) const;
312 
313   // Return whether `*this` and `other` are functionally equivalent.
Equal(const HloComputation & other,bool is_layout_sensitive)314   bool Equal(const HloComputation& other, bool is_layout_sensitive) const {
315     return EqualInternal(other, is_layout_sensitive,
316                          /*ignore_channel_id_values=*/false);
317   }
318 
319   // Same as Equal() but ignores channel ID value mismatches on instructions, as
320   // long as the two instructions both have channel IDs or neither has a channel
321   // ID.
EqualIgnoringChannelIdValues(const HloComputation & other,bool is_layout_sensitive)322   bool EqualIgnoringChannelIdValues(const HloComputation& other,
323                                     bool is_layout_sensitive) const {
324     return EqualInternal(other, is_layout_sensitive,
325                          /*ignore_channel_id_values=*/true);
326   }
327 
328   // Return whether `*this` and `other` are functionally equivalent.
329   bool operator==(const HloComputation& other) const {
330     return Equal(other, true);
331   }
332 
333   // Replaces old instruction with newly created instruction. Removes old
334   // instruction from computation. Updates uses and root instruction.
335   Status ReplaceWithNewInstruction(
336       HloInstruction* old_instruction,
337       std::unique_ptr<HloInstruction> new_instruction);
338 
339   // Replaces an old instruction with a newly created instruction, and adds the
340   // new instruction as an entry computation's parameter. Removes old
341   // instruction from computation. Updates uses and root instruction.
342   Status ReplaceWithNewEntryComputationParameter(
343       HloInstruction* old_instruction,
344       std::unique_ptr<HloInstruction> new_instruction);
345 
346   // Replace old instruction with new instruction.  Updates uses and root
347   // instruction. Removes old instruction from computation. Precondition:
348   // old_instruction and new_instruction must have the compatible shapes.
349   // If |new_instruction| doesn't have any sharding information it will
350   // receive the sharding information of |old_instruction|.
351   Status ReplaceInstruction(HloInstruction* old_instruction,
352                             HloInstruction* new_instruction);
353 
354   // Set/get the module containing this computation.
set_parent(HloModule * module)355   void set_parent(HloModule* module) { parent_ = module; }
parent()356   const HloModule* parent() const { return parent_; }
parent()357   HloModule* parent() { return parent_; }
358 
359   // Visit every node in the computation in DFS post-order with the given
360   // visitor. This is similar to calling HloInstruction::Accept on the root of
361   // the computation except this method also visits instructions not reachable
362   // via the root. The root instruction of the computation is visited last, and
363   // the visitor's FinishVisit method is called once upon completion (with the
364   // root instruction as the argument).
365   template <typename HloInstructionPtr>
366   Status Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor) const;
367 
368   // Same as Accept() above, but the order of operand and control predecessor
369   // visitation is determined by the given operand order; if compare(A, B) ==
370   // true, A is visited before B.
371   Status AcceptWithOperandOrder(
372       DfsHloVisitor* visitor,
373       const HloInstruction::CompareFunction& operand_order) const;
374 
375   // Visit every node in the computation in the given order. 'order' must
376   // be a topological sort of all instructions in the computation.
377   template <typename HloInstructionPtr>
378   Status AcceptOrdered(DfsHloVisitorBase<HloInstructionPtr>* visitor,
379                        absl::Span<HloInstruction* const> order) const;
380 
381   // Returns a deep copy of this computation including all instructions.
382   // If the clone context is specified, it will be populated with the cloned
383   // object mappings, and its module() will be used to add new computations
384   // into.
385   std::unique_ptr<HloComputation> Clone(const string& suffix = "clone",
386                                         HloCloneContext* context = nullptr);
387 
388   // Like Clone(), but if an instruction is present in replacement_map, we use
389   // the map's value to replace that instruction in the cloned computation.
390   //
391   // If replacements maps a key to nullptr, we remove that instruction from the
392   // new computation.  If an element of `replacements` references an instruction
393   // that's not already in the computation, it's cloned and added to the new
394   // computation.
395   //
396   // 'extra_parameters' allows to specify additional parameters that should be
397   // added to the computation.
398   //
399   // All relevant instructions are cloned, *including* unique_ptr in the
400   // `replacements` map.
401   std::unique_ptr<HloComputation> CloneWithReplacements(
402       absl::flat_hash_map<const HloInstruction*,
403                           std::unique_ptr<HloInstruction>>
404           replacements,
405       absl::Span<const HloInstruction* const> extra_parameters = {},
406       HloCloneContext* context = nullptr, const string& suffix = "clone",
407       const HloInstruction* new_root = nullptr);
408 
409   // Convenience overloads for CloneWithReplacements.  You want to do
410   //
411   //   CloneWithReplacements({{a, std::move(b)}, {c, std::move(d)}})  // ERROR
412   //
413   // but that doesn't work because std::initializer_list is not movable.  These
414   // overloads let you do
415   //
416   //   CloneWithReplacementPairs({a, std::move(b)}, {c, std::move(d)});   // OK
417   //
418   std::unique_ptr<HloComputation> CloneWithReplacementPairs(
419       std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
420       HloCloneContext* context = nullptr, const string& suffix = "clone");
421   std::unique_ptr<HloComputation> CloneWithReplacementPairs(
422       std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
423       std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
424       HloCloneContext* context = nullptr, const string& suffix = "clone");
425   std::unique_ptr<HloComputation> CloneWithReplacementPairs(
426       std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
427       std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
428       std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r3,
429       HloCloneContext* context = nullptr, const string& suffix = "clone");
430 
431   // Returns true if the given instruction can be removed from the computation.
432   // Parameter instructions cannot be removed without violating invariants of
433   // the HLO computation with the exception of fusion computation. A parameter
434   // instruction is removable for a fusion computation.
435   //
436   // Note that IsSafelyRemovable() is a necessary condition to remove an
437   // instruction rather than a sufficient condition. For example, instructions
438   // with side-effect (e.g., Send, Infeed) may be removed from a computation,
439   // but the transformation must guarantee the invariants relevant to the
440   // instructions still hold (e.g., Send and Recv must be removed together to
441   // make each channel complete).
442   bool IsSafelyRemovable(const HloInstruction* instruction);
443 
444   // Returns a map from channel-id to the group of instructions associated with
445   // the channel. These instructions will be considered as a single node for
446   // dependency purposes. Send and RecvDone are in the group, and AllReduces
447   // with the same channel id are in the group.
448   using ChannelDependencyGroup =
449       absl::flat_hash_map<int64, absl::InlinedVector<HloInstruction*, 1>>;
450   ChannelDependencyGroup ComputeChannelDependencies() const;
451 
452   // Returns true if this computation has a side effect. A computation has a
453   // side effect if it contains one or more instructions with a side effect.
454   bool HasSideEffect() const;
455 
456   // Returns if this computation is a fusion computation.
IsFusionComputation()457   bool IsFusionComputation() const { return fusion_instruction_ != nullptr; }
458 
459   // Returns if this computation is the entry computation of the module.
460   bool IsEntryComputation() const;
461 
462   // Returns the owning fusion instruction, or nullptr if this is not a fusion
463   // computation.
FusionInstruction()464   HloInstruction* FusionInstruction() const { return fusion_instruction_; }
SetFusionInstruction(HloInstruction * fusion_instruction)465   void SetFusionInstruction(HloInstruction* fusion_instruction) {
466     fusion_instruction_ = fusion_instruction;
467   }
468 
469   // Clear the unique ID of the computation so that it can be re-assigned, such
470   // as for the purpose of compacting the unique IDs.
ClearUniqueIdInternal()471   void ClearUniqueIdInternal() { unique_id_ = -1; }
472 
473   // The id of this computation should be unique within the module.
SetUniqueId(int64 id)474   void SetUniqueId(int64 id) {
475     CHECK_EQ(unique_id_, -1);
476     CHECK_GE(id, 0);
477     unique_id_ = id;
478   }
479 
480   // Returns the instruction in this computation that has name `name`.  Returns
481   // null if there is no such computation.
482   HloInstruction* GetInstructionWithName(absl::string_view name);
483 
unique_id()484   int64 unique_id() const { return unique_id_; }
485 
486   // Deallocate instructions that are marked by "RemoveInstruction". The two
487   // stage clean up process is designed such that HloPass can have stable
488   // internal pointers to HloInstructions while we create and remove
489   // HloInstructions in a pass.
Cleanup()490   void Cleanup() { to_be_deleted_.clear(); }
491 
492   // Returns true if a given instruction is marked dead in this computation.
493   bool IsMarkedAsDead(const HloInstruction* inst);
494 
495  private:
496   explicit HloComputation(
497       const string& name, int parameter_count,
498       std::vector<std::unique_ptr<HloInstruction>>* instructions,
499       HloInstruction* root_instruction, HloInstruction* fusion_instruction);
500 
501   // Internal helper for adding instructions.
502   HloInstruction* AddInstructionInternal(
503       std::unique_ptr<HloInstruction> instruction);
504 
505   // Internal helper for comparison with different options.
506   bool EqualInternal(const HloComputation& other, bool is_layout_sensitive,
507                      bool ignore_channel_id_values) const;
508 
509   // Fuses HLOs in instructions_to_fuse into fusion_instruction.
510   //
511   // Pre-condition: fusion_instruction's opcode is kFusion.
512   void FuseInstructionsInto(
513       absl::Span<HloInstruction* const> instructions_to_fuse,
514       HloInstruction* fusion_instruction);
515 
516   // Internal helper for recursive copying of an instruction. Creates and
517   // returns a deep copy of the given instruction.
518   StatusOr<HloInstruction*> DeepCopyHelper(
519       HloInstruction* instruction, ShapeIndex* index,
520       const std::function<
521           HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
522                           HloComputation* computation)>& copy_leaf);
523 
524   // Internal helper to collect unreachable roots.
525   std::vector<HloInstruction*> CollectUnreachableRoots() const;
526 
527   enum VisitState { kVisiting, kVisited };
528   void ComputeInstructionPostOrder(
529       const HloComputation::ChannelDependencyGroup& channel_dependency_group,
530       std::vector<HloInstruction*>* post_order, HloInstruction* root,
531       absl::flat_hash_map<HloInstruction*, VisitState>* visited) const;
532 
533   Status RemoveUnusedParametersImpl(bool allow_non_fusion);
534 
535   Status RemoveInstructionImpl(HloInstruction* instruction,
536                                bool ignore_safety_check);
537 
538   string name_;
539   int64 unique_id_;
540   HloInstruction* root_instruction_;
541 
542   // If this computation is a fusion computation, this field points to the
543   // corresponding fusion instruction.  Otherwise, this is null.
544   HloInstruction* fusion_instruction_;
545 
546   // Module containing this computation.
547   HloModule* parent_ = nullptr;
548 
549   // Store instructions in std::list as they can be added and removed
550   // arbitrarily and we want a stable iteration order. Keep a map from
551   // instruction pointer to location in the list for fast lookup.
552   using InstructionList = std::list<std::unique_ptr<HloInstruction>>;
553   InstructionList instructions_;
554   absl::flat_hash_map<const HloInstruction*, InstructionList::iterator>
555       instruction_iterators_;
556 
557   // Removed instructions are moved into to_be_deleted_ first and then
558   // deallocated when Cleanup is called.
559   std::vector<std::unique_ptr<HloInstruction>> to_be_deleted_;
560 
561   std::vector<HloInstruction*> param_instructions_;
562 
563   TF_DISALLOW_COPY_AND_ASSIGN(HloComputation);
564 };
565 
566 template <typename HloInstructionPtr>
Accept(DfsHloVisitorBase<HloInstructionPtr> * visitor)567 Status HloComputation::Accept(
568     DfsHloVisitorBase<HloInstructionPtr>* visitor) const {
569   // Visit unreachable roots. Beware that the visitor might delete the currently
570   // visited root, which would invalidate iterators if the unreachable roots
571   // weren't computed ahead of time.
572   for (HloInstruction* root : CollectUnreachableRoots()) {
573     VLOG(3) << "Traversing unreachable root: " << root->ToString();
574     // Call FinishVisit only at the end.
575     TF_RETURN_IF_ERROR(root->Accept(visitor, /*call_finish_visit=*/false));
576   }
577   // Visit the computation root instruction last.
578   return root_instruction()->Accept(visitor, /*call_finish_visit=*/true);
579 }
580 
581 // Explicit instantiations.
582 template Status HloComputation::Accept(DfsHloVisitor* visitor) const;
583 template Status HloComputation::Accept(ConstDfsHloVisitor* visitor) const;
584 
585 template <typename HloInstructionPtr>
AcceptOrdered(DfsHloVisitorBase<HloInstructionPtr> * visitor,absl::Span<HloInstruction * const> order)586 Status HloComputation::AcceptOrdered(
587     DfsHloVisitorBase<HloInstructionPtr>* visitor,
588     absl::Span<HloInstruction* const> order) const {
589   VLOG(3) << "Accepting visitor with order.";
590   for (HloInstruction* root : CollectUnreachableRoots()) {
591     TF_RET_CHECK(absl::c_linear_search(order, root)) << root->ToString();
592   }
593   TF_RET_CHECK(order.size() == instruction_count());
594   absl::flat_hash_set<const HloInstruction*> visited;
595   for (const HloInstruction* instruction : order) {
596     VLOG(3) << "Visiting ordered: " << instruction->ToString();
597     TF_RET_CHECK(instruction_iterators_.contains(instruction))
598         << "Instruction " << instruction->name() << " is not in computation "
599         << name();
600     TF_RET_CHECK(!visited.contains(instruction))
601         << "Instruction " << instruction->name()
602         << " appears more than once in order";
603     HloInstruction* mutable_instruction =
604         const_cast<HloInstruction*>(instruction);
605     TF_RETURN_IF_ERROR(visitor->Preprocess(mutable_instruction));
606     TF_RETURN_IF_ERROR(mutable_instruction->Visit(visitor));
607     visitor->SetVisited(*mutable_instruction);
608     TF_RETURN_IF_ERROR(visitor->Postprocess(mutable_instruction));
609     visited.insert(instruction);
610   }
611   TF_RETURN_IF_ERROR(visitor->FinishVisit(root_instruction()));
612   return Status::OK();
613 }
614 
615 // Explicit instantiations.
616 template Status HloComputation::AcceptOrdered(
617     DfsHloVisitor*, absl::Span<HloInstruction* const>) const;
618 template Status HloComputation::AcceptOrdered(
619     ConstDfsHloVisitor*, absl::Span<HloInstruction* const>) const;
620 
621 }  // namespace xla
622 
623 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
624