1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Implementation note:
17 //
18 // The general idea behind this pass is that we're converting from this:
19 //   %param.A = OldShape
20 //   %param.B = OldShape
21 //   %reshape.A = NewShape reshape(%param.A)
22 //   %reshape.B = NewShape reshape(%param.B)
23 //   %instruction = NewShape instruction(%reshape.A, %reshape.B)
24 // To this:
25 //   %param.A = OldShape
26 //   %param.B = OldShape
27 //   %instruction = OldShape instruction(%param.A, %param.B)
28 //   %reshape = NewShape reshape(%instruction)
29 //
30 // Where the instruction must be elementwise, and both reshapes and transposes
31 // are moved.
32 
33 #include "tensorflow/compiler/xla/service/reshape_mover.h"
34 
35 #include <algorithm>
36 
37 #include "absl/algorithm/container.h"
38 #include "tensorflow/compiler/xla/literal.h"
39 #include "tensorflow/compiler/xla/shape_util.h"
40 #include "tensorflow/compiler/xla/status_macros.h"
41 #include "tensorflow/compiler/xla/util.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 
44 namespace xla {
45 
46 namespace {
47 
IsReshapeOrTranspose(const HloInstruction * instruction)48 bool IsReshapeOrTranspose(const HloInstruction* instruction) {
49   return instruction->opcode() == HloOpcode::kReshape ||
50          instruction->opcode() == HloOpcode::kTranspose;
51 }
52 
53 // Returns true if `instruction` can change its shape simply by adjusting
54 // metadata or if `instruction` is a broadcast of a scalar value.
CanTriviallyChangeShape(const HloInstruction * instruction)55 bool CanTriviallyChangeShape(const HloInstruction* instruction) {
56   // NOTE: Technically a sequence of reshape(reshape(constant)) is also
57   // trivially reshapable, so we might be tempted to simply recurse if
58   // IsReshapeOrTranspose(instruction)==true.
59   //
60   // But it's not that simple. E.g. reshape(reshape(rng)) is only trivially
61   // reshapable if *all* instructions in the chain have user_count == 1. And
62   // reshape(scalar) isn't trivial at all if the reshape itself isn't scalar.
63   // In addition, these cases make it harder to maintain correctness of the
64   // UpdateOperand logic below.
65   //
66   // So don't handle these chains, unless you update the tests and code to deal
67   // with these properly. One idea is to add a pass immediately beforehand that
68   // collapses trivial runs of reshapes / transposes.
69 
70   // A constant can trivially reshape the literal it holds.
71   if (instruction->opcode() == HloOpcode::kConstant) {
72     return true;
73   }
74 
75   // An Rng instruction can be any shape as long as it has one user. Two copies
76   // of the same Rng would be problematic if an Rng of a different shape would
77   // produce random numbers in a different order.
78   if (instruction->opcode() == HloOpcode::kRng &&
79       instruction->user_count() == 1) {
80     return true;
81   }
82 
83   // A broadcase of scalar can trivially change its shape.
84   if (instruction->opcode() == HloOpcode::kBroadcast &&
85       ShapeUtil::IsScalar(instruction->operand(0)->shape())) {
86     return true;
87   }
88 
89   return false;
90 }
91 
92 // Returns true iff `instruction` is a reshape/transpose instruction for which
93 // a shape change is nontrivial.
IsNontrivialReshape(const HloInstruction * instruction)94 bool IsNontrivialReshape(const HloInstruction* instruction) {
95   return !ShapeUtil::IsScalar(instruction->shape()) &&
96          IsReshapeOrTranspose(instruction) &&
97          !CanTriviallyChangeShape(instruction->operand(0));
98 }
99 
100 // Finds the first operand of an instruction that is a non-trivial reshape or
101 // transpose. Returns such an operand or nullptr if not found.
FirstNonScalarAndNonTrivialReshapeOperand(const HloInstruction * hlo)102 HloInstruction* FirstNonScalarAndNonTrivialReshapeOperand(
103     const HloInstruction* hlo) {
104   for (HloInstruction* operand : hlo->operands()) {
105     if (IsNontrivialReshape(operand)) {
106       VLOG(5) << "Found first non-trivial reshape operand of "
107               << hlo->ToString(HloPrintOptions().set_print_metadata(false))
108               << ":\n\t"
109               << operand->ToString(HloPrintOptions().set_print_metadata(false));
110       return operand;
111     }
112   }
113   return nullptr;
114 }
115 
116 // Returns whether `a` and `b` are equivalent reshapes/transposes.
AreEquivalentReshapes(const HloInstruction * a,const HloInstruction * b)117 bool AreEquivalentReshapes(const HloInstruction* a, const HloInstruction* b) {
118   if (a->opcode() != b->opcode() ||
119       !ShapeUtil::SameDimensions(a->shape(), b->shape())) {
120     return false;
121   }
122   switch (a->opcode()) {
123     case HloOpcode::kTranspose:
124       return a->dimensions() == b->dimensions();
125     case HloOpcode::kReshape:
126       return ShapeUtil::SameDimensions(a->operand(0)->shape(),
127                                        b->operand(0)->shape());
128     default:
129       return false;
130   }
131 }
132 
133 // This function is called once we've decided to sink reshape/transpose operands
134 // across an instruction. It returns an updated `operand` with a shape that
135 // plays nicely with `new_operand_shape`; it has the same shape (of the
136 // correct type).
UpdateOperand(const HloInstruction * first_reshape_operand,const Shape & new_operand_shape,HloInstruction * operand)137 HloInstruction* UpdateOperand(const HloInstruction* first_reshape_operand,
138                               const Shape& new_operand_shape,
139                               HloInstruction* operand) {
140   HloComputation* computation = operand->parent();
141   const PrimitiveType element_type = operand->shape().element_type();
142   const Shape new_shape =
143       ShapeUtil::ChangeElementType(new_operand_shape, element_type);
144 
145   switch (operand->opcode()) {
146     case HloOpcode::kConstant: {
147       if (first_reshape_operand->opcode() == HloOpcode::kReshape) {
148         VLOG(5) << "Adding reshape to kConstant operand";
149         return computation->AddInstruction(
150             HloInstruction::CreateReshape(new_shape, operand));
151       } else {
152         CHECK(first_reshape_operand->opcode() == HloOpcode::kTranspose);
153         VLOG(5) << "Adding transpose to kConstant operand";
154         std::vector<int64> inverse_permutation =
155             InversePermutation(first_reshape_operand->dimensions());
156         return computation->AddInstruction(HloInstruction::CreateTranspose(
157             new_shape, operand, inverse_permutation));
158       }
159     }
160     case HloOpcode::kRng: {
161       CHECK_EQ(operand->user_count(), 1);
162       VLOG(5) << "Cloning kRng operand with new shape";
163       return computation->AddInstruction(
164           operand->CloneWithNewOperands(new_shape, operand->operands()));
165     }
166     case HloOpcode::kReshape:
167     case HloOpcode::kTranspose: {
168       VLOG(5) << "Using existing operand of kReshape or kTranspose";
169       return operand->mutable_operand(0);
170     }
171     case HloOpcode::kBroadcast: {
172       CHECK(ShapeUtil::IsScalar(operand->operand(0)->shape()));
173       HloInstruction* inst = computation->AddInstruction(
174           operand->CloneWithNewOperands(new_shape, operand->operands()));
175       VLOG(5) << "Changing broadcast from " << operand->ToString() << " to "
176               << inst->ToString();
177       return inst;
178     }
179 
180     default:
181       LOG(FATAL) << "Unexpected operand opcode during update: " << operand;
182   }
183 }
184 
185 // Actually performs the reshape-move transformation -- that is, sinks the
186 // reshape or transpose operands of `instruction` across it.
PerformSinkReshapeOrTranspose(HloInstruction * instruction,const HloInstruction * first_reshape_operand)187 StatusOr<bool> PerformSinkReshapeOrTranspose(
188     HloInstruction* instruction, const HloInstruction* first_reshape_operand) {
189   auto print_no_metadata = HloPrintOptions().set_print_metadata(false);
190   // At this point we've decided to sink reshape/transpose operands.
191   const Shape& new_operand_shape = first_reshape_operand->operand(0)->shape();
192   VLOG(3) << "** Sinking reshape or transpose: "
193           << instruction->ToString(print_no_metadata)
194           << "\n\tfirst reshape operand: "
195           << first_reshape_operand->ToString(print_no_metadata)
196           << "\n\tnew operand shape: "
197           << ShapeUtil::HumanString(new_operand_shape);
198 
199   auto operands = instruction->operands();
200   for (size_t i = 0; i < operands.size(); ++i) {
201     // All scalar operands remain as-is, even if they're reshape or transpose,
202     // to simplify handling wrt special scalar broadcast rules for ops like
203     // Select. Scalar reshapes should be cheap anyways.
204     if (ShapeUtil::IsScalar(operands[i]->shape())) {
205       continue;
206     }
207     VLOG(3) << "Updating operand #" << i << ": "
208             << operands[i]->ToString(print_no_metadata);
209     operands[i] =
210         UpdateOperand(first_reshape_operand, new_operand_shape, operands[i]);
211   }
212   if (HloOpcode::kFusion == instruction->opcode()) {
213     // Here we already know `instruction` is elementwise, and all the fused
214     // instructions have the same dimensions.
215     for (const auto& fused_instruction : instruction->fused_instructions()) {
216       Shape* shape = fused_instruction->mutable_shape();
217       shape->clear_dimensions();
218       for (int64 i : new_operand_shape.dimensions()) {
219         shape->add_dimensions(i);
220       }
221       *shape->mutable_layout() = new_operand_shape.layout();
222     }
223   }
224   HloComputation* computation = instruction->parent();
225   HloInstruction* new_elementwise =
226       computation->AddInstruction(instruction->CloneWithNewOperands(
227           // `instruction` may change the element type, e.g., from
228           //   operands[0] -> reshape -> convert (`instruction`)
229           // to
230           //   operands[0] -> convert' -> reshape'
231           //
232           // In this case, convert' should have the same element type as
233           // `convert` and the same dimensions as operands[0].
234           ShapeUtil::ChangeElementType(new_operand_shape,
235                                        instruction->shape().element_type()),
236           operands));
237 
238   std::unique_ptr<HloInstruction> new_reshape;
239   switch (first_reshape_operand->opcode()) {
240     case HloOpcode::kReshape:
241       VLOG(3) << "Creating new reshape for new elementwise op: "
242               << new_elementwise->ToString(print_no_metadata);
243       new_reshape =
244           HloInstruction::CreateReshape(instruction->shape(), new_elementwise);
245       break;
246     case HloOpcode::kTranspose:
247       new_reshape =
248           HloInstruction::CreateTranspose(instruction->shape(), new_elementwise,
249                                           first_reshape_operand->dimensions());
250       break;
251     default:
252       LOG(FATAL) << "Bad opcode";
253   }
254   TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
255       instruction, std::move(new_reshape)));
256   return true;
257 }
258 
259 // Returns true if the instruction is a reshape-move candidate.
260 //
261 // An instruction is a reshape-move candidate if the instruction is elementwise,
262 // has at least one nontrivial reshape/transpose operand, and its operands are
263 // either trivially reshapable or are equivalent nontrivial reshapes/transposes.
IsReshapeMoveCandidate(HloInstruction * instruction)264 bool IsReshapeMoveCandidate(HloInstruction* instruction) {
265   auto print_no_metadata = HloPrintOptions().set_print_metadata(false);
266   VLOG(5) << "** Checking instruction: "
267           << instruction->ToString(print_no_metadata);
268 
269   // Only perform reshape-move for live elementwise instructions with operands.
270   const bool is_dead = instruction->user_count() == 0 &&
271                        instruction != instruction->parent()->root_instruction();
272   if (!instruction->IsElementwise() || instruction->operands().empty() ||
273       is_dead) {
274     return false;
275   }
276 
277   // Check whether all operands:
278   //    0. Have the same dimensions as the output.
279   //
280   // And one of the following:
281   //    1. Are reshapes or transposes that have the same input and
282   //       output shapes as all other reshaped or transposed operands.
283   //     or
284   //    2. Are one of kConstant, kRng, broadcast of a scalar value.
285   const HloInstruction* first_reshape_operand = nullptr;
286   for (const HloInstruction* operand : instruction->operands()) {
287     if (!ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) {
288       VLOG(5) << "Operand shape differs from output shape; so preventing "
289                  "movement\n\toperand: "
290               << operand->ToString(print_no_metadata) << "\n\tinstruction: "
291               << instruction->ToString(print_no_metadata);
292       return false;
293     }
294 
295     if (CanTriviallyChangeShape(operand)) {
296       VLOG(5) << "Operand can trivially change shape: "
297               << operand->ToString(print_no_metadata);
298       continue;
299     }
300 
301     if (!IsNontrivialReshape(operand)) {
302       VLOG(5) << "Operand can't trivially change shape: "
303               << operand->ToString(print_no_metadata);
304       return false;
305     }
306 
307     if (first_reshape_operand == nullptr) {
308       first_reshape_operand = operand;
309       VLOG(5) << "First reshape operand "
310               << operand->ToString(print_no_metadata);
311     } else if (AreEquivalentReshapes(first_reshape_operand, operand)) {
312       VLOG(5)
313           << "Operand is an equivalent reshape of the first reshape operand "
314           << operand->ToString(print_no_metadata);
315     } else {
316       // TODO(someone): Look into supporting general ops for the operands as
317       // well.
318       VLOG(5) << "Operand is a reshape but is not equivalent to the first "
319                  "Reshape operand"
320               << operand->ToString(print_no_metadata);
321       return false;
322     }
323   }
324 
325   if (first_reshape_operand) {
326     VLOG(5) << "All operands have easy shape changes: "
327             << instruction->ToString(print_no_metadata);
328   }
329 
330   return first_reshape_operand != nullptr;
331 }
332 
333 // Reshape-moves all qualifying instructions in reshape_candidates.  Returns
334 // true if it makes changes.
335 //
336 // `reshape_candidates` is a set of HloInstructions with nontrivial reshape
337 // operands, and a instruction in the set can be reshape-moved iff all the users
338 // of its nontrivial reshape operands can also be reshaped-moved.
339 //
340 // The algorithm here iteratively finds the nontrivial operands with users that
341 // are outside the set of `reshape_candidates`, and removes their users from
342 // `reshape_candidates`, until either `reshape_candidates` becomes empty or none
343 // of the remaining nontrivial operands have users outside `reshape_candidates`.
344 // In the later case, all the remaining instructions in `reshape_candidates`
345 // are reshape-moved and the routine returns true.
TryReshapeMoveOnCandidates(HloInstructionSet * reshape_candidates)346 StatusOr<bool> TryReshapeMoveOnCandidates(
347     HloInstructionSet* reshape_candidates) {
348   bool removed = true;
349   while (!reshape_candidates->empty() && removed) {
350     if (VLOG_IS_ON(5)) {
351       for (const HloInstruction* instruction : *reshape_candidates) {
352         VLOG(5) << "candidate " << instruction->ToString();
353       }
354     }
355     ConstHloInstructionSet nontrivial_operands;
356     for (const HloInstruction* instruction : *reshape_candidates) {
357       for (const auto* operand : instruction->operands()) {
358         if (IsNontrivialReshape(operand)) {
359           nontrivial_operands.insert(operand);
360         }
361       }
362     }
363 
364     removed = false;
365     for (auto operand : nontrivial_operands) {
366       if (absl::c_any_of(operand->users(), [&](HloInstruction* user) {
367             return !reshape_candidates->count(user);
368           })) {
369         for (auto* user : operand->users()) {
370           removed |= reshape_candidates->erase(user) > 0;
371         }
372       }
373     }
374   }
375 
376   if (reshape_candidates->empty()) {
377     return false;
378   }
379   for (HloInstruction* instruction : *reshape_candidates) {
380     const HloInstruction* first_reshape_operand =
381         FirstNonScalarAndNonTrivialReshapeOperand(instruction);
382     TF_ASSIGN_OR_RETURN(
383         bool did_change,
384         PerformSinkReshapeOrTranspose(instruction, first_reshape_operand));
385     CHECK(did_change);
386   }
387   return true;
388 }
389 
390 }  // namespace
391 
Run(HloModule * module)392 StatusOr<bool> ReshapeMover::Run(HloModule* module) {
393   bool changed = false;
394   VLOG(2) << "Pre ReshapeMover HLO:";
395   XLA_VLOG_LINES(2, module->ToString());
396   for (auto* comp : module->MakeNonfusionComputations()) {
397     HloInstructionSet reshape_candidates;
398     for (HloInstruction* instruction : comp->instructions()) {
399       if (IsReshapeMoveCandidate(instruction)) {
400         reshape_candidates.insert(instruction);
401       }
402     }
403     TF_ASSIGN_OR_RETURN(bool did_change,
404                         TryReshapeMoveOnCandidates(&reshape_candidates));
405     changed |= did_change;
406   }
407   VLOG(2) << "Post ReshapeMover HLO:";
408   XLA_VLOG_LINES(2, module->ToString());
409   return changed;
410 }
411 
412 }  // namespace xla
413