1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/all_to_all_decomposer.h"
17
18 #include <vector>
19
20 #include "absl/algorithm/container.h"
21 #include "absl/strings/str_join.h"
22 #include "absl/types/optional.h"
23 #include "tensorflow/compiler/xla/layout_util.h"
24 #include "tensorflow/compiler/xla/literal_util.h"
25 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/core/platform/logging.h"
35
36 namespace xla {
InstructionMatchesPattern(HloInstruction * instruction)37 bool AllToAllDecomposer::InstructionMatchesPattern(
38 HloInstruction* instruction) {
39 auto* all_to_all = DynCast<HloAllToAllInstruction>(instruction);
40 if (all_to_all == nullptr) {
41 return false;
42 }
43 // Do not attempt to change layout constrained collectives.
44 if (all_to_all->constrain_layout()) {
45 return false;
46 }
47 if (all_to_all->shape().IsTuple()) {
48 return false;
49 }
50 if (decompose_to_tuple_) {
51 return true;
52 }
53 return all_to_all->shape().rank() < min_array_rank_;
54 }
ExpandInstruction(HloInstruction * instruction)55 StatusOr<HloInstruction*> AllToAllDecomposer::ExpandInstruction(
56 HloInstruction* instruction) {
57 auto* all_to_all = Cast<HloAllToAllInstruction>(instruction);
58 int64 split_dim = *all_to_all->split_dimension();
59 int64 all_to_all_group_size =
60 all_to_all->replica_groups().empty()
61 ? instruction->parent()->parent()->config().replica_count()
62 : all_to_all->replica_groups()[0].replica_ids_size();
63 int64 split_size =
64 all_to_all->shape().dimensions(split_dim) / all_to_all_group_size;
65 if (!decompose_to_tuple_) {
66 Shape new_all_to_all_shape;
67 new_all_to_all_shape.set_element_type(
68 instruction->operand(0)->shape().element_type());
69 for (int64 i = 0; i < instruction->shape().rank(); ++i) {
70 if (i != split_dim) {
71 new_all_to_all_shape.add_dimensions(all_to_all->shape().dimensions(i));
72 continue;
73 }
74 new_all_to_all_shape.add_dimensions(all_to_all_group_size);
75 new_all_to_all_shape.add_dimensions(split_size);
76 for (int64 j = all_to_all->shape().rank() + 1; j < min_array_rank_; ++j) {
77 new_all_to_all_shape.add_dimensions(1);
78 }
79 }
80 *(new_all_to_all_shape.mutable_layout()) =
81 LayoutUtil::GetDefaultLayoutForRank(min_array_rank_);
82 HloInstruction* operand_reshape =
83 instruction->parent()->AddInstruction(HloInstruction::CreateReshape(
84 new_all_to_all_shape, instruction->mutable_operand(0)));
85 instruction->SetupDerivedInstruction(operand_reshape);
86 HloInstruction* all_to_all =
87 instruction->parent()->AddInstruction(instruction->CloneWithNewOperands(
88 new_all_to_all_shape, {operand_reshape}));
89 HloInstruction* output_reshape = instruction->parent()->AddInstruction(
90 HloInstruction::CreateReshape(instruction->shape(), all_to_all));
91 instruction->SetupDerivedInstruction(output_reshape);
92 return output_reshape;
93 }
94 DimensionVector slice_starts(all_to_all->shape().rank(), 0);
95 DimensionVector slice_strides(all_to_all->shape().rank(), 1);
96 DimensionVector slice_limits(all_to_all->shape().dimensions().begin(),
97 all_to_all->shape().dimensions().end());
98 slice_limits[split_dim] = split_size;
99 Shape slice_shape = all_to_all->shape();
100 slice_shape.set_dimensions(split_dim, split_size);
101 std::vector<HloInstruction*> slices;
102 slices.reserve(all_to_all_group_size);
103 HloInstruction* operand = all_to_all->mutable_operand(0);
104 for (int64 i = 0; i < all_to_all_group_size; ++i) {
105 slices.push_back(
106 all_to_all->parent()->AddInstruction(HloInstruction::CreateSlice(
107 slice_shape, operand, slice_starts, slice_limits, slice_strides)));
108 all_to_all->SetupDerivedInstruction(slices.back());
109 slice_starts[split_dim] = slice_limits[split_dim];
110 slice_limits[split_dim] += split_size;
111 }
112 Shape all_to_all_shape = ShapeUtil::MakeTupleShape(
113 std::vector<Shape>(all_to_all_group_size, slice_shape));
114 HloInstruction* new_all_to_all =
115 all_to_all->parent()->AddInstruction(HloInstruction::CreateAllToAll(
116 all_to_all_shape, slices, all_to_all->replica_groups(), false,
117 all_to_all->channel_id(), absl::nullopt));
118 std::vector<HloInstruction*> gtes;
119 gtes.reserve(all_to_all_group_size);
120 for (int64 i = 0; i < all_to_all_group_size; ++i) {
121 gtes.push_back(all_to_all->parent()->AddInstruction(
122 HloInstruction::CreateGetTupleElement(slice_shape, new_all_to_all, i)));
123 all_to_all->SetupDerivedInstruction(new_all_to_all);
124 }
125 HloInstruction* concat = all_to_all->parent()->AddInstruction(
126 HloInstruction::CreateConcatenate(all_to_all->shape(), gtes, split_dim));
127 all_to_all->SetupDerivedInstruction(concat);
128 return concat;
129 }
130
131 } // namespace xla
132