1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Implementation note:
17 //
18 // The general idea behind this pass is that we're converting from this:
19 // %param.A = OldShape
20 // %param.B = OldShape
21 // %reshape.A = NewShape reshape(%param.A)
22 // %reshape.B = NewShape reshape(%param.B)
23 // %instruction = NewShape instruction(%reshape.A, %reshape.B)
24 // To this:
25 // %param.A = OldShape
26 // %param.B = OldShape
27 // %instruction = OldShape instruction(%param.A, %param.B)
28 // %reshape = NewShape reshape(%instruction)
29 //
30 // Where the instruction must be elementwise, and both reshapes and transposes
31 // are moved.
32
33 #include "tensorflow/compiler/xla/service/reshape_mover.h"
34
35 #include <algorithm>
36
37 #include "absl/algorithm/container.h"
38 #include "tensorflow/compiler/xla/literal.h"
39 #include "tensorflow/compiler/xla/shape_util.h"
40 #include "tensorflow/compiler/xla/status_macros.h"
41 #include "tensorflow/compiler/xla/util.h"
42 #include "tensorflow/core/lib/core/errors.h"
43
44 namespace xla {
45
46 namespace {
47
IsReshapeOrTranspose(const HloInstruction * instruction)48 bool IsReshapeOrTranspose(const HloInstruction* instruction) {
49 return instruction->opcode() == HloOpcode::kReshape ||
50 instruction->opcode() == HloOpcode::kTranspose;
51 }
52
53 // Returns true if `instruction` can change its shape simply by adjusting
54 // metadata or if `instruction` is a broadcast of a scalar value.
CanTriviallyChangeShape(const HloInstruction * instruction)55 bool CanTriviallyChangeShape(const HloInstruction* instruction) {
56 // NOTE: Technically a sequence of reshape(reshape(constant)) is also
57 // trivially reshapable, so we might be tempted to simply recurse if
58 // IsReshapeOrTranspose(instruction)==true.
59 //
60 // But it's not that simple. E.g. reshape(reshape(rng)) is only trivially
61 // reshapable if *all* instructions in the chain have user_count == 1. And
62 // reshape(scalar) isn't trivial at all if the reshape itself isn't scalar.
63 // In addition, these cases make it harder to maintain correctness of the
64 // UpdateOperand logic below.
65 //
66 // So don't handle these chains, unless you update the tests and code to deal
67 // with these properly. One idea is to add a pass immediately beforehand that
68 // collapses trivial runs of reshapes / transposes.
69
70 // A constant can trivially reshape the literal it holds.
71 if (instruction->opcode() == HloOpcode::kConstant) {
72 return true;
73 }
74
75 // An Rng instruction can be any shape as long as it has one user. Two copies
76 // of the same Rng would be problematic if an Rng of a different shape would
77 // produce random numbers in a different order.
78 if (instruction->opcode() == HloOpcode::kRng &&
79 instruction->user_count() == 1) {
80 return true;
81 }
82
83 // A broadcase of scalar can trivially change its shape.
84 if (instruction->opcode() == HloOpcode::kBroadcast &&
85 ShapeUtil::IsScalar(instruction->operand(0)->shape())) {
86 return true;
87 }
88
89 return false;
90 }
91
92 // Returns true iff `instruction` is a reshape/transpose instruction for which
93 // a shape change is nontrivial.
IsNontrivialReshape(const HloInstruction * instruction)94 bool IsNontrivialReshape(const HloInstruction* instruction) {
95 return !ShapeUtil::IsScalar(instruction->shape()) &&
96 IsReshapeOrTranspose(instruction) &&
97 !CanTriviallyChangeShape(instruction->operand(0));
98 }
99
100 // Finds the first operand of an instruction that is a non-trivial reshape or
101 // transpose. Returns such an operand or nullptr if not found.
FirstNonScalarAndNonTrivialReshapeOperand(const HloInstruction * hlo)102 HloInstruction* FirstNonScalarAndNonTrivialReshapeOperand(
103 const HloInstruction* hlo) {
104 for (HloInstruction* operand : hlo->operands()) {
105 if (IsNontrivialReshape(operand)) {
106 VLOG(5) << "Found first non-trivial reshape operand of "
107 << hlo->ToString(HloPrintOptions().set_print_metadata(false))
108 << ":\n\t"
109 << operand->ToString(HloPrintOptions().set_print_metadata(false));
110 return operand;
111 }
112 }
113 return nullptr;
114 }
115
116 // Returns whether `a` and `b` are equivalent reshapes/transposes.
AreEquivalentReshapes(const HloInstruction * a,const HloInstruction * b)117 bool AreEquivalentReshapes(const HloInstruction* a, const HloInstruction* b) {
118 if (a->opcode() != b->opcode() ||
119 !ShapeUtil::SameDimensions(a->shape(), b->shape())) {
120 return false;
121 }
122 switch (a->opcode()) {
123 case HloOpcode::kTranspose:
124 return a->dimensions() == b->dimensions();
125 case HloOpcode::kReshape:
126 return ShapeUtil::SameDimensions(a->operand(0)->shape(),
127 b->operand(0)->shape());
128 default:
129 return false;
130 }
131 }
132
133 // This function is called once we've decided to sink reshape/transpose operands
134 // across an instruction. It returns an updated `operand` with a shape that
135 // plays nicely with `new_operand_shape`; it has the same shape (of the
136 // correct type).
UpdateOperand(const HloInstruction * first_reshape_operand,const Shape & new_operand_shape,HloInstruction * operand)137 HloInstruction* UpdateOperand(const HloInstruction* first_reshape_operand,
138 const Shape& new_operand_shape,
139 HloInstruction* operand) {
140 HloComputation* computation = operand->parent();
141 const PrimitiveType element_type = operand->shape().element_type();
142 const Shape new_shape =
143 ShapeUtil::ChangeElementType(new_operand_shape, element_type);
144
145 switch (operand->opcode()) {
146 case HloOpcode::kConstant: {
147 if (first_reshape_operand->opcode() == HloOpcode::kReshape) {
148 VLOG(5) << "Adding reshape to kConstant operand";
149 return computation->AddInstruction(
150 HloInstruction::CreateReshape(new_shape, operand));
151 } else {
152 CHECK(first_reshape_operand->opcode() == HloOpcode::kTranspose);
153 VLOG(5) << "Adding transpose to kConstant operand";
154 std::vector<int64> inverse_permutation =
155 InversePermutation(first_reshape_operand->dimensions());
156 return computation->AddInstruction(HloInstruction::CreateTranspose(
157 new_shape, operand, inverse_permutation));
158 }
159 }
160 case HloOpcode::kRng: {
161 CHECK_EQ(operand->user_count(), 1);
162 VLOG(5) << "Cloning kRng operand with new shape";
163 return computation->AddInstruction(
164 operand->CloneWithNewOperands(new_shape, operand->operands()));
165 }
166 case HloOpcode::kReshape:
167 case HloOpcode::kTranspose: {
168 VLOG(5) << "Using existing operand of kReshape or kTranspose";
169 return operand->mutable_operand(0);
170 }
171 case HloOpcode::kBroadcast: {
172 CHECK(ShapeUtil::IsScalar(operand->operand(0)->shape()));
173 HloInstruction* inst = computation->AddInstruction(
174 operand->CloneWithNewOperands(new_shape, operand->operands()));
175 VLOG(5) << "Changing broadcast from " << operand->ToString() << " to "
176 << inst->ToString();
177 return inst;
178 }
179
180 default:
181 LOG(FATAL) << "Unexpected operand opcode during update: " << operand;
182 }
183 }
184
185 // Actually performs the reshape-move transformation -- that is, sinks the
186 // reshape or transpose operands of `instruction` across it.
PerformSinkReshapeOrTranspose(HloInstruction * instruction,const HloInstruction * first_reshape_operand)187 StatusOr<bool> PerformSinkReshapeOrTranspose(
188 HloInstruction* instruction, const HloInstruction* first_reshape_operand) {
189 auto print_no_metadata = HloPrintOptions().set_print_metadata(false);
190 // At this point we've decided to sink reshape/transpose operands.
191 const Shape& new_operand_shape = first_reshape_operand->operand(0)->shape();
192 VLOG(3) << "** Sinking reshape or transpose: "
193 << instruction->ToString(print_no_metadata)
194 << "\n\tfirst reshape operand: "
195 << first_reshape_operand->ToString(print_no_metadata)
196 << "\n\tnew operand shape: "
197 << ShapeUtil::HumanString(new_operand_shape);
198
199 auto operands = instruction->operands();
200 for (size_t i = 0; i < operands.size(); ++i) {
201 // All scalar operands remain as-is, even if they're reshape or transpose,
202 // to simplify handling wrt special scalar broadcast rules for ops like
203 // Select. Scalar reshapes should be cheap anyways.
204 if (ShapeUtil::IsScalar(operands[i]->shape())) {
205 continue;
206 }
207 VLOG(3) << "Updating operand #" << i << ": "
208 << operands[i]->ToString(print_no_metadata);
209 operands[i] =
210 UpdateOperand(first_reshape_operand, new_operand_shape, operands[i]);
211 }
212 if (HloOpcode::kFusion == instruction->opcode()) {
213 // Here we already know `instruction` is elementwise, and all the fused
214 // instructions have the same dimensions.
215 for (const auto& fused_instruction : instruction->fused_instructions()) {
216 Shape* shape = fused_instruction->mutable_shape();
217 shape->clear_dimensions();
218 for (int64 i : new_operand_shape.dimensions()) {
219 shape->add_dimensions(i);
220 }
221 *shape->mutable_layout() = new_operand_shape.layout();
222 }
223 }
224 HloComputation* computation = instruction->parent();
225 HloInstruction* new_elementwise =
226 computation->AddInstruction(instruction->CloneWithNewOperands(
227 // `instruction` may change the element type, e.g., from
228 // operands[0] -> reshape -> convert (`instruction`)
229 // to
230 // operands[0] -> convert' -> reshape'
231 //
232 // In this case, convert' should have the same element type as
233 // `convert` and the same dimensions as operands[0].
234 ShapeUtil::ChangeElementType(new_operand_shape,
235 instruction->shape().element_type()),
236 operands));
237
238 std::unique_ptr<HloInstruction> new_reshape;
239 switch (first_reshape_operand->opcode()) {
240 case HloOpcode::kReshape:
241 VLOG(3) << "Creating new reshape for new elementwise op: "
242 << new_elementwise->ToString(print_no_metadata);
243 new_reshape =
244 HloInstruction::CreateReshape(instruction->shape(), new_elementwise);
245 break;
246 case HloOpcode::kTranspose:
247 new_reshape =
248 HloInstruction::CreateTranspose(instruction->shape(), new_elementwise,
249 first_reshape_operand->dimensions());
250 break;
251 default:
252 LOG(FATAL) << "Bad opcode";
253 }
254 TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
255 instruction, std::move(new_reshape)));
256 return true;
257 }
258
259 // Returns true if the instruction is a reshape-move candidate.
260 //
261 // An instruction is a reshape-move candidate if the instruction is elementwise,
262 // has at least one nontrivial reshape/transpose operand, and its operands are
263 // either trivially reshapable or are equivalent nontrivial reshapes/transposes.
IsReshapeMoveCandidate(HloInstruction * instruction)264 bool IsReshapeMoveCandidate(HloInstruction* instruction) {
265 auto print_no_metadata = HloPrintOptions().set_print_metadata(false);
266 VLOG(5) << "** Checking instruction: "
267 << instruction->ToString(print_no_metadata);
268
269 // Only perform reshape-move for live elementwise instructions with operands.
270 const bool is_dead = instruction->user_count() == 0 &&
271 instruction != instruction->parent()->root_instruction();
272 if (!instruction->IsElementwise() || instruction->operands().empty() ||
273 is_dead) {
274 return false;
275 }
276
277 // Check whether all operands:
278 // 0. Have the same dimensions as the output.
279 //
280 // And one of the following:
281 // 1. Are reshapes or transposes that have the same input and
282 // output shapes as all other reshaped or transposed operands.
283 // or
284 // 2. Are one of kConstant, kRng, broadcast of a scalar value.
285 const HloInstruction* first_reshape_operand = nullptr;
286 for (const HloInstruction* operand : instruction->operands()) {
287 if (!ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) {
288 VLOG(5) << "Operand shape differs from output shape; so preventing "
289 "movement\n\toperand: "
290 << operand->ToString(print_no_metadata) << "\n\tinstruction: "
291 << instruction->ToString(print_no_metadata);
292 return false;
293 }
294
295 if (CanTriviallyChangeShape(operand)) {
296 VLOG(5) << "Operand can trivially change shape: "
297 << operand->ToString(print_no_metadata);
298 continue;
299 }
300
301 if (!IsNontrivialReshape(operand)) {
302 VLOG(5) << "Operand can't trivially change shape: "
303 << operand->ToString(print_no_metadata);
304 return false;
305 }
306
307 if (first_reshape_operand == nullptr) {
308 first_reshape_operand = operand;
309 VLOG(5) << "First reshape operand "
310 << operand->ToString(print_no_metadata);
311 } else if (AreEquivalentReshapes(first_reshape_operand, operand)) {
312 VLOG(5)
313 << "Operand is an equivalent reshape of the first reshape operand "
314 << operand->ToString(print_no_metadata);
315 } else {
316 // TODO(someone): Look into supporting general ops for the operands as
317 // well.
318 VLOG(5) << "Operand is a reshape but is not equivalent to the first "
319 "Reshape operand"
320 << operand->ToString(print_no_metadata);
321 return false;
322 }
323 }
324
325 if (first_reshape_operand) {
326 VLOG(5) << "All operands have easy shape changes: "
327 << instruction->ToString(print_no_metadata);
328 }
329
330 return first_reshape_operand != nullptr;
331 }
332
333 // Reshape-moves all qualifying instructions in reshape_candidates. Returns
334 // true if it makes changes.
335 //
336 // `reshape_candidates` is a set of HloInstructions with nontrivial reshape
337 // operands, and a instruction in the set can be reshape-moved iff all the users
338 // of its nontrivial reshape operands can also be reshaped-moved.
339 //
340 // The algorithm here iteratively finds the nontrivial operands with users that
341 // are outside the set of `reshape_candidates`, and removes their users from
342 // `reshape_candidates`, until either `reshape_candidates` becomes empty or none
343 // of the remaining nontrivial operands have users outside `reshape_candidates`.
344 // In the later case, all the remaining instructions in `reshape_candidates`
345 // are reshape-moved and the routine returns true.
TryReshapeMoveOnCandidates(HloInstructionSet * reshape_candidates)346 StatusOr<bool> TryReshapeMoveOnCandidates(
347 HloInstructionSet* reshape_candidates) {
348 bool removed = true;
349 while (!reshape_candidates->empty() && removed) {
350 if (VLOG_IS_ON(5)) {
351 for (const HloInstruction* instruction : *reshape_candidates) {
352 VLOG(5) << "candidate " << instruction->ToString();
353 }
354 }
355 ConstHloInstructionSet nontrivial_operands;
356 for (const HloInstruction* instruction : *reshape_candidates) {
357 for (const auto* operand : instruction->operands()) {
358 if (IsNontrivialReshape(operand)) {
359 nontrivial_operands.insert(operand);
360 }
361 }
362 }
363
364 removed = false;
365 for (auto operand : nontrivial_operands) {
366 if (absl::c_any_of(operand->users(), [&](HloInstruction* user) {
367 return !reshape_candidates->count(user);
368 })) {
369 for (auto* user : operand->users()) {
370 removed |= reshape_candidates->erase(user) > 0;
371 }
372 }
373 }
374 }
375
376 if (reshape_candidates->empty()) {
377 return false;
378 }
379 for (HloInstruction* instruction : *reshape_candidates) {
380 const HloInstruction* first_reshape_operand =
381 FirstNonScalarAndNonTrivialReshapeOperand(instruction);
382 TF_ASSIGN_OR_RETURN(
383 bool did_change,
384 PerformSinkReshapeOrTranspose(instruction, first_reshape_operand));
385 CHECK(did_change);
386 }
387 return true;
388 }
389
390 } // namespace
391
Run(HloModule * module)392 StatusOr<bool> ReshapeMover::Run(HloModule* module) {
393 bool changed = false;
394 VLOG(2) << "Pre ReshapeMover HLO:";
395 XLA_VLOG_LINES(2, module->ToString());
396 for (auto* comp : module->MakeNonfusionComputations()) {
397 HloInstructionSet reshape_candidates;
398 for (HloInstruction* instruction : comp->instructions()) {
399 if (IsReshapeMoveCandidate(instruction)) {
400 reshape_candidates.insert(instruction);
401 }
402 }
403 TF_ASSIGN_OR_RETURN(bool did_change,
404 TryReshapeMoveOnCandidates(&reshape_candidates));
405 changed |= did_change;
406 }
407 VLOG(2) << "Post ReshapeMover HLO:";
408 XLA_VLOG_LINES(2, module->ToString());
409 return changed;
410 }
411
412 } // namespace xla
413