1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/stable_sort_expander.h"
17
18 #include <limits>
19 #include <memory>
20 #include <vector>
21
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
28 #include "tensorflow/compiler/xla/service/op_expander_pass.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30
31 namespace xla {
32
33 // Looks for a iota operand that can be used as tie breaker in the computation.
34 // If no matching iota operand is found, a iota operand is added to Sort. The
35 // comparison computation is adjusted to break ties using the values from the
36 // iota operand.
ExpandInstruction(HloInstruction * instruction)37 StatusOr<HloInstruction*> StableSortExpander::ExpandInstruction(
38 HloInstruction* instruction) {
39 auto* sort = Cast<HloSortInstruction>(instruction);
40 HloComputation* computation = sort->parent();
41
42 HloInstruction* expanded_sort = nullptr;
43 absl::flat_hash_set<int64> used_indices;
44 int64 iota_index = -1;
45 for (const HloInstruction* operand : sort->operands()) {
46 // We can only use the iota operand if it has an iota dimension which is the
47 // same as the dimension to sort. Also it should have an integral type that
48 // is large enough for the number of elements in the sort dimension. For
49 // now, we only allow S32, because we expect to find a S32 iota operand for
50 // all Sort ops which are created by TopK.
51 // TODO(b/122298745): Also support other types.
52 if (operand->opcode() == HloOpcode::kIota &&
53 Cast<HloIotaInstruction>(operand)->iota_dimension() ==
54 sort->sort_dimension() &&
55 operand->shape().element_type() == S32) {
56 iota_index = sort->operand_index(operand);
57 break;
58 }
59 }
60
61 // If there is currently no iota operand which we could use for making the
62 // sort stable, we will have to add a new such operand.
63 if (iota_index == -1) {
64 Shape iota_shape = sort->operand(0)->shape();
65 // We might need to use S64 if the number of elements in the sort dimension
66 // is bigger than 2^31 - 1.
67 // TODO(b/122298745): Handle Sort ops where S32 is too small for the number
68 // of elements in the sort dimension.
69 if (iota_shape.dimensions(sort->sort_dimension()) >
70 std::numeric_limits<int32>::max()) {
71 return Unimplemented(
72 "Stable sorting of more than 2^31-1 elements is not implemented");
73 }
74 iota_shape.set_element_type(S32);
75 auto iota = computation->AddInstruction(
76 HloInstruction::CreateIota(iota_shape, sort->sort_dimension()));
77
78 // Create a new comparator.
79 auto comparator = sort->to_apply();
80 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
81 replacements;
82 std::vector<std::unique_ptr<HloInstruction>> extra_parameters;
83 std::vector<HloInstruction*> extra_parameter_ptrs;
84 Shape scalar_shape = ShapeUtil::MakeShape(S32, {});
85 extra_parameters.push_back(HloInstruction::CreateParameter(
86 sort->operand_count() * 2, scalar_shape,
87 absl::StrCat("p.", sort->operand_count(), ".lhs")));
88 extra_parameter_ptrs.push_back(extra_parameters.back().get());
89 extra_parameters.push_back(HloInstruction::CreateParameter(
90 sort->operand_count() * 2 + 1, scalar_shape,
91 absl::StrCat("p.", sort->operand_count(), ".rhs")));
92 extra_parameter_ptrs.push_back(extra_parameters.back().get());
93 sort->set_to_apply(sort->GetModule()->AddEmbeddedComputation(
94 comparator->CloneWithReplacements(std::move(replacements),
95 extra_parameter_ptrs)));
96
97 // Replace the original sort op.
98 std::vector<HloInstruction*> new_operands(sort->operands().begin(),
99 sort->operands().end());
100 new_operands.push_back(iota);
101 std::vector<Shape> new_shapes = sort->operand_count() == 1
102 ? std::vector<Shape>{sort->shape()}
103 : sort->shape().tuple_shapes();
104 new_shapes.push_back(iota_shape);
105 Shape new_sort_shape = ShapeUtil::MakeTupleShape(new_shapes);
106 HloInstruction* new_sort = computation->AddInstruction(
107 sort->CloneWithNewOperands(new_sort_shape, new_operands));
108
109 // Add a "wrapper" around the new sort op to make sure we have the same
110 // shape as before. For the rank 1 case, we only need a GetTupleElement,
111 // otherwise we create a Tuple consisting of GetTupleElements of the new
112 // sort.
113 std::vector<HloInstruction*> tuple_elements;
114 tuple_elements.reserve(sort->operand_count());
115 for (int64 i = 0; i < sort->operand_count(); ++i) {
116 tuple_elements.push_back(
117 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
118 sort->operand(i)->shape(), new_sort, i)));
119 }
120 expanded_sort = tuple_elements[0];
121 if (tuple_elements.size() > 1) {
122 expanded_sort = computation->AddInstruction(
123 HloInstruction::CreateTuple(tuple_elements));
124 }
125 sort = Cast<HloSortInstruction>(new_sort);
126 iota_index = sort->operand_count() - 1;
127 }
128
129 // Modify the computation to break ties using the iota operand.
130 auto comparator = sort->to_apply();
131 std::vector<HloInstruction*> instructions_postorder =
132 comparator->MakeInstructionPostOrder();
133 absl::flat_hash_map<HloInstruction*, HloInstruction*> replacements;
134 // Look up instr in the replacements map, and return either the replacement,
135 // or instr, if the replacement isn't present.
136 auto replace = [&](HloInstruction* instr) {
137 auto it = replacements.find(instr);
138 if (it == replacements.end()) {
139 return instr;
140 }
141 return it->second;
142 };
143 HloInstruction* old_root = comparator->root_instruction();
144 // The comparison computation gets 2 * n parameters (n being the number of
145 // operands of Sort), where parameters 2 * i and 2 * i + 1 correspond to two
146 // different scalars of operand i of Sort which are to be compared. The
147 // comparison computation should induce a strict weak order, so if
148 // to_apply(p1.lhs, p1.rhs, ..., pn.lhs, pn.rhs) is equal to
149 // to_apply(p1.rhs, p1.lhs, ..., pn.rhs, pn.lhs), we can conclude that the
150 // values to be compared are equivalent, and perform a tie-breaker comparison.
151 //
152 // We clone each instruction with at least one operand, but use as new
153 // operands of the instruction the replacements of the original operands.
154 // Parameter 2 * i is replaced by parameter 2 * i + 1 and vice versa. This
155 // should make sure that the cloned root instruction gives the result of the
156 // comparison computation when being called with each scalar pair reversed.
157 // parameters corresponding to the iota operand.
158 for (int64 i = 0; i < comparator->num_parameters(); ++i) {
159 replacements[comparator->parameter_instruction(i)] =
160 comparator->parameter_instruction(i ^ 1);
161 }
162 HloInstruction* cloned_root = nullptr;
163 for (HloInstruction* inst : instructions_postorder) {
164 if (inst->operand_count() == 0) {
165 continue;
166 }
167 std::vector<HloInstruction*> new_operands;
168 new_operands.reserve(inst->operand_count());
169 for (HloInstruction* operand : inst->operands()) {
170 new_operands.push_back(replace(operand));
171 }
172 auto new_instruction =
173 inst->CloneWithNewOperands(inst->shape(), new_operands);
174 replacements[inst] = new_instruction.get();
175 if (inst == old_root) {
176 cloned_root = new_instruction.get();
177 }
178 comparator->AddInstruction(std::move(new_instruction));
179 }
180 CHECK_NE(cloned_root, nullptr);
181 Shape scalar_pred = ShapeUtil::MakeShape(PRED, {});
182 HloInstruction* same =
183 comparator->AddInstruction(HloInstruction::CreateCompare(
184 scalar_pred, old_root, cloned_root, ComparisonDirection::kEq));
185 HloInstruction* tie_breaker =
186 comparator->AddInstruction(HloInstruction::CreateCompare(
187 scalar_pred, comparator->parameter_instruction(2 * iota_index),
188 comparator->parameter_instruction(2 * iota_index + 1),
189 ComparisonDirection::kLt));
190 HloInstruction* new_root =
191 comparator->AddInstruction(HloInstruction::CreateTernary(
192 ShapeUtil::MakeShape(PRED, {}), HloOpcode::kSelect, same, tie_breaker,
193 old_root));
194 comparator->set_root_instruction(new_root);
195
196 return expanded_sort;
197 }
198
InstructionMatchesPattern(HloInstruction * instruction)199 bool StableSortExpander::InstructionMatchesPattern(
200 HloInstruction* instruction) {
201 return instruction->opcode() == HloOpcode::kSort &&
202 Cast<HloSortInstruction>(instruction)->is_stable();
203 }
204 } // namespace xla
205