1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/stable_sort_expander.h"
17 
18 #include <limits>
19 #include <memory>
20 #include <vector>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
28 #include "tensorflow/compiler/xla/service/op_expander_pass.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 
31 namespace xla {
32 
33 // Looks for a iota operand that can be used as tie breaker in the computation.
34 // If no matching iota operand is found, a iota operand is added to Sort. The
35 // comparison computation is adjusted to break ties using the values from the
36 // iota operand.
ExpandInstruction(HloInstruction * instruction)37 StatusOr<HloInstruction*> StableSortExpander::ExpandInstruction(
38     HloInstruction* instruction) {
39   auto* sort = Cast<HloSortInstruction>(instruction);
40   HloComputation* computation = sort->parent();
41 
42   HloInstruction* expanded_sort = nullptr;
43   absl::flat_hash_set<int64> used_indices;
44   int64 iota_index = -1;
45   for (const HloInstruction* operand : sort->operands()) {
46     // We can only use the iota operand if it has an iota dimension which is the
47     // same as the dimension to sort. Also it should have an integral type that
48     // is large enough for the number of elements in the sort dimension. For
49     // now, we only allow S32, because we expect to find a S32 iota operand for
50     // all Sort ops which are created by TopK.
51     // TODO(b/122298745): Also support other types.
52     if (operand->opcode() == HloOpcode::kIota &&
53         Cast<HloIotaInstruction>(operand)->iota_dimension() ==
54             sort->sort_dimension() &&
55         operand->shape().element_type() == S32) {
56       iota_index = sort->operand_index(operand);
57       break;
58     }
59   }
60 
61   // If there is currently no iota operand which we could use for making the
62   // sort stable, we will have to add a new such operand.
63   if (iota_index == -1) {
64     Shape iota_shape = sort->operand(0)->shape();
65     // We might need to use S64 if the number of elements in the sort dimension
66     // is bigger than 2^31 - 1.
67     // TODO(b/122298745): Handle Sort ops where S32 is too small for the number
68     // of elements in the sort dimension.
69     if (iota_shape.dimensions(sort->sort_dimension()) >
70         std::numeric_limits<int32>::max()) {
71       return Unimplemented(
72           "Stable sorting of more than 2^31-1 elements is not implemented");
73     }
74     iota_shape.set_element_type(S32);
75     auto iota = computation->AddInstruction(
76         HloInstruction::CreateIota(iota_shape, sort->sort_dimension()));
77 
78     // Create a new comparator.
79     auto comparator = sort->to_apply();
80     absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
81         replacements;
82     std::vector<std::unique_ptr<HloInstruction>> extra_parameters;
83     std::vector<HloInstruction*> extra_parameter_ptrs;
84     Shape scalar_shape = ShapeUtil::MakeShape(S32, {});
85     extra_parameters.push_back(HloInstruction::CreateParameter(
86         sort->operand_count() * 2, scalar_shape,
87         absl::StrCat("p.", sort->operand_count(), ".lhs")));
88     extra_parameter_ptrs.push_back(extra_parameters.back().get());
89     extra_parameters.push_back(HloInstruction::CreateParameter(
90         sort->operand_count() * 2 + 1, scalar_shape,
91         absl::StrCat("p.", sort->operand_count(), ".rhs")));
92     extra_parameter_ptrs.push_back(extra_parameters.back().get());
93     sort->set_to_apply(sort->GetModule()->AddEmbeddedComputation(
94         comparator->CloneWithReplacements(std::move(replacements),
95                                           extra_parameter_ptrs)));
96 
97     // Replace the original sort op.
98     std::vector<HloInstruction*> new_operands(sort->operands().begin(),
99                                               sort->operands().end());
100     new_operands.push_back(iota);
101     std::vector<Shape> new_shapes = sort->operand_count() == 1
102                                         ? std::vector<Shape>{sort->shape()}
103                                         : sort->shape().tuple_shapes();
104     new_shapes.push_back(iota_shape);
105     Shape new_sort_shape = ShapeUtil::MakeTupleShape(new_shapes);
106     HloInstruction* new_sort = computation->AddInstruction(
107         sort->CloneWithNewOperands(new_sort_shape, new_operands));
108 
109     // Add a "wrapper" around the new sort op to make sure we have the same
110     // shape as before. For the rank 1 case, we only need a GetTupleElement,
111     // otherwise we create a Tuple consisting of GetTupleElements of the new
112     // sort.
113     std::vector<HloInstruction*> tuple_elements;
114     tuple_elements.reserve(sort->operand_count());
115     for (int64 i = 0; i < sort->operand_count(); ++i) {
116       tuple_elements.push_back(
117           computation->AddInstruction(HloInstruction::CreateGetTupleElement(
118               sort->operand(i)->shape(), new_sort, i)));
119     }
120     expanded_sort = tuple_elements[0];
121     if (tuple_elements.size() > 1) {
122       expanded_sort = computation->AddInstruction(
123           HloInstruction::CreateTuple(tuple_elements));
124     }
125     sort = Cast<HloSortInstruction>(new_sort);
126     iota_index = sort->operand_count() - 1;
127   }
128 
129   // Modify the computation to break ties using the iota operand.
130   auto comparator = sort->to_apply();
131   std::vector<HloInstruction*> instructions_postorder =
132       comparator->MakeInstructionPostOrder();
133   absl::flat_hash_map<HloInstruction*, HloInstruction*> replacements;
134   // Look up instr in the replacements map, and return either the replacement,
135   // or instr, if the replacement isn't present.
136   auto replace = [&](HloInstruction* instr) {
137     auto it = replacements.find(instr);
138     if (it == replacements.end()) {
139       return instr;
140     }
141     return it->second;
142   };
143   HloInstruction* old_root = comparator->root_instruction();
144   // The comparison computation gets 2 * n parameters (n being the number of
145   // operands of Sort), where parameters 2 * i and 2 * i + 1 correspond to two
146   // different scalars of operand i of Sort which are to be compared. The
147   // comparison computation should induce a strict weak order, so if
148   // to_apply(p1.lhs, p1.rhs, ..., pn.lhs, pn.rhs) is equal to
149   // to_apply(p1.rhs, p1.lhs, ..., pn.rhs, pn.lhs), we can conclude that the
150   // values to be compared are equivalent, and perform a tie-breaker comparison.
151   //
152   // We clone each instruction with at least one operand, but use as new
153   // operands of the instruction the replacements of the original operands.
154   // Parameter 2 * i is replaced by parameter 2 * i + 1 and vice versa. This
155   // should make sure that the cloned root instruction gives the result of the
156   // comparison computation when being called with each scalar pair reversed.
157   // parameters corresponding to the iota operand.
158   for (int64 i = 0; i < comparator->num_parameters(); ++i) {
159     replacements[comparator->parameter_instruction(i)] =
160         comparator->parameter_instruction(i ^ 1);
161   }
162   HloInstruction* cloned_root = nullptr;
163   for (HloInstruction* inst : instructions_postorder) {
164     if (inst->operand_count() == 0) {
165       continue;
166     }
167     std::vector<HloInstruction*> new_operands;
168     new_operands.reserve(inst->operand_count());
169     for (HloInstruction* operand : inst->operands()) {
170       new_operands.push_back(replace(operand));
171     }
172     auto new_instruction =
173         inst->CloneWithNewOperands(inst->shape(), new_operands);
174     replacements[inst] = new_instruction.get();
175     if (inst == old_root) {
176       cloned_root = new_instruction.get();
177     }
178     comparator->AddInstruction(std::move(new_instruction));
179   }
180   CHECK_NE(cloned_root, nullptr);
181   Shape scalar_pred = ShapeUtil::MakeShape(PRED, {});
182   HloInstruction* same =
183       comparator->AddInstruction(HloInstruction::CreateCompare(
184           scalar_pred, old_root, cloned_root, ComparisonDirection::kEq));
185   HloInstruction* tie_breaker =
186       comparator->AddInstruction(HloInstruction::CreateCompare(
187           scalar_pred, comparator->parameter_instruction(2 * iota_index),
188           comparator->parameter_instruction(2 * iota_index + 1),
189           ComparisonDirection::kLt));
190   HloInstruction* new_root =
191       comparator->AddInstruction(HloInstruction::CreateTernary(
192           ShapeUtil::MakeShape(PRED, {}), HloOpcode::kSelect, same, tie_breaker,
193           old_root));
194   comparator->set_root_instruction(new_root);
195 
196   return expanded_sort;
197 }
198 
InstructionMatchesPattern(HloInstruction * instruction)199 bool StableSortExpander::InstructionMatchesPattern(
200     HloInstruction* instruction) {
201   return instruction->opcode() == HloOpcode::kSort &&
202          Cast<HloSortInstruction>(instruction)->is_stable();
203 }
204 }  // namespace xla
205