1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/sort_simplifier.h"
17
18 #include <memory>
19 #include <vector>
20
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/statusor.h"
26
27 namespace xla {
28 namespace {
29
30 // If the sort instruction has a tuple shape then looks for unused output
31 // values and removes them from the sort instruction. Returns true if the
32 // graph has been modified.
RemoveUnusedOperandFromSort(HloInstruction * sort)33 StatusOr<bool> RemoveUnusedOperandFromSort(HloInstruction* sort) {
34 if (!sort->shape().IsTuple()) {
35 return false;
36 }
37
38 HloComputation* computation = sort->parent();
39
40 if (computation->root_instruction() == sort) {
41 // Can't analyse users of the root instruction.
42 return false;
43 }
44
45 absl::flat_hash_set<int64> used_indices;
46 for (const HloInstruction* user : sort->users()) {
47 if (user->opcode() != HloOpcode::kGetTupleElement) {
48 // Can't analyse users other then get-tuple-element.
49 return false;
50 }
51 used_indices.insert(user->tuple_index());
52 }
53
54 // Also note which parameters are used by the comparator computation.
55 auto comparator = sort->to_apply();
56 for (int64 i = 0; i < sort->operand_count() * 2; ++i) {
57 if (comparator->parameter_instruction(i)->user_count() > 0) {
58 // operand i corresponds to parameters 2 * i and 2 * i + 1 of the
59 // computation.
60 used_indices.insert(i / 2);
61 }
62 }
63
64 if (used_indices.size() == sort->operand_count()) {
65 // All operands are used.
66 return false;
67 }
68
69 std::vector<HloInstruction*> operands;
70 std::vector<Shape> new_shapes;
71 for (int64 i = 0; i < sort->operand_count(); ++i) {
72 if (used_indices.contains(i)) {
73 operands.push_back(sort->mutable_operand(i));
74 new_shapes.push_back(sort->operand(i)->shape());
75 }
76 }
77
78 Shape new_sort_shape = new_shapes.size() == 1
79 ? new_shapes[0]
80 : ShapeUtil::MakeTupleShape(new_shapes);
81 HloInstruction* new_sort = computation->AddInstruction(
82 sort->CloneWithNewOperands(new_sort_shape, operands));
83 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
84 replacements;
85 int64 parameter_number = 0;
86 for (int64 i = 0; i < sort->operand_count(); ++i) {
87 auto* old_lhs_parameter = comparator->parameter_instruction(i * 2);
88 auto* old_rhs_parameter = comparator->parameter_instruction(i * 2 + 1);
89 if (used_indices.contains(i)) {
90 Shape scalar_shape =
91 ShapeUtil::MakeShape(sort->operand(i)->shape().element_type(), {});
92 replacements[old_lhs_parameter] = HloInstruction::CreateParameter(
93 parameter_number, scalar_shape,
94 absl::StrCat("p.", parameter_number / 2, ".lhs"));
95 ++parameter_number;
96 replacements[old_rhs_parameter] = HloInstruction::CreateParameter(
97 parameter_number, scalar_shape,
98 absl::StrCat("p.", parameter_number / 2, ".rhs"));
99 ++parameter_number;
100 } else {
101 replacements[old_lhs_parameter] = nullptr;
102 replacements[old_rhs_parameter] = nullptr;
103 }
104 }
105 HloModule* module = sort->GetModule();
106 HloComputation* new_compare = module->AddEmbeddedComputation(
107 comparator->CloneWithReplacements(std::move(replacements)));
108 new_sort->set_to_apply(new_compare);
109
110 // Map from original get-tuple-element tuple index to new HLO instruction
111 absl::flat_hash_map<int64, HloInstruction*> result_map;
112 if (new_sort->shape().IsTuple()) {
113 // Old sort key maps to new sort key.
114 int64 new_index = 0;
115 for (int64 i = 0; i < sort->operand_count(); ++i) {
116 if (used_indices.count(i)) {
117 result_map[i] =
118 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
119 new_shapes[new_index], new_sort, new_index));
120 ++new_index;
121 }
122 }
123 } else {
124 CHECK_EQ(used_indices.size(), 1);
125 result_map[*used_indices.begin()] = new_sort;
126 }
127 std::vector<HloInstruction*> users(sort->users().begin(),
128 sort->users().end());
129 for (HloInstruction* user : users) {
130 TF_RETURN_IF_ERROR(
131 user->ReplaceAllUsesWith(result_map.at(user->tuple_index())));
132 TF_RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands(user));
133 }
134 return true;
135 }
136 } // namespace
137
Run(HloModule * module)138 StatusOr<bool> SortSimplifier::Run(HloModule* module) {
139 VLOG(2) << "HLO module before SortSimplifier:";
140 XLA_VLOG_LINES(2, module->ToString());
141
142 bool changed = false;
143 std::vector<HloInstruction*> sort_instrs;
144 for (auto* comp : module->MakeNonfusionComputations()) {
145 absl::c_copy_if(comp->instructions(), std::back_inserter(sort_instrs),
146 [](const HloInstruction* instr) {
147 return instr->opcode() == HloOpcode::kSort;
148 });
149 }
150
151 for (HloInstruction* sort_instr : sort_instrs) {
152 TF_ASSIGN_OR_RETURN(bool result, RemoveUnusedOperandFromSort(sort_instr));
153 changed |= result;
154 }
155
156 if (changed) {
157 VLOG(2) << "HLO module after SortSimplifier:";
158 XLA_VLOG_LINES(2, module->ToString());
159 } else {
160 VLOG(2) << "HLO module unchanged after SortSimplifier";
161 }
162
163 return changed;
164 }
165 } // namespace xla
166