1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/sort_simplifier.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/statusor.h"
26 
27 namespace xla {
28 namespace {
29 
30 // If the sort instruction has a tuple shape then looks for unused output
31 // values and removes them from the sort instruction. Returns true if the
32 // graph has been modified.
RemoveUnusedOperandFromSort(HloInstruction * sort)33 StatusOr<bool> RemoveUnusedOperandFromSort(HloInstruction* sort) {
34   if (!sort->shape().IsTuple()) {
35     return false;
36   }
37 
38   HloComputation* computation = sort->parent();
39 
40   if (computation->root_instruction() == sort) {
41     // Can't analyse users of the root instruction.
42     return false;
43   }
44 
45   absl::flat_hash_set<int64> used_indices;
46   for (const HloInstruction* user : sort->users()) {
47     if (user->opcode() != HloOpcode::kGetTupleElement) {
48       // Can't analyse users other then get-tuple-element.
49       return false;
50     }
51     used_indices.insert(user->tuple_index());
52   }
53 
54   // Also note which parameters are used by the comparator computation.
55   auto comparator = sort->to_apply();
56   for (int64 i = 0; i < sort->operand_count() * 2; ++i) {
57     if (comparator->parameter_instruction(i)->user_count() > 0) {
58       // operand i corresponds to parameters 2 * i and 2 * i + 1 of the
59       // computation.
60       used_indices.insert(i / 2);
61     }
62   }
63 
64   if (used_indices.size() == sort->operand_count()) {
65     // All operands are used.
66     return false;
67   }
68 
69   std::vector<HloInstruction*> operands;
70   std::vector<Shape> new_shapes;
71   for (int64 i = 0; i < sort->operand_count(); ++i) {
72     if (used_indices.contains(i)) {
73       operands.push_back(sort->mutable_operand(i));
74       new_shapes.push_back(sort->operand(i)->shape());
75     }
76   }
77 
78   Shape new_sort_shape = new_shapes.size() == 1
79                              ? new_shapes[0]
80                              : ShapeUtil::MakeTupleShape(new_shapes);
81   HloInstruction* new_sort = computation->AddInstruction(
82       sort->CloneWithNewOperands(new_sort_shape, operands));
83   absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
84       replacements;
85   int64 parameter_number = 0;
86   for (int64 i = 0; i < sort->operand_count(); ++i) {
87     auto* old_lhs_parameter = comparator->parameter_instruction(i * 2);
88     auto* old_rhs_parameter = comparator->parameter_instruction(i * 2 + 1);
89     if (used_indices.contains(i)) {
90       Shape scalar_shape =
91           ShapeUtil::MakeShape(sort->operand(i)->shape().element_type(), {});
92       replacements[old_lhs_parameter] = HloInstruction::CreateParameter(
93           parameter_number, scalar_shape,
94           absl::StrCat("p.", parameter_number / 2, ".lhs"));
95       ++parameter_number;
96       replacements[old_rhs_parameter] = HloInstruction::CreateParameter(
97           parameter_number, scalar_shape,
98           absl::StrCat("p.", parameter_number / 2, ".rhs"));
99       ++parameter_number;
100     } else {
101       replacements[old_lhs_parameter] = nullptr;
102       replacements[old_rhs_parameter] = nullptr;
103     }
104   }
105   HloModule* module = sort->GetModule();
106   HloComputation* new_compare = module->AddEmbeddedComputation(
107       comparator->CloneWithReplacements(std::move(replacements)));
108   new_sort->set_to_apply(new_compare);
109 
110   // Map from original get-tuple-element tuple index to new HLO instruction
111   absl::flat_hash_map<int64, HloInstruction*> result_map;
112   if (new_sort->shape().IsTuple()) {
113     // Old sort key maps to new sort key.
114     int64 new_index = 0;
115     for (int64 i = 0; i < sort->operand_count(); ++i) {
116       if (used_indices.count(i)) {
117         result_map[i] =
118             computation->AddInstruction(HloInstruction::CreateGetTupleElement(
119                 new_shapes[new_index], new_sort, new_index));
120         ++new_index;
121       }
122     }
123   } else {
124     CHECK_EQ(used_indices.size(), 1);
125     result_map[*used_indices.begin()] = new_sort;
126   }
127   std::vector<HloInstruction*> users(sort->users().begin(),
128                                      sort->users().end());
129   for (HloInstruction* user : users) {
130     TF_RETURN_IF_ERROR(
131         user->ReplaceAllUsesWith(result_map.at(user->tuple_index())));
132     TF_RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands(user));
133   }
134   return true;
135 }
136 }  // namespace
137 
Run(HloModule * module)138 StatusOr<bool> SortSimplifier::Run(HloModule* module) {
139   VLOG(2) << "HLO module before SortSimplifier:";
140   XLA_VLOG_LINES(2, module->ToString());
141 
142   bool changed = false;
143   std::vector<HloInstruction*> sort_instrs;
144   for (auto* comp : module->MakeNonfusionComputations()) {
145     absl::c_copy_if(comp->instructions(), std::back_inserter(sort_instrs),
146                     [](const HloInstruction* instr) {
147                       return instr->opcode() == HloOpcode::kSort;
148                     });
149   }
150 
151   for (HloInstruction* sort_instr : sort_instrs) {
152     TF_ASSIGN_OR_RETURN(bool result, RemoveUnusedOperandFromSort(sort_instr));
153     changed |= result;
154   }
155 
156   if (changed) {
157     VLOG(2) << "HLO module after SortSimplifier:";
158     XLA_VLOG_LINES(2, module->ToString());
159   } else {
160     VLOG(2) << "HLO module unchanged after SortSimplifier";
161   }
162 
163   return changed;
164 }
165 }  // namespace xla
166