1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/all_to_all_decomposer.h"
17 
18 #include <vector>
19 
20 #include "absl/algorithm/container.h"
21 #include "absl/strings/str_join.h"
22 #include "absl/types/optional.h"
23 #include "tensorflow/compiler/xla/layout_util.h"
24 #include "tensorflow/compiler/xla/literal_util.h"
25 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/core/platform/logging.h"
35 
36 namespace xla {
InstructionMatchesPattern(HloInstruction * instruction)37 bool AllToAllDecomposer::InstructionMatchesPattern(
38     HloInstruction* instruction) {
39   auto* all_to_all = DynCast<HloAllToAllInstruction>(instruction);
40   if (all_to_all == nullptr) {
41     return false;
42   }
43   // Do not attempt to change layout constrained collectives.
44   if (all_to_all->constrain_layout()) {
45     return false;
46   }
47   if (all_to_all->shape().IsTuple()) {
48     return false;
49   }
50   if (decompose_to_tuple_) {
51     return true;
52   }
53   return all_to_all->shape().rank() < min_array_rank_;
54 }
ExpandInstruction(HloInstruction * instruction)55 StatusOr<HloInstruction*> AllToAllDecomposer::ExpandInstruction(
56     HloInstruction* instruction) {
57   auto* all_to_all = Cast<HloAllToAllInstruction>(instruction);
58   int64 split_dim = *all_to_all->split_dimension();
59   int64 all_to_all_group_size =
60       all_to_all->replica_groups().empty()
61           ? instruction->parent()->parent()->config().replica_count()
62           : all_to_all->replica_groups()[0].replica_ids_size();
63   int64 split_size =
64       all_to_all->shape().dimensions(split_dim) / all_to_all_group_size;
65   if (!decompose_to_tuple_) {
66     Shape new_all_to_all_shape;
67     new_all_to_all_shape.set_element_type(
68         instruction->operand(0)->shape().element_type());
69     for (int64 i = 0; i < instruction->shape().rank(); ++i) {
70       if (i != split_dim) {
71         new_all_to_all_shape.add_dimensions(all_to_all->shape().dimensions(i));
72         continue;
73       }
74       new_all_to_all_shape.add_dimensions(all_to_all_group_size);
75       new_all_to_all_shape.add_dimensions(split_size);
76       for (int64 j = all_to_all->shape().rank() + 1; j < min_array_rank_; ++j) {
77         new_all_to_all_shape.add_dimensions(1);
78       }
79     }
80     *(new_all_to_all_shape.mutable_layout()) =
81         LayoutUtil::GetDefaultLayoutForRank(min_array_rank_);
82     HloInstruction* operand_reshape =
83         instruction->parent()->AddInstruction(HloInstruction::CreateReshape(
84             new_all_to_all_shape, instruction->mutable_operand(0)));
85     instruction->SetupDerivedInstruction(operand_reshape);
86     HloInstruction* all_to_all =
87         instruction->parent()->AddInstruction(instruction->CloneWithNewOperands(
88             new_all_to_all_shape, {operand_reshape}));
89     HloInstruction* output_reshape = instruction->parent()->AddInstruction(
90         HloInstruction::CreateReshape(instruction->shape(), all_to_all));
91     instruction->SetupDerivedInstruction(output_reshape);
92     return output_reshape;
93   }
94   DimensionVector slice_starts(all_to_all->shape().rank(), 0);
95   DimensionVector slice_strides(all_to_all->shape().rank(), 1);
96   DimensionVector slice_limits(all_to_all->shape().dimensions().begin(),
97                                all_to_all->shape().dimensions().end());
98   slice_limits[split_dim] = split_size;
99   Shape slice_shape = all_to_all->shape();
100   slice_shape.set_dimensions(split_dim, split_size);
101   std::vector<HloInstruction*> slices;
102   slices.reserve(all_to_all_group_size);
103   HloInstruction* operand = all_to_all->mutable_operand(0);
104   for (int64 i = 0; i < all_to_all_group_size; ++i) {
105     slices.push_back(
106         all_to_all->parent()->AddInstruction(HloInstruction::CreateSlice(
107             slice_shape, operand, slice_starts, slice_limits, slice_strides)));
108     all_to_all->SetupDerivedInstruction(slices.back());
109     slice_starts[split_dim] = slice_limits[split_dim];
110     slice_limits[split_dim] += split_size;
111   }
112   Shape all_to_all_shape = ShapeUtil::MakeTupleShape(
113       std::vector<Shape>(all_to_all_group_size, slice_shape));
114   HloInstruction* new_all_to_all =
115       all_to_all->parent()->AddInstruction(HloInstruction::CreateAllToAll(
116           all_to_all_shape, slices, all_to_all->replica_groups(), false,
117           all_to_all->channel_id(), absl::nullopt));
118   std::vector<HloInstruction*> gtes;
119   gtes.reserve(all_to_all_group_size);
120   for (int64 i = 0; i < all_to_all_group_size; ++i) {
121     gtes.push_back(all_to_all->parent()->AddInstruction(
122         HloInstruction::CreateGetTupleElement(slice_shape, new_all_to_all, i)));
123     all_to_all->SetupDerivedInstruction(new_all_to_all);
124   }
125   HloInstruction* concat = all_to_all->parent()->AddInstruction(
126       HloInstruction::CreateConcatenate(all_to_all->shape(), gtes, split_dim));
127   all_to_all->SetupDerivedInstruction(concat);
128   return concat;
129 }
130 
131 }  // namespace xla
132