1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/layout_assignment.h"
17 
18 #include <algorithm>
19 #include <deque>
20 #include <functional>
21 #include <map>
22 #include <memory>
23 #include <numeric>
24 #include <ostream>
25 #include <set>
26 #include <string>
27 #include <tuple>
28 
29 #include "absl/algorithm/container.h"
30 #include "absl/memory/memory.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_format.h"
33 #include "absl/strings/str_join.h"
34 #include "absl/types/span.h"
35 #include "tensorflow/compiler/xla/layout_util.h"
36 #include "tensorflow/compiler/xla/map_util.h"
37 #include "tensorflow/compiler/xla/permutation_util.h"
38 #include "tensorflow/compiler/xla/service/call_graph.h"
39 #include "tensorflow/compiler/xla/service/computation_layout.h"
40 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
41 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
42 #include "tensorflow/compiler/xla/service/hlo_computation.h"
43 #include "tensorflow/compiler/xla/service/hlo_dce.h"
44 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
45 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
46 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
47 #include "tensorflow/compiler/xla/service/logical_buffer.h"
48 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
49 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
50 #include "tensorflow/compiler/xla/shape_layout.h"
51 #include "tensorflow/compiler/xla/shape_util.h"
52 #include "tensorflow/compiler/xla/status_macros.h"
53 #include "tensorflow/compiler/xla/statusor.h"
54 #include "tensorflow/compiler/xla/types.h"
55 #include "tensorflow/compiler/xla/util.h"
56 #include "tensorflow/compiler/xla/xla_data.pb.h"
57 #include "tensorflow/core/lib/core/errors.h"
58 #include "tensorflow/core/lib/core/status.h"
59 #include "tensorflow/core/platform/logging.h"
60 #include "tensorflow/core/platform/protobuf.h"
61 
62 namespace xla {
63 
operator <<(std::ostream & out,const LayoutConstraint & constraint)64 std::ostream& operator<<(std::ostream& out,
65                          const LayoutConstraint& constraint) {
66   out << constraint.ToString();
67   return out;
68 }
69 
BufferLayoutConstraint(const Layout & layout,const LogicalBuffer & buffer,bool mandatory,bool dfs)70 BufferLayoutConstraint::BufferLayoutConstraint(const Layout& layout,
71                                                const LogicalBuffer& buffer,
72                                                bool mandatory, bool dfs)
73     : LayoutConstraint(mandatory, dfs), layout_(layout), buffer_(&buffer) {
74   CHECK(LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()).ok());
75 }
76 
ToString() const77 string BufferLayoutConstraint::ToString() const {
78   return absl::StrFormat("BufferLayoutConstraint %s: %s", buffer_->ToString(),
79                          LayoutUtil::HumanString(layout_));
80 }
81 
OperandLayoutConstraint(const ShapeLayout & shape_layout,const HloInstruction * instruction,int64 operand_no,bool mandatory,bool dfs)82 OperandLayoutConstraint::OperandLayoutConstraint(
83     const ShapeLayout& shape_layout, const HloInstruction* instruction,
84     int64 operand_no, bool mandatory, bool dfs)
85     : LayoutConstraint(mandatory, dfs),
86       shape_layout_(shape_layout),
87       instruction_(instruction),
88       operand_no_(operand_no) {
89   CHECK(shape_layout_.LayoutIsSet());
90   CHECK(ShapeUtil::Compatible(shape_layout.shape(),
91                               instruction->operand(operand_no)->shape()))
92       << shape_layout.shape() << " is not compatible with "
93       << instruction->operand(operand_no)->shape() << " (for operand "
94       << operand_no << " of instruction " << instruction->ToString() << ")";
95 }
96 
ToString() const97 string OperandLayoutConstraint::ToString() const {
98   return absl::StrFormat("OperandLayoutConstraint %s, operand %d: %s",
99                          instruction_->name(), operand_no_,
100                          shape_layout_.ToString());
101 }
102 
ToString() const103 string ResultLayoutConstraint::ToString() const {
104   return absl::StrFormat("ResultLayoutConstraint: %s",
105                          shape_layout_.ToString());
106 }
107 
LayoutConstraints(const TuplePointsToAnalysis & points_to_analysis,HloComputation * computation)108 LayoutConstraints::LayoutConstraints(
109     const TuplePointsToAnalysis& points_to_analysis,
110     HloComputation* computation)
111     : points_to_analysis_(points_to_analysis), computation_(computation) {
112   // Gather all array-shaped logical buffers into unconstrained_buffer_ids.
113   for (HloInstruction* inst : computation_->instructions()) {
114     points_to_analysis_.GetPointsToSet(inst).ForEachElement(
115         [&](const ShapeIndex&, const PointsToSet::BufferList& buffers) {
116           for (const LogicalBuffer* buffer : buffers) {
117             // The points to analysis is computed per module, restrict
118             // constraints to array buffers in this computation.
119             if (buffer->IsArray() &&
120                 buffer->instruction()->parent() == computation) {
121               unconstrained_buffer_ids_.insert(buffer->id());
122             }
123           }
124         });
125   }
126 }
127 
GetBufferSet(const HloInstruction * instruction) const128 PointsToSet::BufferSet* LayoutConstraints::GetBufferSet(
129     const HloInstruction* instruction) const {
130   auto it = buffer_sets_cache_.find(instruction);
131   if (it != buffer_sets_cache_.end()) {
132     return it->second.get();
133   }
134   auto& buffer_set =
135       buffer_sets_cache_
136           .emplace(instruction, absl::make_unique<PointsToSet::BufferSet>())
137           .first->second;
138   const auto& points_to_set = points_to_analysis_.GetPointsToSet(instruction);
139   points_to_set.ForEachElement(
140       [&buffer_set](const ShapeIndex& /*index*/,
141                     const PointsToSet::BufferList& buffers) {
142         buffer_set->insert(buffers.begin(), buffers.end());
143       });
144   return buffer_set.get();
145 }
146 
OperandBufferForwarded(const HloInstruction * instruction,int64 operand_no) const147 bool LayoutConstraints::OperandBufferForwarded(
148     const HloInstruction* instruction, int64 operand_no) const {
149   // The operand is potentially forwarded if the intersection of points-to sets
150   // of the operand and the instruction is non-empty.
151   PointsToSet::BufferSet* output_buffers = GetBufferSet(instruction);
152   PointsToSet::BufferSet* operand_buffers =
153       GetBufferSet(instruction->operand(operand_no));
154   return absl::c_any_of(*output_buffers, [&](const LogicalBuffer* b) {
155     return operand_buffers->count(b) > 0;
156   });
157 }
158 
SetBufferLayout(const Layout & layout,const LogicalBuffer & buffer,bool mandatory,bool dfs)159 Status LayoutConstraints::SetBufferLayout(const Layout& layout,
160                                           const LogicalBuffer& buffer,
161                                           bool mandatory, bool dfs) {
162   VLOG(3) << "SetBufferLayout : " << buffer << " : "
163           << LayoutUtil::HumanString(layout);
164 
165   TF_RETURN_IF_ERROR(points_to_analysis_.VerifyBuffer(buffer));
166   if (!buffer.IsArray()) {
167     return FailedPrecondition(
168         "Layout of buffer %s cannot be constrained because buffer is not "
169         "array-shaped, has shape: %s",
170         buffer.ToString(), ShapeUtil::HumanString(buffer.shape()));
171   }
172   TF_RETURN_IF_ERROR(
173       LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()));
174 
175   auto iter = buffer_constraints_.find(&buffer);
176   if (iter != buffer_constraints_.end()) {
177     const BufferLayoutConstraint& curr_constraint = iter->second;
178     if (Layout::Equal().MinorToMajorOnly()(curr_constraint.layout(), layout)) {
179       // New constraint matches existing constraint. Nothing to do.
180       return Status::OK();
181     }
182     if (curr_constraint.mandatory()) {
183       if (!mandatory) {
184         VLOG(3) << "Buffer" << buffer
185                 << " already has a mandatory layout constrain, skipping";
186         return Status::OK();
187       }
188       return FailedPrecondition(
189           "Buffer %s already has the layout constraint %s, cannot add "
190           "incompatible constraint %s",
191           buffer.ToString(), LayoutUtil::HumanString(curr_constraint.layout()),
192           LayoutUtil::HumanString(layout));
193     }
194     iter->second = BufferLayoutConstraint(layout, buffer, mandatory, dfs);
195   } else {
196     TF_RET_CHECK(unconstrained_buffer_ids_.erase(buffer.id()) == 1)
197         << buffer.ToString();
198     iter = buffer_constraints_
199                .insert(std::make_pair(
200                    &buffer,
201                    BufferLayoutConstraint(layout, buffer, mandatory, dfs)))
202                .first;
203   }
204   added_constraints_.push_back(&iter->second);
205   return Status::OK();
206 }
207 
SetOperandLayout(const Shape & shape_with_layout,const HloInstruction * instruction,int64 operand_no,bool mandatory,bool dfs)208 Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout,
209                                            const HloInstruction* instruction,
210                                            int64 operand_no, bool mandatory,
211                                            bool dfs) {
212   VLOG(3) << "SetOperandLayout : " << instruction->name() << ", operand "
213           << operand_no << " : "
214           << ShapeUtil::HumanStringWithLayout(shape_with_layout);
215 
216   const OperandLayoutConstraint* curr_shape_layout =
217       GetOperandLayoutConstraint(instruction, operand_no);
218   if (curr_shape_layout != nullptr) {
219     if (curr_shape_layout->shape_layout().MatchesLayoutInShape(
220             shape_with_layout, /*minor_to_major_only=*/true)) {
221       // New constraint matches existing constraint. Nothing to do.
222       return Status::OK();
223     }
224     if (curr_shape_layout->mandatory()) {
225       return FailedPrecondition(
226           "Operand %d of instruction %s already has a layout constraint "
227           "%s, cannot add incompatible constraint %s",
228           operand_no, instruction->name(),
229           curr_shape_layout->shape_layout().ToString(),
230           ShapeUtil::HumanStringWithLayout(shape_with_layout));
231     }
232   }
233 
234   // If any buffers in the operand occur in the output of the instruction, then
235   // return an error. This case is not handled because such a constraint changes
236   // layouts beyond this immediate use and is complicated to handle.
237   if (OperandBufferForwarded(instruction, operand_no)) {
238     return FailedPrecondition(
239         "Cannot constraint layout of operand %d of instruction %s "
240         "because instruction forwards operand's LogicalBuffer(s)",
241         operand_no, instruction->name());
242   }
243 
244   auto key = std::make_pair(instruction, operand_no);
245   auto iter = operand_constraints_.find(key);
246   if (iter == operand_constraints_.end()) {
247     auto pair = std::make_pair(
248         key, OperandLayoutConstraint(ShapeLayout(shape_with_layout),
249                                      instruction, operand_no, mandatory, dfs));
250     iter = operand_constraints_.insert(pair).first;
251   } else {
252     iter->second =
253         OperandLayoutConstraint(ShapeLayout(shape_with_layout), instruction,
254                                 operand_no, mandatory, dfs);
255   }
256   added_constraints_.push_back(&iter->second);
257 
258   return Status::OK();
259 }
260 
SetArrayOperandLayout(const Layout & layout,const HloInstruction * instruction,int64 operand_no,bool mandatory,bool dfs)261 Status LayoutConstraints::SetArrayOperandLayout(
262     const Layout& layout, const HloInstruction* instruction, int64 operand_no,
263     bool mandatory, bool dfs) {
264   const HloInstruction* operand = instruction->operand(operand_no);
265   TF_RET_CHECK(operand->shape().IsArray());
266   Shape shape(operand->shape());
267   *shape.mutable_layout() = layout;
268   TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutInShape(shape));
269   return SetOperandLayout(shape, instruction, operand_no, mandatory, dfs);
270 }
271 
SetResultLayout(const Shape & shape_with_layout,bool dfs)272 Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout,
273                                           bool dfs) {
274   VLOG(3) << "SetResultLayout : "
275           << ShapeUtil::HumanStringWithLayout(shape_with_layout);
276 
277   const ShapeLayout* curr_shape_layout = ResultLayout();
278   if (curr_shape_layout != nullptr) {
279     if (!curr_shape_layout->MatchesLayoutInShape(
280             shape_with_layout, /*minor_to_major_only=*/true)) {
281       return FailedPrecondition(
282           "Result of computation %s already has the layout constraint %s, "
283           "cannot add incompatible constraint %s",
284           computation_->name(), curr_shape_layout->ToString(),
285           ShapeUtil::HumanStringWithLayout(shape_with_layout));
286     }
287     // New constraint matches existing constraint. Nothing to do.
288     return Status::OK();
289   }
290   result_constraint_.reset(
291       new ResultLayoutConstraint(ShapeLayout(shape_with_layout), dfs));
292   added_constraints_.push_back(result_constraint_.get());
293 
294   return Status::OK();
295 }
296 
SetInstructionLayout(const Shape & shape_with_layout,const HloInstruction * instruction,bool mandatory,bool dfs)297 Status LayoutConstraints::SetInstructionLayout(
298     const Shape& shape_with_layout, const HloInstruction* instruction,
299     bool mandatory, bool dfs) {
300   VLOG(3) << "SetInstructionLayout : " << instruction->name() << ", "
301           << ShapeUtil::HumanStringWithLayout(shape_with_layout);
302 
303   if (!ShapeUtil::Compatible(shape_with_layout, instruction->shape())) {
304     return FailedPrecondition(
305         "Instruction %s of shape %s cannot be assigned incompatible layout %s",
306         instruction->name(), ShapeUtil::HumanString(instruction->shape()),
307         ShapeUtil::HumanStringWithLayout(shape_with_layout));
308   }
309 
310   // Create a BufferLayoutConstraint for each array shape in the output of the
311   // instruction.
312   return ShapeUtil::ForEachSubshapeWithStatus(
313       shape_with_layout,
314       [this, instruction, mandatory](const Shape& subshape,
315                                      const ShapeIndex& index) -> Status {
316         // The precondition for this method is that the instruction defines all
317         // buffers in its output.
318         auto buffers =
319             points_to_analysis_.GetPointsToSet(instruction).element(index);
320         CHECK_EQ(1, buffers.size());
321         CHECK_EQ(buffers[0]->instruction(), instruction);
322 
323         if (subshape.IsArray() && subshape.has_layout()) {
324           return SetBufferLayout(subshape.layout(), *buffers[0], mandatory);
325         } else {
326           return Status::OK();
327         }
328       });
329 }
330 
BufferLayout(const LogicalBuffer & buffer) const331 const Layout* LayoutConstraints::BufferLayout(
332     const LogicalBuffer& buffer) const {
333   if (const auto* constraint = GetBufferLayoutConstraint(buffer)) {
334     return &constraint->layout();
335   }
336   return nullptr;
337 }
338 
GetBufferLayoutConstraint(const LogicalBuffer & buffer) const339 const BufferLayoutConstraint* LayoutConstraints::GetBufferLayoutConstraint(
340     const LogicalBuffer& buffer) const {
341   auto it = buffer_constraints_.find(&buffer);
342   return it == buffer_constraints_.end() ? nullptr : &it->second;
343 }
344 
OperandLayout(const HloInstruction * instruction,int64 operand_no) const345 const ShapeLayout* LayoutConstraints::OperandLayout(
346     const HloInstruction* instruction, int64 operand_no) const {
347   if (const auto* constraint =
348           GetOperandLayoutConstraint(instruction, operand_no)) {
349     return &constraint->shape_layout();
350   }
351   return nullptr;
352 }
353 
GetOperandLayoutConstraint(const HloInstruction * instruction,int64 operand_no) const354 const OperandLayoutConstraint* LayoutConstraints::GetOperandLayoutConstraint(
355     const HloInstruction* instruction, int64 operand_no) const {
356   auto it = operand_constraints_.find(std::make_pair(instruction, operand_no));
357   return it == operand_constraints_.end() ? nullptr : &it->second;
358 }
359 
ResultLayout() const360 const ShapeLayout* LayoutConstraints::ResultLayout() const {
361   return result_constraint_ ? &result_constraint_->shape_layout() : nullptr;
362 }
363 
ToString() const364 string LayoutConstraints::ToString() const {
365   string output;
366   absl::StrAppend(&output, "LayoutConstraints for computation ",
367                   computation_->name(), ":\n");
368   for (auto* instruction : computation_->MakeInstructionPostOrder()) {
369     absl::StrAppend(&output, "  ", instruction->ToShortString(), "\n");
370     for (int64 i = 0; i < instruction->operand_count(); ++i) {
371       if (OperandLayout(instruction, i) != nullptr) {
372         absl::StrAppend(&output, "    operand (", i,
373                         "): ", OperandLayout(instruction, i)->ToString(), "\n");
374       }
375     }
376     for (const LogicalBuffer* buffer :
377          points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) {
378       if (BufferLayout(*buffer) != nullptr) {
379         absl::StrAppend(&output, "    ", buffer->ToString(), " : ",
380                         LayoutUtil::HumanString(*BufferLayout(*buffer)), "\n");
381       }
382     }
383   }
384 
385   if (ResultLayout() != nullptr) {
386     absl::StrAppend(&output, "  => ", ResultLayout()->ToString(), "\n");
387   }
388   return output;
389 }
390 
391 namespace {
392 
IsHostSendRecv(const HloInstruction * instruction)393 bool IsHostSendRecv(const HloInstruction* instruction) {
394   const HloSendRecvInstruction* send_recv_instr =
395       DynCast<HloSendRecvInstruction>(instruction);
396   return send_recv_instr != nullptr && send_recv_instr->is_host_transfer();
397 }
398 
399 }  // namespace
400 
BuildHostChannelConstraints(HloComputation * computation)401 Status LayoutAssignment::BuildHostChannelConstraints(
402     HloComputation* computation) {
403   for (auto* instruction : computation->instructions()) {
404     const HloSendRecvInstruction* send_recv_instr =
405         DynCast<HloSendRecvInstruction>(instruction);
406     if (send_recv_instr == nullptr || !send_recv_instr->is_host_transfer()) {
407       continue;
408     }
409 
410     // For host transfers the Send and Recv instruction carry the layout.
411     if (instruction->opcode() == HloOpcode::kSend ||
412         instruction->opcode() == HloOpcode::kRecv) {
413       const Shape& data_shape =
414           ShapeUtil::GetTupleElementShape(send_recv_instr->shape(), 0);
415       TF_RET_CHECK(data_shape.IsArray());
416       TF_RET_CHECK(LayoutUtil::HasLayout(data_shape));
417       const Layout* prev_layout = host_channel_constraints_.ConstrainChannel(
418           *send_recv_instr->channel_id(), data_shape.layout());
419       TF_RET_CHECK(prev_layout == nullptr)
420           << "Cannot constrain host transfer layout as it was set to "
421           << LayoutUtil::HumanString(*prev_layout) << ": "
422           << send_recv_instr->ToString();
423     }
424   }
425   return Status::OK();
426 }
427 
428 namespace {
429 
IsLayoutConstrainedCustomCall(HloInstruction * instruction)430 bool IsLayoutConstrainedCustomCall(HloInstruction* instruction) {
431   const HloCustomCallInstruction* custom_call =
432       DynCast<HloCustomCallInstruction>(instruction);
433   return custom_call != nullptr && custom_call->layout_constrained();
434 }
435 
IsLayoutConstrainedCollective(const HloInstruction * instruction)436 bool IsLayoutConstrainedCollective(const HloInstruction* instruction) {
437   const HloCollectiveInstruction* collective =
438       DynCast<HloCollectiveInstruction>(instruction);
439   return collective != nullptr && collective->constrain_layout();
440 }
441 
442 }  // namespace
443 
AddMandatoryConstraints(const ComputationLayout * computation_layout,ChannelLayoutConstraints * channel_constraints,HloComputation * computation,LayoutConstraints * constraints)444 Status LayoutAssignment::AddMandatoryConstraints(
445     const ComputationLayout* computation_layout,
446     ChannelLayoutConstraints* channel_constraints, HloComputation* computation,
447     LayoutConstraints* constraints) {
448   VLOG(3) << "Adding mandatory layout constraints to computation "
449           << computation->name();
450 
451   auto get_channel_constraints = [&](const HloInstruction* instruction) {
452     return IsHostSendRecv(instruction) ? &host_channel_constraints_
453                                        : channel_constraints;
454   };
455 
456   // Constrain layouts of instructions which define values with pre-existing
457   // layouts.
458   for (auto* instruction : computation->instructions()) {
459     if (instruction->opcode() == HloOpcode::kInfeed) {
460       // Infeed layouts must match the layout of the original inserted
461       // instruction.
462       // TODO(b/31425034): Change infeeds to be more like parameters, with
463       // shapes in the ComputationLayout.
464       TF_RETURN_IF_ERROR(
465           constraints->SetInstructionLayout(instruction->shape(), instruction));
466     } else if (instruction->opcode() == HloOpcode::kOutfeed) {
467       // Constrain the input to the Outfeed instruction to be the expected
468       // layout of the Outfeed.
469       TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
470           instruction->outfeed_shape(), instruction, 0));
471     } else if (instruction->opcode() == HloOpcode::kParameter) {
472       if (computation_layout != nullptr) {
473         const ShapeLayout& parameter_layout =
474             computation_layout->parameter_layout(
475                 instruction->parameter_number());
476         // Parameter layouts must match the respective layout in
477         // ComputationLayout, if there is one.
478         TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
479             parameter_layout.shape(), instruction));
480       }
481     } else if (IsLayoutConstrainedCustomCall(instruction)) {
482       const HloCustomCallInstruction* custom_call =
483           DynCast<HloCustomCallInstruction>(instruction);
484       TF_RETURN_IF_ERROR(
485           constraints->SetInstructionLayout(custom_call->shape(), custom_call));
486       for (int64 i = 0; i < custom_call->operand_count(); ++i) {
487         TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
488             custom_call->operand_shapes_with_layout()[i], custom_call, i));
489       }
490     } else if (instruction->opcode() == HloOpcode::kSend ||
491                instruction->opcode() == HloOpcode::kRecv) {
492       CHECK(get_channel_constraints(instruction))
493           << "Multi-module layout assignment requires ChannelLayoutConstraints";
494       int64 channel_id = *instruction->channel_id();
495       if (!get_channel_constraints(instruction)
496                ->IsChannelConstrained(channel_id)) {
497         continue;
498       }
499       if (instruction->opcode() == HloOpcode::kSend) {
500         // TODO(b/68493863): Change to use SetOperandLayout().
501         const Shape send_buffer_shape = instruction->operand(0)->shape();
502         TF_RET_CHECK(send_buffer_shape.IsArray());
503         Shape new_buffer_shape =
504             get_channel_constraints(instruction)
505                 ->LayoutShapeForChannel(send_buffer_shape,
506                                         *instruction->channel_id());
507         TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
508             new_buffer_shape, instruction->operand(0)));
509       } else {
510         const Shape recv_buffer_shape =
511             ShapeUtil::GetTupleElementShape(instruction->shape(), 0);
512         TF_RET_CHECK(recv_buffer_shape.IsArray());
513         TF_ASSIGN_OR_RETURN(
514             const LogicalBuffer* buffer,
515             constraints->points_to_analysis().GetBufferDefinedAt(instruction,
516                                                                  {0}));
517         Shape new_shape =
518             get_channel_constraints(instruction)
519                 ->LayoutShapeForChannel(recv_buffer_shape,
520                                         *instruction->channel_id());
521         TF_RETURN_IF_ERROR(
522             constraints->SetBufferLayout(new_shape.layout(), *buffer));
523       }
524     } else if (IsLayoutConstrainedCollective(instruction)) {
525       TF_RETURN_IF_ERROR(
526           constraints->SetInstructionLayout(instruction->shape(), instruction));
527     } else if (instruction->IsCrossModuleAllReduce()) {
528       CHECK(get_channel_constraints(instruction))
529           << "Multi-module layout assignment requires ChannelLayoutConstraints";
530       int64 channel_id = instruction->channel_id().value();
531       if (!get_channel_constraints(instruction)
532                ->IsChannelConstrained(channel_id)) {
533         continue;
534       }
535       // TODO(b/68493863): Change to use SetOperandLayout().
536       const Shape& buffer_shape = instruction->operand(0)->shape();
537       TF_RET_CHECK(buffer_shape.IsArray());
538       Shape new_buffer_shape =
539           get_channel_constraints(instruction)
540               ->LayoutShapeForChannel(buffer_shape, channel_id);
541       TF_RETURN_IF_ERROR(
542           constraints->SetInstructionLayout(new_buffer_shape, instruction));
543     }
544   }
545 
546   // Constrain layouts of instructions which call computations which have
547   // already been assigned layouts. Instructions which call computations in a
548   // parallel element-wise context (eg, map or reduce) do not need layout
549   // constraints because they operate on scalars.
550   for (auto* instruction : computation->instructions()) {
551     if (instruction->opcode() == HloOpcode::kCall) {
552       // kCall instruction operands and output must match the ComputationLayout
553       // of the called computation.
554       const ComputationLayout& called_computation_layout =
555           FindOrDie(computation_layouts_, instruction->to_apply());
556       TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
557           called_computation_layout.result_layout().shape(), instruction));
558       TF_RET_CHECK(instruction->operand_count() ==
559                    called_computation_layout.parameter_count());
560       for (int64 i = 0; i < instruction->operand_count(); ++i) {
561         TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
562             called_computation_layout.parameter_layout(i).shape(), instruction,
563             i));
564       }
565     } else if (instruction->opcode() == HloOpcode::kWhile) {
566       // Layout of input and output of kWhile instruction must be equal and must
567       // match both input and output of body computation. Also, the input of
568       // condition computation must match kWhile layout.
569       HloComputation* body = instruction->while_body();
570       HloComputation* condition = instruction->while_condition();
571       const HloInstruction* init = instruction->operand(0);
572       ComputationLayout& body_layout = FindOrDie(computation_layouts_, body);
573       ComputationLayout& condition_layout =
574           FindOrDie(computation_layouts_, condition);
575 
576       // Check a few invariants irrespective of layout.
577       CHECK_EQ(1, instruction->operand_count());
578       CHECK_EQ(1, body->num_parameters());
579       CHECK_EQ(1, condition->num_parameters());
580       DCHECK(ShapeUtil::Compatible(body_layout.result_shape(),
581                                    body_layout.parameter_shape(0)));
582       DCHECK(ShapeUtil::Compatible(body_layout.result_shape(),
583                                    condition_layout.parameter_shape(0)));
584       DCHECK(ShapeUtil::Compatible(body_layout.result_shape(), init->shape()));
585 
586       if (body_layout.result_layout() != body_layout.parameter_layout(0)) {
587         VLOG(2) << "Reset %while body parameter layout: body=" << body->name()
588                 << " while=" << instruction->name()
589                 << " shape=" << body_layout.result_layout().ToString();
590         *body_layout.mutable_parameter_layout(0) = body_layout.result_layout();
591       }
592       if (condition_layout.parameter_layout(0) !=
593           body_layout.parameter_layout(0)) {
594         VLOG(2) << "Reset %while condition parameter layout: cond="
595                 << condition->name() << " while=" << instruction->name()
596                 << " shape=" << body_layout.parameter_layout(0).ToString();
597         *condition_layout.mutable_parameter_layout(0) =
598             body_layout.parameter_layout(0);
599       }
600 
601       // Constrain the output and the operand of the while instruction to match
602       // the computations.
603       TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
604           body_layout.result_shape(), instruction, 0));
605       TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
606           body_layout.result_shape(), instruction));
607     } else if (instruction->opcode() == HloOpcode::kConditional) {
608       // Find the conditional branch with the most instructions and force all
609       // other computations to match that layout. A potentially better decision
610       // could count the number FLOPs or how constrained the layouts are.
611       int64 largest_branch = 0;
612       int64 largest_instruction_count =
613           instruction->branch_computation(0)->instruction_count();
614       for (int j = 1; j < instruction->branch_count(); ++j) {
615         const int64 instruction_count =
616             instruction->branch_computation(j)->instruction_count();
617         if (instruction_count > largest_instruction_count) {
618           largest_branch = j;
619           largest_instruction_count = instruction_count;
620         }
621       }
622       ComputationLayout& best_branch_computation_layout =
623           FindOrDie(computation_layouts_,
624                     instruction->branch_computation(largest_branch));
625       for (int k = 0; k < instruction->branch_count(); ++k) {
626         // Visit the best branch first.
627         int j = (k + largest_branch) % instruction->branch_count();
628         TF_RET_CHECK(instruction->branch_computation(j)->num_parameters() == 1);
629         ComputationLayout& branch_computation_layout =
630             FindOrDie(computation_layouts_, instruction->branch_computation(k));
631         if (!branch_computation_layout.result_layout().MatchesLayoutInShape(
632                 best_branch_computation_layout.result_layout().shape(),
633                 /*minor_to_major_only=*/true)) {
634           computation_layouts_.erase(instruction->branch_computation(k));
635           InsertOrDie(&conditional_mismatch_,
636                       instruction->branch_computation(k),
637                       best_branch_computation_layout);
638         } else {
639           TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
640               branch_computation_layout.parameter_shape(0), instruction, k + 1,
641               /*mandatory=*/true));
642         }
643       }
644       TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
645           best_branch_computation_layout.parameter_shape(0), instruction,
646           largest_branch + 1,
647           /*mandatory=*/true));
648       TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
649           best_branch_computation_layout.result_shape(), instruction));
650     }
651   }
652   // Finally set the result layout to match ComputationLayout, if there is one.
653   if (conditional_mismatch_.count(computation) > 0) {
654     TF_RETURN_IF_ERROR(constraints->SetResultLayout(
655         FindOrDie(conditional_mismatch_, computation).result_layout().shape()));
656   } else if (computation_layout != nullptr) {
657     const ShapeLayout& result_layout = computation_layout->result_layout();
658     if (result_layout.LayoutIsSet()) {
659       TF_RETURN_IF_ERROR(constraints->SetResultLayout(result_layout.shape()));
660     }
661   }
662   return Status::OK();
663 }
664 
665 namespace {
666 
LayoutsInShapesEqual(const Shape & lhs,const Shape & rhs)667 bool LayoutsInShapesEqual(const Shape& lhs, const Shape& rhs) {
668   return Layout::Equal().MinorToMajorOnly()(lhs.layout(), rhs.layout());
669 }
670 
671 // The operands of a call must match the layouts of parameters in the
672 // ComputationLayout, and the call instruction itself must match the result
673 // layout in the ComputationLayout.
CheckCallLayout(HloInstruction * call,const ComputationLayout & computation_layout)674 Status CheckCallLayout(HloInstruction* call,
675                        const ComputationLayout& computation_layout) {
676   HloComputation* computation = call->to_apply();
677   TF_RET_CHECK(computation->num_parameters() == call->operand_count());
678   for (int64 i = 0; i < computation->num_parameters(); ++i) {
679     TF_RET_CHECK(computation_layout.parameter_layout(i).MatchesLayoutInShape(
680         call->operand(i)->shape(), /*minor_to_major_only=*/true));
681   }
682   TF_RET_CHECK(computation_layout.result_layout().MatchesLayoutInShape(
683       call->shape(), /*minor_to_major_only=*/true));
684   return Status::OK();
685 }
686 
687 // Operands of layout-constrained custom calls must match the expected
688 // constrained layouts.
CheckCustomCallLayout(HloInstruction * instruction)689 Status CheckCustomCallLayout(HloInstruction* instruction) {
690   if (IsLayoutConstrainedCustomCall(instruction)) {
691     const HloCustomCallInstruction* custom_call =
692         DynCast<HloCustomCallInstruction>(instruction);
693     for (int64 i = 0; i < custom_call->operand_count(); ++i) {
694       TF_RET_CHECK(
695           LayoutsInShapesEqual(custom_call->operand(i)->shape(),
696                                custom_call->operand_shapes_with_layout()[i]));
697     }
698   }
699   return Status::OK();
700 }
701 
702 // For a while instruction, all the following layouts must be the same:
703 //   (1) init operand
704 //   (2) condition computation parameter
705 //   (3) body computation parameter
706 //   (4) body computation result
707 //   (5) while instruction result
CheckWhileLayout(HloInstruction * while_inst,const ComputationLayout & condition_computation_layout,const ComputationLayout & body_computation_layout)708 Status CheckWhileLayout(HloInstruction* while_inst,
709                         const ComputationLayout& condition_computation_layout,
710                         const ComputationLayout& body_computation_layout) {
711   auto init_shape = while_inst->operand(0)->shape();
712   TF_RET_CHECK(
713       condition_computation_layout.parameter_layout(0).MatchesLayoutInShape(
714           init_shape, /*minor_to_major_only=*/true));
715   TF_RET_CHECK(body_computation_layout.parameter_layout(0).MatchesLayoutInShape(
716       init_shape, /*minor_to_major_only=*/true));
717   TF_RET_CHECK(body_computation_layout.result_layout().MatchesLayoutInShape(
718       init_shape, /*minor_to_major_only=*/true));
719   TF_RET_CHECK(LayoutsInShapesEqual(init_shape, while_inst->shape()));
720   return Status::OK();
721 }
722 
CheckConditionalLayout(HloInstruction * instruction,absl::Span<const ComputationLayout> branch_computation_layouts)723 Status CheckConditionalLayout(
724     HloInstruction* instruction,
725     absl::Span<const ComputationLayout> branch_computation_layouts) {
726   for (int j = 0; j < instruction->branch_count(); ++j) {
727     const HloInstruction* branch_operand = instruction->operand(j + 1);
728     TF_RET_CHECK(
729         branch_computation_layouts[0].result_layout().MatchesLayoutInShape(
730             branch_computation_layouts[j].result_layout().shape(),
731             /*minor_to_major_only=*/true));
732     TF_RET_CHECK(
733         branch_computation_layouts[j].result_layout().MatchesLayoutInShape(
734             instruction->shape(), /*minor_to_major_only=*/true));
735     TF_RET_CHECK(
736         branch_computation_layouts[j].result_layout().MatchesLayoutInShape(
737             instruction->branch_computation(j)->root_instruction()->shape(),
738             /*minor_to_major_only=*/true));
739     TF_RET_CHECK(
740         branch_computation_layouts[j].parameter_layout(0).MatchesLayoutInShape(
741             branch_operand->shape(), /*minor_to_major_only=*/true));
742   }
743   return Status::OK();
744 }
745 
746 // Fusion parameters must match the layout of the fusion instructions operands,
747 // and the root of the fusion expression must match the layout of the fusion
748 // instruction.
CheckFusionLayout(HloInstruction * fusion)749 Status CheckFusionLayout(HloInstruction* fusion) {
750   TF_RET_CHECK(HloOpcode::kFusion == fusion->opcode());
751 
752   TF_RET_CHECK(LayoutsInShapesEqual(fusion->shape(),
753                                     fusion->fused_expression_root()->shape()));
754   for (int64 i = 0; i < fusion->operand_count(); ++i) {
755     TF_RET_CHECK(LayoutsInShapesEqual(fusion->fused_parameter(i)->shape(),
756                                       fusion->operand(i)->shape()));
757   }
758   return Status::OK();
759 }
760 
761 // The layout of a parameter must match the respective layout in the
762 // computation's ComputationLayout.
CheckParameterLayout(HloInstruction * parameter,const ComputationLayout & computation_layout)763 Status CheckParameterLayout(HloInstruction* parameter,
764                             const ComputationLayout& computation_layout) {
765   const ShapeLayout& parameter_layout =
766       computation_layout.parameter_layout(parameter->parameter_number());
767   return ShapeUtil::ForEachSubshapeWithStatus(
768       parameter_layout.shape(),
769       [&](const Shape& subshape, const ShapeIndex& shape_index) {
770         if (!ShapeUtil::IsLeafIndex(parameter_layout.shape(), shape_index) ||
771             !subshape.has_layout()) {
772           return Status::OK();
773         }
774         if (!Shape::Equal().MinorToMajorOnlyInLayout().IgnoreDynamicDimension()(
775                 subshape,
776                 ShapeUtil::GetSubshape(parameter->shape(), shape_index))) {
777           return InternalError(
778               "parameter instruction %s does not match layout of computation "
779               "shape: %s",
780               parameter->ToString(), parameter_layout.ToString());
781         }
782         return Status::OK();
783       });
784 }
785 
786 // The layout of a constant instruction must match the layout of its literal.
CheckConstantLayout(HloInstruction * constant)787 Status CheckConstantLayout(HloInstruction* constant) {
788   if (!LayoutsInShapesEqual(constant->literal().shape(), constant->shape())) {
789     return InternalError(
790         "constant instruction %s does not match the layout of its literal %s",
791         constant->ToString(),
792         ShapeUtil::HumanStringWithLayout(constant->literal().shape()));
793   }
794   return Status::OK();
795 }
796 
797 }  // namespace
798 
CreateCopyWithNewLayout(const Shape & shape_with_layout,HloInstruction * instruction)799 StatusOr<HloInstruction*> LayoutAssignment::CreateCopyWithNewLayout(
800     const Shape& shape_with_layout, HloInstruction* instruction) {
801   TF_RET_CHECK(LayoutUtil::HasLayout(shape_with_layout));
802   DCHECK(ShapeUtil::Compatible(shape_with_layout, instruction->shape()))
803       << ShapeUtil::HumanString(shape_with_layout) << " "
804       << ShapeUtil::HumanString(instruction->shape())
805       << " instruction: " << instruction->ToString();
806 
807   if (instruction->shape().IsTuple()) {
808     // Copy tuple elements which have differing layouts.
809     std::vector<HloInstruction*> element_copies;
810     for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
811          ++i) {
812       const Shape& target_shape =
813           ShapeUtil::GetSubshape(shape_with_layout, {i});
814       const Shape& instr_shape =
815           ShapeUtil::GetSubshape(instruction->shape(), {i});
816       HloInstruction* gte = instruction->parent()->AddInstruction(
817           HloInstruction::CreateGetTupleElement(instr_shape, instruction, i));
818 
819       if (Shape::Equal().MinorToMajorOnlyInLayout()(target_shape,
820                                                     instr_shape)) {
821         // Shapes and layouts are equal, no need to copy.
822         element_copies.push_back(gte);
823       } else {
824         SetupCopiedInstruction(*instruction, gte, {i});
825         // Recurse to copy each element.
826         TF_ASSIGN_OR_RETURN(HloInstruction * element_copy,
827                             CreateCopyWithNewLayout(target_shape, gte));
828         element_copies.push_back(element_copy);
829       }
830     }
831     // Gather element copies into a tuple with a new Tuple instruction.
832     HloInstruction* tuple_copy = instruction->parent()->AddInstruction(
833         HloInstruction::CreateTuple(element_copies));
834     SetupCopiedInstruction(*instruction, tuple_copy, {});
835     LayoutUtil::ClearLayout(tuple_copy->mutable_shape());
836     TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
837         shape_with_layout, tuple_copy->mutable_shape()));
838     return tuple_copy;
839   } else if (instruction->shape().IsArray()) {
840     HloInstruction* copy =
841         instruction->parent()->AddInstruction(HloInstruction::CreateUnary(
842             instruction->shape(), HloOpcode::kCopy, instruction));
843     RegisterAddedCopy(copy);
844     SetupCopiedInstruction(*instruction, copy, {});
845     LayoutUtil::ClearLayout(copy->mutable_shape());
846     TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
847         shape_with_layout, copy->mutable_shape()));
848 
849     return copy;
850   } else {
851     return FailedPrecondition(
852         "Can only copy array and tuple shaped instructions");
853   }
854 }
855 
856 // Creates a copy of the given operand if the operand's layout does not match
857 // the given layout. This copy replaces the use in the given instruction. Tuple
858 // operands will be deep-copied.
CopyOperandIfLayoutsDiffer(const ShapeLayout & operand_layout,HloInstruction * instruction,int64 operand_no)859 Status LayoutAssignment::CopyOperandIfLayoutsDiffer(
860     const ShapeLayout& operand_layout, HloInstruction* instruction,
861     int64 operand_no) {
862   HloInstruction* operand = instruction->mutable_operand(operand_no);
863   TF_RET_CHECK(operand_layout.LayoutIsSet());
864   TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape()));
865 
866   if (Shape::Equal().MinorToMajorOnlyInLayout()(operand_layout.shape(),
867                                                 operand->shape())) {
868     VLOG(5) << "Operand " << operand->ToString() << " layout matches in "
869             << instruction->ToString();
870     // Operand layout already matches our constraint. Nothing to do.
871     return Status::OK();
872   }
873   VLOG(4) << "Operand " << operand->ToString() << " layout does not match "
874           << operand_layout.ToString() << " in " << instruction->ToString();
875 
876   // If the operand is only used by a conditional, do the copy inside the branch
877   // to avoid overhead for other branches.
878   if (instruction->opcode() == HloOpcode::kConditional && operand_no > 0 &&
879       instruction->operand(operand_no)->user_count() == 1) {
880     auto branch_comp = instruction->branch_computation(operand_no - 1);
881     auto param = branch_comp->parameter_instruction(0);
882     *param->mutable_shape() = operand->shape();
883     auto param_users = param->users();
884     TF_ASSIGN_OR_RETURN(HloInstruction * param_copy,
885                         CreateCopyWithNewLayout(operand_layout.shape(), param));
886     for (auto user : param_users) {
887       TF_RETURN_IF_ERROR(param->ReplaceUseWithDifferentShape(user, param_copy));
888     }
889     VLOG(4) << "New copy of " << operand->ToString() << " is "
890             << param_copy->ToString();
891     if (param == branch_comp->root_instruction()) {
892       branch_comp->set_root_instruction(param_copy,
893                                         /*accept_different_shape=*/true);
894     }
895     *FindOrDie(computation_layouts_, branch_comp).mutable_parameter_layout(0) =
896         ShapeLayout(operand->shape());
897     return Status::OK();
898   }
899 
900   TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy,
901                       CreateCopyWithNewLayout(operand_layout.shape(), operand));
902 
903   VLOG(4) << "New copy of " << operand->ToString() << " is "
904           << operand_copy->ToString();
905   return instruction->ReplaceOperandWith(operand_no, operand_copy);
906 }
907 
SetupCopiedInstruction(const HloInstruction & instruction,HloInstruction * copy,const ShapeIndex & index)908 void LayoutAssignment::SetupCopiedInstruction(const HloInstruction& instruction,
909                                               HloInstruction* copy,
910                                               const ShapeIndex& index) {
911   if (instruction.has_sharding()) {
912     // If the index is empty, we want to copy the whole sharding, in case the
913     // sharding is a tuple sharding.
914     HloSharding sharding =
915         !index.empty() && instruction.sharding().IsTuple()
916             ? instruction.sharding().GetSubSharding(instruction.shape(), index)
917             : instruction.sharding();
918     // We propagate the sharding to the copied instruction only if it is a
919     // special sharding, like tiled ones.
920     // Otherwise it is preferable to leave the new instruction without device,
921     // and let the automatic device placer to choose the best location.
922     auto device = sharding.UniqueDevice();
923     if (!device || HloSharding::IsReservedDevice(*device)) {
924       copy->set_sharding(sharding);
925     }
926   }
927   copy->set_metadata(instruction.metadata());
928 }
929 
CheckLayouts(HloModule * module)930 Status LayoutAssignment::CheckLayouts(HloModule* module) {
931   TF_ASSIGN_OR_RETURN(auto points_to_analysis,
932                       TuplePointsToAnalysis::Run(module));
933   for (auto* computation : module->MakeNonfusionComputations()) {
934     for (auto* instruction : computation->instructions()) {
935       // Verify every instruction has a layout and the layout is valid for the
936       // shape.
937       TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape()));
938       TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape()));
939 
940       // Use points-to analysis to verify that every subshape element in the
941       // output of the instruction matches the layout of the logical buffer
942       // which could be the source of the subshape value.
943       const PointsToSet& points_to_set =
944           points_to_analysis->GetPointsToSet(instruction);
945       TF_RETURN_IF_ERROR(points_to_set.ForEachElementWithStatus(
946           [&instruction](ShapeIndex index,
947                          const PointsToSet::BufferList& buffers) -> Status {
948             if (ShapeUtil::IsLeafIndex(instruction->shape(), index)) {
949               const Shape& instruction_subshape =
950                   ShapeUtil::GetSubshape(instruction->shape(), index);
951               for (const LogicalBuffer* buffer : buffers) {
952                 if (!Shape::Equal()
953                          .IgnoreDynamicDimension()
954                          .MinorToMajorOnlyInLayout()(instruction_subshape,
955                                                      buffer->shape())) {
956                   return InternalError(
957                       "Layout of instruction %s at index {%s} does not match "
958                       "source LogicalBuffer %s: %s vs %s",
959                       instruction->name(), absl::StrJoin(index, ","),
960                       buffer->ToString(),
961                       ShapeUtil::HumanStringWithLayout(instruction_subshape),
962                       ShapeUtil::HumanStringWithLayout(buffer->shape()));
963                 }
964               }
965             }
966             return Status::OK();
967           }));
968 
969       // Verify instructions that have special layout constraints.
970       switch (instruction->opcode()) {
971         case HloOpcode::kCall:
972           TF_RETURN_IF_ERROR(CheckCallLayout(
973               instruction,
974               FindOrDie(computation_layouts_, instruction->to_apply())));
975           break;
976         case HloOpcode::kCustomCall:
977           TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction));
978           break;
979         case HloOpcode::kFusion:
980           TF_RETURN_IF_ERROR(CheckFusionLayout(instruction));
981           break;
982         case HloOpcode::kParameter:
983           TF_RETURN_IF_ERROR(CheckParameterLayout(
984               instruction,
985               FindOrDie(computation_layouts_, instruction->parent())));
986           break;
987         case HloOpcode::kConstant:
988           TF_RETURN_IF_ERROR(CheckConstantLayout(instruction));
989           break;
990         case HloOpcode::kWhile:
991           TF_RETURN_IF_ERROR(CheckWhileLayout(
992               instruction,
993               FindOrDie(computation_layouts_, instruction->while_condition()),
994               FindOrDie(computation_layouts_, instruction->while_body())));
995           break;
996         case HloOpcode::kConditional: {
997           std::vector<ComputationLayout> branch_computation_layouts;
998           for (auto branch_computation : instruction->branch_computations()) {
999             branch_computation_layouts.emplace_back(
1000                 FindOrDie(computation_layouts_, branch_computation));
1001           }
1002           TF_RETURN_IF_ERROR(CheckConditionalLayout(
1003               instruction, absl::MakeSpan(branch_computation_layouts)));
1004           break;
1005         }
1006         default:
1007           break;
1008       }
1009     }
1010   }
1011   // Finally verify the result layout, if set, matches the layout of the entry
1012   // computation root.
1013   const ShapeLayout& result_layout =
1014       FindOrDie(computation_layouts_, module->entry_computation())
1015           .result_layout();
1016   if (result_layout.LayoutIsSet()) {
1017     TF_RET_CHECK(
1018         Shape::Equal().IgnoreDynamicDimension().MinorToMajorOnlyInLayout()(
1019             module->result_shape(), result_layout.shape()));
1020   }
1021   return Status::OK();
1022 }
1023 
LayoutAssignment(ComputationLayout * entry_computation_layout,std::function<bool (const HloInstruction *)> instruction_can_change_layout_func,ChannelLayoutConstraints * channel_constraints)1024 LayoutAssignment::LayoutAssignment(
1025     ComputationLayout* entry_computation_layout,
1026     std::function<bool(const HloInstruction*)>
1027         instruction_can_change_layout_func,
1028     ChannelLayoutConstraints* channel_constraints)
1029     : entry_computation_layout_(entry_computation_layout),
1030 
1031       saved_entry_computation_layout_(*entry_computation_layout),
1032       channel_layout_constraints_(channel_constraints),
1033       instruction_can_change_layout_func_(
1034           std::move(instruction_can_change_layout_func)) {
1035   if (channel_layout_constraints_ != nullptr) {
1036     // Save a copy of the input ChannelLayoutConstraints so that we can reset it
1037     // if we have to undo previous operations (ClearPreviousPassSideEffects()).
1038     channel_constraints_ = *channel_layout_constraints_;
1039   }
1040   VLOG(1) << "Entry computation layout given to layout assignment: "
1041           << entry_computation_layout_->ToString();
1042 }
1043 
ChooseOperandLayoutFromOutputLayout(const Layout & output_layout,const HloInstruction * instruction,int64 operand_no)1044 std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
1045     const Layout& output_layout, const HloInstruction* instruction,
1046     int64 operand_no) {
1047   const HloInstruction* operand = instruction->operand(operand_no);
1048   CHECK(instruction->shape().IsArray());
1049   CHECK(operand->shape().IsArray());
1050   if (!ShapeUtil::IsScalar(operand->shape()) &&
1051       operand->shape().rank() == instruction->shape().rank() &&
1052       !instruction_can_change_layout_func_(instruction)) {
1053     // Propagate the result layout to the operand layout if the instruction
1054     // requires the same layout out for the result and the operand.
1055     //
1056     // For elementwise operations, using the same layout for the operands and
1057     // the result also has the following benefits:
1058     // 1) the elementwise operation can reuse its operand's buffer, and
1059     // 2) the input and output elements can reuse the same linear index.
1060     return absl::make_unique<Layout>(output_layout);
1061   }
1062 
1063   if (instruction->opcode() == HloOpcode::kReshape) {
1064     // Prefer the operand layout that makes the reshape an bitcast. If any
1065     // dimension bound is 1 in the operand shape, there may be several such
1066     // layouts. So if 'output_layout' is the default layout, try if the
1067     // reshape is a bitcast when using the same layout. This may avoid copy
1068     // operations. For similar reasons, if the operand and output have the same
1069     // rank, try to match the operand's layout to the output.
1070     if (ShapeUtil::TrueRank(operand->shape()) == 1 &&
1071         ShapeUtil::TrueRank(instruction->shape()) == 1) {
1072       // Don't assign a layout in case of R1 -> effective R1 reshape.
1073       return nullptr;
1074     }
1075 
1076     const Shape& output_shape = instruction->shape();
1077     Shape output_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
1078         output_shape.element_type(), AsInt64Slice(output_shape.dimensions()),
1079         LayoutUtil::MinorToMajor(output_layout));
1080     Shape operand_shape = operand->shape();
1081     *operand_shape.mutable_layout() =
1082         LayoutUtil::GetDefaultLayoutForShape(operand_shape);
1083     auto aligned_operand_shape =
1084         ShapeUtil::AlignLayouts(output_shape_with_layout, operand_shape);
1085     if (aligned_operand_shape) {
1086       auto operand_layout = aligned_operand_shape.value().layout();
1087       TF_CHECK_OK(
1088           LayoutUtil::ValidateLayoutForShape(operand_layout, operand_shape));
1089       return absl::make_unique<Layout>(operand_layout);
1090     }
1091   }
1092 
1093   if (instruction->opcode() == HloOpcode::kTranspose) {
1094     // Pick the operand layout that makes the transpose a bitcast.
1095     int64 rank = instruction->shape().rank();
1096     std::vector<int64> new_minor_to_major(rank);
1097     for (int64 i = 0; i < rank; ++i) {
1098       int64 output_dim = LayoutUtil::Minor(output_layout, i);
1099       int64 operand_dim = instruction->dimensions(output_dim);
1100       new_minor_to_major[i] = operand_dim;
1101     }
1102     Layout operand_layout = LayoutUtil::MakeLayout(new_minor_to_major);
1103     TF_CHECK_OK(
1104         LayoutUtil::ValidateLayoutForShape(operand_layout, operand->shape()));
1105     return absl::make_unique<Layout>(operand_layout);
1106   }
1107 
1108   return nullptr;
1109 }
1110 
ChooseOutputLayoutFromOperandLayout(const Layout & operand_layout,const HloInstruction * user,int64 operand_no)1111 std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
1112     const Layout& operand_layout, const HloInstruction* user,
1113     int64 operand_no) {
1114   const HloInstruction* operand = user->operand(operand_no);
1115 
1116   CHECK(user->shape().IsArray() && operand->shape().IsArray());
1117 
1118   if (!ShapeUtil::IsScalar(operand->shape()) &&
1119       operand->shape().rank() == user->shape().rank() &&
1120       !instruction_can_change_layout_func_(user)) {
1121     // Assign users the same layout as the operand.
1122     return absl::make_unique<Layout>(operand_layout);
1123   }
1124 
1125   if (user->opcode() == HloOpcode::kReshape) {
1126     // Prefer the user layout that makes the reshape an bitcast. If any
1127     // dimension bound is 1 in the user shape, there may be several such
1128     // layouts. So if 'operand_layout' is the default layout, try if the
1129     // reshape is a bitcast when using the same layout. This may avoid copy
1130     // operations. For similar reasons, if the operand and output have the same
1131     // rank, try to match the outputs's layout to the operand.
1132     if (ShapeUtil::TrueRank(operand->shape()) == 1 &&
1133         ShapeUtil::TrueRank(user->shape()) == 1) {
1134       // Don't assign a layout in case of R1 -> effective R1 reshape.
1135       return nullptr;
1136     }
1137     Shape operand_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
1138         operand->shape().element_type(),
1139         AsInt64Slice(operand->shape().dimensions()),
1140         LayoutUtil::MinorToMajor(operand_layout));
1141     Shape output_shape = user->shape();
1142     *output_shape.mutable_layout() =
1143         LayoutUtil::GetDefaultLayoutForShape(output_shape);
1144     auto aligned_user_shape =
1145         ShapeUtil::AlignLayouts(operand_shape_with_layout, output_shape);
1146     if (aligned_user_shape) {
1147       auto user_layout = aligned_user_shape.value().layout();
1148       TF_CHECK_OK(
1149           LayoutUtil::ValidateLayoutForShape(user_layout, output_shape));
1150       return absl::make_unique<Layout>(user_layout);
1151     }
1152   }
1153 
1154   if (user->opcode() == HloOpcode::kTranspose) {
1155     // Pick the user layout that makes the transpose a bitcast.
1156     int64 rank = user->shape().rank();
1157     std::vector<int64> new_minor_to_major(rank);
1158     auto inverse_dimensions = InversePermutation(user->dimensions());
1159     for (int64 i = 0; i < rank; ++i) {
1160       int64 operand_dim = LayoutUtil::Minor(operand_layout, i);
1161       int64 user_dim = inverse_dimensions[operand_dim];
1162       new_minor_to_major[i] = user_dim;
1163     }
1164     Layout user_layout = LayoutUtil::MakeLayout(new_minor_to_major);
1165     TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(user_layout, user->shape()));
1166     return absl::make_unique<Layout>(user_layout);
1167   }
1168 
1169   return nullptr;
1170 }
1171 
PropagateConstraints(LayoutConstraints * constraints)1172 Status LayoutAssignment::PropagateConstraints(LayoutConstraints* constraints) {
1173   // Gathers all initial constraints in a worklist and propagates them in
1174   // depth-first order. DFS order seems to be better than BFS because a
1175   // constraint is propagated as far as possible before propagating unrelated
1176   // constraints which makes it less likely that conflicting constraints will be
1177   // propagated to instructions. However, we should experiment with other orders
1178   // too.
1179   std::deque<const LayoutConstraint*> worklist;
1180 
1181   // Lambda for moving newly added constraints to the worklist.
1182   auto add_new_constraints_to_worklist = [constraints, &worklist]() {
1183     // Add constraints to the front of the deque for DFS ordering.
1184     for (auto* constraint : constraints->ConsumeAddedConstraints()) {
1185       if (constraint->dfs()) {
1186         worklist.push_front(constraint);
1187       } else {
1188         worklist.push_back(constraint);
1189       }
1190     }
1191   };
1192   add_new_constraints_to_worklist();
1193 
1194   while (!worklist.empty()) {
1195     const LayoutConstraint* layout_constraint = worklist.front();
1196     worklist.pop_front();
1197     VLOG(2) << "Propagating " << layout_constraint->ToString()
1198             << " to its neighbors.";
1199     if (auto* buffer_constraint =
1200             dynamic_cast<const BufferLayoutConstraint*>(layout_constraint)) {
1201       TF_RETURN_IF_ERROR(
1202           PropagateBufferConstraint(*buffer_constraint, constraints));
1203     } else if (auto* operand_constraint =
1204                    dynamic_cast<const OperandLayoutConstraint*>(
1205                        layout_constraint)) {
1206       TF_RETURN_IF_ERROR(
1207           PropagateOperandConstraint(*operand_constraint, constraints));
1208     } else if (auto* result_constraint =
1209                    dynamic_cast<const ResultLayoutConstraint*>(
1210                        layout_constraint)) {
1211       TF_RETURN_IF_ERROR(
1212           PropagateResultConstraint(*result_constraint, constraints));
1213     } else {
1214       LOG(FATAL) << "Invalid constraint type: " << *layout_constraint;
1215     }
1216 
1217     add_new_constraints_to_worklist();
1218   }
1219   return Status::OK();
1220 }
1221 
1222 namespace {
1223 
1224 // Returns a vector containing all array-shaped uses (instruction and operand
1225 // number) of the given logical buffer or its aliases.
GetArrayUsesOfBuffer(const LogicalBuffer & buffer,const TuplePointsToAnalysis & points_to_analysis)1226 std::vector<std::pair<const HloInstruction*, int64>> GetArrayUsesOfBuffer(
1227     const LogicalBuffer& buffer,
1228     const TuplePointsToAnalysis& points_to_analysis) {
1229   CHECK(buffer.IsArray());
1230   std::vector<std::pair<const HloInstruction*, int64>> uses;
1231   for (const auto& buffer_alias : points_to_analysis.GetBufferAliases(buffer)) {
1232     if (!buffer_alias.instruction()->shape().IsArray()) {
1233       continue;
1234     }
1235     // This alias must be the top-level (index == {}) of the instruction's
1236     // result because the instruction produces an array.
1237     CHECK(buffer_alias.index().empty());
1238 
1239     // Add all uses of the instruction's output.
1240     for (const HloInstruction* user : buffer_alias.instruction()->users()) {
1241       for (int64 operand_no :
1242            user->OperandIndices(buffer_alias.instruction())) {
1243         uses.emplace_back(user, operand_no);
1244       }
1245     }
1246   }
1247   return uses;
1248 }
1249 
1250 }  // namespace
1251 
PropagateUseConstraintToDefs(const ShapeLayout & shape_layout,const HloInstruction * instruction,LayoutConstraints * constraints)1252 Status LayoutAssignment::PropagateUseConstraintToDefs(
1253     const ShapeLayout& shape_layout, const HloInstruction* instruction,
1254     LayoutConstraints* constraints) {
1255   // Try to set all logical buffers which may be sources of the given operand to
1256   // match the given layout.
1257   const PointsToSet& points_to_set =
1258       constraints->points_to_analysis().GetPointsToSet(instruction);
1259   return points_to_set.ForEachElementWithStatus(
1260       [&shape_layout, constraints](
1261           const ShapeIndex& index,
1262           const PointsToSet::BufferList& buffers) -> Status {
1263         if (ShapeUtil::IsLeafIndex(shape_layout.shape(), index)) {
1264           for (const LogicalBuffer* buffer : buffers) {
1265             if (constraints->BufferLayout(*buffer) == nullptr &&
1266                 buffer->shape().IsArray()) {
1267               TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
1268                   ShapeUtil::GetSubshape(shape_layout.shape(), index).layout(),
1269                   *buffer, /*mandatory=*/true));
1270             }
1271           }
1272         }
1273         return Status::OK();
1274       });
1275 }
1276 
1277 namespace {
1278 // A transpose or a reshape that only changes trivial dimensions have meaningful
1279 // layouts that are valuable to propagate in a depthfirst manner to avoid
1280 // unassigned layouts in the graph.
InstructionShouldPropagateDepthFirst(const HloInstruction & hlo,bool forward_propagation=true)1281 bool InstructionShouldPropagateDepthFirst(const HloInstruction& hlo,
1282                                           bool forward_propagation = true) {
1283   switch (hlo.opcode()) {
1284     case HloOpcode::kFusion:
1285       return hlo.IsCustomFusion();
1286     case HloOpcode::kGather:
1287       return true;
1288     case HloOpcode::kReshape:
1289       return hlo.operand(0)->shape().rank() == 1 ||
1290              (forward_propagation &&
1291               std::get<0>(hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions()));
1292     case HloOpcode::kScatter:
1293     case HloOpcode::kTranspose:
1294       return true;
1295     default:
1296       return false;
1297   }
1298 }
1299 
1300 }  // namespace
1301 
PropagateOperandConstraint(const OperandLayoutConstraint & operand_constraint,LayoutConstraints * constraints)1302 Status LayoutAssignment::PropagateOperandConstraint(
1303     const OperandLayoutConstraint& operand_constraint,
1304     LayoutConstraints* constraints) {
1305   // Try to set the layout of the logical buffers in the given operand to match
1306   // the constrained layout. This avoids copies.
1307   TF_RETURN_IF_ERROR(
1308       PropagateUseConstraintToDefs(operand_constraint.shape_layout(),
1309                                    operand_constraint.operand(), constraints));
1310 
1311   // For array-shaped operands and user instructions try to pick a minimum cost
1312   // layout. For example, if the operand of an elementwise instruction is
1313   // constrained to a certain layout we want the output of the instruction to
1314   // have the same layout.
1315   //
1316   // If the user is not array-shaped, we still want to propagate the layout
1317   // to siblings if the instruction can't change layout. This is to represent
1318   // the information that non-layout-changing instructions should have the same
1319   // layout for the operands with the same ranks.
1320   const HloInstruction* operand = operand_constraint.operand();
1321   const HloInstruction* user = operand_constraint.instruction();
1322   if (!operand->shape().IsArray()) {
1323     return Status::OK();
1324   }
1325 
1326   if (user->opcode() == HloOpcode::kAllReduce) {
1327     const auto shape_index =
1328         user->operand_count() == 1
1329             ? ShapeIndex()
1330             : ShapeIndex({operand_constraint.operand_no()});
1331     TF_ASSIGN_OR_RETURN(const LogicalBuffer* buffer,
1332                         constraints->points_to_analysis().GetBufferDefinedAt(
1333                             user, shape_index));
1334     const BufferLayoutConstraint* constraint =
1335         constraints->GetBufferLayoutConstraint(*buffer);
1336     if (constraint == nullptr) {
1337       TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
1338           operand_constraint.shape_layout().layout(), *buffer,
1339           /*mandatory=*/false));
1340     }
1341   }
1342   if (instruction_can_change_layout_func_(user) && !user->shape().IsArray()) {
1343     return Status::OK();
1344   }
1345 
1346   // Only try to choose a low cost layout if the instruction 'user' defines its
1347   // output (ie, doesn't forward a buffer from elsewhere).
1348   if (constraints->OperandBufferForwarded(user,
1349                                           operand_constraint.operand_no())) {
1350     return Status::OK();
1351   }
1352 
1353   int64 operand_rank = operand->shape().rank();
1354   if (operand_rank <= 1) {
1355     return Status::OK();
1356   }
1357 
1358   // Propagate layouts between operands of the same instruction. This is a
1359   // constraint on non-layout-changing instructions.
1360   if (!instruction_can_change_layout_func_(user)) {
1361     // Only propgate the layout of the largest concatenate operand.
1362     if (user->opcode() == HloOpcode::kConcatenate) {
1363       for (int64 operand_no = 0; operand_no < user->operand_count();
1364            ++operand_no) {
1365         const HloInstruction* sibling = user->operand(operand_no);
1366         if (sibling == operand) {
1367           continue;
1368         }
1369         if (sibling->shape().dimensions(user->concatenate_dimension()) >
1370             operand->shape().dimensions(user->concatenate_dimension())) {
1371           return Status::OK();
1372         }
1373       }
1374     }
1375     // Make sure all siblings have the same layout as the operand.
1376     for (int64 operand_no = 0; operand_no < user->operand_count();
1377          ++operand_no) {
1378       if (user->operand(operand_no) == operand) {
1379         continue;
1380       }
1381       const HloInstruction* sibling = user->operand(operand_no);
1382       const int64 sibling_rank = sibling->shape().rank();
1383       if (sibling_rank <= 1) {
1384         continue;
1385       }
1386       if (operand_rank != sibling_rank) {
1387         continue;
1388       }
1389       const OperandLayoutConstraint* constraint =
1390           constraints->GetOperandLayoutConstraint(user, operand_no);
1391       if (constraint != nullptr) {
1392         // Due to the DFS of the propagation we can end up here when operand_no
1393         // has a layout set that hasn't been propagated yet (is still on the
1394         // stack of layouts to propagate).
1395         // We can continue here and leave the operands with different layouts,
1396         // as we will either:
1397         // - overwrite the current operand when the DFS gets back to propagating
1398         //   operand(operand_no) to its siblings
1399         // - overwrite operand(operand_no)'s layout with a mandatory layout if
1400         //   we continue to propagate our layout to the result, and then
1401         //   backwards into all operands (if the result is an array of rank > 1)
1402         continue;
1403       }
1404       TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
1405           operand_constraint.shape_layout().layout(), user, operand_no,
1406           /*mandatory=*/false));
1407     }
1408     TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
1409         user->shape(),
1410         [&](const Shape& subshape, const ShapeIndex& shape_index) {
1411           if (subshape.IsTuple()) {
1412             return Status::OK();
1413           }
1414           if (subshape.rank() <= 1) {
1415             return Status::OK();
1416           }
1417 
1418           // Assign the right layout to input fusion of higher rank reduce
1419           // operations.
1420           if (subshape.rank() != operand->shape().rank()) {
1421             return Status::OK();
1422           }
1423           if (!constraints->points_to_analysis()
1424                    .InstructionDefinesBufferAtIndex(user, shape_index)) {
1425             return Status::OK();
1426           }
1427           // TODO(b/67641796): Are there cases except fusion that use this code
1428           // path?
1429           TF_ASSIGN_OR_RETURN(
1430               const LogicalBuffer* buffer,
1431               constraints->points_to_analysis().GetBufferDefinedAt(
1432                   user, shape_index));
1433           // Make sure the output has the same layout as the operand.
1434           const BufferLayoutConstraint* constraint =
1435               constraints->GetBufferLayoutConstraint(*buffer);
1436           // If we already have a constraint for the buffer it was assigned but
1437           // hasn't propagated yet. This can happen with diamond-shaped graphs
1438           // where one path is first evaluated in depth-first order (we're here)
1439           // and the other path is propagated later. We don't set the layout
1440           // here as it will always be overwritten later.
1441           if (constraint == nullptr) {
1442             TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
1443                 operand_constraint.shape_layout().layout(), *buffer,
1444                 /*mandatory=*/false));
1445           }
1446           return Status::OK();
1447         }));
1448     return Status::OK();
1449   }
1450   TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
1451       user->shape(), [&](const Shape& subshape, const ShapeIndex& shape_index) {
1452         if (subshape.IsTuple()) {
1453           return Status::OK();
1454         }
1455         if (subshape.rank() <= 1) {
1456           return Status::OK();
1457         }
1458         if (!constraints->points_to_analysis().InstructionDefinesBufferAtIndex(
1459                 user, shape_index)) {
1460           return Status::OK();
1461         }
1462         TF_ASSIGN_OR_RETURN(
1463             const LogicalBuffer* buffer,
1464             constraints->points_to_analysis().GetBufferDefinedAt(user,
1465                                                                  shape_index));
1466         if (constraints->BufferLayout(*buffer) == nullptr ||
1467             !constraints->GetBufferLayoutConstraint(*buffer)->mandatory()) {
1468           std::unique_ptr<Layout> layout = ChooseOutputLayoutFromOperandLayout(
1469               operand_constraint.shape_layout().layout(), user,
1470               operand_constraint.operand_no());
1471           if (layout != nullptr) {
1472             TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
1473                 *layout, *buffer,
1474                 /*mandatory=*/user->opcode() == HloOpcode::kReduce,
1475                 /*dfs=*/InstructionShouldPropagateDepthFirst(*user)));
1476           }
1477         }
1478         return Status::OK();
1479       }));
1480   return Status::OK();
1481 }
1482 
PropagateBufferConstraintToOperands(const BufferLayoutConstraint & buffer_constraint,LayoutConstraints * constraints)1483 Status LayoutAssignment::PropagateBufferConstraintToOperands(
1484     const BufferLayoutConstraint& buffer_constraint,
1485     LayoutConstraints* constraints) {
1486   VLOG(5) << "PropagateBufferConstraintToOperands: "
1487           << buffer_constraint.ToString();
1488   const LogicalBuffer& buffer = buffer_constraint.buffer();
1489 
1490   const HloInstruction* instruction = buffer.instruction();
1491   if (IsAtMostRank1(instruction->shape())) {
1492     return Status::OK();
1493   }
1494 
1495   if (instruction->opcode() == HloOpcode::kAllReduce) {
1496     TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
1497         buffer_constraint.layout(), instruction,
1498         instruction->operand_count() == 1 ? 0 : buffer.index()[0],
1499         /*mandatory=*/true));
1500     return Status::OK();
1501   }
1502   for (int64 operand_no = 0; operand_no < instruction->operand_count();
1503        ++operand_no) {
1504     const HloInstruction* operand = instruction->operand(operand_no);
1505     if (IsAtMostRank1(operand->shape())) {
1506       continue;
1507     }
1508     if (!instruction_can_change_layout_func_(instruction)) {
1509       // Copy the layout to the operand.
1510       if (buffer.IsArray() && operand->shape().IsArray() &&
1511           operand->shape().rank() ==
1512               LayoutUtil::MinorToMajor(buffer_constraint.layout()).size()) {
1513         TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
1514             buffer_constraint.layout(), instruction, operand_no,
1515             /*mandatory=*/true));
1516       }
1517     } else {
1518       if (!buffer.IsTopLevel() ||
1519           !instruction->operand(operand_no)->shape().IsArray()) {
1520         continue;  // Don't touch buffers that are internal to a tuple.
1521       }
1522       VLOG(6) << "Propagating constraint to operand " << operand_no << " of "
1523               << instruction->ToShortString();
1524       // Assign a layout if there is no constraint already.
1525       const OperandLayoutConstraint* constraint =
1526           constraints->GetOperandLayoutConstraint(instruction, operand_no);
1527       if (constraint == nullptr || !constraint->mandatory()) {
1528         std::unique_ptr<Layout> operand_layout =
1529             ChooseOperandLayoutFromOutputLayout(buffer_constraint.layout(),
1530                                                 instruction, operand_no);
1531         if (operand_layout != nullptr) {
1532           TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
1533               *operand_layout, instruction, operand_no, /*mandatory=*/false,
1534               /*dfs=*/
1535               InstructionShouldPropagateDepthFirst(
1536                   *instruction, /*forward_propagation=*/false)));
1537         }
1538       } else {
1539         VLOG(6) << "Operand already has a constraint "
1540                 << constraint->ToString();
1541       }
1542     }
1543   }
1544   return Status::OK();
1545 }
1546 
PropagateBufferConstraint(const BufferLayoutConstraint & buffer_constraint,LayoutConstraints * constraints)1547 Status LayoutAssignment::PropagateBufferConstraint(
1548     const BufferLayoutConstraint& buffer_constraint,
1549     LayoutConstraints* constraints) {
1550   // Only propagate array layouts.
1551   const LogicalBuffer& buffer = buffer_constraint.buffer();
1552   if (!buffer.IsArray()) {
1553     return Status::OK();
1554   }
1555   TF_RETURN_IF_ERROR(
1556       PropagateBufferConstraintToUses(buffer_constraint, constraints));
1557   return PropagateBufferConstraintToOperands(buffer_constraint, constraints);
1558 }
1559 
PropagateBufferConstraintToUses(const BufferLayoutConstraint & buffer_constraint,LayoutConstraints * constraints)1560 Status LayoutAssignment::PropagateBufferConstraintToUses(
1561     const BufferLayoutConstraint& buffer_constraint,
1562     LayoutConstraints* constraints) {
1563   const LogicalBuffer& buffer = buffer_constraint.buffer();
1564   TF_RET_CHECK(buffer.IsArray());
1565 
1566   // Propagate the layout to all array uses of the logical buffer. This skips
1567   // uses of the buffer where the buffer is the element of a tuple.
1568   for (const auto& user_operand_no :
1569        GetArrayUsesOfBuffer(buffer, constraints->points_to_analysis())) {
1570     const HloInstruction* user = user_operand_no.first;
1571     int64 operand_no = user_operand_no.second;
1572     // Only add an operand constraint if the user does not forward the buffer
1573     // because this case is not handled is SetOperandLayout.
1574     if (constraints->OperandLayout(user, operand_no) == nullptr &&
1575         !constraints->OperandBufferForwarded(user, operand_no)) {
1576       TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
1577           buffer_constraint.layout(), user, operand_no, /*mandatory=*/false));
1578     }
1579   }
1580 
1581   // Propagate to backedges of kWhile.
1582   CallGraphNode& node = call_graph_->GetNode(buffer.instruction()->parent());
1583   if (node.caller_callsites().size() != 1) {
1584     return Status::OK();
1585   }
1586   const HloInstruction* parent = node.caller_callsites()[0].instruction();
1587   if (parent->opcode() != HloOpcode::kWhile) {
1588     return Status::OK();
1589   }
1590 
1591   for (HloInstruction* user : buffer.instruction()->users()) {
1592     if (user->parent()->root_instruction()->opcode() != HloOpcode::kTuple) {
1593       continue;
1594     }
1595     if (user->parent()->root_instruction() == user) {
1596       VLOG(3) << "Propagating layout through backedge"
1597               << buffer_constraint.layout().ToString();
1598       int64 index = user->operand_index(buffer.instruction());
1599       TF_ASSIGN_OR_RETURN(
1600           auto buffer, constraints->points_to_analysis().GetBufferDefinedAt(
1601                            user->parent()->parameter_instruction(0), {index}));
1602 
1603       TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
1604           buffer_constraint.layout(), *buffer, /*mandatory=*/false));
1605     }
1606   }
1607 
1608   return Status::OK();
1609 }
1610 
PropagateResultConstraint(const ResultLayoutConstraint & layout_constraint,LayoutConstraints * constraints)1611 Status LayoutAssignment::PropagateResultConstraint(
1612     const ResultLayoutConstraint& layout_constraint,
1613     LayoutConstraints* constraints) {
1614   // Propagate the use constraint of the root instruction up to the logical
1615   // buffers which make up the result.
1616   return PropagateUseConstraintToDefs(
1617       layout_constraint.shape_layout(),
1618       constraints->computation()->root_instruction(), constraints);
1619 }
1620 
1621 namespace {
1622 
1623 // Infers the layout of the array at the given index in the given instruction's
1624 // output using points-to analysis. Precondition: The given instruction must
1625 // not produce this array value (that is, the array is forwarded from the
1626 // instruction's operands).
InferArrayLayout(const TuplePointsToAnalysis & points_to_analysis,HloInstruction * instruction,const ShapeIndex & index)1627 StatusOr<Layout> InferArrayLayout(
1628     const TuplePointsToAnalysis& points_to_analysis,
1629     HloInstruction* instruction, const ShapeIndex& index) {
1630   // This function should only be called for array shapes which don't yet have
1631   // layouts.
1632   const Shape& subshape = ShapeUtil::GetSubshape(instruction->shape(), index);
1633   TF_RET_CHECK(subshape.IsArray());
1634   TF_RET_CHECK(!subshape.has_layout());
1635 
1636   // The instruction should not define the buffer at this index.
1637   TF_RET_CHECK(
1638       !points_to_analysis.InstructionDefinesBufferAtIndex(instruction, index))
1639       << instruction->ToString();
1640 
1641   const auto& source_buffers =
1642       points_to_analysis.GetPointsToSet(instruction).element(index);
1643   TF_RET_CHECK(!source_buffers.empty());
1644 
1645   // Verify the layout is the same for every LogicalBuffer which this location
1646   // ('instruction' and 'index') points to.
1647   const Layout* first_buffer_layout = nullptr;
1648   for (const LogicalBuffer* source_buffer : source_buffers) {
1649     if (!source_buffer->shape().has_layout()) {
1650       // This should not happen because we've assigned layouts to all
1651       // instructions preceding this one.
1652       return InternalError("LogicalBuffer %s does not have a layout",
1653                            source_buffer->ToString());
1654     }
1655 
1656     if (first_buffer_layout == nullptr) {
1657       first_buffer_layout = &source_buffer->shape().layout();
1658     } else if (!Layout::Equal().MinorToMajorOnly()(
1659                    source_buffer->shape().layout(), *first_buffer_layout)) {
1660       // The points-to set is ambiguous for this index and the different source
1661       // buffers have different layouts. This case is possible in valid XLA
1662       // computations because we do not propagate BufferLayoutConstraints to all
1663       // LogicalBuffers which may alias the constrained LogicalBuffer at some
1664       // point in the computation.
1665       return FailedPrecondition(
1666           "Array at index {%s} in instruction %s aliases buffers %s "
1667           "and %s which have different layouts",
1668           absl::StrJoin(index, ","), instruction->name(),
1669           source_buffers[0]->ToString(), source_buffer->ToString());
1670     }
1671   }
1672 
1673   return *first_buffer_layout;
1674 }
1675 
1676 // For fusion instructions, set the layout of each fused parameter instruction
1677 // to match the layout of its corresponding fusion instruction operand. Also,
1678 // set the layout of the fused root to match the layout of the fusion
1679 // instruction itself.
SetFusionLayouts(HloInstruction * fusion)1680 Status SetFusionLayouts(HloInstruction* fusion) {
1681   TF_RET_CHECK(fusion->opcode() == HloOpcode::kFusion);
1682   for (auto* fused_instruction :
1683        fusion->fused_instructions_computation()->MakeInstructionPostOrder()) {
1684     if (fused_instruction->opcode() == HloOpcode::kParameter) {
1685       const HloInstruction* fusion_operand =
1686           fusion->operand(fused_instruction->parameter_number());
1687       DCHECK(ShapeUtil::Compatible(fusion_operand->shape(),
1688                                    fused_instruction->shape()));
1689       TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1690           fusion_operand->shape(), fused_instruction->mutable_shape()));
1691     } else if (fused_instruction == fusion->fused_expression_root()) {
1692       // The layout of the root of the fused expression must match the fusion
1693       // instruction layout.
1694       DCHECK(
1695           ShapeUtil::Compatible(fusion->shape(), fused_instruction->shape()));
1696       TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1697           fusion->shape(), fused_instruction->mutable_shape()));
1698     } else if (fused_instruction->opcode() == HloOpcode::kGetTupleElement) {
1699       // A GTE inherits its layout from its operand (which should ultimately be
1700       // a parameter).
1701       TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1702           fused_instruction->operand(0)->shape().tuple_shapes(
1703               fused_instruction->tuple_index()),
1704           fused_instruction->mutable_shape()));
1705     } else if (fused_instruction->opcode() == HloOpcode::kConstant) {
1706       // Give constants the layout of their literal.
1707       TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1708           fused_instruction->literal().shape(),
1709           fused_instruction->mutable_shape()));
1710     } else if (fused_instruction->opcode() == HloOpcode::kInfeed) {
1711       // Nop; leave the infeed layout alone.
1712     } else if (!fusion->IsCustomFusion()) {
1713       // Other instructions don't have layouts inside of fusion nodes.
1714       // But do not clear layouts for other instructions in custom fusion nodes.
1715       LayoutUtil::ClearLayout(fused_instruction->mutable_shape());
1716     }
1717   }
1718 
1719   return Status::OK();
1720 }
1721 
1722 }  // namespace
1723 
AssignLayouts(const LayoutConstraints & constraints,HloComputation * computation)1724 Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints,
1725                                        HloComputation* computation) {
1726   VLOG(2) << "Assigning layouts to computation: " << computation->name();
1727   XLA_VLOG_LINES(2, computation->ToString());
1728   XLA_VLOG_LINES(2, constraints.ToString());
1729 
1730   for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
1731     LayoutUtil::ClearLayout(instruction->mutable_shape());
1732 
1733     // Set the layouts of the array shapes this instruction defines as indicated
1734     // by the respective BufferLayoutConstraints. Any array shapes in the output
1735     // of the instruction which are not defined by the instruction (eg, array
1736     // elements in a Tuple instruction) will be assigned below via inference.
1737     for (const LogicalBuffer* buffer :
1738          constraints.points_to_analysis().GetBuffersDefinedByInstruction(
1739              instruction)) {
1740       if (!buffer->shape().IsArray()) {
1741         continue;
1742       }
1743 
1744       TF_RET_CHECK(buffer->instruction() == instruction);
1745       const Layout* buffer_layout = constraints.BufferLayout(*buffer);
1746       TF_RET_CHECK(buffer_layout != nullptr);
1747 
1748       if (instruction->opcode() == HloOpcode::kConstant) {
1749         // For constants, we also need to change the layout of the internal
1750         // literal.
1751         instruction->RelayoutConstant(*buffer_layout, buffer->index());
1752       } else {
1753         Shape* buffer_subshape = ShapeUtil::GetMutableSubshape(
1754             instruction->mutable_shape(), buffer->index());
1755         *buffer_subshape->mutable_layout() = *buffer_layout;
1756       }
1757     }
1758 
1759     // Any remaining layouts in the output of the instruction must be
1760     // inferrable using points-to analysis.
1761     TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus(
1762         instruction->mutable_shape(),
1763         [instruction, &constraints](Shape* subshape, const ShapeIndex& index) {
1764           if (subshape->has_layout() || !subshape->IsArray()) {
1765             return Status::OK();
1766           }
1767           // Set Layout of subshape to match layout of LogicalBuffer which
1768           // produces it.
1769           TF_ASSIGN_OR_RETURN(*subshape->mutable_layout(),
1770                               InferArrayLayout(constraints.points_to_analysis(),
1771                                                instruction, index));
1772           return Status::OK();
1773         }));
1774 
1775     // Create a copy of an operand if the operand instruction's layout does not
1776     // match the use constraint (OperandLayoutConstraint).
1777     for (int64 operand_no = 0; operand_no < instruction->operand_count();
1778          ++operand_no) {
1779       const ShapeLayout* operand_layout =
1780           constraints.OperandLayout(instruction, operand_no);
1781       if (operand_layout != nullptr) {
1782         TF_RETURN_IF_ERROR(CopyOperandIfLayoutsDiffer(*operand_layout,
1783                                                       instruction, operand_no));
1784       }
1785     }
1786 
1787     // Fusion instructions require some layouts to be set on fused instructions
1788     // inside the fusion instruction.
1789     if (instruction->opcode() == HloOpcode::kFusion) {
1790       TF_RETURN_IF_ERROR(SetFusionLayouts(instruction));
1791     }
1792 
1793     // Execute extra verification step once the layout has been finalized.
1794     TF_RETURN_IF_ERROR(Verify(instruction));
1795 
1796     // Shape must be valid.
1797     TF_RETURN_IF_ERROR(
1798         ShapeUtil::ValidateShapeWithOptionalLayout(instruction->shape()));
1799 
1800     // Verify all layouts in the shape have been set.
1801     TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape()));
1802   }
1803   return Status::OK();
1804 }
1805 
CalculateComputationLayout(HloComputation * computation)1806 Status LayoutAssignment::CalculateComputationLayout(
1807     HloComputation* computation) {
1808   ComputationLayout computation_layout(computation->ComputeProgramShape(),
1809                                        /*ignore_layouts=*/false);
1810   InsertOrDie(&computation_layouts_, computation, computation_layout);
1811   VLOG(2) << "  Calculated ComputationLayout = "
1812           << computation_layout.ToString();
1813   return Status::OK();
1814 }
1815 
ClearComputationLayouts(HloComputation * computation)1816 Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) {
1817   // Clear existing layouts of the instructions.  All layouts must be assigned
1818   // by the LayoutAssignment pass, except for those on parameters, the
1819   // computation result, and a couple special cases. The former two are
1820   // specified in computation_layout.  Clearing the layouts here avoids hiding
1821   // potential bugs in the layout assignment pass that may accidentally use the
1822   // existing layout.
1823   for (HloInstruction* instruction : computation->instructions()) {
1824     if (instruction->opcode() == HloOpcode::kBitcast) {
1825       // bitcasts are inherently layout sensitive and so a bitcast instruction
1826       // present in the IR before layout assignment is a bug.
1827       return InternalError(
1828           "Unexpected bitcast operation seen during layout assignment: %s.",
1829           instruction->ToString());
1830     }
1831     // Some instructions carry mandatory layouts in their shape.
1832     if (instruction->opcode() != HloOpcode::kInfeed &&
1833         !IsLayoutConstrainedCustomCall(instruction) &&
1834         !IsLayoutConstrainedCollective(instruction)) {
1835       LayoutUtil::ClearLayout(instruction->mutable_shape());
1836     }
1837   }
1838   return Status::OK();
1839 }
1840 
RunOnComputation(ComputationLayout * computation_layout,HloComputation * computation,ChannelLayoutConstraints * channel_constraints)1841 Status LayoutAssignment::RunOnComputation(
1842     ComputationLayout* computation_layout, HloComputation* computation,
1843     ChannelLayoutConstraints* channel_constraints) {
1844   VLOG(2) << "LayoutAssignment::RunOnComputation(" << computation->name()
1845           << ")";
1846 
1847   // Must be run before clearing layouts.
1848   TF_RETURN_IF_ERROR(BuildHostChannelConstraints(computation));
1849 
1850   TF_RETURN_IF_ERROR(ClearComputationLayouts(computation));
1851   if (computation_layout != nullptr) {
1852     auto it = computation_layouts_.find(computation);
1853     if (it == computation_layouts_.end()) {
1854       VLOG(2) << "  New ComputationLayout = " << computation_layout->ToString();
1855       computation_layouts_.emplace(computation, *computation_layout);
1856     } else {
1857       TF_RET_CHECK(computation_layout == &it->second ||
1858                    computation_layout == entry_computation_layout_);
1859       VLOG(2) << "  Existing ComputationLayout = "
1860               << computation_layout->ToString();
1861     }
1862   } else {
1863     VLOG(2) << "  No ComputationLayout specified (will be calculated)";
1864   }
1865 
1866   // Construct LayoutConstraints with all layout constraints of the computation.
1867   LayoutConstraints constraints(*points_to_analysis_, computation);
1868 
1869   // Add constraints required for correctness on all backends (eg, entry
1870   // parameter layout constraints).
1871   TF_RETURN_IF_ERROR(AddMandatoryConstraints(
1872       computation_layout, channel_constraints, computation, &constraints));
1873 
1874   // Add any backend-specific constraints.
1875   TF_RETURN_IF_ERROR(AddBackendConstraints(&constraints));
1876 
1877   // Propagates layouts from mandatory and backend constraints.
1878   TF_RETURN_IF_ERROR(PropagateConstraints(&constraints));
1879 
1880   // Prior to applying default layouts, we take note of all HLO instructions
1881   // which lack a layout constraint.
1882   for (LogicalBuffer::Id buffer_id : constraints.unconstrained_buffer_ids()) {
1883     unconstrained_layout_instructions_.insert(
1884         points_to_analysis_->GetBuffer(buffer_id).instruction());
1885   }
1886 
1887   // While any unconstrained buffers remain, pick an arbitrary buffer, give it a
1888   // layout and propagate the change.
1889   while (!constraints.unconstrained_buffer_ids().empty()) {
1890     int unconstrained_count = constraints.unconstrained_buffer_ids().size();
1891 
1892     // Arbitrarily pick the first unconstrained buffer and give it the default
1893     // layout (or the literal layout, in case of constants). By construction
1894     // unconstrained_buffers() has a stable sort based on LogicalBuffer::Id.
1895     const LogicalBuffer& buffer = points_to_analysis_->GetBuffer(
1896         *constraints.unconstrained_buffer_ids().begin());
1897     const HloInstruction* instruction = buffer.instruction();
1898     Layout new_layout =
1899         instruction->opcode() == HloOpcode::kConstant
1900             ? ShapeUtil::GetSubshape(instruction->literal().shape(),
1901                                      buffer.index())
1902                   .layout()
1903             : GetUnconstrainedLayout(buffer);
1904     TF_RETURN_IF_ERROR(constraints.SetBufferLayout(new_layout, buffer,
1905                                                    /*mandatory=*/false));
1906 
1907     TF_RETURN_IF_ERROR(PropagateConstraints(&constraints));
1908 
1909     // To verify progress has been made, check that the number of unconstrained
1910     // buffers has been reduced.
1911     CHECK_LT(constraints.unconstrained_buffer_ids().size(),
1912              unconstrained_count);
1913   }
1914   // All logical buffers should have constraints at this point. All that
1915   // remains is assign the constraints to the buffers and infer layouts for
1916   // aliased buffers.
1917   TF_RETURN_IF_ERROR(AssignLayouts(constraints, computation));
1918 
1919   // If the computation layout wasn't specified, now it is the time to compute
1920   // it according to the parameters and root instruction layouts.
1921   // This allows the first pass through this API to record the best flowing
1922   // layout to parameters and root instruction.
1923   if (computation_layout == nullptr) {
1924     TF_RETURN_IF_ERROR(CalculateComputationLayout(computation));
1925   }
1926 
1927   // Record the layouts assigned for any communication ops in
1928   // channel_constraints so that they are constrained for future modules.
1929   if (channel_constraints != nullptr) {
1930     TF_RETURN_IF_ERROR(
1931         ConstrainChannelLayouts(computation, channel_constraints));
1932   }
1933 
1934   // Copy the root instruction's result if its layout does not match the result
1935   // layout constraint.
1936   if (constraints.ResultLayout() != nullptr) {
1937     // Layout assignment at this point only does minor-to-major assignment so
1938     // tiling info should be ignored here for comparison.
1939     if (!constraints.ResultLayout()->MatchesLayoutInShape(
1940             computation->root_instruction()->shape(),
1941             /*minor_to_major_only=*/true)) {
1942       if (conditional_mismatch_.count(computation) > 0) {
1943         *FindOrDie(computation_layouts_, computation).mutable_result_layout() =
1944             FindOrDie(conditional_mismatch_, computation).result_layout();
1945       }
1946       TF_ASSIGN_OR_RETURN(
1947           HloInstruction * new_root,
1948           CreateCopyWithNewLayout(constraints.ResultLayout()->shape(),
1949                                   computation->root_instruction()));
1950       computation->set_root_instruction(new_root);
1951     } else {
1952       // Copy the specified tiling info.
1953       auto assign_tiling = [&constraints](xla::Shape* subshape,
1954                                           const xla::ShapeIndex& index) {
1955         if (subshape->IsArray()) {
1956           const Shape& result_shape = ShapeUtil::GetSubshape(
1957               constraints.ResultLayout()->shape(), index);
1958           subshape->mutable_layout()->mutable_tiles()->assign(
1959               result_shape.layout().tiles().begin(),
1960               result_shape.layout().tiles().end());
1961         }
1962       };
1963       xla::ShapeUtil::ForEachMutableSubshape(
1964           computation->root_instruction()->mutable_shape(), assign_tiling);
1965     }
1966   }
1967   return Status::OK();
1968 }
1969 
ConstrainChannelLayouts(HloComputation * computation,ChannelLayoutConstraints * channel_constraints)1970 Status LayoutAssignment::ConstrainChannelLayouts(
1971     HloComputation* computation,
1972     ChannelLayoutConstraints* channel_constraints) {
1973   auto get_channel_constraints = [&](const HloInstruction* instruction) {
1974     return IsHostSendRecv(instruction) ? &host_channel_constraints_
1975                                        : channel_constraints;
1976   };
1977   // We go through the kRecvDone before. These must either impose their layout,
1978   // or find a matching one already existing (ConstrainChannel() returns
1979   // nullptr).
1980   for (HloInstruction* instruction : computation->instructions()) {
1981     if (instruction->opcode() == HloOpcode::kRecvDone) {
1982       const Layout* layout =
1983           get_channel_constraints(instruction)
1984               ->ConstrainChannel(
1985                   *instruction->channel_id(),
1986                   ShapeUtil::GetSubshape(instruction->shape(), {0}).layout());
1987       TF_RET_CHECK(layout == nullptr)
1988           << instruction->ToString()
1989           << " cannot constrain layout as it was set to "
1990           << LayoutUtil::HumanString(*layout);
1991     }
1992   }
1993   // After that we go through the kSend. These are likely going to have a kCopy
1994   // as operand (otherwise we add it), so in case the constrained layout does
1995   // not match, we can change the kCopy layout (and the kSend one as well).
1996   for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
1997     if (instruction->opcode() == HloOpcode::kSend) {
1998       HloInstruction* operand = instruction->mutable_operand(0);
1999       get_channel_constraints(instruction)
2000           ->ConstrainChannel(*instruction->channel_id(),
2001                              operand->shape().layout());
2002     } else if (instruction->IsCrossModuleAllReduce()) {
2003       get_channel_constraints(instruction)
2004           ->ConstrainChannel(instruction->channel_id().value(),
2005                              instruction->shape().layout());
2006     }
2007   }
2008   return Status::OK();
2009 }
2010 
PropagateMemorySpace(HloModule * module)2011 Status LayoutAssignment::PropagateMemorySpace(HloModule* module) {
2012   TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(module));
2013   for (const auto& buffer : alias_analysis->buffers()) {
2014     // First go through values to collect the memory spaces.
2015     int64 buffer_memory_space = Layout::kDefaultMemorySpace;
2016     for (auto value : buffer.values()) {
2017       const Shape& defining_shape = value->defining_position().shape();
2018       int64 memory_space = defining_shape.layout().memory_space();
2019       if (memory_space != Layout::kDefaultMemorySpace) {
2020         if (buffer_memory_space != Layout::kDefaultMemorySpace &&
2021             memory_space != buffer_memory_space) {
2022           return InternalError(
2023               "Buffer %d (%s) has conflicting memory spaces: %d and %d.",
2024               buffer.id(), value->ToShortString(), buffer_memory_space,
2025               memory_space);
2026         }
2027         buffer_memory_space = memory_space;
2028       }
2029     }
2030 
2031     // If we encounter a memory space other than the default, then propagate all
2032     // the positions with the buffer's memory space.
2033     if (buffer_memory_space != Layout::kDefaultMemorySpace) {
2034       for (auto value : buffer.values()) {
2035         for (auto& position : value->positions()) {
2036           Shape* shape = ShapeUtil::GetMutableSubshape(
2037               position.instruction->mutable_shape(), position.index);
2038           shape->mutable_layout()->set_memory_space(buffer_memory_space);
2039         }
2040       }
2041     }
2042   }
2043   return Status::OK();
2044 }
2045 
PropagateComputationLayouts(HloComputation * computation,ComputationLayout * computation_layout)2046 Status LayoutAssignment::PropagateComputationLayouts(
2047     HloComputation* computation, ComputationLayout* computation_layout) {
2048   ComputationLayout computed_computation_layout(
2049       computation->ComputeProgramShape(),
2050       /*ignore_layouts=*/false);
2051   for (int64 i = 0; i < computed_computation_layout.parameter_count(); ++i) {
2052     ShapeLayout* param_layout = computation_layout->mutable_parameter_layout(i);
2053     bool needs_assign = false;
2054     TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
2055         param_layout->shape(),
2056         [&](const Shape& subshape, const ShapeIndex& shape_index) {
2057           if (!ShapeUtil::IsLeafIndex(param_layout->shape(), shape_index)) {
2058             return Status::OK();
2059           }
2060           if (!subshape.has_layout()) {
2061             needs_assign = true;
2062             return Status::OK();
2063           }
2064           const auto& computed_subshape = ShapeUtil::GetSubshape(
2065               computed_computation_layout.parameter_shape(i), shape_index);
2066           if (subshape.layout() != computed_subshape.layout()) {
2067             return InternalError(
2068                 "Assigned parameter shape %s does not match layout of "
2069                 "computation shape: %s",
2070                 computed_computation_layout.ToString(),
2071                 computation_layout->ToString());
2072           }
2073           return Status::OK();
2074         }));
2075     if (needs_assign) {
2076       VLOG(4) << "Assigning layout to parameter " << i << " of computation "
2077               << computation->name() << ": "
2078               << computed_computation_layout.parameter_layout(i).ToString();
2079       *param_layout = computed_computation_layout.parameter_layout(i);
2080     }
2081   }
2082   ShapeLayout* result_layout = computation_layout->mutable_result_layout();
2083   if (!result_layout->LayoutIsSet()) {
2084     VLOG(4) << "Assigning result layout of computation " << computation->name()
2085             << ": " << computed_computation_layout.result_layout().ToString();
2086     *result_layout = computed_computation_layout.result_layout();
2087   } else {
2088     TF_RET_CHECK(
2089         Shape::Equal().IgnoreDynamicDimension().MinorToMajorOnlyInLayout()(
2090             computed_computation_layout.result_layout().shape(),
2091             result_layout->shape()));
2092   }
2093   return Status::OK();
2094 }
2095 
Run(HloModule * module)2096 StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
2097   VLOG(2) << "Running layout assignment on module " << module->name();
2098   TF_RETURN_IF_ERROR(Init());
2099   call_graph_ = CallGraph::Build(module);
2100   auto computations = module->computations();
2101 
2102   // Add copy to the operand of Send instructions, since we cannot call
2103   // SetOperandLayout on Send instructions as it aliases its input to the
2104   // output.
2105   //
2106   // TODO(b/68493863): Remove this once we can call SetOperandLayout() on the
2107   // operand buffers that aliases with the output.
2108   for (HloComputation* computation : module->computations()) {
2109     for (HloInstruction* instruction :
2110          computation->MakeInstructionPostOrder()) {
2111       if (instruction->opcode() == HloOpcode::kSend) {
2112         TF_RETURN_IF_ERROR(AddCopyForOperand(instruction, 0));
2113       }
2114     }
2115   }
2116 
2117   // Clone Conditional computations with multiple callsites.
2118   for (HloComputation* computation : computations) {
2119     CallGraphNode& node = call_graph_->GetNode(computation);
2120     if (node.caller_callsites().size() == 1) {
2121       continue;
2122     }
2123     if (absl::c_none_of(node.caller_callsites(), [](CallSite caller) {
2124           return caller.instruction()->opcode() == HloOpcode::kConditional;
2125         })) {
2126       continue;
2127     }
2128     for (int64 i = 0; i < node.caller_callsites().size() - 1; ++i) {
2129       HloInstruction* caller = node.caller_callsites()[i].instruction();
2130       if (caller->opcode() == HloOpcode::kConditional) {
2131         for (int64 k = 0; k < caller->branch_count(); ++k) {
2132           if (computation == caller->branch_computation(k)) {
2133             caller->set_branch_computation(
2134                 k, module->AddEmbeddedComputation(computation->Clone()));
2135             break;
2136           }
2137         }
2138       }
2139     }
2140   }
2141 
2142   // Verify computation layout is sane.
2143   const HloComputation* entry = module->entry_computation();
2144   TF_RET_CHECK(entry_computation_layout_->parameter_count() ==
2145                entry->num_parameters());
2146   for (int64 i = 0; i < entry->num_parameters(); ++i) {
2147     TF_RET_CHECK(
2148         ShapeUtil::Compatible(entry_computation_layout_->parameter_shape(i),
2149                               entry->parameter_instruction(i)->shape()));
2150   }
2151   TF_RET_CHECK(ShapeUtil::Compatible(entry_computation_layout_->result_shape(),
2152                                      entry->root_instruction()->shape()));
2153 
2154   // We do two passes. The first one we pass a nullptr ComputationLayout to
2155   // the RunOnComputation() calls (for non entry computations), and we register
2156   // the ComputationLayout which are naturally flowing in DFS fashion to the
2157   // parameters and root instruction.
2158   // Walking in DFS mode though, means that we can end up with incorrect layouts
2159   // when seen from an outer instruction, which has across-computation
2160   // constraints to impose.
2161   // For example, the kWhile instruction needs to enforce the same layouts for
2162   // the parameters and root of the body, as well as the condition parameters.
2163   // Similarly, the kConditional instruction needs to enforce the same layouts
2164   // for the root of the true and false computations.
2165   // So in the first pass, while allowing the layouts to flow to parameters and
2166   // root, we also fix up the eventually inconsistent ComputationLayout, which
2167   // will be then made mandatory by the second pass.
2168   for (int64 i = 0; i < 2; ++i) {
2169     VLOG(5) << "Running " << (i == 0 ? "un" : "") << "constrained pass";
2170     TF_RETURN_IF_ERROR(ClearPreviousPassSideEffects(module));
2171     TF_ASSIGN_OR_RETURN(auto points_to_analysis,
2172                         TuplePointsToAnalysis::Run(module));
2173     points_to_analysis_ = std::move(points_to_analysis);
2174     for (auto* computation : module->MakeComputationPostOrder()) {
2175       if (computation->IsFusionComputation()) {
2176         continue;
2177       }
2178       if (computation == module->entry_computation()) {
2179         TF_RETURN_IF_ERROR(RunOnComputation(entry_computation_layout_,
2180                                             module->entry_computation(),
2181                                             channel_layout_constraints_));
2182       } else {
2183         ComputationLayout* computation_layout =
2184             (i == 0 || conditional_mismatch_.count(computation) > 0)
2185                 ? nullptr
2186                 : &FindOrDie(computation_layouts_, computation);
2187         TF_RETURN_IF_ERROR(RunOnComputation(computation_layout, computation,
2188                                             channel_layout_constraints_));
2189       }
2190     }
2191   }
2192   TF_RETURN_IF_ERROR(PropagateComputationLayouts(module->entry_computation(),
2193                                                  entry_computation_layout_));
2194 
2195   TF_RETURN_IF_ERROR(PropagateMemorySpace(module));
2196 
2197   TF_RETURN_IF_ERROR(CheckLayouts(module));
2198 
2199   // All layouts are reset then reassigned by this pass.
2200   return true;
2201 }
2202 
2203 /* static */
InstructionCanChangeLayout(const HloInstruction * instruction)2204 bool LayoutAssignment::InstructionCanChangeLayout(
2205     const HloInstruction* instruction) {
2206   switch (instruction->opcode()) {
2207     case HloOpcode::kAbs:
2208     case HloOpcode::kAdd:
2209     case HloOpcode::kAddDependency:
2210     case HloOpcode::kAnd:
2211     case HloOpcode::kAtan2:
2212     case HloOpcode::kBitcastConvert:
2213     case HloOpcode::kCeil:
2214     case HloOpcode::kClamp:
2215     case HloOpcode::kClz:
2216     case HloOpcode::kCompare:
2217     case HloOpcode::kComplex:
2218     case HloOpcode::kConcatenate:
2219     case HloOpcode::kConditional:
2220     case HloOpcode::kConvert:
2221     case HloOpcode::kCos:
2222     case HloOpcode::kAllGather:
2223     case HloOpcode::kAllToAll:
2224     case HloOpcode::kCollectivePermute:
2225     case HloOpcode::kDivide:
2226     case HloOpcode::kDynamicSlice:
2227     case HloOpcode::kDynamicUpdateSlice:
2228     case HloOpcode::kExp:
2229     case HloOpcode::kExpm1:
2230     case HloOpcode::kFft:
2231     case HloOpcode::kFloor:
2232     case HloOpcode::kImag:
2233     case HloOpcode::kIsFinite:
2234     case HloOpcode::kLog:
2235     case HloOpcode::kLog1p:
2236     case HloOpcode::kLogistic:
2237     case HloOpcode::kMap:
2238     case HloOpcode::kMaximum:
2239     case HloOpcode::kMinimum:
2240     case HloOpcode::kMultiply:
2241     case HloOpcode::kNegate:
2242     case HloOpcode::kNot:
2243     case HloOpcode::kOr:
2244     case HloOpcode::kXor:
2245     case HloOpcode::kPad:
2246     case HloOpcode::kPower:
2247     case HloOpcode::kReal:
2248     case HloOpcode::kReducePrecision:
2249     case HloOpcode::kReduceWindow:
2250     case HloOpcode::kRemainder:
2251     case HloOpcode::kReverse:
2252     case HloOpcode::kRoundNearestAfz:
2253     case HloOpcode::kRsqrt:
2254     case HloOpcode::kScatter:
2255     case HloOpcode::kSelect:
2256     case HloOpcode::kSelectAndScatter:
2257     case HloOpcode::kShiftLeft:
2258     case HloOpcode::kShiftRightArithmetic:
2259     case HloOpcode::kShiftRightLogical:
2260     case HloOpcode::kSign:
2261     case HloOpcode::kSin:
2262     case HloOpcode::kSlice:
2263     case HloOpcode::kSort:
2264     case HloOpcode::kSqrt:
2265     case HloOpcode::kCbrt:
2266     case HloOpcode::kSubtract:
2267     case HloOpcode::kTanh:
2268     case HloOpcode::kPopulationCount:
2269     case HloOpcode::kTriangularSolve:
2270     case HloOpcode::kCholesky:
2271     case HloOpcode::kTupleSelect:
2272     case HloOpcode::kWhile:
2273     case HloOpcode::kSetDimensionSize:
2274     // AllReduce is variadic so it needs to be careful to assign the same layout
2275     // to the corresponding input argument and Tuple index.
2276     case HloOpcode::kAllReduce:
2277       return false;
2278     case HloOpcode::kBatchNormGrad:
2279     case HloOpcode::kBatchNormInference:
2280     case HloOpcode::kBatchNormTraining:
2281     case HloOpcode::kBitcast:
2282     case HloOpcode::kBroadcast:
2283     case HloOpcode::kCall:
2284     case HloOpcode::kCollectivePermuteStart:
2285     case HloOpcode::kCollectivePermuteDone:
2286     case HloOpcode::kConstant:
2287     case HloOpcode::kConvolution:
2288     case HloOpcode::kCopy:
2289     case HloOpcode::kCopyStart:
2290     case HloOpcode::kCopyDone:
2291     case HloOpcode::kCustomCall:
2292     case HloOpcode::kDomain:
2293     case HloOpcode::kDot:
2294     case HloOpcode::kFusion:
2295     case HloOpcode::kGather:
2296     case HloOpcode::kGetTupleElement:
2297     case HloOpcode::kInfeed:
2298     case HloOpcode::kIota:
2299     case HloOpcode::kOutfeed:
2300     case HloOpcode::kParameter:
2301     case HloOpcode::kPartitionId:
2302     case HloOpcode::kRecv:
2303     case HloOpcode::kRecvDone:
2304     case HloOpcode::kReduce:
2305     case HloOpcode::kReplicaId:
2306     case HloOpcode::kReshape:
2307     case HloOpcode::kDynamicReshape:
2308     case HloOpcode::kRng:
2309     case HloOpcode::kRngBitGenerator:
2310     case HloOpcode::kRngGetAndUpdateState:
2311     case HloOpcode::kSend:
2312     case HloOpcode::kSendDone:
2313     case HloOpcode::kAfterAll:
2314     case HloOpcode::kTrace:
2315     case HloOpcode::kTranspose:
2316     case HloOpcode::kTuple:
2317     case HloOpcode::kGetDimensionSize:
2318       return true;
2319   }
2320 }
2321 
2322 /* static */
IsAtMostRank1(const Shape & shape)2323 bool LayoutAssignment::IsAtMostRank1(const Shape& shape) {
2324   if (shape.IsArray()) {
2325     return shape.rank() <= 1;
2326   }
2327   return absl::c_all_of(shape.tuple_shapes(), [](const Shape& subshape) {
2328     return IsAtMostRank1(subshape);
2329   });
2330 }
2331 
Init()2332 Status LayoutAssignment::Init() {
2333   computation_layouts_.clear();
2334   conditional_mismatch_.clear();
2335   *entry_computation_layout_ = saved_entry_computation_layout_;
2336   return Status::OK();
2337 }
2338 
ClearPreviousPassSideEffects(HloModule * module)2339 Status LayoutAssignment::ClearPreviousPassSideEffects(HloModule* module) {
2340   VLOG(5) << "Clearing previous side effects";
2341   // Clear all the copies which have been added, and all the related
2342   // instructions (like GTE and tuples).
2343   int64 removed_copies = 0;
2344   for (HloComputation* computation : module->computations()) {
2345     for (HloInstruction* instruction :
2346          computation->MakeInstructionPostOrder()) {
2347       if (instruction->opcode() == HloOpcode::kCopy &&
2348           added_copies_.contains(instruction)) {
2349         VLOG(5) << "Removing added copy: " << instruction->ToString();
2350         TF_RETURN_IF_ERROR(
2351             instruction->ReplaceAllUsesWith(instruction->mutable_operand(0)));
2352         TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));
2353         ++removed_copies;
2354       }
2355     }
2356   }
2357   added_copies_.clear();
2358   unconstrained_layout_instructions_.clear();
2359   if (removed_copies > 0) {
2360     TupleSimplifier tuple_simplifier;
2361     HloDCE dce;
2362     TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
2363     TF_RETURN_IF_ERROR(dce.Run(module).status());
2364     call_graph_ = CallGraph::Build(module);
2365   }
2366   return Status::OK();
2367 }
2368 
AddCopyForOperand(HloInstruction * instruction,int64 operand_number)2369 Status LayoutAssignment::AddCopyForOperand(HloInstruction* instruction,
2370                                            int64 operand_number) {
2371   HloInstruction* operand = instruction->mutable_operand(operand_number);
2372   if (operand->opcode() != HloOpcode::kCopy || operand->user_count() > 1) {
2373     HloInstruction* copy =
2374         instruction->parent()->AddInstruction(HloInstruction::CreateUnary(
2375             operand->shape(), HloOpcode::kCopy, operand));
2376     SetupCopiedInstruction(*operand, copy, {});
2377     LayoutUtil::ClearLayout(copy->mutable_shape());
2378     TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(operand_number, copy));
2379   }
2380   return Status::OK();
2381 }
2382 
2383 }  // namespace xla
2384