1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/layout_assignment.h"
17 
18 #include <algorithm>
19 #include <deque>
20 #include <functional>
21 #include <map>
22 #include <memory>
23 #include <numeric>
24 #include <ostream>
25 #include <set>
26 #include <string>
27 #include <tuple>
28 
29 #include "absl/memory/memory.h"
30 #include "absl/strings/str_cat.h"
31 #include "absl/strings/str_format.h"
32 #include "absl/strings/str_join.h"
33 #include "absl/types/span.h"
34 #include "tensorflow/compiler/xla/layout_util.h"
35 #include "tensorflow/compiler/xla/map_util.h"
36 #include "tensorflow/compiler/xla/service/computation_layout.h"
37 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
38 #include "tensorflow/compiler/xla/service/hlo_computation.h"
39 #include "tensorflow/compiler/xla/service/hlo_dce.h"
40 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
41 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
42 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
43 #include "tensorflow/compiler/xla/service/logical_buffer.h"
44 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
45 #include "tensorflow/compiler/xla/shape_layout.h"
46 #include "tensorflow/compiler/xla/shape_util.h"
47 #include "tensorflow/compiler/xla/status_macros.h"
48 #include "tensorflow/compiler/xla/statusor.h"
49 #include "tensorflow/compiler/xla/types.h"
50 #include "tensorflow/compiler/xla/util.h"
51 #include "tensorflow/compiler/xla/xla_data.pb.h"
52 #include "tensorflow/core/lib/core/errors.h"
53 #include "tensorflow/core/lib/core/status.h"
54 #include "tensorflow/core/platform/logging.h"
55 #include "tensorflow/core/platform/protobuf.h"
56 
57 namespace xla {
58 
operator <<(std::ostream & out,const LayoutConstraint & constraint)59 std::ostream& operator<<(std::ostream& out,
60                          const LayoutConstraint& constraint) {
61   out << constraint.ToString();
62   return out;
63 }
64 
BufferLayoutConstraint(const Layout & layout,const LogicalBuffer & buffer,bool mandatory,bool dfs)65 BufferLayoutConstraint::BufferLayoutConstraint(const Layout& layout,
66                                                const LogicalBuffer& buffer,
67                                                bool mandatory, bool dfs)
68     : LayoutConstraint(mandatory, dfs), layout_(layout), buffer_(&buffer) {
69   CHECK(LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()).ok());
70 }
71 
ToString() const72 string BufferLayoutConstraint::ToString() const {
73   return absl::StrFormat("BufferLayoutConstraint %s: %s", buffer_->ToString(),
74                          LayoutUtil::HumanString(layout_));
75 }
76 
OperandLayoutConstraint(const ShapeLayout & shape_layout,const HloInstruction * instruction,int64 operand_no,bool mandatory,bool dfs)77 OperandLayoutConstraint::OperandLayoutConstraint(
78     const ShapeLayout& shape_layout, const HloInstruction* instruction,
79     int64 operand_no, bool mandatory, bool dfs)
80     : LayoutConstraint(mandatory, dfs),
81       shape_layout_(shape_layout),
82       instruction_(instruction),
83       operand_no_(operand_no) {
84   CHECK(shape_layout_.LayoutIsSet());
85   CHECK(ShapeUtil::Compatible(shape_layout.shape(),
86                               instruction->operand(operand_no)->shape()))
87       << shape_layout.shape() << " is not compatible with "
88       << instruction->operand(operand_no)->shape() << " (for operand "
89       << operand_no << " of instruction " << instruction->ToString() << ")";
90 }
91 
ToString() const92 string OperandLayoutConstraint::ToString() const {
93   return absl::StrFormat("OperandLayoutConstraint %s, operand %d: %s",
94                          instruction_->name(), operand_no_,
95                          shape_layout_.ToString());
96 }
97 
ToString() const98 string ResultLayoutConstraint::ToString() const {
99   return absl::StrFormat("ResultLayoutConstraint: %s",
100                          shape_layout_.ToString());
101 }
102 
LayoutConstraints(const TuplePointsToAnalysis & points_to_analysis,HloComputation * computation)103 LayoutConstraints::LayoutConstraints(
104     const TuplePointsToAnalysis& points_to_analysis,
105     HloComputation* computation)
106     : points_to_analysis_(points_to_analysis), computation_(computation) {
107   // Gather all array-shaped logical buffers into unconstrained_buffer_ids.
108   for (HloInstruction* inst : computation_->instructions()) {
109     points_to_analysis_.GetPointsToSet(inst).ForEachElement(
110         [&](const ShapeIndex&, const PointsToSet::BufferList& buffers) {
111           for (const LogicalBuffer* buffer : buffers) {
112             // The points to analysis is computed per module, restrict
113             // constraints to array buffers in this computation.
114             if (buffer->IsArray() &&
115                 buffer->instruction()->parent() == computation) {
116               unconstrained_buffer_ids_.insert(buffer->id());
117             }
118           }
119         });
120   }
121 }
122 
GetBufferSet(const HloInstruction * instruction) const123 PointsToSet::BufferSet* LayoutConstraints::GetBufferSet(
124     const HloInstruction* instruction) const {
125   auto it = buffer_sets_cache_.find(instruction);
126   if (it != buffer_sets_cache_.end()) {
127     return it->second.get();
128   }
129   auto& buffer_set =
130       buffer_sets_cache_
131           .emplace(instruction, absl::make_unique<PointsToSet::BufferSet>())
132           .first->second;
133   const auto& points_to_set = points_to_analysis_.GetPointsToSet(instruction);
134   points_to_set.ForEachElement(
135       [&buffer_set](const ShapeIndex& /*index*/,
136                     const PointsToSet::BufferList& buffers) {
137         buffer_set->insert(buffers.begin(), buffers.end());
138       });
139   return buffer_set.get();
140 }
141 
OperandBufferForwarded(const HloInstruction * instruction,int64 operand_no) const142 bool LayoutConstraints::OperandBufferForwarded(
143     const HloInstruction* instruction, int64 operand_no) const {
144   // The operand is potentially forwarded if the intersection of points-to sets
145   // of the operand and the instruction is non-empty.
146   PointsToSet::BufferSet* output_buffers = GetBufferSet(instruction);
147   PointsToSet::BufferSet* operand_buffers =
148       GetBufferSet(instruction->operand(operand_no));
149   return absl::c_any_of(*output_buffers, [&](const LogicalBuffer* b) {
150     return operand_buffers->count(b) > 0;
151   });
152 }
153 
SetBufferLayout(const Layout & layout,const LogicalBuffer & buffer,bool mandatory,bool dfs)154 Status LayoutConstraints::SetBufferLayout(const Layout& layout,
155                                           const LogicalBuffer& buffer,
156                                           bool mandatory, bool dfs) {
157   VLOG(3) << "SetBufferLayout : " << buffer << " : "
158           << LayoutUtil::HumanString(layout);
159 
160   TF_RETURN_IF_ERROR(points_to_analysis_.VerifyBuffer(buffer));
161   if (!buffer.IsArray()) {
162     return FailedPrecondition(
163         "Layout of buffer %s cannot be constrained because buffer is not "
164         "array-shaped, has shape: %s",
165         buffer.ToString(), ShapeUtil::HumanString(buffer.shape()));
166   }
167   TF_RETURN_IF_ERROR(
168       LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()));
169 
170   auto iter = buffer_constraints_.find(&buffer);
171   if (iter != buffer_constraints_.end()) {
172     const BufferLayoutConstraint& curr_constraint = iter->second;
173     if (LayoutUtil::Equal(curr_constraint.layout(), layout)) {
174       // New constraint matches existing constraint. Nothing to do.
175       return Status::OK();
176     }
177     if (curr_constraint.mandatory()) {
178       return FailedPrecondition(
179           "Buffer %s already has the layout constraint %s, cannot add "
180           "incompatible constraint %s",
181           buffer.ToString(), LayoutUtil::HumanString(curr_constraint.layout()),
182           LayoutUtil::HumanString(layout));
183     }
184     iter->second = BufferLayoutConstraint(layout, buffer, mandatory, dfs);
185   } else {
186     TF_RET_CHECK(unconstrained_buffer_ids_.erase(buffer.id()) == 1)
187         << buffer.ToString();
188     iter = buffer_constraints_
189                .insert(std::make_pair(
190                    &buffer,
191                    BufferLayoutConstraint(layout, buffer, mandatory, dfs)))
192                .first;
193   }
194   added_constraints_.push_back(&iter->second);
195   return Status::OK();
196 }
197 
SetOperandLayout(const Shape & shape_with_layout,const HloInstruction * instruction,int64 operand_no,bool mandatory,bool dfs)198 Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout,
199                                            const HloInstruction* instruction,
200                                            int64 operand_no, bool mandatory,
201                                            bool dfs) {
202   VLOG(3) << "SetOperandLayout : " << instruction->name() << ", operand "
203           << operand_no << " : "
204           << ShapeUtil::HumanStringWithLayout(shape_with_layout);
205 
206   const OperandLayoutConstraint* curr_shape_layout =
207       GetOperandLayoutConstraint(instruction, operand_no);
208   if (curr_shape_layout != nullptr) {
209     if (curr_shape_layout->shape_layout().MatchesLayoutInShape(
210             shape_with_layout)) {
211       // New constraint matches existing constraint. Nothing to do.
212       return Status::OK();
213     }
214     if (curr_shape_layout->mandatory()) {
215       return FailedPrecondition(
216           "Operand %d of instruction %s already has a layout constraint "
217           "%s, cannot add incompatible constraint %s",
218           operand_no, instruction->name(),
219           curr_shape_layout->shape_layout().ToString(),
220           ShapeUtil::HumanStringWithLayout(shape_with_layout));
221     }
222   }
223 
224   // If any buffers in the operand occur in the output of the instruction, then
225   // return an error. This case is not handled because such a constraint changes
226   // layouts beyond this immediate use and is complicated to handle.
227   if (OperandBufferForwarded(instruction, operand_no)) {
228     return FailedPrecondition(
229         "Cannot constraint layout of operand %d of instruction %s "
230         "because instruction forwards operand's LogicalBuffer(s)",
231         operand_no, instruction->name());
232   }
233 
234   auto key = std::make_pair(instruction, operand_no);
235   auto iter = operand_constraints_.find(key);
236   if (iter == operand_constraints_.end()) {
237     auto pair = std::make_pair(
238         key, OperandLayoutConstraint(ShapeLayout(shape_with_layout),
239                                      instruction, operand_no, mandatory, dfs));
240     iter = operand_constraints_.insert(pair).first;
241   } else {
242     iter->second =
243         OperandLayoutConstraint(ShapeLayout(shape_with_layout), instruction,
244                                 operand_no, mandatory, dfs);
245   }
246   added_constraints_.push_back(&iter->second);
247 
248   return Status::OK();
249 }
250 
SetArrayOperandLayout(const Layout & layout,const HloInstruction * instruction,int64 operand_no,bool mandatory,bool dfs)251 Status LayoutConstraints::SetArrayOperandLayout(
252     const Layout& layout, const HloInstruction* instruction, int64 operand_no,
253     bool mandatory, bool dfs) {
254   const HloInstruction* operand = instruction->operand(operand_no);
255   TF_RET_CHECK(operand->shape().IsArray());
256   Shape shape(operand->shape());
257   *shape.mutable_layout() = layout;
258   TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutInShape(shape));
259   return SetOperandLayout(shape, instruction, operand_no, mandatory, dfs);
260 }
261 
SetResultLayout(const Shape & shape_with_layout,bool dfs)262 Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout,
263                                           bool dfs) {
264   VLOG(3) << "SetResultLayout : "
265           << ShapeUtil::HumanStringWithLayout(shape_with_layout);
266 
267   const ShapeLayout* curr_shape_layout = ResultLayout();
268   if (curr_shape_layout != nullptr) {
269     if (!curr_shape_layout->MatchesLayoutInShape(shape_with_layout)) {
270       return FailedPrecondition(
271           "Result of computation %s already has the layout constraint %s, "
272           "cannot add incompatible constraint %s",
273           computation_->name(), curr_shape_layout->ToString(),
274           ShapeUtil::HumanStringWithLayout(shape_with_layout));
275     }
276     // New constraint matches existing constraint. Nothing to do.
277     return Status::OK();
278   }
279 
280   result_constraint_.reset(
281       new ResultLayoutConstraint(ShapeLayout(shape_with_layout), dfs));
282   added_constraints_.push_back(result_constraint_.get());
283 
284   return Status::OK();
285 }
286 
SetInstructionLayout(const Shape & shape_with_layout,const HloInstruction * instruction,bool mandatory,bool dfs)287 Status LayoutConstraints::SetInstructionLayout(
288     const Shape& shape_with_layout, const HloInstruction* instruction,
289     bool mandatory, bool dfs) {
290   VLOG(3) << "SetInstructionLayout : " << instruction->name() << ", "
291           << ShapeUtil::HumanStringWithLayout(shape_with_layout);
292 
293   if (!ShapeUtil::Compatible(shape_with_layout, instruction->shape())) {
294     return FailedPrecondition(
295         "Instruction %s of shape %s cannot be assigned incompatible layout %s",
296         instruction->name(), ShapeUtil::HumanString(instruction->shape()),
297         ShapeUtil::HumanStringWithLayout(shape_with_layout));
298   }
299 
300   // Create a BufferLayoutConstraint for each array shape in the output of the
301   // instruction.
302   return ShapeUtil::ForEachSubshapeWithStatus(
303       shape_with_layout,
304       [this, instruction, mandatory](const Shape& subshape,
305                                      const ShapeIndex& index) -> Status {
306         // The precondition for this method is that the instruction defines all
307         // buffers in its output.
308         auto buffers =
309             points_to_analysis_.GetPointsToSet(instruction).element(index);
310         CHECK_EQ(1, buffers.size());
311         CHECK_EQ(buffers[0]->instruction(), instruction);
312 
313         if (subshape.IsArray()) {
314           return SetBufferLayout(subshape.layout(), *buffers[0], mandatory);
315         } else {
316           return Status::OK();
317         }
318       });
319 }
320 
BufferLayout(const LogicalBuffer & buffer) const321 const Layout* LayoutConstraints::BufferLayout(
322     const LogicalBuffer& buffer) const {
323   if (const auto* constraint = GetBufferLayoutConstraint(buffer)) {
324     return &constraint->layout();
325   }
326   return nullptr;
327 }
328 
GetBufferLayoutConstraint(const LogicalBuffer & buffer) const329 const BufferLayoutConstraint* LayoutConstraints::GetBufferLayoutConstraint(
330     const LogicalBuffer& buffer) const {
331   auto it = buffer_constraints_.find(&buffer);
332   return it == buffer_constraints_.end() ? nullptr : &it->second;
333 }
334 
OperandLayout(const HloInstruction * instruction,int64 operand_no) const335 const ShapeLayout* LayoutConstraints::OperandLayout(
336     const HloInstruction* instruction, int64 operand_no) const {
337   if (const auto* constraint =
338           GetOperandLayoutConstraint(instruction, operand_no)) {
339     return &constraint->shape_layout();
340   }
341   return nullptr;
342 }
343 
GetOperandLayoutConstraint(const HloInstruction * instruction,int64 operand_no) const344 const OperandLayoutConstraint* LayoutConstraints::GetOperandLayoutConstraint(
345     const HloInstruction* instruction, int64 operand_no) const {
346   auto it = operand_constraints_.find(std::make_pair(instruction, operand_no));
347   return it == operand_constraints_.end() ? nullptr : &it->second;
348 }
349 
ResultLayout() const350 const ShapeLayout* LayoutConstraints::ResultLayout() const {
351   return result_constraint_ ? &result_constraint_->shape_layout() : nullptr;
352 }
353 
ToString() const354 string LayoutConstraints::ToString() const {
355   string output;
356   absl::StrAppend(&output, "LayoutConstraints for computation ",
357                   computation_->name(), ":\n");
358   for (auto* instruction : computation_->MakeInstructionPostOrder()) {
359     absl::StrAppend(&output, "  ", instruction->ToShortString(), "\n");
360     for (int64 i = 0; i < instruction->operand_count(); ++i) {
361       if (OperandLayout(instruction, i) != nullptr) {
362         absl::StrAppend(&output, "    operand (", i,
363                         "): ", OperandLayout(instruction, i)->ToString(), "\n");
364       }
365     }
366     for (const LogicalBuffer* buffer :
367          points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) {
368       if (BufferLayout(*buffer) != nullptr) {
369         absl::StrAppend(&output, "    ", buffer->ToString(), " : ",
370                         LayoutUtil::HumanString(*BufferLayout(*buffer)), "\n");
371       }
372     }
373   }
374 
375   if (ResultLayout() != nullptr) {
376     absl::StrAppend(&output, "  => ", ResultLayout()->ToString(), "\n");
377   }
378   return output;
379 }
380 
381 namespace {
382 
IsHostSendRecv(const HloInstruction * instruction)383 bool IsHostSendRecv(const HloInstruction* instruction) {
384   const HloSendRecvInstruction* send_recv_instr =
385       DynCast<HloSendRecvInstruction>(instruction);
386   return send_recv_instr != nullptr && send_recv_instr->is_host_transfer();
387 }
388 
389 }  // namespace
390 
BuildHostChannelConstraints(HloComputation * computation)391 Status LayoutAssignment::BuildHostChannelConstraints(
392     HloComputation* computation) {
393   for (auto* instruction : computation->instructions()) {
394     const HloSendRecvInstruction* send_recv_instr =
395         DynCast<HloSendRecvInstruction>(instruction);
396     if (send_recv_instr == nullptr || !send_recv_instr->is_host_transfer()) {
397       continue;
398     }
399 
400     // For host transfers the Send and Recv instruction carry the layout.
401     if (instruction->opcode() == HloOpcode::kSend ||
402         instruction->opcode() == HloOpcode::kRecv) {
403       const Shape& data_shape =
404           ShapeUtil::GetTupleElementShape(send_recv_instr->shape(), 0);
405       TF_RET_CHECK(data_shape.IsArray());
406       TF_RET_CHECK(LayoutUtil::HasLayout(data_shape));
407       const Layout* prev_layout = host_channel_constraints_.ConstrainChannel(
408           send_recv_instr->channel_id(), data_shape.layout());
409       TF_RET_CHECK(prev_layout == nullptr)
410           << "Cannot constrain host transfer layout as it was set to "
411           << LayoutUtil::HumanString(*prev_layout) << ": "
412           << send_recv_instr->ToString();
413     }
414   }
415   return Status::OK();
416 }
417 
418 namespace {
419 
IsLayoutConstrainedCustomCall(HloInstruction * instruction)420 bool IsLayoutConstrainedCustomCall(HloInstruction* instruction) {
421   const HloCustomCallInstruction* custom_call =
422       DynCast<HloCustomCallInstruction>(instruction);
423   return custom_call != nullptr && custom_call->layout_constrained();
424 }
425 
426 }  // namespace
427 
AddMandatoryConstraints(const ComputationLayout * computation_layout,ChannelLayoutConstraints * channel_constraints,HloComputation * computation,LayoutConstraints * constraints)428 Status LayoutAssignment::AddMandatoryConstraints(
429     const ComputationLayout* computation_layout,
430     ChannelLayoutConstraints* channel_constraints, HloComputation* computation,
431     LayoutConstraints* constraints) {
432   VLOG(3) << "Adding mandatory layout constraints to computation "
433           << computation->name();
434 
435   auto get_channel_constraints = [&](const HloInstruction* instruction) {
436     return IsHostSendRecv(instruction) ? &host_channel_constraints_
437                                        : channel_constraints;
438   };
439 
440   // Constrain layouts of instructions which define values with pre-existing
441   // layouts.
442   for (auto* instruction : computation->instructions()) {
443     if (instruction->opcode() == HloOpcode::kInfeed) {
444       // Infeed layouts must match the layout of the original inserted
445       // instruction.
446       // TODO(b/31425034): Change infeeds to be more like parameters, with
447       // shapes in the ComputationLayout.
448       TF_RETURN_IF_ERROR(
449           constraints->SetInstructionLayout(instruction->shape(), instruction));
450     } else if (instruction->opcode() == HloOpcode::kOutfeed) {
451       // Constrain the input to the Outfeed instruction to be the expected
452       // layout of the Outfeed.
453       TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
454           instruction->outfeed_shape(), instruction, 0));
455     } else if (instruction->opcode() == HloOpcode::kParameter) {
456       if (computation_layout != nullptr) {
457         const ShapeLayout& parameter_layout =
458             computation_layout->parameter_layout(
459                 instruction->parameter_number());
460         if (parameter_layout.LayoutIsSet()) {
461           // Parameter layouts must match the respective layout in
462           // ComputationLayout, if there is one.
463           TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
464               parameter_layout.shape(), instruction));
465         }
466       }
467     } else if (IsLayoutConstrainedCustomCall(instruction)) {
468       const HloCustomCallInstruction* custom_call =
469           DynCast<HloCustomCallInstruction>(instruction);
470       TF_RETURN_IF_ERROR(
471           constraints->SetInstructionLayout(custom_call->shape(), custom_call));
472       for (int64 i = 0; i < custom_call->operand_count(); ++i) {
473         TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
474             custom_call->operand_shapes_with_layout()[i], custom_call, i));
475       }
476     } else if (instruction->opcode() == HloOpcode::kSend ||
477                instruction->opcode() == HloOpcode::kRecv) {
478       CHECK(get_channel_constraints(instruction))
479           << "Multi-module layout assignment requires ChannelLayoutConstraints";
480       int64 channel_id = instruction->channel_id();
481       if (!get_channel_constraints(instruction)
482                ->IsChannelConstrained(channel_id)) {
483         continue;
484       }
485       if (instruction->opcode() == HloOpcode::kSend) {
486         // TODO(b/68493863): Change to use SetOperandLayout().
487         const Shape send_buffer_shape = instruction->operand(0)->shape();
488         TF_RET_CHECK(send_buffer_shape.IsArray());
489         Shape new_buffer_shape =
490             get_channel_constraints(instruction)
491                 ->LayoutShapeForChannel(send_buffer_shape,
492                                         instruction->channel_id());
493         TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
494             new_buffer_shape, instruction->operand(0)));
495       } else {
496         const Shape recv_buffer_shape =
497             ShapeUtil::GetTupleElementShape(instruction->shape(), 0);
498         TF_RET_CHECK(recv_buffer_shape.IsArray());
499         TF_ASSIGN_OR_RETURN(
500             const LogicalBuffer* buffer,
501             constraints->points_to_analysis().GetBufferDefinedAt(instruction,
502                                                                  {0}));
503         Shape new_shape = get_channel_constraints(instruction)
504                               ->LayoutShapeForChannel(
505                                   recv_buffer_shape, instruction->channel_id());
506         TF_RETURN_IF_ERROR(
507             constraints->SetBufferLayout(new_shape.layout(), *buffer));
508       }
509     } else if (instruction->IsCrossModuleAllReduce()) {
510       CHECK(get_channel_constraints(instruction))
511           << "Multi-module layout assignment requires ChannelLayoutConstraints";
512       int64 all_reduce_id = instruction->all_reduce_id().value();
513       if (!get_channel_constraints(instruction)
514                ->IsChannelConstrained(all_reduce_id)) {
515         continue;
516       }
517       // TODO(b/68493863): Change to use SetOperandLayout().
518       const Shape& buffer_shape = instruction->operand(0)->shape();
519       TF_RET_CHECK(buffer_shape.IsArray());
520       Shape new_buffer_shape =
521           get_channel_constraints(instruction)
522               ->LayoutShapeForChannel(buffer_shape, all_reduce_id);
523       TF_RETURN_IF_ERROR(
524           constraints->SetInstructionLayout(new_buffer_shape, instruction));
525     }
526   }
527 
528   // Constrain layouts of instructions which call computations which have
529   // already been assigned layouts. Instructions which call computations in a
530   // parallel element-wise context (eg, map or reduce) do not need layout
531   // constraints because they operate on scalars.
532   for (auto* instruction : computation->instructions()) {
533     if (instruction->opcode() == HloOpcode::kCall) {
534       // kCall instruction operands and output must match the ComputationLayout
535       // of the called computation.
536       const ComputationLayout& called_computation_layout =
537           FindOrDie(computation_layouts_, instruction->to_apply());
538       TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
539           called_computation_layout.result_layout().shape(), instruction));
540       TF_RET_CHECK(instruction->operand_count() ==
541                    called_computation_layout.parameter_count());
542       for (int64 i = 0; i < instruction->operand_count(); ++i) {
543         TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
544             called_computation_layout.parameter_layout(i).shape(), instruction,
545             i));
546       }
547     } else if (instruction->opcode() == HloOpcode::kWhile) {
548       // Layout of input and output of kWhile instruction must be equal and must
549       // match both input and output of body computation. Also, the input of
550       // condition computation must match kWhile layout.
551       HloComputation* body = instruction->while_body();
552       HloComputation* condition = instruction->while_condition();
553       const HloInstruction* init = instruction->operand(0);
554       ComputationLayout& body_layout = FindOrDie(computation_layouts_, body);
555       ComputationLayout& condition_layout =
556           FindOrDie(computation_layouts_, condition);
557 
558       // Check a few invariants irrespective of layout.
559       CHECK_EQ(1, instruction->operand_count());
560       CHECK_EQ(1, body->num_parameters());
561       CHECK_EQ(1, condition->num_parameters());
562       DCHECK(ShapeUtil::Compatible(body_layout.result_shape(),
563                                    body_layout.parameter_shape(0)));
564       DCHECK(ShapeUtil::Compatible(body_layout.result_shape(),
565                                    condition_layout.parameter_shape(0)));
566       DCHECK(ShapeUtil::Compatible(body_layout.result_shape(), init->shape()));
567 
568       if (body_layout.result_layout() != body_layout.parameter_layout(0)) {
569         VLOG(2) << "Reset %while body parameter layout: body=" << body->name()
570                 << " while=" << instruction->name()
571                 << " shape=" << body_layout.result_layout().ToString();
572         *body_layout.mutable_parameter_layout(0) = body_layout.result_layout();
573       }
574       if (condition_layout.parameter_layout(0) !=
575           body_layout.parameter_layout(0)) {
576         VLOG(2) << "Reset %while condition parameter layout: cond="
577                 << condition->name() << " while=" << instruction->name()
578                 << " shape=" << body_layout.parameter_layout(0).ToString();
579         *condition_layout.mutable_parameter_layout(0) =
580             body_layout.parameter_layout(0);
581       }
582 
583       // Constrain the output and the operand of the while instruction to match
584       // the computations.
585       TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
586           body_layout.result_shape(), instruction));
587       TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
588           body_layout.result_shape(), instruction, 0));
589     } else if (instruction->opcode() == HloOpcode::kConditional) {
590       // Find the conditional branch with the most instructions and force all
591       // other computations to match that layout. A potentially better decison
592       // could count the number FLOPs or how constrained the layouts are.
593       int64 largest_branch = 0;
594       int64 largest_instruction_count =
595           instruction->branch_computation(0)->instruction_count();
596       for (int j = 1; j < instruction->branch_count(); ++j) {
597         const int64 instruction_count =
598             instruction->branch_computation(j)->instruction_count();
599         if (instruction_count > largest_instruction_count) {
600           largest_branch = j;
601           largest_instruction_count = instruction_count;
602         }
603       }
604       ComputationLayout& best_branch_computation_layout =
605           FindOrDie(computation_layouts_,
606                     instruction->branch_computation(largest_branch));
607       for (int k = 0; k < instruction->branch_count(); ++k) {
608         // Visit the best branch first.
609         int j = (k + largest_branch) % instruction->branch_count();
610         TF_RET_CHECK(instruction->branch_computation(j)->num_parameters() == 1);
611         ComputationLayout& branch_computation_layout =
612             FindOrDie(computation_layouts_, instruction->branch_computation(j));
613 
614         DCHECK(ShapeUtil::Compatible(
615             instruction->operand(j + 1)->shape(),
616             branch_computation_layout.parameter_shape(0)));
617         if (best_branch_computation_layout.result_layout() !=
618             branch_computation_layout.result_layout()) {
619           // We assign layouts in DFS fashion, so the largest_branch and current
620           // branch computations might have negotiated a different layout. But
621           // for the case instruction POV the layout must match, so we run again
622           // on the branch j computation, this time with proper computation
623           // layout.
624           VLOG(2) << "Reset %conditional branch " << j
625                   << " computation result layout: branch_computation="
626                   << instruction->branch_computation(j)->name()
627                   << " case=" << instruction->name() << " shape="
628                   << best_branch_computation_layout.result_layout().ToString();
629           *branch_computation_layout.mutable_result_layout() =
630               best_branch_computation_layout.result_layout();
631         }
632         if (k == 0) {
633           TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
634               best_branch_computation_layout.result_shape(), instruction));
635         }
636         TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
637             branch_computation_layout.parameter_shape(0), instruction, j + 1,
638             /*mandatory=*/true));
639       }
640     }
641   }
642   // Finally set the result layout to match ComputationLayout, if there is one.
643   if (computation_layout != nullptr) {
644     const ShapeLayout& result_layout = computation_layout->result_layout();
645     if (result_layout.LayoutIsSet()) {
646       TF_RETURN_IF_ERROR(constraints->SetResultLayout(result_layout.shape()));
647     }
648   }
649   return Status::OK();
650 }
651 
652 namespace {
653 
654 // The operands of a call must match the layouts of parameters in the
655 // ComputationLayout, and the call instruction itself must match the result
656 // layout in the ComputationLayout.
CheckCallLayout(HloInstruction * call,const ComputationLayout & computation_layout)657 Status CheckCallLayout(HloInstruction* call,
658                        const ComputationLayout& computation_layout) {
659   HloComputation* computation = call->to_apply();
660   TF_RET_CHECK(computation->num_parameters() == call->operand_count());
661   for (int64 i = 0; i < computation->num_parameters(); ++i) {
662     TF_RET_CHECK(computation_layout.parameter_layout(i).MatchesLayoutInShape(
663         call->operand(i)->shape()));
664   }
665   TF_RET_CHECK(
666       computation_layout.result_layout().MatchesLayoutInShape(call->shape()));
667   return Status::OK();
668 }
669 
670 // Operands of layout-constrained custom calls must match the expected
671 // constrained layouts.
CheckCustomCallLayout(HloInstruction * instruction)672 Status CheckCustomCallLayout(HloInstruction* instruction) {
673   if (IsLayoutConstrainedCustomCall(instruction)) {
674     const HloCustomCallInstruction* custom_call =
675         DynCast<HloCustomCallInstruction>(instruction);
676     for (int64 i = 0; i < custom_call->operand_count(); ++i) {
677       TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
678           custom_call->operand(i)->shape(),
679           custom_call->operand_shapes_with_layout()[i]));
680     }
681   }
682   return Status::OK();
683 }
684 
685 // For a while instruction, all the following layouts must be the same:
686 //   (1) init operand
687 //   (2) condition computation parameter
688 //   (3) body computation parameter
689 //   (4) body computation result
690 //   (5) while instruction result
CheckWhileLayout(HloInstruction * while_inst,const ComputationLayout & condition_computation_layout,const ComputationLayout & body_computation_layout)691 Status CheckWhileLayout(HloInstruction* while_inst,
692                         const ComputationLayout& condition_computation_layout,
693                         const ComputationLayout& body_computation_layout) {
694   auto init_shape = while_inst->operand(0)->shape();
695   TF_RET_CHECK(
696       condition_computation_layout.parameter_layout(0).MatchesLayoutInShape(
697           init_shape));
698   TF_RET_CHECK(body_computation_layout.parameter_layout(0).MatchesLayoutInShape(
699       init_shape));
700   TF_RET_CHECK(
701       body_computation_layout.result_layout().MatchesLayoutInShape(init_shape));
702   TF_RET_CHECK(
703       LayoutUtil::LayoutsInShapesEqual(init_shape, while_inst->shape()));
704   return Status::OK();
705 }
706 
CheckConditionalLayout(HloInstruction * instruction,absl::Span<const ComputationLayout> branch_computation_layouts)707 Status CheckConditionalLayout(
708     HloInstruction* instruction,
709     absl::Span<const ComputationLayout> branch_computation_layouts) {
710   for (int j = 0; j < instruction->branch_count(); ++j) {
711     const HloInstruction* branch_operand = instruction->operand(j + 1);
712     TF_RET_CHECK(branch_computation_layouts[0].result_layout() ==
713                  branch_computation_layouts[j].result_layout());
714     TF_RET_CHECK(
715         branch_computation_layouts[j].result_layout().MatchesLayoutInShape(
716             instruction->shape()));
717     TF_RET_CHECK(
718         branch_computation_layouts[j].result_layout().MatchesLayoutInShape(
719             instruction->branch_computation(j)->root_instruction()->shape()));
720     TF_RET_CHECK(
721         branch_computation_layouts[j].parameter_layout(0).MatchesLayoutInShape(
722             branch_operand->shape()));
723   }
724   return Status::OK();
725 }
726 
727 // Fusion parameters must match the layout of the fusion instructions operands,
728 // and the root of the fusion expression must match the layout of the fusion
729 // instruction.
CheckFusionLayout(HloInstruction * fusion)730 Status CheckFusionLayout(HloInstruction* fusion) {
731   TF_RET_CHECK(HloOpcode::kFusion == fusion->opcode());
732 
733   TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
734       fusion->shape(), fusion->fused_expression_root()->shape()));
735   for (int64 i = 0; i < fusion->operand_count(); ++i) {
736     TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
737         fusion->fused_parameter(i)->shape(), fusion->operand(i)->shape()));
738   }
739   return Status::OK();
740 }
741 
742 // The layout of a parameter must match the respective layout in the
743 // computation's ComputationLayout.
CheckParameterLayout(HloInstruction * parameter,const ComputationLayout & computation_layout)744 Status CheckParameterLayout(HloInstruction* parameter,
745                             const ComputationLayout& computation_layout) {
746   const ShapeLayout& parameter_layout =
747       computation_layout.parameter_layout(parameter->parameter_number());
748   if (parameter_layout.LayoutIsSet() &&
749       !parameter_layout.MatchesLayoutInShape(parameter->shape())) {
750     return InternalError(
751         "parameter instruction %s does not match layout of computation "
752         "shape: %s",
753         parameter->ToString(), parameter_layout.ToString());
754   }
755   return Status::OK();
756 }
757 
758 // The layout of a constant instruction must match the layout of its literal.
CheckConstantLayout(HloInstruction * constant)759 Status CheckConstantLayout(HloInstruction* constant) {
760   if (!LayoutUtil::LayoutsInShapesEqual(constant->literal().shape(),
761                                         constant->shape())) {
762     return InternalError(
763         "constant instruction %s does not match the layout of its literal %s",
764         constant->ToString(),
765         ShapeUtil::HumanStringWithLayout(constant->literal().shape()));
766   }
767   return Status::OK();
768 }
769 
770 }  // namespace
771 
CreateCopyWithNewLayout(const Shape & shape_with_layout,HloInstruction * instruction)772 StatusOr<HloInstruction*> LayoutAssignment::CreateCopyWithNewLayout(
773     const Shape& shape_with_layout, HloInstruction* instruction) {
774   TF_RET_CHECK(LayoutUtil::HasLayout(shape_with_layout));
775   DCHECK(ShapeUtil::Compatible(shape_with_layout, instruction->shape()))
776       << ShapeUtil::HumanString(shape_with_layout) << " "
777       << ShapeUtil::HumanString(instruction->shape())
778       << " instruction: " << instruction->ToString();
779 
780   if (instruction->shape().IsTuple()) {
781     // Copy tuple elements which have differing layouts.
782     std::vector<HloInstruction*> element_copies;
783     for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
784          ++i) {
785       const Shape& target_shape =
786           ShapeUtil::GetSubshape(shape_with_layout, {i});
787       const Shape& instr_shape =
788           ShapeUtil::GetSubshape(instruction->shape(), {i});
789       HloInstruction* gte = instruction->parent()->AddInstruction(
790           HloInstruction::CreateGetTupleElement(instr_shape, instruction, i));
791 
792       if (ShapeUtil::Equal(target_shape, instr_shape)) {
793         // Shapes and layouts are equal, no need to copy.
794         element_copies.push_back(gte);
795       } else {
796         SetupCopiedInstruction(*instruction, gte, {i});
797         // Recurse to copy each element.
798         TF_ASSIGN_OR_RETURN(HloInstruction * element_copy,
799                             CreateCopyWithNewLayout(target_shape, gte));
800         element_copies.push_back(element_copy);
801       }
802     }
803     // Gather element copies into a tuple with a new Tuple instruction.
804     HloInstruction* tuple_copy = instruction->parent()->AddInstruction(
805         HloInstruction::CreateTuple(element_copies));
806     SetupCopiedInstruction(*instruction, tuple_copy, {});
807     LayoutUtil::ClearLayout(tuple_copy->mutable_shape());
808     TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
809         shape_with_layout, tuple_copy->mutable_shape()));
810     return tuple_copy;
811   } else if (instruction->shape().IsArray()) {
812     HloInstruction* copy =
813         instruction->parent()->AddInstruction(HloInstruction::CreateUnary(
814             instruction->shape(), HloOpcode::kCopy, instruction));
815     RegisterAddedCopy(copy);
816     SetupCopiedInstruction(*instruction, copy, {});
817     LayoutUtil::ClearLayout(copy->mutable_shape());
818     TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
819         shape_with_layout, copy->mutable_shape()));
820 
821     return copy;
822   } else {
823     return FailedPrecondition(
824         "Can only copy array and tuple shaped instructions");
825   }
826 }
827 
828 // Creates a copy of the given operand if the operand's layout does not match
829 // the given layout. This copy replaces the use in the given instruction. Tuple
830 // operands will be deep-copied.
CopyOperandIfLayoutsDiffer(const ShapeLayout & operand_layout,HloInstruction * instruction,int64 operand_no)831 Status LayoutAssignment::CopyOperandIfLayoutsDiffer(
832     const ShapeLayout& operand_layout, HloInstruction* instruction,
833     int64 operand_no) {
834   HloInstruction* operand = instruction->mutable_operand(operand_no);
835   TF_RET_CHECK(operand_layout.LayoutIsSet());
836   TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape()));
837 
838   if (ShapeUtil::Equal(operand_layout.shape(), operand->shape())) {
839     VLOG(5) << "Operand " << operand->ToString() << " layout matches in "
840             << instruction->ToString();
841     // Operand layout already matches our constraint. Nothing to do.
842     return Status::OK();
843   }
844   VLOG(4) << "Operand " << operand->ToString() << " layout does not match "
845           << operand_layout.ToString() << " in " << instruction->ToString();
846 
847   TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy,
848                       CreateCopyWithNewLayout(operand_layout.shape(), operand));
849 
850   VLOG(4) << "New copy of " << operand->ToString() << " is "
851           << operand_copy->ToString();
852   return instruction->ReplaceOperandWith(operand_no, operand_copy);
853 }
854 
SetupCopiedInstruction(const HloInstruction & instruction,HloInstruction * copy,const ShapeIndex & index)855 void LayoutAssignment::SetupCopiedInstruction(const HloInstruction& instruction,
856                                               HloInstruction* copy,
857                                               const ShapeIndex& index) {
858   if (instruction.has_sharding()) {
859     // If the index is empty, we want to copy the whole sharding, in case the
860     // sharding is a tuple sharding.
861     HloSharding sharding =
862         !index.empty() && instruction.sharding().IsTuple()
863             ? instruction.sharding().GetSubSharding(instruction.shape(), index)
864             : instruction.sharding();
865     // We propagate the sharding to the copied instruction only if it is a
866     // special sharding, like tiled ones.
867     // Otherwise it is preferable to leave the new instruction without device,
868     // and let the automatic device placer to choose the best location.
869     auto device = sharding.UniqueDevice();
870     if (!device || HloSharding::IsReservedDevice(*device)) {
871       copy->set_sharding(sharding);
872     }
873   }
874   copy->set_metadata(instruction.metadata());
875 }
876 
CheckLayouts(HloModule * module)877 Status LayoutAssignment::CheckLayouts(HloModule* module) {
878   TF_ASSIGN_OR_RETURN(auto points_to_analysis,
879                       TuplePointsToAnalysis::Run(module));
880   for (auto* computation : module->MakeNonfusionComputations()) {
881     for (auto* instruction : computation->instructions()) {
882       // Verify every instruction has a layout and the layout is valid for the
883       // shape.
884       TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape()));
885       TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape()));
886 
887       // Use points-to analysis to verify that every subshape element in the
888       // output of the instruction matches the layout of the logical buffer
889       // which could be the source of the subshape value.
890       const PointsToSet& points_to_set =
891           points_to_analysis->GetPointsToSet(instruction);
892       TF_RETURN_IF_ERROR(points_to_set.ForEachElementWithStatus(
893           [&instruction](ShapeIndex index,
894                          const PointsToSet::BufferList& buffers) -> Status {
895             if (ShapeUtil::IsLeafIndex(instruction->shape(), index)) {
896               const Shape& instruction_subshape =
897                   ShapeUtil::GetSubshape(instruction->shape(), index);
898               for (const LogicalBuffer* buffer : buffers) {
899                 if (!ShapeUtil::Equal(instruction_subshape, buffer->shape())) {
900                   return InternalError(
901                       "Layout of instruction %s at index {%s} does not match "
902                       "source LogicalBuffer %s: %s vs %s",
903                       instruction->name(), absl::StrJoin(index, ","),
904                       buffer->ToString(),
905                       ShapeUtil::HumanStringWithLayout(instruction_subshape),
906                       ShapeUtil::HumanStringWithLayout(buffer->shape()));
907                 }
908               }
909             }
910             return Status::OK();
911           }));
912 
913       // Verify instructions that have special layout constraints.
914       switch (instruction->opcode()) {
915         case HloOpcode::kCall:
916           TF_RETURN_IF_ERROR(CheckCallLayout(
917               instruction,
918               FindOrDie(computation_layouts_, instruction->to_apply())));
919           break;
920         case HloOpcode::kCustomCall:
921           TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction));
922           break;
923         case HloOpcode::kFusion:
924           TF_RETURN_IF_ERROR(CheckFusionLayout(instruction));
925           break;
926         case HloOpcode::kParameter:
927           TF_RETURN_IF_ERROR(CheckParameterLayout(
928               instruction,
929               FindOrDie(computation_layouts_, instruction->parent())));
930           break;
931         case HloOpcode::kConstant:
932           TF_RETURN_IF_ERROR(CheckConstantLayout(instruction));
933           break;
934         case HloOpcode::kWhile:
935           TF_RETURN_IF_ERROR(CheckWhileLayout(
936               instruction,
937               FindOrDie(computation_layouts_, instruction->while_condition()),
938               FindOrDie(computation_layouts_, instruction->while_body())));
939           break;
940         case HloOpcode::kConditional: {
941           std::vector<ComputationLayout> branch_computation_layouts;
942           for (auto branch_computation : instruction->branch_computations()) {
943             branch_computation_layouts.emplace_back(
944                 FindOrDie(computation_layouts_, branch_computation));
945           }
946           TF_RETURN_IF_ERROR(CheckConditionalLayout(
947               instruction, absl::MakeSpan(branch_computation_layouts)));
948           break;
949         }
950         default:
951           break;
952       }
953     }
954   }
955   // Finally verify the result layout, if set, matches the layout of the entry
956   // computation root.
957   const ShapeLayout& result_layout =
958       FindOrDie(computation_layouts_, module->entry_computation())
959           .result_layout();
960   if (result_layout.LayoutIsSet()) {
961     TF_RET_CHECK(
962         ShapeUtil::Equal(module->result_shape(), result_layout.shape()));
963   }
964   return Status::OK();
965 }
966 
LayoutAssignment(ComputationLayout * entry_computation_layout,std::function<bool (const HloInstruction *)> instruction_can_change_layout_func,ChannelLayoutConstraints * channel_constraints)967 LayoutAssignment::LayoutAssignment(
968     ComputationLayout* entry_computation_layout,
969     std::function<bool(const HloInstruction*)>
970         instruction_can_change_layout_func,
971     ChannelLayoutConstraints* channel_constraints)
972     : entry_computation_layout_(entry_computation_layout),
973 
974       saved_entry_computation_layout_(*entry_computation_layout),
975       channel_layout_constraints_(channel_constraints),
976       instruction_can_change_layout_func_(
977           std::move(instruction_can_change_layout_func)) {
978   if (channel_layout_constraints_ != nullptr) {
979     // Save a copy of the input ChannelLayoutConstraints so that we can reset it
980     // if we have to undo previous operations (ClearPreviousPassSideEffects()).
981     channel_constraints_ = *channel_layout_constraints_;
982   }
983   VLOG(1) << "Entry computation layout given to layout assignment: "
984           << entry_computation_layout_->ToString();
985 }
986 
ChooseOperandLayoutFromOutputLayout(const Layout & output_layout,const HloInstruction * instruction,int64 operand_no)987 std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
988     const Layout& output_layout, const HloInstruction* instruction,
989     int64 operand_no) {
990   const HloInstruction* operand = instruction->operand(operand_no);
991   CHECK(instruction->shape().IsArray());
992   CHECK(operand->shape().IsArray());
993   if (!ShapeUtil::IsScalar(operand->shape()) &&
994       operand->shape().rank() == instruction->shape().rank() &&
995       !instruction_can_change_layout_func_(instruction)) {
996     // Propagate the result layout to the operand layout if the instruction
997     // requires the same layout out for the result and the operand.
998     //
999     // For elementwise operations, using the same layout for the operands and
1000     // the result also has the following benefits:
1001     // 1) the elementwise operation can reuse its operand's buffer, and
1002     // 2) the input and output elements can reuse the same linear index.
1003     return absl::make_unique<Layout>(output_layout);
1004   }
1005 
1006   if (instruction->opcode() == HloOpcode::kReshape) {
1007     // Prefer the operand layout that makes the reshape an bitcast. If any
1008     // dimension bound is 1 in the operand shape, there may be several such
1009     // layouts. So if 'output_layout' is the default layout, try if the
1010     // reshape is a bitcast when using the same layout. This may avoid copy
1011     // operations. For similar reasons, if the operand and output have the same
1012     // rank, try to match the operand's layout to the output.
1013     if (ShapeUtil::TrueRank(operand->shape()) == 1 &&
1014         ShapeUtil::TrueRank(instruction->shape()) == 1) {
1015       // Don't assign a layout in case of R1 -> effective R1 reshape.
1016       return nullptr;
1017     }
1018     const Shape& output_shape = instruction->shape();
1019     Shape output_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
1020         output_shape.element_type(), AsInt64Slice(output_shape.dimensions()),
1021         LayoutUtil::MinorToMajor(output_layout));
1022     Shape operand_shape = operand->shape();
1023     *operand_shape.mutable_layout() =
1024         LayoutUtil::GetDefaultLayoutForShape(operand_shape);
1025     auto aligned_operand_shape =
1026         ShapeUtil::AlignLayouts(output_shape_with_layout, operand_shape);
1027     if (aligned_operand_shape) {
1028       auto operand_layout = aligned_operand_shape.value().layout();
1029       TF_CHECK_OK(
1030           LayoutUtil::ValidateLayoutForShape(operand_layout, operand_shape));
1031       return absl::make_unique<Layout>(operand_layout);
1032     }
1033   }
1034 
1035   if (instruction->opcode() == HloOpcode::kTranspose) {
1036     // Pick the operand layout that makes the transpose a bitcast.
1037     int64 rank = instruction->shape().rank();
1038     std::vector<int64> new_minor_to_major(rank);
1039     for (int64 i = 0; i < rank; ++i) {
1040       int64 output_dim = LayoutUtil::Minor(output_layout, i);
1041       int64 operand_dim = instruction->dimensions(output_dim);
1042       new_minor_to_major[i] = operand_dim;
1043     }
1044     Layout operand_layout = LayoutUtil::MakeLayout(new_minor_to_major);
1045     TF_CHECK_OK(
1046         LayoutUtil::ValidateLayoutForShape(operand_layout, operand->shape()));
1047     return absl::make_unique<Layout>(operand_layout);
1048   }
1049 
1050   return nullptr;
1051 }
1052 
ChooseOutputLayoutFromOperandLayout(const Layout & operand_layout,const HloInstruction * user,int64 operand_no)1053 std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
1054     const Layout& operand_layout, const HloInstruction* user,
1055     int64 operand_no) {
1056   const HloInstruction* operand = user->operand(operand_no);
1057 
1058   CHECK(user->shape().IsArray() && operand->shape().IsArray());
1059 
1060   if (!ShapeUtil::IsScalar(operand->shape()) &&
1061       operand->shape().rank() == user->shape().rank() &&
1062       !instruction_can_change_layout_func_(user)) {
1063     // Assign users the same layout as the operand.
1064     return absl::make_unique<Layout>(operand_layout);
1065   }
1066 
1067   if (user->opcode() == HloOpcode::kReshape) {
1068     // Prefer the user layout that makes the reshape an bitcast. If any
1069     // dimension bound is 1 in the user shape, there may be several such
1070     // layouts. So if 'operand_layout' is the default layout, try if the
1071     // reshape is a bitcast when using the same layout. This may avoid copy
1072     // operations. For similar reasons, if the operand and output have the same
1073     // rank, try to match the outputs's layout to the operand.
1074     if (ShapeUtil::TrueRank(operand->shape()) == 1 &&
1075         ShapeUtil::TrueRank(user->shape()) == 1) {
1076       // Don't assign a layout in case of R1 -> effective R1 reshape.
1077       return nullptr;
1078     }
1079     Shape operand_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
1080         operand->shape().element_type(),
1081         AsInt64Slice(operand->shape().dimensions()),
1082         LayoutUtil::MinorToMajor(operand_layout));
1083     Shape output_shape = user->shape();
1084     *output_shape.mutable_layout() =
1085         LayoutUtil::GetDefaultLayoutForShape(output_shape);
1086     auto aligned_user_shape =
1087         ShapeUtil::AlignLayouts(operand_shape_with_layout, output_shape);
1088     if (aligned_user_shape) {
1089       auto user_layout = aligned_user_shape.value().layout();
1090       TF_CHECK_OK(
1091           LayoutUtil::ValidateLayoutForShape(user_layout, output_shape));
1092       return absl::make_unique<Layout>(user_layout);
1093     }
1094   }
1095 
1096   if (user->opcode() == HloOpcode::kTranspose) {
1097     // Pick the user layout that makes the transpose a bitcast.
1098     int64 rank = user->shape().rank();
1099     std::vector<int64> new_minor_to_major(rank);
1100     auto inverse_dimensions = InversePermutation(user->dimensions());
1101     for (int64 i = 0; i < rank; ++i) {
1102       int64 operand_dim = LayoutUtil::Minor(operand_layout, i);
1103       int64 user_dim = inverse_dimensions[operand_dim];
1104       new_minor_to_major[i] = user_dim;
1105     }
1106     Layout user_layout = LayoutUtil::MakeLayout(new_minor_to_major);
1107     TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(user_layout, user->shape()));
1108     return absl::make_unique<Layout>(user_layout);
1109   }
1110 
1111   return nullptr;
1112 }
1113 
PropagateConstraints(LayoutConstraints * constraints)1114 Status LayoutAssignment::PropagateConstraints(LayoutConstraints* constraints) {
1115   // Gathers all initial constraints in a worklist and propagates them in
1116   // depth-first order. DFS order seems to be better than BFS because a
1117   // constraint is propagated as far as possible before propagating unrelated
1118   // constraints which makes it less likely that conflicting constraints will be
1119   // propagated to instructions. However, we should experiment with other orders
1120   // too.
1121   std::deque<const LayoutConstraint*> worklist;
1122 
1123   // Lambda for moving newly added constraints to the worklist.
1124   auto add_new_constraints_to_worklist = [constraints, &worklist]() {
1125     // Add constraints to the front of the deque for DFS ordering.
1126     for (auto* constraint : constraints->ConsumeAddedConstraints()) {
1127       if (constraint->dfs()) {
1128         worklist.push_front(constraint);
1129       } else {
1130         worklist.push_back(constraint);
1131       }
1132     }
1133   };
1134   add_new_constraints_to_worklist();
1135 
1136   while (!worklist.empty()) {
1137     const LayoutConstraint* layout_constraint = worklist.front();
1138     worklist.pop_front();
1139     VLOG(2) << "Propagating " << layout_constraint->ToString()
1140             << " to its neighbors.";
1141     if (auto* buffer_constraint =
1142             dynamic_cast<const BufferLayoutConstraint*>(layout_constraint)) {
1143       TF_RETURN_IF_ERROR(
1144           PropagateBufferConstraint(*buffer_constraint, constraints));
1145     } else if (auto* operand_constraint =
1146                    dynamic_cast<const OperandLayoutConstraint*>(
1147                        layout_constraint)) {
1148       TF_RETURN_IF_ERROR(
1149           PropagateOperandConstraint(*operand_constraint, constraints));
1150     } else if (auto* result_constraint =
1151                    dynamic_cast<const ResultLayoutConstraint*>(
1152                        layout_constraint)) {
1153       TF_RETURN_IF_ERROR(
1154           PropagateResultConstraint(*result_constraint, constraints));
1155     } else {
1156       LOG(FATAL) << "Invalid constraint type: " << *layout_constraint;
1157     }
1158 
1159     add_new_constraints_to_worklist();
1160   }
1161   return Status::OK();
1162 }
1163 
1164 namespace {
1165 
1166 // Returns a vector containing all array-shaped uses (instruction and operand
1167 // number) of the given logical buffer or its aliases.
GetArrayUsesOfBuffer(const LogicalBuffer & buffer,const TuplePointsToAnalysis & points_to_analysis)1168 std::vector<std::pair<const HloInstruction*, int64>> GetArrayUsesOfBuffer(
1169     const LogicalBuffer& buffer,
1170     const TuplePointsToAnalysis& points_to_analysis) {
1171   CHECK(buffer.IsArray());
1172   std::vector<std::pair<const HloInstruction*, int64>> uses;
1173   for (const auto& buffer_alias : points_to_analysis.GetBufferAliases(buffer)) {
1174     if (!buffer_alias.instruction()->shape().IsArray()) {
1175       continue;
1176     }
1177     // This alias must be the top-level (index == {}) of the instruction's
1178     // result because the instruction produces an array.
1179     CHECK(buffer_alias.index().empty());
1180 
1181     // Add all uses of the instruction's output.
1182     for (const HloInstruction* user : buffer_alias.instruction()->users()) {
1183       for (int64 operand_no :
1184            user->OperandIndices(buffer_alias.instruction())) {
1185         uses.emplace_back(user, operand_no);
1186       }
1187     }
1188   }
1189   return uses;
1190 }
1191 
1192 }  // namespace
1193 
PropagateUseConstraintToDefs(const ShapeLayout & shape_layout,const HloInstruction * instruction,LayoutConstraints * constraints)1194 Status LayoutAssignment::PropagateUseConstraintToDefs(
1195     const ShapeLayout& shape_layout, const HloInstruction* instruction,
1196     LayoutConstraints* constraints) {
1197   // Try to set all logical buffers which may be sources of the given operand to
1198   // match the given layout.
1199   const PointsToSet& points_to_set =
1200       constraints->points_to_analysis().GetPointsToSet(instruction);
1201   return points_to_set.ForEachElementWithStatus(
1202       [&shape_layout, constraints](
1203           const ShapeIndex& index,
1204           const PointsToSet::BufferList& buffers) -> Status {
1205         if (ShapeUtil::IsLeafIndex(shape_layout.shape(), index)) {
1206           for (const LogicalBuffer* buffer : buffers) {
1207             if (constraints->BufferLayout(*buffer) == nullptr &&
1208                 buffer->shape().IsArray()) {
1209               TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
1210                   ShapeUtil::GetSubshape(shape_layout.shape(), index).layout(),
1211                   *buffer, /*mandatory=*/true));
1212             }
1213           }
1214         }
1215         return Status::OK();
1216       });
1217 }
1218 
1219 namespace {
1220 // A transpose or a reshape that only changes trivial dimensions have meaningful
1221 // layouts that are valuable to propagate in a depthfirst manner to avoid
1222 // unassigned layouts in the graph.
InstructionShouldPropagateDepthFirst(const HloInstruction & hlo)1223 bool InstructionShouldPropagateDepthFirst(const HloInstruction& hlo) {
1224   switch (hlo.opcode()) {
1225     case HloOpcode::kReshape:
1226       return std::get<0>(hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions());
1227     case HloOpcode::kTranspose:
1228       return true;
1229     default:
1230       return false;
1231   }
1232 }
1233 
1234 }  // namespace
1235 
PropagateOperandConstraint(const OperandLayoutConstraint & operand_constraint,LayoutConstraints * constraints)1236 Status LayoutAssignment::PropagateOperandConstraint(
1237     const OperandLayoutConstraint& operand_constraint,
1238     LayoutConstraints* constraints) {
1239   // Try to set the layout of the logical buffers in the given operand to match
1240   // the constrained layout. This avoids copies.
1241   TF_RETURN_IF_ERROR(
1242       PropagateUseConstraintToDefs(operand_constraint.shape_layout(),
1243                                    operand_constraint.operand(), constraints));
1244 
1245   // For array-shaped operands and user instructions try to pick a minimum cost
1246   // layout. For example, if the operand of an elementwise instruction is
1247   // constrained to a certain layout we want the output of the instruction to
1248   // have the same layout.
1249   //
1250   // If the user is not array-shaped, we still want to propagate the layout
1251   // to siblings if the instruction can't change layout. This is to represent
1252   // the information that non-layout-changing instructions should have the same
1253   // layout for the operands with the same ranks.
1254   const HloInstruction* operand = operand_constraint.operand();
1255   const HloInstruction* user = operand_constraint.instruction();
1256   if (!operand->shape().IsArray()) {
1257     return Status::OK();
1258   }
1259   if (instruction_can_change_layout_func_(user) && !user->shape().IsArray()) {
1260     return Status::OK();
1261   }
1262 
1263   // Only try to choose a low cost layout if the instruction 'user' defines its
1264   // output (ie, doesn't forward a buffer from elsewhere).
1265   if (constraints->OperandBufferForwarded(user,
1266                                           operand_constraint.operand_no())) {
1267     return Status::OK();
1268   }
1269 
1270   int64 operand_rank = operand->shape().rank();
1271   if (operand_rank <= 1) {
1272     return Status::OK();
1273   }
1274 
1275   // Propagate layouts between operands of the same instruction. This is a
1276   // constraint on non-layout-changing instructions.
1277   if (!instruction_can_change_layout_func_(user)) {
1278     // Make sure all siblings have the same layout as the operand.
1279     for (int64 operand_no = 0; operand_no < user->operand_count();
1280          ++operand_no) {
1281       if (user->operand(operand_no) == operand) {
1282         continue;
1283       }
1284       const HloInstruction* sibling = user->operand(operand_no);
1285       const int64 sibling_rank = sibling->shape().rank();
1286       if (sibling_rank <= 1) {
1287         continue;
1288       }
1289       if (operand_rank != sibling_rank) {
1290         continue;
1291       }
1292       const OperandLayoutConstraint* constraint =
1293           constraints->GetOperandLayoutConstraint(user, operand_no);
1294       if (constraint != nullptr) {
1295         // Due to the DFS of the propagation we can end up here when operand_no
1296         // has a layout set that hasn't been propagated yet (is still on the
1297         // stack of layouts to propagate).
1298         // We can continue here and leave the operands with different layouts,
1299         // as we will either:
1300         // - overwrite the current operand when the DFS gets back to propagating
1301         //   operand(operand_no) to its siblings
1302         // - overwrite operand(operand_no)'s layout with a mandatory layout if
1303         //   we continue to propagate our layout to the result, and then
1304         //   backwards into all operands (if the result is an array of rank > 1)
1305         continue;
1306       }
1307       TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
1308           operand_constraint.shape_layout().layout(), user, operand_no,
1309           /*mandatory=*/false));
1310     }
1311     TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
1312         user->shape(),
1313         [&](const Shape& subshape, const ShapeIndex& shape_index) {
1314           if (subshape.IsTuple()) {
1315             return Status::OK();
1316           }
1317           if (subshape.rank() <= 1) {
1318             return Status::OK();
1319           }
1320 
1321           // Assign the right layout to input fusion of higher rank reduce
1322           // operations.
1323           if (subshape.rank() != operand->shape().rank()) {
1324             return Status::OK();
1325           }
1326           // TODO(b/67641796): Are there cases except fusion that use this code
1327           // path?
1328           TF_ASSIGN_OR_RETURN(
1329               const LogicalBuffer* buffer,
1330               constraints->points_to_analysis().GetBufferDefinedAt(
1331                   user, shape_index));
1332           // Make sure the output has the same layout as the operand.
1333           const BufferLayoutConstraint* constraint =
1334               constraints->GetBufferLayoutConstraint(*buffer);
1335           // If we already have a constraint for the buffer it was assigned but
1336           // hasn't propagated yet. This can happen with diamond-shaped graphs
1337           // where one path is first evaluated in depth-first order (we're here)
1338           // and the other path is propagated later. We don't set the layout
1339           // here as it will always be overwritten later.
1340           if (constraint == nullptr) {
1341             TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
1342                 operand_constraint.shape_layout().layout(), *buffer,
1343                 /*mandatory=*/false));
1344           }
1345           return Status::OK();
1346         }));
1347     return Status::OK();
1348   }
1349   TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
1350       user->shape(), [&](const Shape& subshape, const ShapeIndex& shape_index) {
1351         if (subshape.IsTuple()) {
1352           return Status::OK();
1353         }
1354         if (subshape.rank() <= 1) {
1355           return Status::OK();
1356         }
1357         TF_ASSIGN_OR_RETURN(
1358             const LogicalBuffer* buffer,
1359             constraints->points_to_analysis().GetBufferDefinedAt(user,
1360                                                                  shape_index));
1361         if (constraints->BufferLayout(*buffer) == nullptr ||
1362             !constraints->GetBufferLayoutConstraint(*buffer)->mandatory()) {
1363           std::unique_ptr<Layout> layout = ChooseOutputLayoutFromOperandLayout(
1364               operand_constraint.shape_layout().layout(), user,
1365               operand_constraint.operand_no());
1366           if (layout != nullptr) {
1367             TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
1368                 *layout, *buffer,
1369                 /*mandatory=*/user->opcode() == HloOpcode::kReduce,
1370                 /*dfs=*/InstructionShouldPropagateDepthFirst(*user)));
1371           }
1372         }
1373         return Status::OK();
1374       }));
1375   return Status::OK();
1376 }
1377 
PropagateBufferConstraintToOperands(const BufferLayoutConstraint & buffer_constraint,LayoutConstraints * constraints)1378 Status LayoutAssignment::PropagateBufferConstraintToOperands(
1379     const BufferLayoutConstraint& buffer_constraint,
1380     LayoutConstraints* constraints) {
1381   VLOG(5) << "PropagateBufferConstraintToOperands: "
1382           << buffer_constraint.ToString();
1383   const LogicalBuffer& buffer = buffer_constraint.buffer();
1384 
1385   const HloInstruction* instruction = buffer.instruction();
1386   if (IsAtMostRank1(instruction->shape())) {
1387     return Status::OK();
1388   }
1389 
1390   for (int64 operand_no = 0; operand_no < instruction->operand_count();
1391        ++operand_no) {
1392     const HloInstruction* operand = instruction->operand(operand_no);
1393     if (IsAtMostRank1(operand->shape())) {
1394       continue;
1395     }
1396     if (!instruction_can_change_layout_func_(instruction)) {
1397       // Copy the layout to the operand.
1398       if (buffer.IsArray() && operand->shape().IsArray() &&
1399           operand->shape().rank() ==
1400               LayoutUtil::MinorToMajor(buffer_constraint.layout()).size()) {
1401         TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
1402             buffer_constraint.layout(), instruction, operand_no,
1403             /*mandatory=*/true));
1404       }
1405     } else {
1406       if (!buffer.IsTopLevel() ||
1407           !instruction->operand(operand_no)->shape().IsArray()) {
1408         continue;  // Don't touch buffers that are internal to a tuple.
1409       }
1410       VLOG(6) << "Propagating constraint to operand " << operand_no << " of "
1411               << instruction->ToShortString();
1412       // Assign a layout if there is no constraint already.
1413       const OperandLayoutConstraint* constraint =
1414           constraints->GetOperandLayoutConstraint(instruction, operand_no);
1415       if (constraint == nullptr || !constraint->mandatory()) {
1416         std::unique_ptr<Layout> operand_layout =
1417             ChooseOperandLayoutFromOutputLayout(buffer_constraint.layout(),
1418                                                 instruction, operand_no);
1419         if (operand_layout != nullptr) {
1420           TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
1421               *operand_layout, instruction, operand_no, /*mandatory=*/false,
1422               /*dfs=*/InstructionShouldPropagateDepthFirst(*instruction)));
1423         }
1424       } else {
1425         VLOG(6) << "Operand already has a constraint "
1426                 << constraint->ToString();
1427       }
1428     }
1429   }
1430   return Status::OK();
1431 }
1432 
PropagateBufferConstraint(const BufferLayoutConstraint & buffer_constraint,LayoutConstraints * constraints)1433 Status LayoutAssignment::PropagateBufferConstraint(
1434     const BufferLayoutConstraint& buffer_constraint,
1435     LayoutConstraints* constraints) {
1436   // Only propagate array layouts.
1437   const LogicalBuffer& buffer = buffer_constraint.buffer();
1438   if (!buffer.IsArray()) {
1439     return Status::OK();
1440   }
1441   TF_RETURN_IF_ERROR(
1442       PropagateBufferConstraintToUses(buffer_constraint, constraints));
1443   return PropagateBufferConstraintToOperands(buffer_constraint, constraints);
1444 }
1445 
PropagateBufferConstraintToUses(const BufferLayoutConstraint & buffer_constraint,LayoutConstraints * constraints)1446 Status LayoutAssignment::PropagateBufferConstraintToUses(
1447     const BufferLayoutConstraint& buffer_constraint,
1448     LayoutConstraints* constraints) {
1449   const LogicalBuffer& buffer = buffer_constraint.buffer();
1450   TF_RET_CHECK(buffer.IsArray());
1451 
1452   // Propagate the layout to all array uses of the logical buffer. This skips
1453   // uses of the buffer where the buffer is the element of a tuple.
1454   for (const auto& user_operand_no :
1455        GetArrayUsesOfBuffer(buffer, constraints->points_to_analysis())) {
1456     const HloInstruction* user = user_operand_no.first;
1457     int64 operand_no = user_operand_no.second;
1458     // Only add an operand constraint if the user does not forward the buffer
1459     // because this case is not handled is SetOperandLayout.
1460     if (constraints->OperandLayout(user, operand_no) == nullptr &&
1461         !constraints->OperandBufferForwarded(user, operand_no)) {
1462       TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
1463           buffer_constraint.layout(), user, operand_no, /*mandatory=*/false));
1464     }
1465   }
1466 
1467   return Status::OK();
1468 }
1469 
PropagateResultConstraint(const ResultLayoutConstraint & layout_constraint,LayoutConstraints * constraints)1470 Status LayoutAssignment::PropagateResultConstraint(
1471     const ResultLayoutConstraint& layout_constraint,
1472     LayoutConstraints* constraints) {
1473   // Propagate the use constraint of the root instruction up to the logical
1474   // buffers which make up the result.
1475   return PropagateUseConstraintToDefs(
1476       layout_constraint.shape_layout(),
1477       constraints->computation()->root_instruction(), constraints);
1478 }
1479 
1480 namespace {
1481 
1482 // Infers the layout of the array at the given index in the given instruction's
1483 // output using points-to analysis. Precondition: The given instruction must
1484 // not produce this array value (that is, the array is forwarded from the
1485 // instruction's operands).
InferArrayLayout(const TuplePointsToAnalysis & points_to_analysis,HloInstruction * instruction,const ShapeIndex & index)1486 StatusOr<Layout> InferArrayLayout(
1487     const TuplePointsToAnalysis& points_to_analysis,
1488     HloInstruction* instruction, const ShapeIndex& index) {
1489   // This function should only be called for array shapes which don't yet have
1490   // layouts.
1491   const Shape& subshape = ShapeUtil::GetSubshape(instruction->shape(), index);
1492   TF_RET_CHECK(subshape.IsArray());
1493   TF_RET_CHECK(!subshape.has_layout());
1494 
1495   // The instruction should not define the buffer at this index.
1496   TF_RET_CHECK(
1497       !points_to_analysis.InstructionDefinesBufferAtIndex(instruction, index))
1498       << instruction->ToString();
1499 
1500   const auto& source_buffers =
1501       points_to_analysis.GetPointsToSet(instruction).element(index);
1502   TF_RET_CHECK(!source_buffers.empty());
1503 
1504   // Verify the layout is the same for every LogicalBuffer which this location
1505   // ('instruction' and 'index') points to.
1506   const Layout* first_buffer_layout = nullptr;
1507   for (const LogicalBuffer* source_buffer : source_buffers) {
1508     if (!source_buffer->shape().has_layout()) {
1509       // This should not happen because we've assigned layouts to all
1510       // instructions preceding this one.
1511       return InternalError("LogicalBuffer %s does not have a layout",
1512                            source_buffer->ToString());
1513     }
1514 
1515     if (first_buffer_layout == nullptr) {
1516       first_buffer_layout = &source_buffer->shape().layout();
1517     } else if (!LayoutUtil::Equal(source_buffer->shape().layout(),
1518                                   *first_buffer_layout)) {
1519       // The points-to set is ambiguous for this index and the different source
1520       // buffers have different layouts. This case is possible in valid XLA
1521       // computations because we do not propagate BufferLayoutConstraints to all
1522       // LogicalBuffers which may alias the constrained LogicalBuffer at some
1523       // point in the computation.
1524       return FailedPrecondition(
1525           "Array at index {%s} in instruction %s aliases buffers %s "
1526           "and %s which have different layouts",
1527           absl::StrJoin(index, ","), instruction->name(),
1528           source_buffers[0]->ToString(), source_buffer->ToString());
1529     }
1530   }
1531 
1532   return *first_buffer_layout;
1533 }
1534 
1535 // For fusion instructions, set the layout of each fused parameter instruction
1536 // to match the layout of its corresponding fusion instruction operand. Also,
1537 // set the layout of the fused root to match the layout of the fusion
1538 // instruction itself.
SetFusionLayouts(HloInstruction * fusion)1539 Status SetFusionLayouts(HloInstruction* fusion) {
1540   TF_RET_CHECK(fusion->opcode() == HloOpcode::kFusion);
1541   for (auto* fused_instruction :
1542        fusion->fused_instructions_computation()->MakeInstructionPostOrder()) {
1543     if (fused_instruction->opcode() == HloOpcode::kParameter) {
1544       const HloInstruction* fusion_operand =
1545           fusion->operand(fused_instruction->parameter_number());
1546       DCHECK(ShapeUtil::Compatible(fusion_operand->shape(),
1547                                    fused_instruction->shape()));
1548       TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1549           fusion_operand->shape(), fused_instruction->mutable_shape()));
1550     } else if (fused_instruction == fusion->fused_expression_root()) {
1551       // The layout of the root of the fused expression must match the fusion
1552       // instruction layout.
1553       DCHECK(
1554           ShapeUtil::Compatible(fusion->shape(), fused_instruction->shape()));
1555       TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1556           fusion->shape(), fused_instruction->mutable_shape()));
1557     } else if (fused_instruction->opcode() == HloOpcode::kGetTupleElement) {
1558       // A GTE inherits its layout from its operand (which should ultimately be
1559       // a parameter).
1560       TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1561           fused_instruction->operand(0)->shape().tuple_shapes(
1562               fused_instruction->tuple_index()),
1563           fused_instruction->mutable_shape()));
1564     } else if (fused_instruction->opcode() == HloOpcode::kConstant) {
1565       // Give constants the layout of their literal.
1566       TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1567           fused_instruction->literal().shape(),
1568           fused_instruction->mutable_shape()));
1569     } else if (fused_instruction->opcode() == HloOpcode::kInfeed) {
1570       // Nop; leave the infeed layout alone.
1571     } else if (fusion->fusion_kind() != HloInstruction::FusionKind::kCustom) {
1572       // Other instructions don't have layouts inside of fusion nodes.
1573       // But do not clear layouts for other instructions in custom fusion nodes.
1574       LayoutUtil::ClearLayout(fused_instruction->mutable_shape());
1575     }
1576   }
1577 
1578   return Status::OK();
1579 }
1580 
1581 }  // namespace
1582 
AssignLayouts(const LayoutConstraints & constraints,HloComputation * computation)1583 Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints,
1584                                        HloComputation* computation) {
1585   VLOG(2) << "Assigning layouts to computation: " << computation->name();
1586   XLA_VLOG_LINES(2, computation->ToString());
1587   XLA_VLOG_LINES(2, constraints.ToString());
1588 
1589   for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
1590     LayoutUtil::ClearLayout(instruction->mutable_shape());
1591 
1592     // Create a copy of an operand if the operand instruction's layout does not
1593     // match the use constraint (OperandLayoutConstraint).
1594     for (int64 operand_no = 0; operand_no < instruction->operand_count();
1595          ++operand_no) {
1596       const ShapeLayout* operand_layout =
1597           constraints.OperandLayout(instruction, operand_no);
1598       if (operand_layout != nullptr) {
1599         TF_RETURN_IF_ERROR(CopyOperandIfLayoutsDiffer(*operand_layout,
1600                                                       instruction, operand_no));
1601       }
1602     }
1603 
1604     // Set the layouts of the array shapes this instruction defines as indicated
1605     // by the respective BufferLayoutConstraints. Any array shapes in the output
1606     // of the instruction which are not defined by the instruction (eg, array
1607     // elements in a Tuple instruction) will be assigned below via inference.
1608     for (const LogicalBuffer* buffer :
1609          constraints.points_to_analysis().GetBuffersDefinedByInstruction(
1610              instruction)) {
1611       if (!buffer->shape().IsArray()) {
1612         continue;
1613       }
1614 
1615       TF_RET_CHECK(buffer->instruction() == instruction);
1616       const Layout* buffer_layout = constraints.BufferLayout(*buffer);
1617       TF_RET_CHECK(buffer_layout != nullptr);
1618 
1619       if (instruction->opcode() == HloOpcode::kConstant) {
1620         // For constants, we also need to change the layout of the internal
1621         // literal.
1622         instruction->RelayoutConstant(*buffer_layout, buffer->index());
1623       } else {
1624         Shape* buffer_subshape = ShapeUtil::GetMutableSubshape(
1625             instruction->mutable_shape(), buffer->index());
1626         *buffer_subshape->mutable_layout() = *buffer_layout;
1627       }
1628     }
1629 
1630     // Any remaining layouts in the output of the instruction must be
1631     // inferrable using points-to analysis.
1632     TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus(
1633         instruction->mutable_shape(),
1634         [instruction, &constraints](Shape* subshape, const ShapeIndex& index) {
1635           if (subshape->has_layout() || !subshape->IsArray()) {
1636             return Status::OK();
1637           }
1638           // Set Layout of subshape to match layout of LogicalBuffer which
1639           // produces it.
1640           TF_ASSIGN_OR_RETURN(*subshape->mutable_layout(),
1641                               InferArrayLayout(constraints.points_to_analysis(),
1642                                                instruction, index));
1643           return Status::OK();
1644         }));
1645 
1646     // Fusion instructions require some layouts to be set on fused instructions
1647     // inside the fusion instruction.
1648     if (instruction->opcode() == HloOpcode::kFusion) {
1649       TF_RETURN_IF_ERROR(SetFusionLayouts(instruction));
1650     }
1651 
1652     // Execute extra verification step once the layout has been finalized.
1653     TF_RETURN_IF_ERROR(Verify(instruction));
1654 
1655     // Shape must be valid.
1656     TF_RETURN_IF_ERROR(
1657         ShapeUtil::ValidateShapeWithOptionalLayout(instruction->shape()));
1658 
1659     // Verify all layouts in the shape have been set.
1660     TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape()));
1661   }
1662   return Status::OK();
1663 }
1664 
CalculateComputationLayout(HloComputation * computation)1665 Status LayoutAssignment::CalculateComputationLayout(
1666     HloComputation* computation) {
1667   ComputationLayout computation_layout(computation->ComputeProgramShape(),
1668                                        /*ignore_layouts=*/false);
1669   InsertOrDie(&computation_layouts_, computation, computation_layout);
1670   VLOG(2) << "  Calculated ComputationLayout = "
1671           << computation_layout.ToString();
1672   return Status::OK();
1673 }
1674 
ClearComputationLayouts(HloComputation * computation)1675 Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) {
1676   // Clear existing layouts of the instructions.  All layouts must be assigned
1677   // by the LayoutAssignment pass, except for those on parameters, the
1678   // computation result, and a couple special cases. The former two are
1679   // specified in computation_layout.  Clearing the layouts here avoids hiding
1680   // potential bugs in the layout assignment pass that may accidentally use the
1681   // existing layout.
1682   for (HloInstruction* instruction : computation->instructions()) {
1683     if (instruction->opcode() == HloOpcode::kBitcast) {
1684       // bitcasts are inherently layout sensitive and so a bitcast instruction
1685       // present in the IR before layout assignment is a bug.
1686       return InternalError(
1687           "Unexpected bitcast operation seen during layout assignment: %s.",
1688           instruction->ToString());
1689     }
1690     // Some instructions carry mandatory layouts in their shape.
1691     if (instruction->opcode() != HloOpcode::kInfeed &&
1692         !IsLayoutConstrainedCustomCall(instruction)) {
1693       LayoutUtil::ClearLayout(instruction->mutable_shape());
1694     }
1695   }
1696   return Status::OK();
1697 }
1698 
RunOnComputation(ComputationLayout * computation_layout,const TuplePointsToAnalysis & points_to_analysis,HloComputation * computation,ChannelLayoutConstraints * channel_constraints)1699 Status LayoutAssignment::RunOnComputation(
1700     ComputationLayout* computation_layout,
1701     const TuplePointsToAnalysis& points_to_analysis,
1702     HloComputation* computation,
1703     ChannelLayoutConstraints* channel_constraints) {
1704   VLOG(2) << "LayoutAssignment::RunOnComputation(" << computation->name()
1705           << ")";
1706 
1707   // Must be run before clearing layouts.
1708   TF_RETURN_IF_ERROR(BuildHostChannelConstraints(computation));
1709 
1710   TF_RETURN_IF_ERROR(ClearComputationLayouts(computation));
1711   if (computation_layout != nullptr) {
1712     auto it = computation_layouts_.find(computation);
1713     if (it == computation_layouts_.end()) {
1714       VLOG(2) << "  New ComputationLayout = " << computation_layout->ToString();
1715       computation_layouts_.emplace(computation, *computation_layout);
1716     } else {
1717       TF_RET_CHECK(computation_layout == &it->second ||
1718                    computation_layout == entry_computation_layout_);
1719       VLOG(2) << "  Existing ComputationLayout = "
1720               << computation_layout->ToString();
1721     }
1722   } else {
1723     VLOG(2) << "  No ComputationLayout specified (will be calculated)";
1724   }
1725 
1726   // Construct LayoutConstraints with all layout constraints of the computation.
1727   LayoutConstraints constraints(points_to_analysis, computation);
1728 
1729   // Add constraints required for correctness on all backends (eg, entry
1730   // parameter layout constraints).
1731   TF_RETURN_IF_ERROR(AddMandatoryConstraints(
1732       computation_layout, channel_constraints, computation, &constraints));
1733 
1734   // Add any backend-specific constraints.
1735   TF_RETURN_IF_ERROR(AddBackendConstraints(&constraints));
1736 
1737   // Propagates layouts from mandatory and backend constraints.
1738   TF_RETURN_IF_ERROR(PropagateConstraints(&constraints));
1739 
1740   // Prior to applying default layouts, we take note of all HLO instructions
1741   // which lack a layout constraint.
1742   for (LogicalBuffer::Id buffer_id : constraints.unconstrained_buffer_ids()) {
1743     unconstrained_layout_instructions_.insert(
1744         points_to_analysis.GetBuffer(buffer_id).instruction());
1745   }
1746 
1747   // While any unconstrained buffers remain, pick an arbitrary buffer, give it a
1748   // layout and propagate the change.
1749   while (!constraints.unconstrained_buffer_ids().empty()) {
1750     int unconstrained_count = constraints.unconstrained_buffer_ids().size();
1751 
1752     // Arbitrarily pick the first unconstrained buffer and give it the default
1753     // layout (or the literal layout, in case of constants). By construction
1754     // unconstrained_buffers() has a stable sort based on LogicalBuffer::Id.
1755     const LogicalBuffer& buffer = points_to_analysis.GetBuffer(
1756         *constraints.unconstrained_buffer_ids().begin());
1757     const HloInstruction* instruction = buffer.instruction();
1758     Layout new_layout =
1759         instruction->opcode() == HloOpcode::kConstant
1760             ? ShapeUtil::GetSubshape(instruction->literal().shape(),
1761                                      buffer.index())
1762                   .layout()
1763             : LayoutUtil::GetDefaultLayoutForShape(buffer.shape());
1764     TF_RETURN_IF_ERROR(constraints.SetBufferLayout(new_layout, buffer,
1765                                                    /*mandatory=*/false));
1766 
1767     TF_RETURN_IF_ERROR(PropagateConstraints(&constraints));
1768 
1769     // To verify progress has been made, check that the number of unconstrained
1770     // buffers has been reduced.
1771     CHECK_LT(constraints.unconstrained_buffer_ids().size(),
1772              unconstrained_count);
1773   }
1774   // All logical buffers should have constraints at this point. All that
1775   // remains is assign the constraints to the buffers and infer layouts for
1776   // aliased buffers.
1777   TF_RETURN_IF_ERROR(AssignLayouts(constraints, computation));
1778 
1779   // If the computation layout wasn't specified, now it is the time to compute
1780   // it according to the parameters and root instruction layouts.
1781   // This allows the first pass through this API to record the best flowing
1782   // layout to parameters and root instruction.
1783   if (computation_layout == nullptr) {
1784     TF_RETURN_IF_ERROR(CalculateComputationLayout(computation));
1785   }
1786 
1787   // Record the layouts assigned for any communication ops in
1788   // channel_constraints so that they are constrained for future modules.
1789   if (channel_constraints != nullptr) {
1790     TF_RETURN_IF_ERROR(
1791         ConstrainChannelLayouts(computation, channel_constraints));
1792   }
1793 
1794   // Copy the root instruction's result if its layout does not match the result
1795   // layout constraint.
1796   if (constraints.ResultLayout() != nullptr &&
1797       !constraints.ResultLayout()->MatchesLayoutInShape(
1798           computation->root_instruction()->shape())) {
1799     TF_ASSIGN_OR_RETURN(
1800         HloInstruction * new_root,
1801         CreateCopyWithNewLayout(constraints.ResultLayout()->shape(),
1802                                 computation->root_instruction()));
1803     computation->set_root_instruction(new_root);
1804   }
1805   return Status::OK();
1806 }
1807 
ConstrainChannelLayouts(HloComputation * computation,ChannelLayoutConstraints * channel_constraints)1808 Status LayoutAssignment::ConstrainChannelLayouts(
1809     HloComputation* computation,
1810     ChannelLayoutConstraints* channel_constraints) {
1811   auto get_channel_constraints = [&](const HloInstruction* instruction) {
1812     return IsHostSendRecv(instruction) ? &host_channel_constraints_
1813                                        : channel_constraints;
1814   };
1815   // We go through the kRecvDone before. These must either impose their layout,
1816   // or find a matching one already existing (ConstrainChannel() returns
1817   // nullptr).
1818   for (HloInstruction* instruction : computation->instructions()) {
1819     if (instruction->opcode() == HloOpcode::kRecvDone) {
1820       const Layout* layout =
1821           get_channel_constraints(instruction)
1822               ->ConstrainChannel(
1823                   instruction->channel_id(),
1824                   ShapeUtil::GetSubshape(instruction->shape(), {0}).layout());
1825       TF_RET_CHECK(layout == nullptr)
1826           << instruction->ToString()
1827           << " cannot constrain layout as it was set to "
1828           << LayoutUtil::HumanString(*layout);
1829     }
1830   }
1831   // After that we go through the kSend. These are likely going to have a kCopy
1832   // as operand (otherwise we add it), so in case the constrained layout does
1833   // not match, we can change the kCopy layout (and the kSend one as well).
1834   for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
1835     if (instruction->opcode() == HloOpcode::kSend) {
1836       HloInstruction* operand = instruction->mutable_operand(0);
1837       const Layout* layout = get_channel_constraints(instruction)
1838                                  ->ConstrainChannel(instruction->channel_id(),
1839                                                     operand->shape().layout());
1840       if (layout != nullptr) {
1841         // We found an already constrained layout which does not match the one
1842         // the kSend wants to impose. Either add a new kCopy, or use the
1843         // existing one to marshal the correct shape.
1844         Shape shape = operand->shape();
1845         *shape.mutable_layout() = *layout;
1846         if (operand->opcode() != HloOpcode::kCopy) {
1847           HloInstruction* copy = operand->parent()->AddInstruction(
1848               HloInstruction::CreateUnary(shape, HloOpcode::kCopy, operand));
1849           RegisterAddedCopy(copy);
1850           SetupCopiedInstruction(*operand, copy, {});
1851           TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(0, copy));
1852           operand = copy;
1853         } else {
1854           *operand->mutable_shape() = shape;
1855         }
1856         Shape* send_shape =
1857             ShapeUtil::GetMutableSubshape(instruction->mutable_shape(), {0});
1858         *send_shape = shape;
1859       }
1860     } else if (instruction->IsCrossModuleAllReduce()) {
1861       const Layout* layout =
1862           get_channel_constraints(instruction)
1863               ->ConstrainChannel(instruction->all_reduce_id().value(),
1864                                  instruction->shape().layout());
1865       if (layout != nullptr) {
1866         // We found an already constrained layout which does not match the one
1867         // the channel wants to impose. Either add a new kCopy, or use the
1868         // existing one to marshal the correct shape.
1869         HloInstruction* operand = instruction->mutable_operand(0);
1870         Shape shape = operand->shape();
1871         *shape.mutable_layout() = *layout;
1872         if (operand->opcode() != HloOpcode::kCopy) {
1873           HloInstruction* copy = operand->parent()->AddInstruction(
1874               HloInstruction::CreateUnary(shape, HloOpcode::kCopy, operand));
1875           RegisterAddedCopy(copy);
1876           SetupCopiedInstruction(*operand, copy, {});
1877           TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(0, copy));
1878           operand = copy;
1879         } else {
1880           *operand->mutable_shape() = shape;
1881         }
1882         *instruction->mutable_shape() = shape;
1883       }
1884     }
1885   }
1886   return Status::OK();
1887 }
1888 
PropagateComputationLayouts(HloComputation * computation,ComputationLayout * computation_layout)1889 Status LayoutAssignment::PropagateComputationLayouts(
1890     HloComputation* computation, ComputationLayout* computation_layout) {
1891   ComputationLayout computed_computation_layout(
1892       computation->ComputeProgramShape(),
1893       /*ignore_layouts=*/false);
1894   for (int64 i = 0; i < computed_computation_layout.parameter_count(); ++i) {
1895     ShapeLayout* param_layout = computation_layout->mutable_parameter_layout(i);
1896     if (!param_layout->LayoutIsSet()) {
1897       VLOG(4) << "Assigning layout to parameter " << i << " of computation "
1898               << computation->name() << ": "
1899               << computed_computation_layout.parameter_layout(i).ToString();
1900       *param_layout = computed_computation_layout.parameter_layout(i);
1901     } else {
1902       TF_RET_CHECK(computed_computation_layout.parameter_layout(i) ==
1903                    *param_layout);
1904     }
1905   }
1906   ShapeLayout* result_layout = computation_layout->mutable_result_layout();
1907   if (!result_layout->LayoutIsSet()) {
1908     VLOG(4) << "Assigning result layout of computation " << computation->name()
1909             << ": " << computed_computation_layout.result_layout().ToString();
1910     *result_layout = computed_computation_layout.result_layout();
1911   } else {
1912     TF_RET_CHECK(computed_computation_layout.result_layout() == *result_layout);
1913   }
1914   return Status::OK();
1915 }
1916 
Run(HloModule * module)1917 StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
1918   VLOG(2) << "Running layout assignment on module " << module->name();
1919   TF_RETURN_IF_ERROR(Init());
1920 
1921   // Verify computation layout is sane.
1922   const HloComputation* entry = module->entry_computation();
1923   TF_RET_CHECK(entry_computation_layout_->parameter_count() ==
1924                entry->num_parameters());
1925   for (int64 i = 0; i < entry->num_parameters(); ++i) {
1926     TF_RET_CHECK(
1927         ShapeUtil::Compatible(entry_computation_layout_->parameter_shape(i),
1928                               entry->parameter_instruction(i)->shape()));
1929   }
1930   TF_RET_CHECK(ShapeUtil::Compatible(entry_computation_layout_->result_shape(),
1931                                      entry->root_instruction()->shape()));
1932 
1933   // We do two passes. The first one we pass a nullptr ComputationLayout to
1934   // the RunOnComputation() calls (for non entry computations), and we register
1935   // the ComputationLayout which are naturally flowing in DFS fashion to the
1936   // parameters and root instruction.
1937   // Walking in DFS mode though, means that we can end up with incorrect layouts
1938   // when seen from an outer instruction, which has across-computation
1939   // constraints to impose.
1940   // For example, the kWhile instruction needs to enforce the same layouts for
1941   // the parameters and root of the body, as well as the condition parameters.
1942   // Similarly, the kConditional instruction needs to enforce the same layouts
1943   // for the root of the true and false computations.
1944   // So in the first pass, while allowing the layouts to flow to parameters and
1945   // root, we also fix up the eventually inconsistent ComputationLayout, which
1946   // will be then made mandatory by the second pass.
1947   for (int64 i = 0; i < 2; ++i) {
1948     VLOG(5) << "Running " << (i == 0 ? "un" : "") << "constrained pass";
1949     TF_RETURN_IF_ERROR(ClearPreviousPassSideEffects(module));
1950     TF_ASSIGN_OR_RETURN(auto points_to_analysis,
1951                         TuplePointsToAnalysis::Run(module));
1952     for (auto* computation : module->MakeComputationPostOrder()) {
1953       if (computation->IsFusionComputation()) {
1954         continue;
1955       }
1956       if (computation == module->entry_computation()) {
1957         TF_RETURN_IF_ERROR(RunOnComputation(
1958             entry_computation_layout_, *points_to_analysis,
1959             module->entry_computation(), channel_layout_constraints_));
1960       } else {
1961         ComputationLayout* computation_layout =
1962             (i == 0) ? nullptr : &FindOrDie(computation_layouts_, computation);
1963         TF_RETURN_IF_ERROR(RunOnComputation(computation_layout,
1964                                             *points_to_analysis, computation,
1965                                             channel_layout_constraints_));
1966       }
1967     }
1968   }
1969   TF_RETURN_IF_ERROR(PropagateComputationLayouts(module->entry_computation(),
1970                                                  entry_computation_layout_));
1971   TF_RETURN_IF_ERROR(CheckLayouts(module));
1972 
1973   // All layouts are reset then reassigned by this pass.
1974   return true;
1975 }
1976 
1977 /* static */
InstructionCanChangeLayout(const HloInstruction * instruction)1978 bool LayoutAssignment::InstructionCanChangeLayout(
1979     const HloInstruction* instruction) {
1980   switch (instruction->opcode()) {
1981     case HloOpcode::kAbs:
1982     case HloOpcode::kAdd:
1983     case HloOpcode::kAddDependency:
1984     case HloOpcode::kAnd:
1985     case HloOpcode::kAtan2:
1986     case HloOpcode::kBitcastConvert:
1987     case HloOpcode::kCeil:
1988     case HloOpcode::kClamp:
1989     case HloOpcode::kClz:
1990     case HloOpcode::kCompare:
1991     case HloOpcode::kComplex:
1992     case HloOpcode::kConcatenate:
1993     case HloOpcode::kConditional:
1994     case HloOpcode::kConvert:
1995     case HloOpcode::kCos:
1996     case HloOpcode::kAllReduce:
1997     case HloOpcode::kAllToAll:
1998     case HloOpcode::kCollectivePermute:
1999     case HloOpcode::kDivide:
2000     case HloOpcode::kDynamicSlice:
2001     case HloOpcode::kDynamicUpdateSlice:
2002     case HloOpcode::kExp:
2003     case HloOpcode::kExpm1:
2004     case HloOpcode::kFft:
2005     case HloOpcode::kFloor:
2006     case HloOpcode::kImag:
2007     case HloOpcode::kIsFinite:
2008     case HloOpcode::kLog:
2009     case HloOpcode::kLog1p:
2010     case HloOpcode::kMap:
2011     case HloOpcode::kMaximum:
2012     case HloOpcode::kMinimum:
2013     case HloOpcode::kMultiply:
2014     case HloOpcode::kNegate:
2015     case HloOpcode::kNot:
2016     case HloOpcode::kOr:
2017     case HloOpcode::kXor:
2018     case HloOpcode::kPad:
2019     case HloOpcode::kPower:
2020     case HloOpcode::kReal:
2021     case HloOpcode::kReducePrecision:
2022     case HloOpcode::kReduceWindow:
2023     case HloOpcode::kRemainder:
2024     case HloOpcode::kReverse:
2025     case HloOpcode::kRoundNearestAfz:
2026     case HloOpcode::kRsqrt:
2027     case HloOpcode::kScatter:
2028     case HloOpcode::kSelect:
2029     case HloOpcode::kSelectAndScatter:
2030     case HloOpcode::kShiftLeft:
2031     case HloOpcode::kShiftRightArithmetic:
2032     case HloOpcode::kShiftRightLogical:
2033     case HloOpcode::kSign:
2034     case HloOpcode::kSin:
2035     case HloOpcode::kSlice:
2036     case HloOpcode::kSort:
2037     case HloOpcode::kSqrt:
2038     case HloOpcode::kSubtract:
2039     case HloOpcode::kTanh:
2040     case HloOpcode::kTriangularSolve:
2041     case HloOpcode::kCholesky:
2042     case HloOpcode::kTupleSelect:
2043     case HloOpcode::kWhile:
2044       return false;
2045     case HloOpcode::kBatchNormGrad:
2046     case HloOpcode::kBatchNormInference:
2047     case HloOpcode::kBatchNormTraining:
2048     case HloOpcode::kBitcast:
2049     case HloOpcode::kBroadcast:
2050     case HloOpcode::kCall:
2051     case HloOpcode::kConstant:
2052     case HloOpcode::kConvolution:
2053     case HloOpcode::kCopy:
2054     case HloOpcode::kCustomCall:
2055     case HloOpcode::kDomain:
2056     case HloOpcode::kDot:
2057     case HloOpcode::kFusion:
2058     case HloOpcode::kGather:
2059     case HloOpcode::kGetTupleElement:
2060     case HloOpcode::kInfeed:
2061     case HloOpcode::kIota:
2062     case HloOpcode::kOutfeed:
2063     case HloOpcode::kParameter:
2064     case HloOpcode::kRecv:
2065     case HloOpcode::kRecvDone:
2066     case HloOpcode::kReduce:
2067     case HloOpcode::kReplicaId:
2068     case HloOpcode::kReshape:
2069     case HloOpcode::kRng:
2070     case HloOpcode::kSend:
2071     case HloOpcode::kSendDone:
2072     case HloOpcode::kAfterAll:
2073     case HloOpcode::kTrace:
2074     case HloOpcode::kTranspose:
2075     case HloOpcode::kTuple:
2076     case HloOpcode::kGetDimensionSize:
2077       return true;
2078   }
2079 }
2080 
2081 /* static */
IsAtMostRank1(const Shape & shape)2082 bool LayoutAssignment::IsAtMostRank1(const Shape& shape) {
2083   if (shape.IsArray()) {
2084     return shape.rank() <= 1;
2085   }
2086   return absl::c_all_of(shape.tuple_shapes(), [](const Shape& subshape) {
2087     return IsAtMostRank1(subshape);
2088   });
2089 }
2090 
Init()2091 Status LayoutAssignment::Init() {
2092   computation_layouts_.clear();
2093   *entry_computation_layout_ = saved_entry_computation_layout_;
2094   return Status::OK();
2095 }
2096 
ClearPreviousPassSideEffects(HloModule * module)2097 Status LayoutAssignment::ClearPreviousPassSideEffects(HloModule* module) {
2098   VLOG(5) << "Clearing previous side effects";
2099   // Clear all the copies which have been added, and all the related
2100   // instructions (like GTE and tuples).
2101   int64 removed_copies = 0;
2102   for (HloComputation* computation : module->computations()) {
2103     for (HloInstruction* instruction :
2104          computation->MakeInstructionPostOrder()) {
2105       if (instruction->opcode() == HloOpcode::kCopy &&
2106           added_copies_.contains(instruction)) {
2107         VLOG(5) << "Removing added copy: " << instruction->ToString();
2108         TF_RETURN_IF_ERROR(
2109             instruction->ReplaceAllUsesWith(instruction->mutable_operand(0)));
2110         TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));
2111         ++removed_copies;
2112       }
2113     }
2114   }
2115   added_copies_.clear();
2116   unconstrained_layout_instructions_.clear();
2117   if (removed_copies > 0) {
2118     TupleSimplifier tuple_simplifier;
2119     HloDCE dce;
2120     TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
2121     TF_RETURN_IF_ERROR(dce.Run(module).status());
2122   }
2123   ResetChannelConstraints();
2124   return Status::OK();
2125 }
2126 
AddCopyForOperand(HloInstruction * instruction,int64 operand_number)2127 Status LayoutAssignment::AddCopyForOperand(HloInstruction* instruction,
2128                                            int64 operand_number) {
2129   HloInstruction* operand = instruction->mutable_operand(operand_number);
2130   if (operand->opcode() != HloOpcode::kCopy || operand->user_count() > 1) {
2131     HloInstruction* copy =
2132         instruction->parent()->AddInstruction(HloInstruction::CreateUnary(
2133             operand->shape(), HloOpcode::kCopy, operand));
2134     SetupCopiedInstruction(*operand, copy, {});
2135     LayoutUtil::ClearLayout(copy->mutable_shape());
2136     TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(operand_number, copy));
2137   }
2138   return Status::OK();
2139 }
2140 
2141 }  // namespace xla
2142