1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LAYOUT_ASSIGNMENT_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LAYOUT_ASSIGNMENT_H_
18 
19 #include <iosfwd>
20 #include <map>
21 #include <memory>
22 #include <set>
23 #include <string>
24 #include <unordered_map>
25 #include <utility>
26 #include <vector>
27 
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/container/flat_hash_set.h"
30 #include "tensorflow/compiler/xla/layout_util.h"
31 #include "tensorflow/compiler/xla/service/call_graph.h"
32 #include "tensorflow/compiler/xla/service/computation_layout.h"
33 #include "tensorflow/compiler/xla/service/hlo_computation.h"
34 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
35 #include "tensorflow/compiler/xla/service/hlo_module.h"
36 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
37 #include "tensorflow/compiler/xla/service/logical_buffer.h"
38 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
39 #include "tensorflow/compiler/xla/shape_layout.h"
40 #include "tensorflow/compiler/xla/shape_util.h"
41 #include "tensorflow/compiler/xla/statusor.h"
42 #include "tensorflow/compiler/xla/types.h"
43 #include "tensorflow/compiler/xla/xla_data.pb.h"
44 #include "tensorflow/core/lib/core/status.h"
45 #include "tensorflow/core/platform/types.h"
46 
47 namespace xla {
48 
49 // Abstract base class for layout constraints. These constraint objects are
50 // gathered together in LayoutConstraints object.
51 class LayoutConstraint {
52  public:
LayoutConstraint(bool mandatory,bool dfs)53   LayoutConstraint(bool mandatory, bool dfs)
54       : mandatory_(mandatory), dfs_(dfs) {}
55   virtual ~LayoutConstraint() = default;
56 
57   virtual string ToString() const = 0;
58 
59   // True if this constraint cannot be overwritten by a different constraint.
mandatory()60   bool mandatory() const { return mandatory_; }
61 
62   // When true, propagate in DFS. When false, constraint will propagate in BFS.
dfs()63   bool dfs() const { return dfs_; }
64 
65  private:
66   bool mandatory_;
67   bool dfs_;
68 };
69 
70 std::ostream& operator<<(std::ostream& out, const LayoutConstraint& constraint);
71 
72 // Layout constraint on a single LogicalBuffer. This constrains the layout of an
73 // array produced by a particular instruction.
74 class BufferLayoutConstraint : public LayoutConstraint {
75  public:
76   BufferLayoutConstraint(const Layout& layout, const LogicalBuffer& buffer,
77                          bool mandatory, bool dfs);
78 
buffer()79   const LogicalBuffer& buffer() const { return *buffer_; }
layout()80   const Layout& layout() const { return layout_; }
81 
82   string ToString() const override;
83 
84  private:
85   Layout layout_;
86   const LogicalBuffer* buffer_;
87 };
88 
89 // Constraint on the layout of the operand of an instruction. The constrained
90 // shape can be arbitrarily shaped (array or tuple). This is a constraint on the
91 // use of a shaped value and is not a hard constraint on the instruction(s)
92 // which define the value as copies may be inserted between the definition and
93 // use.
94 class OperandLayoutConstraint : public LayoutConstraint {
95  public:
96   OperandLayoutConstraint(const ShapeLayout& shape_layout,
97                           const HloInstruction* instruction, int64 operand_no,
98                           bool mandatory, bool dfs);
99 
shape_layout()100   const ShapeLayout& shape_layout() const { return shape_layout_; }
instruction()101   const HloInstruction* instruction() const { return instruction_; }
operand_no()102   const int64 operand_no() const { return operand_no_; }
operand()103   const HloInstruction* operand() const {
104     return instruction_->operand(operand_no_);
105   }
106 
107   string ToString() const override;
108 
109  private:
110   ShapeLayout shape_layout_;
111   const HloInstruction* instruction_;
112   int64 operand_no_;
113 };
114 
115 // Constraint on the layout of the result of the entry computation.
116 class ResultLayoutConstraint : public LayoutConstraint {
117  public:
118   explicit ResultLayoutConstraint(const ShapeLayout& shape_layout,
119                                   bool dfs = false)
LayoutConstraint(true,dfs)120       : LayoutConstraint(/*mandatory=*/true, dfs),
121         shape_layout_(shape_layout) {}
122 
shape_layout()123   const ShapeLayout& shape_layout() const { return shape_layout_; }
124   string ToString() const override;
125 
126  private:
127   const ShapeLayout shape_layout_;
128 };
129 
130 // Class encapsulating the layout constraints of the values in a HLO
131 // computation.
132 class LayoutConstraints {
133  public:
134   LayoutConstraints(const TuplePointsToAnalysis& points_to_analysis,
135                     HloComputation* computation);
136   ~LayoutConstraints() = default;
137 
computation()138   const HloComputation* computation() const { return computation_; }
computation()139   HloComputation* computation() { return computation_; }
points_to_analysis()140   const TuplePointsToAnalysis& points_to_analysis() const {
141     return points_to_analysis_;
142   }
143 
144   // Return a vector containing the constraints which have been added to the
145   // LayoutConstraints object since the construction of the object or since the
146   // last time ConsumeAddedConstraints() has been called. This is used to
147   // identify newly added constraints when propagating layouts.
ConsumeAddedConstraints()148   std::vector<const LayoutConstraint*> ConsumeAddedConstraints() {
149     std::vector<const LayoutConstraint*> ret_vec(std::move(added_constraints_));
150     added_constraints_.clear();
151     return ret_vec;
152   }
ClearAddedConstraints()153   void ClearAddedConstraints() { added_constraints_.clear(); }
154 
155   // Returns the layout of a LogicalBuffer, the layout of the operand of the
156   // instruction, or the layout of the result of the computation, respectively,
157   // if it has been constrained. Otherwise return nullptr.
158   const Layout* BufferLayout(const LogicalBuffer& buffer) const;
159   const BufferLayoutConstraint* GetBufferLayoutConstraint(
160       const LogicalBuffer& buffer) const;
161   const ShapeLayout* OperandLayout(const HloInstruction* instruction,
162                                    int64 operand_no) const;
163   const OperandLayoutConstraint* GetOperandLayoutConstraint(
164       const HloInstruction* instruction, int64 operand_no) const;
165   const ShapeLayout* ResultLayout() const;
166 
167   // Add a constraint on the layout of a LogicalBuffer, the layout of the
168   // operand of the instruction, or the layout of the result of the computation,
169   // respectively.
170   Status SetBufferLayout(const Layout& layout, const LogicalBuffer& buffer,
171                          bool mandatory = true, bool dfs = true);
172   Status SetOperandLayout(const Shape& shape_with_layout,
173                           const HloInstruction* instruction, int64 operand_no,
174                           bool mandatory = true, bool dfs = true);
175   Status SetResultLayout(const Shape& shape_with_layout, bool dfs = true);
176 
177   // Convenience wrapper around SetOperandLayout for setting the layout of a
178   // operand using a Layout object. The operand must be array-shaped.
179   Status SetArrayOperandLayout(const Layout& layout,
180                                const HloInstruction* instruction,
181                                int64 operand_no, bool mandatory = true,
182                                bool dfs = true);
183 
184   // Convenience wrapper around SetBufferLayout. Sets the layouts of all buffers
185   // created by the instruction to the layouts in the given shape. The
186   // instruction must define every logical buffer in its output.
187   Status SetInstructionLayout(const Shape& shape_with_layout,
188                               const HloInstruction* instruction,
189                               bool mandatory = true, bool dfs = true);
190 
191   // Returns true if any buffer in the given operand is forwarded to the output
192   // of the given instruction. For example, the Tuple instruction forwards the
193   // buffers of its operands and would return true for each of its operands.
194   bool OperandBufferForwarded(const HloInstruction* instruction,
195                               int64 operand_no) const;
196 
197   // Returns the set of logical buffers (by LogicalBuffer:Id) which do not
198   // yet have a layout constraint
unconstrained_buffer_ids()199   const std::set<LogicalBuffer::Id>& unconstrained_buffer_ids() const {
200     return unconstrained_buffer_ids_;
201   }
202 
203   string ToString() const;
204 
205  private:
206   // Find a bufferset in the bufferset cache. This is useful since we can
207   // currently create the flattened buffer set for the same instruction many
208   // times, which is often slow.
209   PointsToSet::BufferSet* GetBufferSet(const HloInstruction* instruction) const;
210 
211   // The set of BufferLayoutConstraints applied to the computation.
212   std::unordered_map<const LogicalBuffer*, BufferLayoutConstraint>
213       buffer_constraints_;
214 
215   // The set of OperandLayoutConstraints applied to the computation.
216   using OperandConstraintKey = std::pair<const HloInstruction*, int64>;
217   std::map<OperandConstraintKey, OperandLayoutConstraint> operand_constraints_;
218 
219   // The result constraint for the computation (can be null).
220   std::unique_ptr<ResultLayoutConstraint> result_constraint_;
221 
222   // A vector which holds constraints as they are added. Can be cleared with
223   // ClearAddedConstraints.
224   std::vector<const LayoutConstraint*> added_constraints_;
225 
226   // Points-to analysis for the module. Used to propagate constraints through
227   // the HLO graph.
228   const TuplePointsToAnalysis& points_to_analysis_;
229 
230   // Array-shaped buffers which have not yet been constrained.
231   std::set<LogicalBuffer::Id> unconstrained_buffer_ids_;
232 
233   mutable absl::flat_hash_map<const HloInstruction*,
234                               std::unique_ptr<PointsToSet::BufferSet>>
235       buffer_sets_cache_;
236 
237   HloComputation* computation_;
238 };
239 
240 // Contains constraints on the layout of channels; sends and recvs.
241 class ChannelLayoutConstraints {
242  public:
243   // Construct an empty constraint set.
ChannelLayoutConstraints()244   ChannelLayoutConstraints() {}
245 
246   // Returns true if channel_id has a layout constraint.
IsChannelConstrained(int64 channel_id)247   bool IsChannelConstrained(int64 channel_id) const {
248     return constraints_.contains(channel_id);
249   }
250 
251   // Given `shape`, apply the layout for `channel_id`. `channel_id` must already
252   // be constrained.
LayoutShapeForChannel(Shape shape,int64 channel_id)253   Shape LayoutShapeForChannel(Shape shape, int64 channel_id) const {
254     auto it = constraints_.find(channel_id);
255     CHECK(it != constraints_.end()) << "Channel " << channel_id;
256     *shape.mutable_layout() = it->second;
257     return shape;
258   }
259 
260   // Returns the layout constraint for `channel_id`, which must already be
261   // constrained.
LayoutForChannel(int64 channel_id)262   const Layout& LayoutForChannel(int64 channel_id) const {
263     auto it = constraints_.find(channel_id);
264     CHECK(it != constraints_.end()) << "Channel " << channel_id;
265     return it->second;
266   }
267 
268   // Adds a new layout constraint for `channel_id`. If a constraint for
269   // `channel_id` has been added, this API returns nullptr, otherwise returns
270   // the layout which has already been set for the channel.
ConstrainChannel(int64 channel_id,const Layout & layout)271   const Layout* ConstrainChannel(int64 channel_id, const Layout& layout) {
272     auto it = constraints_.emplace(std::make_pair(channel_id, layout));
273     if (it.second) {
274       return nullptr;
275     }
276     return LayoutUtil::Equal(layout, it.first->second) ? nullptr
277                                                        : &it.first->second;
278   }
279 
280  private:
281   absl::flat_hash_map<int64, Layout> constraints_;
282 };
283 
284 // HLO pass which assigns layouts to all instructions in the HLO module while
285 // satisfying all necessary invariants and minimizing cost.
286 class LayoutAssignment : public HloModulePass {
287  public:
288   // entry_computation_layout is modified to populate a layout for the result in
289   // the case that no particular layout is requested.
290   //
291   // instruction_can_change_layout_func is a function object that determines
292   // whether an instruction can change layouts. An instruction not being able to
293   // change layout means that it requires operands with the same rank as the
294   // output to have the same layout as the output.
295   //
296   // channel_constraints is both an input and output. Any sends or recvs that
297   // are present in channel_constraints will be laid out as constrained. Any
298   // unconstrained sends or recvs will be laid out as locally optimal and their
299   // layout will be added as a constraint to channel_constraints.
300   //
301   // If channel_constraints is nullptr, no kSend or kRecvs must be contained
302   // within any module passed to `Run`.
303   explicit LayoutAssignment(
304       ComputationLayout* entry_computation_layout,
305       std::function<bool(const HloInstruction*)>
306           instruction_can_change_layout_func = InstructionCanChangeLayout,
307       ChannelLayoutConstraints* channel_constraints = nullptr);
~LayoutAssignment()308   ~LayoutAssignment() override {}
name()309   absl::string_view name() const override { return "layout-assignment"; }
310 
311   // Assign layouts to the given module. Returns whether the module was changed
312   // (any layouts were changed).
313   StatusOr<bool> Run(HloModule* module) override;
314 
315   // Determines whether an instruction can change layouts. An instruction not
316   // being able to change layout means that it requires operands with the same
317   // rank as the output to have the same layout as the output.
318   static bool InstructionCanChangeLayout(const HloInstruction* instruction);
319 
320   // In case of an array shape returns true iff it is at most rank 1. In case of
321   // a tuple shape returns true iff all leaf shapes are at most rank 1.
322   static bool IsAtMostRank1(const Shape& shape);
323 
324  protected:
325   // These methods, invoked by PropagateConstraints, propagate a layout
326   // constraint to its neighbors (i.e. operands and users) in order to minimize
327   // the cost of the instructions being constrainted on. New constraints are
328   // added to the given constraint set.
329   //
330   // Backends can override these methods with backend-specific propagation
331   // rules.
332   virtual Status PropagateBufferConstraint(
333       const BufferLayoutConstraint& layout_constraint,
334       LayoutConstraints* constraints);
335   virtual Status PropagateOperandConstraint(
336       const OperandLayoutConstraint& layout_constraint,
337       LayoutConstraints* constraints);
338   virtual Status PropagateResultConstraint(
339       const ResultLayoutConstraint& layout_constraint,
340       LayoutConstraints* constraints);
341 
GetUnconstrainedLayout(const LogicalBuffer & buffer)342   virtual Layout GetUnconstrainedLayout(const LogicalBuffer& buffer) {
343     return LayoutUtil::GetDefaultLayoutForShape(buffer.shape());
344   }
345   // Called after layouts of an instruction have been finalized to allow
346   // subclasses to check for platform specific assumptions.
Verify(const HloInstruction * instruction)347   virtual Status Verify(const HloInstruction* instruction) {
348     return Status::OK();
349   }
350 
351   // Propagates a buffer layout constraint into the operands that use it.
352   Status PropagateBufferConstraintToUses(
353       const BufferLayoutConstraint& layout_constraint,
354       LayoutConstraints* constraints);
355 
356   // Propagates a layout constraint on the use of the result of the given
357   // instruction to the definitions of the LogicalBuffers which make up the
358   // result.
359   Status PropagateUseConstraintToDefs(const ShapeLayout& shape_layout,
360                                       const HloInstruction* instruction,
361                                       LayoutConstraints* constraints);
362 
363   // Propagates the memory space defined in the entry computation to the called
364   // computations.
365   Status PropagateMemorySpace(HloModule* module);
366 
367   // Chooses a layout of operand `operand_no` of `instruction` that minimizes
368   // the cost of `instruction`. `output_layout` is the layout of `instruction`.
369   // Returns null if it can't decide the best layout.
370   // Precondition: `instruction` and the operand are array-shaped.
371   virtual std::unique_ptr<Layout> ChooseOperandLayoutFromOutputLayout(
372       const Layout& output_layout, const HloInstruction* instruction,
373       int64 operand_no);
374   // Given the layout of `user`'s `operand_no`-th operand, chooses a layout of
375   // `user` that minimizes its cost on that operand.  Returns null if it can't
376   // decide the best layout.
377   // Precondition: `user` and the operand are array-shaped.
378   virtual std::unique_ptr<Layout> ChooseOutputLayoutFromOperandLayout(
379       const Layout& operand_layout, const HloInstruction* user,
380       int64 operand_no);
381 
382  private:
383   // Initializes the layout assignment object for a new Run() call.
384   Status Init();
385 
386   // Adds constraints which must be satisfied for correctness on all
387   // backends. Called once prior to propagating constraints.
388   Status AddMandatoryConstraints(const ComputationLayout* computation_layout,
389                                  ChannelLayoutConstraints* channel_constraints,
390                                  HloComputation* computation,
391                                  LayoutConstraints* constraints);
392 
393   // This method can be overridden to add backend-specific constraints to the
394   // layout of the instructions of a computation. This method is called after
395   // all mandatory constraints have been added via AddMandatoryConstraints
396   // and before propagating constraints.
AddBackendConstraints(LayoutConstraints * constraints)397   virtual Status AddBackendConstraints(LayoutConstraints* constraints) {
398     return Status::OK();
399   }
400 
401   // Construct constraints and assign layouts to all instructions in the
402   // computation satisfying the given ComputationLayout, if not nullptr.
403   // Otherwise the ComputationLayout will be calculated by propagating the
404   // computation instruction constraints.
405   // Layouts constraints are added, then propagated until all LogicalBuffers in
406   // the computation are constrained.
407   Status RunOnComputation(ComputationLayout* computation_layout,
408                           HloComputation* computation,
409                           ChannelLayoutConstraints* channel_constraints);
410 
411   // Assign layouts to the instructions of a computation which satisfy the given
412   // layout constraints. Copies may be added to satisfy the constraints. The
413   // given LayoutConstraints must have layout constraints every logical buffer
414   // in the computation.
415   Status AssignLayouts(const LayoutConstraints& constraints,
416                        HloComputation* computation);
417 
418   // Propagates layout constraints from a set of initial constraints in order to
419   // minimize the local cost of the computation. This propagation is *not*
420   // required for correctness.
421   Status PropagateConstraints(LayoutConstraints* constraints);
422 
423   Status PropagateBufferConstraintToOperands(
424       const BufferLayoutConstraint& buffer_constraint,
425       LayoutConstraints* constraints);
426 
427   // Check that all layouts in the module have been set and satisfy all
428   // necessary conditions.
429   Status CheckLayouts(HloModule* module);
430 
431   // Computes the ComputationLayout of the given computation based of the
432   // layouts assigned to parameters and root instruction, and inserts it to the
433   // computation_layouts_ map.
434   Status CalculateComputationLayout(HloComputation* computation);
435 
436   // Clears all the layouts which can be cleared within a computation.
437   Status ClearComputationLayouts(HloComputation* computation);
438 
439   // Clears the side effects of a previous pass, like added copy instructions.
440   Status ClearPreviousPassSideEffects(HloModule* module);
441 
442   // Propagates the layouts computed by the layout assignment pass on the given
443   // computation, to the computation layout passed in to this API.
444   // This API propagates missing layout, and also checks that the caller
445   // specified have been respected, by comparing those with the parameters and
446   // root computation instruction.
447   Status PropagateComputationLayouts(HloComputation* computation,
448                                      ComputationLayout* computation_layout);
449 
450   // The pointer to the ComputationLayout passed as constructor parameter.
451   ComputationLayout* entry_computation_layout_;
452 
453   // A copy of entry_computation_layout_ used to reset it to the initial values
454   // during the multiple passes done by the layout assignment operation.
455   ComputationLayout saved_entry_computation_layout_;
456 
457  protected:
458   // Sets up the copy instruction according to the characteristic (sharding,
459   // metadata, ...) of the reference instruction. The index argument is used
460   // when the instruction is a tuple, and in such case the index represents
461   // the location from where the copy instruction was created from.
462   // If the index is empty, the whole sharding will be propagated, even in case
463   // the instruction has a tuple sharding.
464   static void SetupCopiedInstruction(const HloInstruction& instruction,
465                                      HloInstruction* copy,
466                                      const ShapeIndex& index);
467 
468   // Creates and returns a copy of the given instruction with a different
469   // layout. Tuple-shaped instructions will be deep-copied, and the last Tuple
470   // instruction producing the copy is returned.
471   StatusOr<HloInstruction*> CreateCopyWithNewLayout(
472       const Shape& shape_with_layout, HloInstruction* instruction);
473 
474   // Creates a copy of the given operand if the operand's layout does not match
475   // the given layout. This copy replaces the use in the given instruction.
476   // Tuple operands will be deep-copied.
477   virtual Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout,
478                                             HloInstruction* instruction,
479                                             int64 operand_no);
480 
481   // Registers a copy instruction added by the layout assignment pass.
RegisterAddedCopy(HloInstruction * copy)482   void RegisterAddedCopy(HloInstruction* copy) {
483     CHECK_EQ(copy->opcode(), HloOpcode::kCopy);
484     added_copies_.insert(copy);
485   }
486 
487   // Adds a copy for the operand of an instruction, unless such operand is
488   // already a copy, and has a single user (which is forcibly the instruction
489   // itself).
490   Status AddCopyForOperand(HloInstruction* instruction, int64 operand_number);
491 
492   // Apply the channel layout constraints by populating the channel_constraints
493   // data structure passed in at constructor time. Eventually adds copies in
494   // case two ends of a channel ended up with a different leyout.
495   Status ConstrainChannelLayouts(HloComputation* computation,
496                                  ChannelLayoutConstraints* channel_constraints);
497 
498   // Resets the input ChannelLayoutConstraints to the original copy received
499   // from the constructor input.
ResetChannelConstraints()500   void ResetChannelConstraints() {
501     if (channel_layout_constraints_ != nullptr) {
502       *channel_layout_constraints_ = channel_constraints_;
503     }
504   }
505 
506   // Adds constraints related to host Send/Recv instructions.
507   Status BuildHostChannelConstraints(HloComputation* computation);
508 
509   // Map containing the layouts of all computations assigned so
510   // far. Computations are handled in a topological sort where computations are
511   // handled before their caller instructions so the layouts of caller
512   // instructions can be set to match the computation.
513   std::map<HloComputation*, ComputationLayout> computation_layouts_;
514 
515   // Map from branch computations to the result layout they should apply.
516   std::map<HloComputation*, ComputationLayout> conditional_mismatch_;
517 
518   // Every copy added to the module by the layout assignment pass is registered
519   // here.
520   absl::flat_hash_set<HloInstruction*> added_copies_;
521 
522   // The pointer to the channel layout constraints passed in with the
523   // constructor. If not nullptr, this is an input/output argument.
524   ChannelLayoutConstraints* channel_layout_constraints_ = nullptr;
525 
526   // A copy of the input layout constraints used to reset the above pointer in
527   // case we have to undo operations due to the multiple passes over the
528   // computations/instructions.
529   ChannelLayoutConstraints channel_constraints_;
530 
531   // Layout constraints for send/recv instructions which communicate with the
532   // host.
533   ChannelLayoutConstraints host_channel_constraints_;
534 
535   // Module points to analysis that can be updated for cloned computations.
536   std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
537 
538   // The set of HLO instructions which lacked any layout constraint, thus
539   // receiving propagated default layouts.
540   absl::flat_hash_set<const HloInstruction*> unconstrained_layout_instructions_;
541 
542   std::function<bool(const HloInstruction*)>
543       instruction_can_change_layout_func_;
544 
545   // CallGraph of the module, used to track callsites of each computation.
546   std::unique_ptr<CallGraph> call_graph_;
547 };
548 
549 }  // namespace xla
550 
551 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_LAYOUT_ASSIGNMENT_H_
552