1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
17 
18 #include <algorithm>
19 #include <cmath>
20 #include <functional>
21 #include <iterator>
22 #include <memory>
23 #include <numeric>
24 #include <string>
25 #include <utility>
26 #include <vector>
27 
28 #include "absl/algorithm/container.h"
29 #include "absl/container/flat_hash_map.h"
30 #include "absl/container/flat_hash_set.h"
31 #include "absl/container/inlined_vector.h"
32 #include "absl/memory/memory.h"
33 #include "absl/strings/str_cat.h"
34 #include "absl/types/optional.h"
35 #include "absl/types/span.h"
36 #include "tensorflow/compiler/xla/layout_util.h"
37 #include "tensorflow/compiler/xla/literal.h"
38 #include "tensorflow/compiler/xla/literal_util.h"
39 #include "tensorflow/compiler/xla/primitive_util.h"
40 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
41 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
42 #include "tensorflow/compiler/xla/service/hlo_computation.h"
43 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
44 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
45 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
46 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
47 #include "tensorflow/compiler/xla/service/hlo_query.h"
48 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
49 #include "tensorflow/compiler/xla/shape.h"
50 #include "tensorflow/compiler/xla/shape_util.h"
51 #include "tensorflow/compiler/xla/status_macros.h"
52 #include "tensorflow/compiler/xla/types.h"
53 #include "tensorflow/compiler/xla/util.h"
54 #include "tensorflow/compiler/xla/window_util.h"
55 #include "tensorflow/compiler/xla/xla_data.pb.h"
56 #include "tensorflow/core/lib/core/bits.h"
57 #include "tensorflow/core/lib/core/errors.h"
58 #include "tensorflow/core/lib/core/status.h"
59 #include "tensorflow/core/platform/logging.h"
60 #include "tensorflow/core/platform/types.h"
61 
62 namespace xla {
63 
64 namespace {
65 
66 namespace m = match;
67 
IsAll(const HloInstruction * op,int8 value)68 bool IsAll(const HloInstruction* op, int8 value) {
69   switch (op->opcode()) {
70     case HloOpcode::kBroadcast:
71       return IsAll(op->operand(0), value);
72     case HloOpcode::kConstant:
73       return op->literal().IsAll(value);
74     default:
75       return false;
76   }
77 }
78 
79 // Checks whether `op` is a floating-point constant or broadcast of a constant
80 // of the form +/- 2^k for some integer k positive, negative, or zero.  Such
81 // values are interesting because multiplying by a power of 2 just moves the
82 // exponent.
IsAllFpConstantPowerOf2(const HloInstruction * op)83 bool IsAllFpConstantPowerOf2(const HloInstruction* op) {
84   // Unwrap the broadcast if necessary.
85   const HloInstruction* c;
86   if (!Match(op, m::ConstantEffectiveScalar(&c)) &&
87       !Match(op, m::Broadcast(m::Constant(&c).WithShape(
88                      m::Shape().IsEffectiveScalar())))) {
89     return false;
90   }
91   auto val = [&]() -> absl::optional<double> {
92     switch (c->shape().element_type()) {
93       case BF16:
94         return static_cast<double>(c->literal().GetFirstElement<bfloat16>());
95       case F16:
96         return static_cast<double>(c->literal().GetFirstElement<Eigen::half>());
97       case F32:
98         return c->literal().GetFirstElement<float>();
99       case F64:
100         return c->literal().GetFirstElement<double>();
101       default:
102         // Cowardly refuse to consider complex types.
103         return absl::nullopt;
104     }
105   }();
106   if (!val) {
107     return false;
108   }
109 
110   int exp;
111   double mantissa = std::frexp(*val, &exp);
112   // frexp returns a value in the range (-1, -0.5] U [0.5, 1).  A return value
113   // of +/-0.5 therefore indicates that the floating point value is a power of
114   // 2.
115   return mantissa == 0.5 || mantissa == -0.5;
116 }
117 
118 // Returns whether the given transpose produces a result which is bit-wise
119 // identical to its operand and thus may be replaced with a bitcast.
TransposeIsBitcast(const HloInstruction * transpose)120 bool TransposeIsBitcast(const HloInstruction* transpose) {
121   CHECK_EQ(HloOpcode::kTranspose, transpose->opcode());
122   const HloInstruction* operand = transpose->operand(0);
123   return ShapeUtil::TransposeIsBitcast(operand->shape(), transpose->shape(),
124                                        transpose->dimensions());
125 }
126 
127 // Recursive helper for method below.
BitcastingOperandOfReshapeOrCopyChainHelper(HloInstruction * instr,HloInstruction * operand,const AlgebraicSimplifierOptions & options)128 HloInstruction* BitcastingOperandOfReshapeOrCopyChainHelper(
129     HloInstruction* instr, HloInstruction* operand,
130     const AlgebraicSimplifierOptions& options) {
131   // Can't replace chain of copies and reshapes with bitcasts if the compiler
132   // used a memory layout which isn't compatible.
133   if (options.ReshapeIsBitcast(operand->shape(), instr->shape())) {
134     return operand;
135   }
136 
137   // If the operand is a copy or reshape try to see if the operand's operand
138   // would produce a bitcast with initial instruction.
139   if (HloOpcode::kReshape == operand->opcode() ||
140       HloOpcode::kCopy == operand->opcode()) {
141     return BitcastingOperandOfReshapeOrCopyChainHelper(
142         instr, operand->mutable_operand(0), options);
143   }
144   return nullptr;
145 }
146 
147 // Returns an operand of a chain of reshapes and copies that is bit-wise
148 // identical to first reshape or copy in the chain.
BitcastingOperandOfReshapeOrCopyChain(HloInstruction * instr,const AlgebraicSimplifierOptions & options)149 HloInstruction* BitcastingOperandOfReshapeOrCopyChain(
150     HloInstruction* instr, const AlgebraicSimplifierOptions& options) {
151   if (!options.is_layout_sensitive()) {
152     return nullptr;
153   }
154   CHECK(HloOpcode::kReshape == instr->opcode() ||
155         HloOpcode::kCopy == instr->opcode());
156   return BitcastingOperandOfReshapeOrCopyChainHelper(
157       instr, instr->mutable_operand(0), options);
158 }
159 
IsUnstridedSlice(const HloInstruction * hlo)160 bool IsUnstridedSlice(const HloInstruction* hlo) {
161   return absl::c_all_of(hlo->slice_strides(),
162                         [](int64 stride) { return stride == 1; });
163 }
164 
165 // AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain
166 // algebraic expressions to simplified forms. Note: This only supports
167 // simplifications that simply look at the operands of an instruction. For the
168 // more general case a worklist based approach would be needed.
169 class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
170  public:
171   // Default visitor action is to do nothing and return OK.
DefaultAction(HloInstruction *)172   Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
173     return Status::OK();
174   }
175 
176   Status HandleAdd(HloInstruction* add) override;
177 
178   Status HandleAnd(HloInstruction* logical_and) override;
179 
180   Status HandleBitcast(HloInstruction* bitcast) override;
181 
182   Status HandleBitcastConvert(HloInstruction* bitcast) override;
183 
184   Status HandleBroadcast(HloInstruction* broadcast) override;
185 
186   Status HandleConcatenate(HloInstruction* concatenate) override;
187 
188   Status HandleConstant(HloInstruction* constant) override;
189 
190   Status HandleCopy(HloInstruction* copy) override;
191 
192   Status HandleConvert(HloInstruction* convert) override;
193 
194   Status HandleComplex(HloInstruction* complex) override;
195 
196   Status HandleReal(HloInstruction* real) override;
197 
198   Status HandleImag(HloInstruction* imag) override;
199 
200   Status HandleIota(HloInstruction* instruction) override;
201 
202   Status HandleConvolution(HloInstruction* convolution) override;
203 
204   Status HandleDivide(HloInstruction* divide) override;
205 
206   Status HandleDot(HloInstruction* dot) override;
207 
208   Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
209 
210   Status HandleLog(HloInstruction* log) override;
211 
212   Status HandleMultiply(HloInstruction* multiply) override;
213 
214   Status HandleNegate(HloInstruction* negate) override;
215 
216   Status HandleNot(HloInstruction* logical_not) override;
217 
218   Status HandleOr(HloInstruction* logical_or) override;
219 
220   Status HandlePad(HloInstruction* pad) override;
221 
222   Status HandlePower(HloInstruction* power) override;
223 
224   Status HandleRemainder(HloInstruction* remainder) override;
225 
226   Status HandleReshape(HloInstruction* reshape) override;
227 
228   Status HandleReduce(HloInstruction* reduce) override;
229 
230   Status HandleReduceWindow(HloInstruction* reduce_window) override;
231 
232   Status HandleReverse(HloInstruction* reverse) override;
233   Status HandleSlice(HloInstruction* slice) override;
234   Status HandleDynamicSlice(HloInstruction* dynamic_slice) override;
235   Status HandleDynamicUpdateSlice(
236       HloInstruction* dynamic_update_slice) override;
237 
238   Status HandleSelect(HloInstruction* select) override;
239 
240   Status HandleSort(HloInstruction* sort) override;
241 
242   Status HandleTranspose(HloInstruction* transpose) override;
243 
244   Status HandleSubtract(HloInstruction* sub) override;
245 
246   Status HandleMap(HloInstruction* map) override;
247 
248   // Returns whether algebraic simplification has occurred.
changed() const249   const bool changed() const { return changed_; }
250 
251   // Runs the visitor on a computation.
252   static bool Run(HloComputation* computation,
253                   const AlgebraicSimplifierOptions& options);
254 
255  private:
AlgebraicSimplifierVisitor(HloComputation * computation,const AlgebraicSimplifierOptions & options)256   explicit AlgebraicSimplifierVisitor(HloComputation* computation,
257                                       const AlgebraicSimplifierOptions& options)
258       : computation_(computation), options_(options) {}
259 
260   // Transforms Dots where at least one input is a vector or has a degenerate
261   // dimension and converts it into a multiply and reduce. This should enable
262   // more fusion than leaving the nodes as Dot operations.
263   StatusOr<bool> HandleDotStrengthReduction(HloInstruction* dot);
264 
265   // Removes dimension dim from hlo.
StripDim(HloInstruction * hlo,int64 dim)266   HloInstruction* StripDim(HloInstruction* hlo, int64 dim) {
267     CHECK_EQ(hlo->shape().dimensions(dim), 1);
268     return computation_->AddInstruction(HloInstruction::CreateReshape(
269         ShapeUtil::DeleteDimension(dim, hlo->shape()), hlo));
270   }
271 
272   // Reshapes an instruction to rank 1 if it is not already rank 1.
Flatten(HloInstruction * hlo)273   HloInstruction* Flatten(HloInstruction* hlo) {
274     if (hlo->shape().rank() == 1) {
275       return hlo;
276     }
277     return computation_->AddInstruction(HloInstruction::CreateReshape(
278         ShapeUtil::MakeShape(hlo->shape().element_type(),
279                              {ShapeUtil::ElementsIn(hlo->shape())}),
280         hlo));
281   }
282 
283   // Converts to primitive type if the input hlo is not that type, otherwise
284   // returns the original hlo.
AsType(HloInstruction * hlo,const PrimitiveType element_type)285   HloInstruction* AsType(HloInstruction* hlo,
286                          const PrimitiveType element_type) {
287     if (hlo->shape().element_type() == element_type) {
288       return hlo;
289     }
290     return computation_->AddInstruction(HloInstruction::CreateConvert(
291         ShapeUtil::ChangeElementType(hlo->shape(), element_type), hlo));
292   }
293 
294   // Transposes a dot operand such that the batch dimensions are the msot major,
295   // and the contracting dimensions are most minor.
NormalizeDotOperandToBatchMajorAndContractingMinor(HloInstruction * dot_operand,absl::Span<const int64> batch_dimensions,absl::Span<const int64> contracting_dimensions)296   StatusOr<HloInstruction*> NormalizeDotOperandToBatchMajorAndContractingMinor(
297       HloInstruction* dot_operand, absl::Span<const int64> batch_dimensions,
298       absl::Span<const int64> contracting_dimensions) {
299     std::vector<int64> transpose_dimensions(batch_dimensions.begin(),
300                                             batch_dimensions.end());
301     for (int64 i = 0; i < dot_operand->shape().rank(); ++i) {
302       if (!(absl::c_linear_search(batch_dimensions, i) ||
303             absl::c_linear_search(contracting_dimensions, i))) {
304         transpose_dimensions.push_back(i);
305       }
306     }
307     transpose_dimensions.insert(transpose_dimensions.end(),
308                                 contracting_dimensions.begin(),
309                                 contracting_dimensions.end());
310     return MakeTransposeHlo(dot_operand, transpose_dimensions);
311   }
312 
313   // Helper method to perform and add reduction on a list of dimensions.
AddReduce(HloInstruction * hlo,absl::Span<const int64> dims)314   HloInstruction* AddReduce(HloInstruction* hlo, absl::Span<const int64> dims) {
315     HloInstruction* zero =
316         computation_->AddInstruction(HloInstruction::CreateConstant(
317             LiteralUtil::Zero(hlo->shape().element_type()).Clone()));
318     HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation();
319     Shape shape = ShapeUtil::FilterDimensions(
320         [&](int64 dim) { return !absl::c_linear_search(dims, dim); },
321         hlo->shape());
322     return computation_->AddInstruction(HloInstruction::CreateReduce(
323         shape, hlo, zero, dims, AddReduce_computation));
324   }
325 
AddReduce(HloInstruction * hlo,int64 dim)326   HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) {
327     return AddReduce(hlo, std::vector<int64>{dim});
328   }
329 
330   // Convenience method for replacing an instruction with a bitcast. If operand
331   // is not null, then the bitcast will use the specified operand instead of the
332   // operand of the instruction.
333   void ReplaceWithBitcast(HloInstruction* instruction,
334                           HloInstruction* operand = nullptr);
335 
336   // Replace old instruction with new instruction if old and new instructions
337   // have the same shape. Updates uses and root instruction. Returns whether a
338   // replacement was made.
339   bool ReplaceInstructionIfSameShape(HloInstruction* old_instruction,
340                                      HloInstruction* new_instruction);
341 
342   // Returns whether the shape of the output of the given instructions are the
343   // same for the purposes of simplification. If options_.is_layout_sensitive()
344   // is true, then this tests shape equality including layout
345   // (ShapeUtil::Equal). If options_.is_layout_sensitive() is false, then the
346   // tests shape compatibility (ShapeUtil::Compatible).
347   bool SameShape(const HloInstruction* lhs, const HloInstruction* rhs) const;
348 
349   // Returns whether it was possible to transform `root` to a clamp instruction.
350   // With min a minimum instruction, max a maximum instruction, min_operand a
351   // operand of min and max_operand a operand of max.
352   // Precondition: root is either a minimum or a maximum.
353   bool TransformToClampIfSameShape(HloInstruction* root, HloInstruction* min,
354                                    HloInstruction* min_operand,
355                                    HloInstruction* operand, HloInstruction* max,
356                                    HloInstruction* max_operand);
357 
358   // A Broadcast that feeds an element-wise operation with a unique non-scalar
359   // operand can sink to after the operation.
360   StatusOr<bool> TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
361       HloInstruction* broadcast);
362 
363   // Replaces the existing HLO instruction old_instruction, with
364   // new_instruction, and marks the optimizer status as changed.
365   // Returns the Status representing the result of the replace operation.
ReplaceWithNewInstruction(HloInstruction * old_instruction,std::unique_ptr<HloInstruction> new_instruction)366   Status ReplaceWithNewInstruction(
367       HloInstruction* old_instruction,
368       std::unique_ptr<HloInstruction> new_instruction) {
369     VLOG(3) << "Replacing instruction:";
370     VLOG(3) << "  old: " << old_instruction->ToString();
371     VLOG(3) << "  new: " << new_instruction->ToString();
372     TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
373         old_instruction, std::move(new_instruction)));
374     changed_ = true;
375     return Status::OK();
376   }
377 
378   // Replaces the existing HLO instruction old_instruction, with
379   // new_instruction, and marks the optimizer status as changed.
380   // Returns the Status representing the result of the replace operation.
ReplaceInstruction(HloInstruction * old_instruction,HloInstruction * new_instruction)381   Status ReplaceInstruction(HloInstruction* old_instruction,
382                             HloInstruction* new_instruction) {
383     VLOG(3) << "Replacing instruction:";
384     VLOG(3) << "  old: " << old_instruction->ToString();
385     VLOG(3) << "  new: " << new_instruction->ToString();
386     TF_RETURN_IF_ERROR(
387         computation_->ReplaceInstruction(old_instruction, new_instruction));
388     changed_ = true;
389     return Status::OK();
390   }
391 
392   StatusOr<HloInstruction*> OptimizeDotOfConcat(HloInstruction* dot);
393   StatusOr<HloInstruction*> OptimizeDotOfConcatHelper(
394       const HloInstruction& dot, HloInstruction* lhs, int64 lhs_contracting_dim,
395       HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped);
396 
397   StatusOr<HloInstruction*> OptimizeDotOfGather(HloInstruction* dot);
398 
GetOrCreateScalarAddComputation()399   HloComputation* GetOrCreateScalarAddComputation() {
400     if (scalar_add_computation_) {
401       return scalar_add_computation_;
402     }
403 
404     HloComputation::Builder b("scalar_add_computation");
405     Shape shape = ShapeUtil::MakeShape(F32, {});
406     auto scalar_lhs = b.AddInstruction(
407         HloInstruction::CreateParameter(0, shape, "scalar_lhs"));
408     auto scalar_rhs = b.AddInstruction(
409         HloInstruction::CreateParameter(1, shape, "scalar_rhs"));
410     auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary(
411         shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs));
412     scalar_add_computation_ =
413         computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
414     return scalar_add_computation_;
415   }
416 
417   // Tries to fold a kPad in the input or filter into the convolution
418   // instruction's window.
419   StatusOr<bool> FoldConvInputPad(HloInstruction* convolution);
420   StatusOr<bool> FoldConvFilterPad(HloInstruction* convolution);
421 
422   // Tries to use a kDot in place of the given convolution.
423   StatusOr<bool> SimplifyConvToDot(HloInstruction* convolution);
424 
425   // Tries to simplify a slice where the result of the slice is a scalar.
426   StatusOr<bool> TrySimplifyScalarSlice(HloInstruction* slice);
427 
428   // Tries to convert slice(reshape(X)) into reshape(slice(X))
429   StatusOr<bool> TryToReorderSliceAndReshape(HloInstruction* slice);
430 
431   // Current HloComputation instance the AlgebraicSimplifierVisitor is
432   // traversing.
433   HloComputation* computation_;
434 
435   // The backend-specific options selected for the algebraic simplifier.
436   const AlgebraicSimplifierOptions& options_;
437 
438   // Whether algebraic simplification has occurred.
439   bool changed_ = false;
440 
441   // Cached computation for adding two scalar F32.
442   HloComputation* scalar_add_computation_ = nullptr;
443 };
444 
445 }  // namespace
446 
Run(HloComputation * computation,const AlgebraicSimplifierOptions & options)447 bool AlgebraicSimplifierVisitor::Run(
448     HloComputation* computation, const AlgebraicSimplifierOptions& options) {
449   AlgebraicSimplifierVisitor visitor(computation, options);
450   TF_CHECK_OK(computation->Accept(&visitor));
451   return visitor.changed_;
452 }
453 
SameShape(const HloInstruction * lhs,const HloInstruction * rhs) const454 bool AlgebraicSimplifierVisitor::SameShape(const HloInstruction* lhs,
455                                            const HloInstruction* rhs) const {
456   if (options_.is_layout_sensitive()) {
457     return ShapeUtil::Equal(lhs->shape(), rhs->shape());
458   } else {
459     return ShapeUtil::Compatible(lhs->shape(), rhs->shape());
460   }
461 }
462 
ReplaceWithBitcast(HloInstruction * instruction,HloInstruction * operand)463 void AlgebraicSimplifierVisitor::ReplaceWithBitcast(HloInstruction* instruction,
464                                                     HloInstruction* operand) {
465   CHECK_EQ(1, instruction->operand_count());
466   if (operand == nullptr) {
467     operand = instruction->mutable_operand(0);
468   }
469   CHECK_EQ(ShapeUtil::ElementsIn(instruction->shape()),
470            ShapeUtil::ElementsIn(operand->shape()));
471   CHECK_EQ(ShapeUtil::ByteSizeOf(instruction->shape()),
472            ShapeUtil::ByteSizeOf(operand->shape()));
473 
474   auto bitcast = computation_->AddInstruction(HloInstruction::CreateUnary(
475       instruction->shape(), HloOpcode::kBitcast, operand));
476   TF_CHECK_OK(ReplaceInstruction(instruction, bitcast));
477 }
478 
ReplaceInstructionIfSameShape(HloInstruction * old_instruction,HloInstruction * new_instruction)479 bool AlgebraicSimplifierVisitor::ReplaceInstructionIfSameShape(
480     HloInstruction* old_instruction, HloInstruction* new_instruction) {
481   if (!SameShape(old_instruction, new_instruction)) {
482     return false;
483   }
484   TF_CHECK_OK(ReplaceInstruction(old_instruction, new_instruction));
485   return true;
486 }
487 
HandleAdd(HloInstruction * add)488 Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) {
489   HloInstruction *lhs, *rhs;
490   CHECK(Match(add, m::Add(m::Op(&lhs), m::Op(&rhs))));
491 
492   // A + 0 => A
493   VLOG(10) << "trying transform [A + 0 => A]: " << add->ToString();
494   if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(add, lhs)) {
495     return Status::OK();
496   }
497   // 0 + A => A
498   VLOG(10) << "trying transform [0 + A => A]: " << add->ToString();
499   if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(add, rhs)) {
500     return Status::OK();
501   }
502 
503   // Canonicalization: Put constants on the right.  This makes the reassociation
504   // rules below simpler.
505   VLOG(10) << "trying transform [Const + A => A + Const]";
506   if (Match(add, m::Add(m::Constant(), m::NonConstant()))) {
507     return ReplaceWithNewInstruction(
508         add,
509         HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, rhs, lhs));
510   }
511 
512   // Reassociate to allow constant folding.
513   //
514   // Note: This is not general.  For example, we won't reassociate
515   //
516   //   (A + C1) + (B + C2) =>  A + B + (C1 + C2).
517   //
518   VLOG(10) << "trying transform [(A + C1) + C2 => A + (C1 + C2)]";
519   HloInstruction *a, *c1, *c2;
520   if (Match(add, m::Add(m::Add(m::NonConstant(&a), m::Constant(&c1)),
521                         m::Constant(&c2)))) {
522     TF_ASSIGN_OR_RETURN(auto* sum_of_constants,
523                         MakeBinaryHlo(HloOpcode::kAdd, c1, c2));
524     return ReplaceWithNewInstruction(
525         add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, a,
526                                           sum_of_constants));
527   }
528 
529   // A*C + B*C => (A+B)*C
530   //
531   //  - If A, B, and C are integers, do this unconditionally. Proof of
532   //    correctness: https://rise4fun.com/Alive/u9X.
533   //
534   //  - If A, B, and C are floating point, do this if C is a scalar constant or
535   //    broadcast of scalar constant and is equal to +/- 2^k for some (possibly
536   //    negative) integer k.
537   //
538   //    Multiplying by a power of 2 just moves the exponent, so our answer is
539   //    exact modulo rounding of intermediate results so long as
540   //
541   //     - none of the three products has an exponent which underflows (so the
542   //       result is 0 or denormal), and
543   //     - none of the three products overflows to inf.
544   //
545   //    Proof: See algebraic_simplifier_proof_distributive_property.py.
546   //
547   //    We deem these differences in rounding, underflow, and overflow
548   //    acceptable in the ML context.
549   HloInstruction *b, *c;
550   if (((Match(lhs, m::Multiply(m::Op(&a), m::Op(&c))) &&
551         Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b)))) ||
552        (Match(lhs, m::Multiply(m::Op(&c), m::Op(&a))) &&
553         Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b))))) &&
554       (ShapeUtil::ElementIsIntegral(add->shape()) ||
555        IsAllFpConstantPowerOf2(c))) {
556     return ReplaceWithNewInstruction(
557         add, HloInstruction::CreateBinary(
558                  add->shape(), HloOpcode::kMultiply,
559                  computation_->AddInstruction(HloInstruction::CreateBinary(
560                      add->shape(), HloOpcode::kAdd, a, b)),
561                  c));
562   }
563   return Status::OK();
564 }
565 
HandleAnd(HloInstruction * logical_and)566 Status AlgebraicSimplifierVisitor::HandleAnd(HloInstruction* logical_and) {
567   HloInstruction *lhs, *rhs;
568   CHECK(Match(logical_and, m::And(m::Op(&lhs), m::Op(&rhs))));
569   // Simplify logical and
570   if (ShapeUtil::HasPrimitiveType(lhs->shape(), xla::PRED) &&
571       ShapeUtil::HasPrimitiveType(rhs->shape(), xla::PRED)) {
572     // A && True => A
573     VLOG(10) << "trying transform [A && True => A]: "
574              << logical_and->ToString();
575     if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(logical_and, lhs)) {
576       return Status::OK();
577     }
578     // True && A => A
579     VLOG(10) << "trying transform [True && A => A]: "
580              << logical_and->ToString();
581     if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(logical_and, rhs)) {
582       return Status::OK();
583     }
584 
585     // A && False => False
586     VLOG(10) << "trying transform [A && False => False]: "
587              << logical_and->ToString();
588     if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(logical_and, rhs)) {
589       return Status::OK();
590     }
591 
592     // False && A => False
593     VLOG(10) << "trying transform [False && A => False]: "
594              << logical_and->ToString();
595     if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(logical_and, lhs)) {
596       return Status::OK();
597     }
598   }
599 
600   return Status::OK();
601 }
602 
HandleBitcast(HloInstruction * bitcast)603 Status AlgebraicSimplifierVisitor::HandleBitcast(HloInstruction* bitcast) {
604   // If a bitcast feeds a bitcast, make it a single bitcast.
605   HloInstruction* op;
606   if (Match(bitcast, m::Bitcast(m::Bitcast(m::Op(&op))))) {
607     return ReplaceWithNewInstruction(
608         bitcast,
609         HloInstruction::CreateUnary(bitcast->shape(), HloOpcode::kBitcast, op));
610   }
611   // All bitcasts can be eliminated (assuming layout constraints are
612   // satisified).
613   ReplaceInstructionIfSameShape(bitcast, bitcast->mutable_operand(0));
614   return Status::OK();
615 }
616 
HandleBitcastConvert(HloInstruction * bitcast)617 Status AlgebraicSimplifierVisitor::HandleBitcastConvert(
618     HloInstruction* bitcast) {
619   // Eliminate bitcast converts between same shape.
620   ReplaceInstructionIfSameShape(bitcast, bitcast->mutable_operand(0));
621   return Status::OK();
622 }
623 
HandleCopy(HloInstruction * copy)624 Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) {
625   // If a copy feeds a copy, make it a single copy.
626   HloInstruction* op;
627   if (Match(copy, m::Copy(m::Copy(m::Op(&op))))) {
628     return ReplaceWithNewInstruction(
629         copy, HloInstruction::CreateUnary(copy->shape(), HloOpcode::kCopy, op));
630   }
631   // All copies can be eliminated (assuming layout constraints are satisified).
632   if (ReplaceInstructionIfSameShape(copy, copy->mutable_operand(0))) {
633     return Status::OK();
634   }
635 
636   if (HloInstruction* bitcast_operand =
637           BitcastingOperandOfReshapeOrCopyChain(copy, options_)) {
638     ReplaceWithBitcast(copy, bitcast_operand);
639   }
640 
641   return Status::OK();
642 }
643 
HandleConcatenate(HloInstruction * concatenate)644 Status AlgebraicSimplifierVisitor::HandleConcatenate(
645     HloInstruction* concatenate) {
646   absl::Span<HloInstruction* const> operands(concatenate->operands());
647   if (operands.size() == 1) {
648     // Unary concatenates are useless.
649     ReplaceInstructionIfSameShape(concatenate, operands[0]);
650     return Status::OK();
651   }
652   // Filter out and remove empty operands.
653   std::vector<HloInstruction*> nonempty_operands;
654   for (HloInstruction* operand : operands) {
655     if (!ShapeUtil::IsZeroElementArray(operand->shape())) {
656       nonempty_operands.push_back(operand);
657     }
658   }
659   if (nonempty_operands.size() < operands.size()) {
660     HloInstruction* replacement;
661     if (nonempty_operands.empty()) {
662       replacement = operands[0];
663     } else if (nonempty_operands.size() == 1) {
664       replacement = nonempty_operands[0];
665     } else {
666       replacement =
667           computation_->AddInstruction(concatenate->CloneWithNewOperands(
668               concatenate->shape(), nonempty_operands));
669     }
670     VLOG(10) << "trying to replace " << concatenate->ToString() << " with "
671              << replacement->ToString();
672     ReplaceInstructionIfSameShape(concatenate, replacement);
673     return Status::OK();
674   }
675 
676   // Check if we can merge "adjacent" slice operands which take slices from the
677   // same other op. For simplicity we only merge unstrided slices.
678   int64 concatenate_dimension = concatenate->concatenate_dimension();
679   for (int64 i = 0; i < operands.size(); ++i) {
680     if (operands[i]->opcode() != HloOpcode::kSlice ||
681         !IsUnstridedSlice(operands[i])) {
682       continue;
683     }
684     int64 slice_end = operands[i]->slice_limits(concatenate_dimension);
685     HloInstruction* slice_operand = operands[i]->mutable_operand(0);
686     int64 j = i + 1;
687     while (j < operands.size() && operands[j]->opcode() == HloOpcode::kSlice &&
688            IsUnstridedSlice(operands[j]) &&
689            operands[j]->operand(0) == slice_operand &&
690            operands[j]->slice_starts(concatenate_dimension) == slice_end) {
691       // Check that all the slice_start values are the same in all other
692       // dimensions. This implies that the slice_limit values are also the same,
693       // because operands of concatenate need to have the same shape, and we
694       // already checked that the slices are unstrided.
695       bool same_other_starts = true;
696       for (int64 k = 0; k < operands[j]->slice_starts().size(); ++k) {
697         if (k == concatenate_dimension) {
698           continue;
699         }
700         if (operands[i]->slice_starts(k) != operands[j]->slice_starts(k)) {
701           same_other_starts = false;
702           break;
703         }
704       }
705       if (!same_other_starts) {
706         break;
707       }
708       slice_end = operands[j]->slice_limits(concatenate_dimension);
709       ++j;
710     }
711     if (j - i > 1) {
712       Shape new_slice_shape = operands[i]->shape();
713       new_slice_shape.set_dimensions(
714           concatenate_dimension,
715           slice_end - operands[i]->slice_starts(concatenate_dimension));
716       auto new_limit_indices = operands[i]->slice_limits();
717       new_limit_indices[concatenate_dimension] = slice_end;
718       auto new_slice_op =
719           computation_->AddInstruction(HloInstruction::CreateSlice(
720               new_slice_shape, slice_operand,
721               /*start_indices=*/operands[i]->slice_starts(),
722               /*limit_indices=*/new_limit_indices,
723               /*strides=*/operands[i]->slice_strides()));
724       std::vector<HloInstruction*> new_operands;
725       for (int64 k = 0; k < i; ++k) {
726         new_operands.push_back(operands[k]);
727       }
728       new_operands.push_back(new_slice_op);
729       for (int64 k = j; k < operands.size(); ++k) {
730         new_operands.push_back(operands[k]);
731       }
732       auto replacement =
733           computation_->AddInstruction(concatenate->CloneWithNewOperands(
734               concatenate->shape(), new_operands));
735       ReplaceInstructionIfSameShape(concatenate, replacement);
736       return Status::OK();
737     }
738   }
739 
740   if (operands.size() == 2) {
741     // A binary concat with a broadcasted scalar as an operand can be converted
742     // into a pad which is simpler to fold into other operations.
743     bool is_effective_low_pad = Match(
744         operands[0], m::Broadcast(m::Op().WithShape(m::Shape().IsScalar())));
745     bool is_effective_high_pad = Match(
746         operands[1], m::Broadcast(m::Op().WithShape(m::Shape().IsScalar())));
747     if (!is_effective_low_pad && !is_effective_high_pad) {
748       return Status::OK();
749     }
750     PaddingConfig padding_config;
751     for (int64 dim = 0; dim < operands[0]->shape().rank(); ++dim) {
752       auto padding_config_dim = padding_config.add_dimensions();
753       padding_config_dim->set_edge_padding_high(0);
754       padding_config_dim->set_edge_padding_low(0);
755       padding_config_dim->set_interior_padding(0);
756       if (dim == concatenate_dimension) {
757         if (is_effective_low_pad) {
758           padding_config_dim->set_edge_padding_low(
759               operands[0]->shape().dimensions(dim));
760         } else {
761           padding_config_dim->set_edge_padding_high(
762               operands[1]->shape().dimensions(dim));
763         }
764       }
765     }
766     int64 operand_to_pad = is_effective_low_pad ? 1 : 0;
767     int64 pad_value_operand = is_effective_low_pad ? 0 : 1;
768     HloInstruction* pad =
769         computation_->AddInstruction(HloInstruction::CreatePad(
770             concatenate->shape(), operands[operand_to_pad],
771             operands[pad_value_operand]->mutable_operand(0), padding_config));
772     return ReplaceInstruction(concatenate, pad);
773   }
774   return Status::OK();
775 }
776 
BuildTupleConstant(HloComputation * computation,const LiteralSlice & literal)777 static HloInstruction* BuildTupleConstant(HloComputation* computation,
778                                           const LiteralSlice& literal) {
779   if (literal.shape().IsTuple()) {
780     std::vector<HloInstruction*> elems;
781     elems.reserve(ShapeUtil::TupleElementCount(literal.shape()));
782     for (int i = 0; i < ShapeUtil::TupleElementCount(literal.shape()); ++i) {
783       elems.push_back(
784           BuildTupleConstant(computation, LiteralSlice(literal, {i})));
785     }
786     return computation->AddInstruction(HloInstruction::CreateTuple(elems));
787   } else {
788     return computation->AddInstruction(
789         HloInstruction::CreateConstant(literal.Clone()));
790   }
791 }
792 
HandleConstant(HloInstruction * constant)793 Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) {
794   // Tuple constants aren't directly supported by any backend. Expand them into
795   // explicit Tuple instructions.
796   if (constant->shape().IsTuple()) {
797     return ReplaceInstruction(
798         constant, BuildTupleConstant(computation_, constant->literal()));
799   }
800 
801   if (constant->shape().element_type() == TOKEN) {
802     return Status::OK();
803   }
804 
805   // If a literal is all the same element replace it with a scalar broadcast.
806   if (ShapeUtil::ElementsIn(constant->shape()) > 1 &&
807       constant->literal().IsAllFirst()) {
808     Literal unique_scalar(
809         LiteralUtil::GetFirstScalarLiteral(constant->literal()));
810     HloInstruction* scalar = computation_->AddInstruction(
811         HloInstruction::CreateConstant(std::move(unique_scalar)));
812     return ReplaceWithNewInstruction(
813         constant,
814         HloInstruction::CreateBroadcast(constant->shape(), scalar, {}));
815   }
816 
817   // If a literal is an increasing sequence from zero, replace it with an iota.
818   if (constant->shape().rank() == 1 &&
819       ShapeUtil::ElementsIn(constant->shape()) > 1 &&
820       constant->literal().IsR1Iota()) {
821     return ReplaceWithNewInstruction(
822         constant, HloInstruction::CreateIota(constant->shape(), 0));
823   }
824   return Status::OK();
825 }
826 
HandleSubtract(HloInstruction * sub)827 Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) {
828   HloInstruction *lhs, *rhs;
829   CHECK(Match(sub, m::Subtract(m::Op(&lhs), m::Op(&rhs))));
830   // A - 0 => A
831   VLOG(10) << "trying transform [A - 0 => A]: " << sub->ToString();
832   if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(sub, lhs)) {
833     return Status::OK();
834   }
835 
836   // Canonicalize subtraction of a constant to addition.
837   VLOG(10) << "trying transform [A - Const => A + (-Const)]";
838   if (Match(sub, m::Subtract(m::NonConstant(&lhs), m::Constant(&rhs)))) {
839     HloInstruction* negative_const = computation_->AddInstruction(
840         HloInstruction::CreateUnary(rhs->shape(), HloOpcode::kNegate, rhs));
841     return ReplaceWithNewInstruction(
842         sub, HloInstruction::CreateBinary(sub->shape(), HloOpcode::kAdd, lhs,
843                                           negative_const));
844   }
845 
846   return Status::OK();
847 }
848 namespace {
849 template <typename T>
InvertConstant(const HloInstruction & constant,Literal * result)850 Status InvertConstant(const HloInstruction& constant, Literal* result) {
851   return result->Populate<T>([&](absl::Span<const int64> indices) {
852     return T{1.0} / constant.literal().Get<T>(indices);
853   });
854 }
855 
856 template <typename T>
TryDivideToShift(HloInstruction * divide,HloComputation * computation)857 std::unique_ptr<HloInstruction> TryDivideToShift(HloInstruction* divide,
858                                                  HloComputation* computation) {
859   HloInstruction *a, *b, *c;
860   CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b))));
861 
862   if (ShapeUtil::ElementIsIntegral(divide->shape()) &&
863       !Match(b, m::ConstantEffectiveScalar(&c)) &&
864       !Match(b, m::Broadcast(m::ConstantEffectiveScalar(&c)))) {
865     return nullptr;
866   }
867 
868   if (ShapeUtil::ElementIsSigned(divide->shape())) {
869     int64 b_value = c->literal().GetFirstElement<T>();
870     if (b_value > 0 && IsPowerOfTwo(static_cast<uint64>(b_value))) {
871       // Handle negative dividends by negating the result of the division.
872       HloInstruction* zero_like_a = BroadcastZeros(
873           computation, a->shape().element_type(), a->shape().dimensions());
874 
875       auto* dividend_is_negative =
876           computation->AddInstruction(HloInstruction::CreateCompare(
877               ShapeUtil::ChangeElementType(a->shape(), PRED), a, zero_like_a,
878               ComparisonDirection::kLt));
879 
880       auto* negated_dividend = computation->AddInstruction(
881           HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a));
882 
883       auto* abs_dividend =
884           computation->AddInstruction(HloInstruction::CreateTernary(
885               a->shape(), HloOpcode::kSelect, dividend_is_negative,
886               negated_dividend, a));
887 
888       int log2_abs_b_value = tensorflow::Log2Floor64(b_value);
889 
890       auto* shift_amount =
891           computation->AddInstruction(HloInstruction::CreateConstant(
892               LiteralUtil::CreateR0<T>(log2_abs_b_value)));
893       if (!ShapeUtil::IsScalar(b->shape())) {
894         shift_amount = computation->AddInstruction(
895             HloInstruction::CreateBroadcast(b->shape(), shift_amount, {}));
896       }
897 
898       auto* quotient = computation->AddInstruction(HloInstruction::CreateBinary(
899           divide->shape(), HloOpcode::kShiftRightLogical, abs_dividend,
900           shift_amount));
901 
902       auto* neqated_quotient =
903           computation->AddInstruction(HloInstruction::CreateUnary(
904               quotient->shape(), HloOpcode::kNegate, quotient));
905 
906       return HloInstruction::CreateTernary(divide->shape(), HloOpcode::kSelect,
907                                            dividend_is_negative,
908                                            neqated_quotient, quotient);
909     }
910   } else {
911     uint64 b_value = c->literal().GetFirstElement<T>();
912     if (IsPowerOfTwo(b_value)) {
913       int log2_abs_b_value = tensorflow::Log2Floor64(b_value);
914       HloInstruction* shift_amount =
915           computation->AddInstruction(HloInstruction::CreateConstant(
916               LiteralUtil::CreateR0<T>(log2_abs_b_value)));
917       if (!ShapeUtil::IsScalar(b->shape())) {
918         shift_amount = computation->AddInstruction(
919             HloInstruction::CreateBroadcast(b->shape(), shift_amount, {}));
920       }
921       return HloInstruction::CreateBinary(
922           divide->shape(), HloOpcode::kShiftRightLogical, a, shift_amount);
923     }
924   }
925 
926   return nullptr;
927 }
928 }  // namespace
929 
HandleDivide(HloInstruction * divide)930 Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
931   HloInstruction *a, *b, *c, *d;
932   CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b))));
933   // A/1 => A
934   VLOG(10) << "trying transform [A/1 => A]: " << divide->ToString();
935   if (IsAll(b, 1) && ReplaceInstructionIfSameShape(divide, a)) {
936     return Status::OK();
937   }
938 
939   // A / B => A >> log2(B) if B is a power of 2.
940   switch (divide->shape().element_type()) {
941     case S8:
942       if (std::unique_ptr<HloInstruction> shift =
943               TryDivideToShift<int8>(divide, computation_)) {
944         return ReplaceWithNewInstruction(divide, std::move(shift));
945       }
946       break;
947     case S16:
948       if (std::unique_ptr<HloInstruction> shift =
949               TryDivideToShift<int16>(divide, computation_)) {
950         return ReplaceWithNewInstruction(divide, std::move(shift));
951       }
952       break;
953     case S32:
954       if (std::unique_ptr<HloInstruction> shift =
955               TryDivideToShift<int32>(divide, computation_)) {
956         return ReplaceWithNewInstruction(divide, std::move(shift));
957       }
958       break;
959     case S64:
960       if (std::unique_ptr<HloInstruction> shift =
961               TryDivideToShift<int64>(divide, computation_)) {
962         return ReplaceWithNewInstruction(divide, std::move(shift));
963       }
964       break;
965     case U8:
966       if (std::unique_ptr<HloInstruction> shift =
967               TryDivideToShift<uint8>(divide, computation_)) {
968         return ReplaceWithNewInstruction(divide, std::move(shift));
969       }
970       break;
971     case U16:
972       if (std::unique_ptr<HloInstruction> shift =
973               TryDivideToShift<uint16>(divide, computation_)) {
974         return ReplaceWithNewInstruction(divide, std::move(shift));
975       }
976       break;
977     case U32:
978       if (std::unique_ptr<HloInstruction> shift =
979               TryDivideToShift<uint32>(divide, computation_)) {
980         return ReplaceWithNewInstruction(divide, std::move(shift));
981       }
982       break;
983     case U64:
984       if (std::unique_ptr<HloInstruction> shift =
985               TryDivideToShift<uint64>(divide, computation_)) {
986         return ReplaceWithNewInstruction(divide, std::move(shift));
987       }
988       break;
989     default:
990       break;
991   }
992 
993   Shape* shape;
994   // exp(A)/exp(B) => exp(A-B)
995   if (Match(divide, m::Divide(m::Exp(m::Op(&a)), m::Exp(m::Op(&b)))
996                         .WithShape(m::Shape(&shape)))) {
997     VLOG(10) << "transform [exp(A)/exp(B) => exp(A-B)]: " << divide->ToString();
998     HloInstruction* subtract = computation_->AddInstruction(
999         HloInstruction::CreateBinary(*shape, HloOpcode::kSubtract, a, b));
1000     return ReplaceWithNewInstruction(
1001         divide, HloInstruction::CreateUnary(*shape, HloOpcode::kExp, subtract));
1002   }
1003 
1004   // A/exp(B) => A*exp(-B)
1005   if (Match(divide, m::Divide(m::Op(&a), m::Exp(m::Op(&b))))) {
1006     VLOG(10) << "transform [A/exp(B) => A*exp(-B)]: " << divide->ToString();
1007     HloInstruction* negate = computation_->AddInstruction(
1008         HloInstruction::CreateUnary(divide->shape(), HloOpcode::kNegate, b));
1009     HloInstruction* new_exp = computation_->AddInstruction(
1010         HloInstruction::CreateUnary(divide->shape(), HloOpcode::kExp, negate));
1011     return ReplaceWithNewInstruction(
1012         divide, HloInstruction::CreateBinary(divide->shape(),
1013                                              HloOpcode::kMultiply, a, new_exp));
1014   }
1015 
1016   // A/pow(B,C) => A*pow(B,-C)
1017   if (Match(divide, m::Divide(m::Op(&a), m::Power(m::Op(&b), m::Op(&c))))) {
1018     VLOG(10) << "transform [A/pow(B,C) => A*pow(B,-C)]: " << divide->ToString();
1019     // The output shape of the created negate operator should be the same as the
1020     // input.
1021     const Shape& negate_shape = c->shape();
1022     HloInstruction* negate = computation_->AddInstruction(
1023         HloInstruction::CreateUnary(negate_shape, HloOpcode::kNegate, c));
1024     // And the power operator should retain the output shape of the old one.
1025     const Shape& new_power_shape = b->shape();
1026     HloInstruction* new_power =
1027         computation_->AddInstruction(HloInstruction::CreateBinary(
1028             new_power_shape, HloOpcode::kPower, b, negate));
1029     return ReplaceWithNewInstruction(
1030         divide, HloInstruction::CreateBinary(
1031                     divide->shape(), HloOpcode::kMultiply, a, new_power));
1032   }
1033 
1034   // A/sqrt(B) => A*rsqrt(X).
1035   if (Match(divide, m::Divide(m::Op(&a), m::Sqrt(m::Op(&b))))) {
1036     auto* rsqrt = computation_->AddInstruction(
1037         HloInstruction::CreateUnary(divide->shape(), HloOpcode::kRsqrt, b));
1038     return ReplaceWithNewInstruction(
1039         divide, HloInstruction::CreateBinary(rsqrt->shape(),
1040                                              HloOpcode::kMultiply, a, rsqrt));
1041   }
1042 
1043   // A/rsqrt(B) => A*sqrt(B).
1044   if (Match(divide, m::Divide(m::Op(&a), m::Rsqrt(m::Op(&b))))) {
1045     auto* sqrt = computation_->AddInstruction(
1046         HloInstruction::CreateUnary(divide->shape(), HloOpcode::kSqrt, b));
1047     return ReplaceWithNewInstruction(
1048         divide, HloInstruction::CreateBinary(sqrt->shape(),
1049                                              HloOpcode::kMultiply, a, sqrt));
1050   }
1051 
1052   // Simplifying integral division would produce unexpected results.
1053   if (ShapeUtil::ElementIsIntegral(divide->shape())) {
1054     return Status::OK();
1055   }
1056 
1057   // A / Const => A * (1 / Const)
1058   //
1059   // (Backends can do this transformation, but generally only if the constant is
1060   // a scalar.)
1061   if (Match(divide, m::Divide(m::NonConstant(&a), m::Constant(&b)))) {
1062     Shape result_shape = b->literal().shape();
1063     Literal new_literal(result_shape);
1064     switch (result_shape.element_type()) {
1065       case F16:
1066         TF_RETURN_IF_ERROR(InvertConstant<half>(*b, &new_literal));
1067         break;
1068       case F32:
1069         TF_RETURN_IF_ERROR(InvertConstant<float>(*b, &new_literal));
1070         break;
1071       case BF16:
1072         TF_RETURN_IF_ERROR(InvertConstant<bfloat16>(*b, &new_literal));
1073         break;
1074       case F64:
1075         TF_RETURN_IF_ERROR(InvertConstant<double>(*b, &new_literal));
1076         break;
1077       case C64:
1078         TF_RETURN_IF_ERROR(InvertConstant<complex64>(*b, &new_literal));
1079         break;
1080       case C128:
1081         TF_RETURN_IF_ERROR(InvertConstant<complex128>(*b, &new_literal));
1082         break;
1083       default:
1084         return Status::OK();
1085     }
1086     auto inverse = computation_->AddInstruction(
1087         HloInstruction::CreateConstant((new_literal.Clone())));
1088     TF_ASSIGN_OR_RETURN(auto new_divide,
1089                         MakeBinaryHlo(HloOpcode::kMultiply, a, inverse));
1090     return ReplaceInstruction(divide, new_divide);
1091   }
1092 
1093   // (A / B) / (C / D)  =>  (A / B)*(D / C) => (A * D) / (B * C)
1094   if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)),
1095                               m::Divide(m::Op(&c), m::Op(&d))))) {
1096     TF_ASSIGN_OR_RETURN(auto a_times_d,
1097                         MakeBinaryHlo(HloOpcode::kMultiply, a, d));
1098     TF_ASSIGN_OR_RETURN(auto b_times_c,
1099                         MakeBinaryHlo(HloOpcode::kMultiply, b, c));
1100     TF_ASSIGN_OR_RETURN(auto new_divide, MakeBinaryHlo(HloOpcode::kDivide,
1101                                                        a_times_d, b_times_c));
1102 
1103     return ReplaceInstruction(divide, new_divide);
1104   }
1105 
1106   // (A / B) / C => A / (B * C)
1107   if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)), m::Op(&c)))) {
1108     TF_ASSIGN_OR_RETURN(auto b_times_c,
1109                         MakeBinaryHlo(HloOpcode::kMultiply, b, c));
1110     TF_ASSIGN_OR_RETURN(auto new_divide,
1111                         MakeBinaryHlo(HloOpcode::kDivide, a, b_times_c));
1112     return ReplaceInstruction(divide, new_divide);
1113   }
1114 
1115   // A / (B / C) => (A*C) / B
1116   if (Match(divide, m::Divide(m::Op(&a), m::Divide(m::Op(&b), m::Op(&c))))) {
1117     TF_ASSIGN_OR_RETURN(auto a_times_c,
1118                         MakeBinaryHlo(HloOpcode::kMultiply, a, c));
1119     TF_ASSIGN_OR_RETURN(auto new_divide,
1120                         MakeBinaryHlo(HloOpcode::kDivide, a_times_c, b));
1121     return ReplaceInstruction(divide, new_divide);
1122   }
1123 
1124   return Status::OK();
1125 }
1126 
HandleDotStrengthReduction(HloInstruction * dot)1127 StatusOr<bool> AlgebraicSimplifierVisitor::HandleDotStrengthReduction(
1128     HloInstruction* dot) {
1129   HloInstruction *lhs, *rhs;
1130   CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs))));
1131 
1132   const auto kept_dim = [](int64 rank, int64 contracting_dimension,
1133                            absl::Span<const int64> batch_dimensions) -> int64 {
1134     for (int64 i = 0; i < rank; ++i) {
1135       if (i != contracting_dimension &&
1136           !absl::c_linear_search(batch_dimensions, i)) {
1137         return i;
1138       }
1139     }
1140     return -1;
1141   };
1142 
1143   const int64 dot_rank = dot->shape().rank();
1144   const int64 rhs_rank = rhs->shape().rank();
1145   const int64 lhs_rank = lhs->shape().rank();
1146   const auto& dnums = dot->dot_dimension_numbers();
1147   if (dnums.rhs_contracting_dimensions_size() != 1) {
1148     return false;
1149   }
1150   if (dot_rank > 2 && (lhs_rank != rhs_rank || lhs_rank != dot_rank)) {
1151     return false;
1152   }
1153   int64 lhs_collapsing_dim = dnums.lhs_contracting_dimensions(0);
1154   int64 lhs_kept_dim = kept_dim(lhs_rank, lhs_collapsing_dim,
1155                                 AsInt64Slice(dnums.lhs_batch_dimensions()));
1156   // If there is no non-contracting dimension in rank 2, do not strength reduce.
1157   if (lhs_kept_dim == -1 && lhs_rank > 1) {
1158     return false;
1159   }
1160   if (lhs->IsRank2Transpose()) {
1161     lhs = lhs->mutable_operand(0);
1162     std::swap(lhs_collapsing_dim, lhs_kept_dim);
1163   }
1164 
1165   int64 rhs_collapsing_dim = dnums.rhs_contracting_dimensions(0);
1166   int64 rhs_kept_dim = kept_dim(rhs_rank, rhs_collapsing_dim,
1167                                 AsInt64Slice(dnums.rhs_batch_dimensions()));
1168   // If there is no non-contracting dimension in rank 2, do not strength reduce.
1169   if (rhs_kept_dim == -1 && rhs_rank > 1) {
1170     return false;
1171   }
1172   if (rhs->IsRank2Transpose()) {
1173     rhs = rhs->mutable_operand(0);
1174     std::swap(rhs_collapsing_dim, rhs_kept_dim);
1175   }
1176 
1177   auto reshape_if_necessary = [&](HloInstruction* hlo) {
1178     hlo = AsType(hlo, dot->shape().element_type());
1179     if (!ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) {
1180       hlo = computation_->AddInstruction(
1181           HloInstruction::CreateReshape(dot->shape(), hlo));
1182     }
1183     return hlo;
1184   };
1185 
1186   auto add_reduce_in_f32 = [&](HloInstruction* hlo, const int64 dim) {
1187     return AddReduce(AsType(hlo, F32), dim);
1188   };
1189 
1190   auto broadcast = [&](HloInstruction* hlo, const Shape& shape,
1191                        absl::Span<const int64> dims) {
1192     return computation_->AddInstruction(
1193         HloInstruction::CreateBroadcast(shape, hlo, dims));
1194   };
1195 
1196   auto broadcast_to_dim = [&](HloInstruction* hlo, const Shape& shape,
1197                               int64 dim) {
1198     return broadcast(hlo, shape, {dim});
1199   };
1200 
1201   auto multiply = [&](HloInstruction* local_lhs, HloInstruction* local_rhs) {
1202     return computation_->AddInstruction(HloInstruction::CreateBinary(
1203         local_lhs->shape(), HloOpcode::kMultiply, local_lhs, local_rhs));
1204   };
1205 
1206   // Strength reduce dot(a[K] , b[K]) =
1207   //  reshape(result.shape,
1208   //          reduce_sum(multiply(a, b), {0}))
1209   if (rhs_rank == 1 && lhs_rank == 1) {
1210     TF_RETURN_IF_ERROR(ReplaceInstruction(
1211         dot, reshape_if_necessary(add_reduce_in_f32(multiply(lhs, rhs), 0))));
1212     return true;
1213   }
1214 
1215   if (ShapeUtil::IsEffectiveScalar(rhs->shape()) &&
1216       ShapeUtil::IsEffectiveScalar(lhs->shape())) {
1217     TF_RETURN_IF_ERROR(ReplaceInstruction(
1218         dot, reshape_if_necessary(multiply(Flatten(lhs), Flatten(rhs)))));
1219     return true;
1220   }
1221 
1222   // Simplify outer product into multiply with broadcasting.
1223   //
1224   // A dot(a[M, 1], b[1, N]) = multiply(a [M,1], b [1, N])
1225   if (rhs_rank == 2 && rhs->shape().dimensions(rhs_collapsing_dim) == 1) {
1226     TF_RETURN_IF_ERROR(ReplaceInstruction(
1227         dot, multiply(broadcast_to_dim(Flatten(lhs), dot->shape(), 0),
1228                       broadcast_to_dim(Flatten(rhs), dot->shape(), 1))));
1229     return true;
1230   }
1231 
1232   // Strength reduce dot(a[1, K], b) =
1233   //    reshape(result.shape,
1234   //      reduce_sum(
1235   //        multiply(broadcast(reshape(a, [K]), {0}), b),
1236   //        {0})
1237   //      )
1238   //    )
1239   if (lhs_rank == 1 ||
1240       (lhs_rank == 2 && lhs->shape().dimensions(lhs_kept_dim) == 1)) {
1241     if (rhs->shape().rank() == 1) {
1242       TF_RETURN_IF_ERROR(
1243           ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32(
1244                                       multiply(Flatten(lhs), rhs), 0))));
1245       return true;
1246     }
1247     TF_RETURN_IF_ERROR(ReplaceInstruction(
1248         dot, reshape_if_necessary(add_reduce_in_f32(
1249                  multiply(broadcast_to_dim(Flatten(lhs), rhs->shape(),
1250                                            rhs_collapsing_dim),
1251                           rhs),
1252                  rhs_collapsing_dim))));
1253     return true;
1254   }
1255 
1256   // Strength reduce dot(a, b[K, 1]) =
1257   //  reshape(result.shape,
1258   //    reduce_sum(multiply(a, broadcast(reshape([K],b), {1})), {0})
1259   //  )
1260   if (rhs_rank == 1 ||
1261       (rhs_rank == 2 && rhs->shape().dimensions(rhs_kept_dim) == 1)) {
1262     TF_RETURN_IF_ERROR(ReplaceInstruction(
1263         dot, reshape_if_necessary(add_reduce_in_f32(
1264                  multiply(lhs, broadcast_to_dim(Flatten(rhs), lhs->shape(),
1265                                                 lhs_collapsing_dim)),
1266                  lhs_collapsing_dim))));
1267     return true;
1268   }
1269 
1270   // Only consider kDot with batch dimension.
1271   if (dot_rank <= 2) {
1272     return false;
1273   }
1274 
1275   CHECK_EQ(rhs_rank, lhs_rank);
1276   CHECK_EQ(dot_rank, lhs_rank);
1277   // If there is more than one non-contracting dimension or the batch dimensions
1278   // are not equal, bail out since transposes may be required to do a strength
1279   // reduction.
1280   if (dnums.rhs_batch_dimensions_size() + 2 != dot_rank ||
1281       !absl::c_equal(dnums.lhs_batch_dimensions(),
1282                      dnums.rhs_batch_dimensions())) {
1283     return false;
1284   }
1285 
1286   auto broadcast_dims = [](int64 rank, int64 non_broadcast_dim) {
1287     absl::InlinedVector<int64, 8> dims;
1288     for (int64 i = 0; i < rank; ++i) {
1289       if (i != non_broadcast_dim) {
1290         dims.push_back(i);
1291       }
1292     }
1293     return dims;
1294   };
1295 
1296   // If the contracting dimension is 1, remove the degnerate dimnensions from
1297   // the lhs and rhs, broadcast each to the result shape and multiply.
1298   if (lhs->shape().dimensions(lhs_collapsing_dim) == 1 &&
1299       (rhs_kept_dim == rhs_rank - 1 ||
1300        (rhs_collapsing_dim == rhs_rank - 1 && rhs_kept_dim == rhs_rank - 2))) {
1301     CHECK_EQ(rhs->shape().dimensions(rhs_collapsing_dim), 1);
1302     const int64 lhs_kept_dim_in_output =
1303         lhs_kept_dim > lhs_collapsing_dim ? (lhs_kept_dim - 1) : lhs_kept_dim;
1304     absl::InlinedVector<int64, 8> lhs_broadcast_dims;
1305     for (const int64 dim : dnums.lhs_batch_dimensions()) {
1306       lhs_broadcast_dims.push_back(dim > lhs_collapsing_dim ? (dim - 1) : dim);
1307     }
1308     absl::InlinedVector<int64, 8> rhs_broadcast_dims = lhs_broadcast_dims;
1309     lhs_broadcast_dims.push_back(lhs_kept_dim_in_output);
1310     absl::c_sort(lhs_broadcast_dims);
1311     rhs_broadcast_dims.push_back(dot_rank - 1);
1312     absl::c_sort(rhs_broadcast_dims);
1313     TF_RETURN_IF_ERROR(ReplaceInstruction(
1314         dot, reshape_if_necessary(
1315                  multiply(broadcast(StripDim(lhs, lhs_collapsing_dim),
1316                                     dot->shape(), lhs_broadcast_dims),
1317                           broadcast(StripDim(rhs, rhs_collapsing_dim),
1318                                     dot->shape(), rhs_broadcast_dims)))));
1319     return true;
1320   }
1321 
1322   // If the lhs and rhs non-contracting dimensions are both one, strip each one,
1323   // multiply and then reduce the collapsing dimension
1324   if (lhs->shape().dimensions(lhs_kept_dim) == 1 &&
1325       rhs->shape().dimensions(rhs_kept_dim) == 1 &&
1326       lhs_kept_dim == rhs_kept_dim) {
1327     auto new_lhs = StripDim(lhs, lhs_kept_dim);
1328     auto new_rhs = StripDim(rhs, rhs_kept_dim);
1329     const int64 reduce_dim = rhs_kept_dim < rhs_collapsing_dim
1330                                  ? (rhs_collapsing_dim - 1)
1331                                  : rhs_collapsing_dim;
1332     TF_RETURN_IF_ERROR(
1333         ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32(
1334                                     multiply(new_lhs, new_rhs), reduce_dim))));
1335     return true;
1336   }
1337 
1338   // If the lhs  non-contracting dimensions is one, strip the one, brodcast to
1339   // the rhs shape, multiply and then reduce the collapsing dimension
1340   if (lhs->shape().dimensions(lhs_kept_dim) == 1) {
1341     auto new_lhs = broadcast(StripDim(lhs, lhs_kept_dim), rhs->shape(),
1342                              broadcast_dims(rhs_rank, rhs_kept_dim));
1343     TF_RETURN_IF_ERROR(ReplaceInstruction(
1344         dot, reshape_if_necessary(add_reduce_in_f32(multiply(new_lhs, rhs),
1345                                                     rhs_collapsing_dim))));
1346     return true;
1347   }
1348 
1349   // If the rhs  non-contracting dimensions is one, strip the one, brodcast to
1350   // the lhs shape, multiply and then reduce the collapsing dimension
1351   if (rhs->shape().dimensions(rhs_kept_dim) == 1) {
1352     auto new_rhs = broadcast(StripDim(rhs, rhs_kept_dim), lhs->shape(),
1353                              broadcast_dims(lhs_rank, lhs_kept_dim));
1354     TF_RETURN_IF_ERROR(ReplaceInstruction(
1355         dot, reshape_if_necessary(add_reduce_in_f32(multiply(lhs, new_rhs),
1356                                                     lhs_collapsing_dim))));
1357     return true;
1358   }
1359 
1360   return false;
1361 }
1362 
OptimizeDotOfConcat(HloInstruction * dot)1363 StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcat(
1364     HloInstruction* dot) {
1365   const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
1366   if (dnums.lhs_contracting_dimensions_size() != 1 ||
1367       dnums.lhs_batch_dimensions_size() != 0) {
1368     return nullptr;
1369   }
1370 
1371   const int64 lhs_contracting_dim = dnums.lhs_contracting_dimensions(0);
1372   const int64 rhs_contracting_dim = dnums.rhs_contracting_dimensions(0);
1373   HloInstruction *lhs, *rhs;
1374   CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs))));
1375 
1376   TF_ASSIGN_OR_RETURN(
1377       HloInstruction * optimized_lhs_concat,
1378       OptimizeDotOfConcatHelper(*dot, lhs, lhs_contracting_dim, rhs,
1379                                 rhs_contracting_dim, /*swapped=*/false));
1380   if (optimized_lhs_concat) {
1381     return optimized_lhs_concat;
1382   }
1383 
1384   return OptimizeDotOfConcatHelper(*dot, rhs, rhs_contracting_dim, lhs,
1385                                    lhs_contracting_dim, /*swapped=*/true);
1386 }
1387 
OptimizeDotOfConcatHelper(const HloInstruction & dot,HloInstruction * lhs,int64 lhs_contracting_dim,HloInstruction * rhs,int64 rhs_contracting_dim,bool swapped)1388 StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper(
1389     const HloInstruction& dot, HloInstruction* lhs, int64 lhs_contracting_dim,
1390     HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped) {
1391   bool can_optimize = lhs->opcode() == HloOpcode::kConcatenate &&
1392                       lhs->concatenate_dimension() == lhs_contracting_dim &&
1393                       rhs->opcode() == HloOpcode::kConstant;
1394   if (!can_optimize) {
1395     return nullptr;
1396   }
1397 
1398   // We're replacing this:
1399   //
1400   //   +-----+-----+-----+      +-------------------+
1401   //   |     |     |     |      |                   |
1402   //   |     |     |     |      |        R_0        |
1403   //   |     |     |     |      |                   |
1404   //   |     |     |     |      +-------------------+
1405   //   |     |     |     |      |                   |
1406   //   | L_0 | L_1 | L_2 |   *  |        R_1        |
1407   //   |     |     |     |      |                   |
1408   //   |     |     |     |      +-------------------+
1409   //   |     |     |     |      |                   |
1410   //   |     |     |     |      |        R_2        |
1411   //   |     |     |     |      |                   |
1412   //   +-----+-----+-----+      +-------------------+
1413   //
1414   // with this:
1415   //
1416   // [Sum over i]
1417   //
1418   //   +-----+     +-------------------+
1419   //   |     |     |                   |
1420   //   |     |  *  |        R_i        |
1421   //   |     |     |                   |
1422   //   |     |     +-------------------+
1423   //   |     |
1424   //   | L_i |
1425   //   |     |
1426   //   |     |
1427   //   |     |
1428   //   |     |
1429   //   |     |
1430   //   +-----+
1431   //
1432   // where the LHS is a concatenate operation (so we can "split" the LHS tensor
1433   // for free) and the RHS is a constant tensor (and thus can be split at
1434   // compile time).  In the future, we may also want to do this when both the
1435   // LHS and the RHS are concatenate operations that line up along the dimension
1436   // being contracted over.
1437   //
1438   // We should be able to generalize this transform to work on a non-constant
1439   // RHS when/if we have in-place slices or support input-fusing slices into
1440   // Dots.
1441 
1442   // Dimension numbers for the new dot instructions we'll create (L_i * R_i in
1443   // the diagram above).
1444   DotDimensionNumbers new_dot_dnums;
1445   new_dot_dnums.add_lhs_contracting_dimensions(swapped ? rhs_contracting_dim
1446                                                        : lhs_contracting_dim);
1447   new_dot_dnums.add_rhs_contracting_dimensions(swapped ? lhs_contracting_dim
1448                                                        : rhs_contracting_dim);
1449 
1450   // Here we use the MKN notation, where the contracted dimension has K
1451   // elements and the two non-contracted dimensions have M and N elements.
1452   HloInstruction* add_result = nullptr;
1453   int64 rhs_contracting_dim_offset = 0;
1454   int64 n = rhs->shape().dimensions(1 - rhs_contracting_dim);
1455   for (HloInstruction* concat_op : lhs->operands()) {
1456     int64 sub_k = concat_op->shape().dimensions(lhs_contracting_dim);
1457     Shape rhs_slice_shape(rhs->shape());
1458     rhs_slice_shape.set_dimensions(rhs_contracting_dim, sub_k);
1459 
1460     std::array<int64, 2> start_indices;
1461     start_indices[rhs_contracting_dim] = rhs_contracting_dim_offset;
1462     start_indices[1 - rhs_contracting_dim] = 0;
1463 
1464     std::array<int64, 2> limit_indices;
1465     limit_indices[rhs_contracting_dim] = rhs_contracting_dim_offset + sub_k;
1466     limit_indices[1 - rhs_contracting_dim] = n;
1467 
1468     HloInstruction* rhs_slice =
1469         computation_->AddInstruction(HloInstruction::CreateSlice(
1470             rhs_slice_shape, rhs, /*start_indices=*/start_indices,
1471             /*limit_indices=*/limit_indices, /*strides=*/{1, 1}));
1472 
1473     // TODO(b/69062148): We can get rid of `swapped` once all backends support
1474     // "non-canonical" contraction dimensions (that contracts dimension 1 of the
1475     // LHS with dimension 0 of the RHS).  But for now we keep the same
1476     // contraction dimensions as the incoming dot operation to ensure the new
1477     // dot operations can be lowered.
1478     HloInstruction *new_dot_lhs, *new_dot_rhs;
1479     if (swapped) {
1480       new_dot_lhs = rhs_slice;
1481       new_dot_rhs = concat_op;
1482     } else {
1483       new_dot_lhs = concat_op;
1484       new_dot_rhs = rhs_slice;
1485     }
1486 
1487     auto* new_dot = computation_->AddInstruction(
1488         HloInstruction::CreateDot(dot.shape(), new_dot_lhs, new_dot_rhs,
1489                                   new_dot_dnums, dot.precision_config()));
1490 
1491     if (add_result) {
1492       add_result = computation_->AddInstruction(HloInstruction::CreateBinary(
1493           dot.shape(), HloOpcode::kAdd, add_result, new_dot));
1494     } else {
1495       add_result = new_dot;
1496     }
1497 
1498     rhs_contracting_dim_offset += sub_k;
1499   }
1500 
1501   return add_result;
1502 }
1503 
OptimizeDotOfGather(HloInstruction * dot)1504 StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfGather(
1505     HloInstruction* dot) {
1506   const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
1507   if (dnums.lhs_contracting_dimensions_size() != 1 ||
1508       dnums.rhs_contracting_dimensions_size() != 1 ||
1509       dnums.lhs_batch_dimensions_size() != 0 ||
1510       dnums.rhs_batch_dimensions_size() != 0 ||
1511       dot->shape().dimensions_size() != 2) {  // dot output 2D
1512     VLOG(10) << "DotOfGather: Can only optimize 2D, non-batch dot operations.";
1513     return nullptr;
1514   }
1515 
1516   // Optimize either dot(DS(ctA), ctB)) or dot(ctB, DS(ctA)).
1517   // Currently a Gather is a DynamicSlice.
1518   auto is_dynamic_slice_constant_combination =
1519       [](HloInstruction* a, HloInstruction* b, int a_contracting_dimension) {
1520         // First operand is a DynamicSlice(Constant).
1521         if (a->opcode() != HloOpcode::kDynamicSlice) {
1522           return false;
1523         }
1524         auto* dynamic_slice_op = a->operand(0);
1525         if (dynamic_slice_op->opcode() != HloOpcode::kConstant) {
1526           return false;
1527         }
1528         // Second operand is a Constant.
1529         if (b->opcode() != HloOpcode::kConstant) {
1530           return false;
1531         }
1532         // The DynamicSlice output is a vector.
1533         const Shape& dynamic_slice_shape = a->shape();
1534         if (dynamic_slice_shape.dimensions(1 - a_contracting_dimension) != 1) {
1535           return false;
1536         }
1537         // Constant size is the same before and after slice in the contracting
1538         // dimension, otherwise we either must precompute for all possible slice
1539         // indices or dot is invalid.
1540         const Shape& dynamic_slice_op_shape = dynamic_slice_op->shape();
1541         if (dynamic_slice_op_shape.dimensions(a_contracting_dimension) !=
1542             dynamic_slice_shape.dimensions(a_contracting_dimension)) {
1543           return false;
1544         }
1545         return true;
1546       };
1547 
1548   HloInstruction* lhs = dot->mutable_operand(0);
1549   HloInstruction* rhs = dot->mutable_operand(1);
1550   int lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0);
1551   int rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0);
1552 
1553   if (!is_dynamic_slice_constant_combination(
1554           lhs, rhs, /*a_contracting_dimension=*/lhs_contracting_dimension) &&
1555       !is_dynamic_slice_constant_combination(
1556           rhs, lhs, /*a_contracting_dimension=*/rhs_contracting_dimension)) {
1557     VLOG(10) << "DotOfGather: Can only optimize dot(DS(ctA), ctB)) or "
1558                 "dot(ctB, DS(ctA)), where the two constants have equal "
1559                 "contracting dimensions.";
1560     return nullptr;
1561   }
1562 
1563   // LHS is DynamicSlice:
1564   // input: dot(DS(ctA), ctB))
1565   // where DS(ctA) = DS({M x K}, {start, 0}, {1, K}) and ctB = {K x N}.
1566   // => input dimensions: dot({1 x K}, {K x N}) => {1 x N}.
1567   // output: DS(dot(ctA, ctB))
1568   // => output dimensions: DS ({M x N}, {start, 0}, {1, N}) => {1 x N}.
1569 
1570   // RHS is DynamicSlice:
1571   // input: dot(ctA, DS(ctB))
1572   // where ctA = {M x K} and DS(ctB) = DS({K x N}, {0, start}, {K, 1}).
1573   // => input dimensions: dot({M x K}, {K x 1}) => {M x 1}.
1574   // output: DS(dot(ctA, ctB))
1575   // => output dimensions: DS ({M x N}, {0, start}, {M, 1}) => {M x 1}.
1576 
1577   bool lhs_is_dynamic_slice = lhs->opcode() == HloOpcode::kDynamicSlice;
1578   HloDynamicSliceInstruction* dynamic_slice =
1579       lhs_is_dynamic_slice ? Cast<HloDynamicSliceInstruction>(lhs)
1580                            : Cast<HloDynamicSliceInstruction>(rhs);
1581 
1582   // ctA:
1583   HloInstruction* left_operand =
1584       lhs_is_dynamic_slice ? lhs->mutable_operand(0) : lhs;
1585   // ctB:
1586   HloInstruction* right_operand =
1587       lhs_is_dynamic_slice ? rhs : rhs->mutable_operand(0);
1588   // Build ctA x ctB.
1589   const int m = left_operand->shape().dimensions(1 - lhs_contracting_dimension);
1590   const int n =
1591       right_operand->shape().dimensions(1 - rhs_contracting_dimension);
1592   auto memoized_shape =
1593       ShapeUtil::MakeShape(dot->shape().element_type(), {m, n});
1594   auto* memoized_inst = computation_->AddInstruction(
1595       HloInstruction::CreateDot(memoized_shape, left_operand, right_operand,
1596                                 dnums, dot->precision_config()));
1597   // Get pair {start, 0} or {0, start}.
1598   // Position of start:
1599   int index_of_non_zero_start = lhs_is_dynamic_slice
1600                                     ? 1 - lhs_contracting_dimension
1601                                     : 1 - rhs_contracting_dimension;
1602   // Position of zero:
1603   int index_of_zero_start = 1 - index_of_non_zero_start;
1604 
1605   // Slice out start and 0 components and reorder if necessary.
1606   auto indices_type = dynamic_slice->operand(1)->shape().element_type();
1607   Shape s_shape = ShapeUtil::MakeShape(indices_type, {1});
1608   Shape d_shape = ShapeUtil::MakeShape(indices_type, {2});
1609   HloInstruction* non_zero_start =
1610       dynamic_slice->mutable_operand(1 + index_of_non_zero_start);
1611   HloInstruction* zero_start =
1612       dynamic_slice->mutable_operand(1 + index_of_zero_start);
1613   std::vector<HloInstruction*> new_start_indices;
1614   if (lhs_is_dynamic_slice) {
1615     new_start_indices = {non_zero_start, zero_start};
1616   } else {
1617     new_start_indices = {zero_start, non_zero_start};
1618   }
1619 
1620   // Build DynamicSlice(ctA x ctB).
1621   const int new_slice_m = lhs_is_dynamic_slice ? 1 : m;
1622   const int new_slice_n = lhs_is_dynamic_slice ? n : 1;
1623   auto* memoized_lookup =
1624       computation_->AddInstruction(HloInstruction::CreateDynamicSlice(
1625           dot->shape(), memoized_inst, new_start_indices,
1626           {new_slice_m, new_slice_n}));
1627 
1628   return memoized_lookup;
1629 }
1630 
HandleDot(HloInstruction * dot)1631 Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
1632   HloInstruction *lhs, *rhs;
1633   CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs))));
1634   if (options_.is_layout_sensitive()) {
1635     return Status::OK();
1636   }
1637   // Replace a zero element dot with a broadcast of the constant 0.
1638   if (ShapeUtil::IsZeroElementArray(dot->shape()) ||
1639       ShapeUtil::IsZeroElementArray(lhs->shape()) ||
1640       ShapeUtil::IsZeroElementArray(rhs->shape())) {
1641     auto zero = computation_->AddInstruction(HloInstruction::CreateConstant(
1642         LiteralUtil::Zero(dot->shape().element_type())));
1643     return ReplaceWithNewInstruction(
1644         dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {}));
1645   }
1646 
1647   // Only optimize F32 or BF16 dot operations where the dot, rhs and lhs are
1648   // rank 2 or below.
1649   if (dot->shape().element_type() != F32 &&
1650       dot->shape().element_type() != BF16) {
1651     return Status::OK();
1652   }
1653 
1654   // If there are no contracting dimensions, a dot can be rewritten as
1655   // mul(broadcast(transpose(x)),broadcast(transpose(y)))
1656   if (dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == 0) {
1657     TF_ASSIGN_OR_RETURN(
1658         HloInstruction * new_lhs,
1659         NormalizeDotOperandToBatchMajorAndContractingMinor(
1660             lhs,
1661             AsInt64Slice(dot->dot_dimension_numbers().lhs_batch_dimensions()),
1662             AsInt64Slice(
1663                 dot->dot_dimension_numbers().lhs_contracting_dimensions())));
1664     if (dot->shape().rank() != lhs->shape().rank()) {
1665       std::vector<int64> lhs_broadcast_dims(lhs->shape().rank());
1666       absl::c_iota(lhs_broadcast_dims, 0);
1667       new_lhs = computation_->AddInstruction(HloInstruction::CreateBroadcast(
1668           dot->shape(), new_lhs, lhs_broadcast_dims));
1669     }
1670     TF_ASSIGN_OR_RETURN(
1671         HloInstruction * new_rhs,
1672         NormalizeDotOperandToBatchMajorAndContractingMinor(
1673             rhs,
1674             AsInt64Slice(dot->dot_dimension_numbers().rhs_batch_dimensions()),
1675             AsInt64Slice(
1676                 dot->dot_dimension_numbers().rhs_contracting_dimensions())));
1677     if (dot->shape().rank() != rhs->shape().rank()) {
1678       std::vector<int64> rhs_broadcast_dims(
1679           dot->dot_dimension_numbers().lhs_batch_dimensions_size());
1680       absl::c_iota(rhs_broadcast_dims, 0);
1681       for (int64 i = lhs->shape().rank(); i < dot->shape().rank(); ++i) {
1682         rhs_broadcast_dims.push_back(i);
1683       }
1684       new_rhs = computation_->AddInstruction(HloInstruction::CreateBroadcast(
1685           dot->shape(), new_rhs, rhs_broadcast_dims));
1686     }
1687     return ReplaceWithNewInstruction(
1688         dot, HloInstruction::CreateBinary(dot->shape(), HloOpcode::kMultiply,
1689                                           new_lhs, new_rhs));
1690   }
1691 
1692   // If the lhs or rhs have only batch and contracting dimensions, a dot can be
1693   // rewritten as reduce(mul(broadcast(transpose(x)),broadcast(transpose(y))))
1694   if ((dot->dot_dimension_numbers().lhs_batch_dimensions_size() +
1695            dot->dot_dimension_numbers().lhs_contracting_dimensions_size() ==
1696        lhs->shape().rank()) ||
1697       (dot->dot_dimension_numbers().rhs_contracting_dimensions_size() +
1698            dot->dot_dimension_numbers().rhs_batch_dimensions_size() ==
1699        rhs->shape().rank())) {
1700     TF_ASSIGN_OR_RETURN(
1701         HloInstruction * new_lhs,
1702         NormalizeDotOperandToBatchMajorAndContractingMinor(
1703             lhs,
1704             AsInt64Slice(dot->dot_dimension_numbers().lhs_batch_dimensions()),
1705             AsInt64Slice(
1706                 dot->dot_dimension_numbers().lhs_contracting_dimensions())));
1707     TF_ASSIGN_OR_RETURN(
1708         HloInstruction * new_rhs,
1709         NormalizeDotOperandToBatchMajorAndContractingMinor(
1710             rhs,
1711             AsInt64Slice(dot->dot_dimension_numbers().rhs_batch_dimensions()),
1712             AsInt64Slice(
1713                 dot->dot_dimension_numbers().rhs_contracting_dimensions())));
1714 
1715     int64 lhs_outer_dims =
1716         lhs->shape().rank() -
1717         (dot->dot_dimension_numbers().lhs_batch_dimensions_size() +
1718          dot->dot_dimension_numbers().lhs_contracting_dimensions_size());
1719     int64 rhs_outer_dims =
1720         rhs->shape().rank() -
1721         (dot->dot_dimension_numbers().rhs_batch_dimensions_size() +
1722          dot->dot_dimension_numbers().rhs_contracting_dimensions_size());
1723     CHECK(lhs_outer_dims == 0 || rhs_outer_dims == 0);
1724     if (rhs_outer_dims > 0) {
1725       std::vector<int64> lhs_broadcast_dims(
1726           dot->dot_dimension_numbers().lhs_batch_dimensions_size());
1727       absl::c_iota(lhs_broadcast_dims, 0);
1728       lhs_broadcast_dims.resize(lhs->shape().rank());
1729       std::iota(lhs_broadcast_dims.begin() +
1730                     dot->dot_dimension_numbers().lhs_batch_dimensions_size(),
1731                 lhs_broadcast_dims.end(),
1732                 dot->dot_dimension_numbers().lhs_batch_dimensions_size() +
1733                     rhs_outer_dims);
1734       new_lhs = computation_->AddInstruction(HloInstruction::CreateBroadcast(
1735           new_rhs->shape(), new_lhs, lhs_broadcast_dims));
1736     } else if (lhs_outer_dims > 0) {
1737       std::vector<int64> rhs_broadcast_dims(
1738           dot->dot_dimension_numbers().rhs_batch_dimensions_size());
1739       absl::c_iota(rhs_broadcast_dims, 0);
1740       rhs_broadcast_dims.resize(rhs->shape().rank());
1741       std::iota(rhs_broadcast_dims.begin() +
1742                     dot->dot_dimension_numbers().rhs_batch_dimensions_size(),
1743                 rhs_broadcast_dims.end(),
1744                 dot->dot_dimension_numbers().rhs_batch_dimensions_size() +
1745                     lhs_outer_dims);
1746       new_rhs = computation_->AddInstruction(HloInstruction::CreateBroadcast(
1747           new_lhs->shape(), new_rhs, rhs_broadcast_dims));
1748     }
1749 
1750     TF_ASSIGN_OR_RETURN(HloInstruction * new_dot,
1751                         MakeBinaryHlo(HloOpcode::kMultiply, new_lhs, new_rhs));
1752     std::vector<int64> reduce_dims(
1753         dot->dot_dimension_numbers().lhs_contracting_dimensions_size());
1754     new_dot = AsType(new_dot, F32);
1755     const int64 outer_dims = std::max(rhs_outer_dims, lhs_outer_dims);
1756     absl::c_iota(
1757         reduce_dims,
1758         outer_dims + dot->dot_dimension_numbers().lhs_batch_dimensions_size());
1759     new_dot = AddReduce(new_dot, reduce_dims);
1760     new_dot = AsType(new_dot, dot->shape().element_type());
1761     return ReplaceInstruction(dot, new_dot);
1762   }
1763 
1764   if (lhs->shape().rank() > 2 || rhs->shape().rank() > 2 ||
1765       dot->shape().rank() > 2) {
1766     if (options_.enable_dot_strength_reduction() &&
1767         !options_.is_layout_sensitive()) {
1768       TF_RETURN_IF_ERROR(HandleDotStrengthReduction(dot).status());
1769     }
1770     return Status::OK();
1771   }
1772 
1773   TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_concat_optimized,
1774                       OptimizeDotOfConcat(dot));
1775   if (dot_of_concat_optimized) {
1776     VLOG(10) << "Replaced dot(concat(...), constant) with add(dot(..., "
1777                 "constant)...)";
1778     return ReplaceInstruction(dot, dot_of_concat_optimized);
1779   }
1780 
1781   // Simplify dot(ConstA, Gather(Index, ConstB)) to:
1782   // Gather(Index, dot*(ConstA, ConstB)), where dot* is an appropriately
1783   // batched version of dot.
1784   TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_gather_optimized,
1785                       OptimizeDotOfGather(dot));
1786   if (dot_of_gather_optimized) {
1787     VLOG(10) << "Replaced dot(constA, gather(i, constB)) with "
1788                 "gather(i, dot*(constA, constB))";
1789     return ReplaceInstruction(dot, dot_of_gather_optimized);
1790   }
1791 
1792   if (options_.enable_dot_strength_reduction() &&
1793       !options_.is_layout_sensitive()) {
1794     TF_ASSIGN_OR_RETURN(bool did_strength_reduction,
1795                         HandleDotStrengthReduction(dot));
1796     if (did_strength_reduction) {
1797       return Status::OK();
1798     }
1799   }
1800 
1801   // Simplify dot(transpose(a), transpose(b)) to transpose(dot(b,a)).
1802   if (dot->dot_dimension_numbers().lhs_batch_dimensions_size() == 0 &&
1803       dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == 1 &&
1804       dot->dot_dimension_numbers().lhs_contracting_dimensions(0) == 1 &&
1805       dot->dot_dimension_numbers().rhs_contracting_dimensions(0) == 0 &&
1806       lhs->IsRank2Transpose() && rhs->IsRank2Transpose()) {
1807     DotDimensionNumbers dot_dimension_numbers;
1808     dot_dimension_numbers.add_lhs_contracting_dimensions(1);
1809     dot_dimension_numbers.add_rhs_contracting_dimensions(0);
1810     auto new_dot = computation_->AddInstruction(HloInstruction::CreateDot(
1811         ShapeUtil::PermuteDimensions({1, 0}, dot->shape()),
1812         rhs->mutable_operand(0), lhs->mutable_operand(0), dot_dimension_numbers,
1813         dot->precision_config()));
1814     return ReplaceWithNewInstruction(
1815         dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0}));
1816   }
1817 
1818   return Status::OK();
1819 }
1820 
HandleMultiply(HloInstruction * multiply)1821 Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) {
1822   HloInstruction *lhs, *rhs;
1823   CHECK(Match(multiply, m::Multiply(m::Op(&lhs), m::Op(&rhs))));
1824   // A*1 => A
1825   VLOG(10) << "trying transform [A*1 => A]: " << multiply->ToString();
1826   if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(multiply, lhs)) {
1827     return Status::OK();
1828   }
1829   // 1*A => A
1830   VLOG(10) << "trying transform [1*A => A]: " << multiply->ToString();
1831   if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(multiply, rhs)) {
1832     return Status::OK();
1833   }
1834 
1835   // 0*A => 0. Only applies for integral types for correct NaN-handling.
1836   if (IsAll(lhs, 0) &&
1837       primitive_util::IsIntegralType(multiply->shape().element_type()) &&
1838       ReplaceInstructionIfSameShape(multiply, lhs)) {
1839     return Status::OK();
1840   }
1841   // A*0 => 0
1842   if (IsAll(rhs, 0) &&
1843       primitive_util::IsIntegralType(multiply->shape().element_type()) &&
1844       ReplaceInstructionIfSameShape(multiply, rhs)) {
1845     return Status::OK();
1846   }
1847 
1848   // exp(A) * exp(B) => exp(A+B)
1849   if (Match(multiply, m::Multiply(m::Exp(m::Op(&lhs)), m::Exp(m::Op(&rhs))))) {
1850     auto add = computation_->AddInstruction(HloInstruction::CreateBinary(
1851         multiply->shape(), HloOpcode::kAdd, lhs, rhs));
1852     return ReplaceWithNewInstruction(
1853         multiply,
1854         HloInstruction::CreateUnary(multiply->shape(), HloOpcode::kExp, add));
1855   }
1856   return Status::OK();
1857 }
1858 
HandleNegate(HloInstruction * negate)1859 Status AlgebraicSimplifierVisitor::HandleNegate(HloInstruction* negate) {
1860   // negate(negate(x)) => x
1861   HloInstruction* x;
1862   if (Match(negate, m::Negate(m::Negate(m::Op(&x)))) &&
1863       ReplaceInstructionIfSameShape(negate, x)) {
1864     return Status::OK();
1865   }
1866   return Status::OK();
1867 }
1868 
HandleNot(HloInstruction * logical_not)1869 Status AlgebraicSimplifierVisitor::HandleNot(HloInstruction* logical_not) {
1870   // not(not(x)) => x
1871   HloInstruction* x;
1872   if (Match(logical_not, m::Not(m::Not(m::Op(&x)))) &&
1873       ReplaceInstructionIfSameShape(logical_not, x)) {
1874     return Status::OK();
1875   }
1876   return Status::OK();
1877 }
1878 
HandleOr(HloInstruction * logical_or)1879 Status AlgebraicSimplifierVisitor::HandleOr(HloInstruction* logical_or) {
1880   HloInstruction *lhs, *rhs;
1881   CHECK(Match(logical_or, m::Or(m::Op(&lhs), m::Op(&rhs))));
1882 
1883   // Simplify logical or
1884   if (ShapeUtil::HasPrimitiveType(lhs->shape(), xla::PRED) &&
1885       ShapeUtil::HasPrimitiveType(rhs->shape(), xla::PRED)) {
1886     // A || True => True
1887     VLOG(10) << "trying transform [A || True => True]: "
1888              << logical_or->ToString();
1889     if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(logical_or, rhs)) {
1890       return Status::OK();
1891     }
1892     // True || A => True
1893     VLOG(10) << "trying transform [True || A => True]: "
1894              << logical_or->ToString();
1895     if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(logical_or, lhs)) {
1896       return Status::OK();
1897     }
1898 
1899     // A || False => A
1900     VLOG(10) << "trying transform [A || False => A]: "
1901              << logical_or->ToString();
1902     if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(logical_or, lhs)) {
1903       return Status::OK();
1904     }
1905 
1906     // False || A => A
1907     VLOG(10) << "trying transform [False || A => A]: "
1908              << logical_or->ToString();
1909     if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(logical_or, rhs)) {
1910       return Status::OK();
1911     }
1912   }
1913 
1914   return Status::OK();
1915 }
1916 
HandleLog(HloInstruction * log)1917 Status AlgebraicSimplifierVisitor::HandleLog(HloInstruction* log) {
1918   // ln(exp(A)) => A
1919   VLOG(10) << "trying transform [ln(exp(A)) => A]: " << log->ToString();
1920   HloInstruction *a, *b;
1921   if (Match(log, m::Log(m::Exp(m::Op(&a)))) &&
1922       ReplaceInstructionIfSameShape(log, a)) {
1923     return Status::OK();
1924   }
1925 
1926   // ln(pow(A,B)) => B*ln(A)
1927   if (Match(log, m::Log(m::Power(m::Op(&a), m::Op(&b))))) {
1928     auto new_log = computation_->AddInstruction(
1929         HloInstruction::CreateUnary(log->shape(), HloOpcode::kLog, a));
1930     return ReplaceWithNewInstruction(
1931         log, HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply,
1932                                           new_log, b));
1933   }
1934 
1935   return Status::OK();
1936 }
1937 
HandleGetTupleElement(HloInstruction * get_tuple_element)1938 Status AlgebraicSimplifierVisitor::HandleGetTupleElement(
1939     HloInstruction* get_tuple_element) {
1940   auto operand = get_tuple_element->mutable_operand(0);
1941   if (operand->opcode() == HloOpcode::kTuple) {
1942     // get_tuple_element(make_tuple({A_0, A_1, ..., A_n}), i) => A_i
1943     VLOG(10) << "trying transform "
1944              << "[get_tuple_element(make_tuple({...,A_i,...}), i)] => A_i: "
1945              << get_tuple_element->ToString();
1946     if (ReplaceInstructionIfSameShape(
1947             get_tuple_element,
1948             operand->mutable_operand(get_tuple_element->tuple_index()))) {
1949       return Status::OK();
1950     }
1951   }
1952   return Status::OK();
1953 }
1954 
1955 namespace {
1956 
1957 // Return whether the given reshape instruction leaves the dimensions at the
1958 // given input indices unmodified, and returns their output indices.
1959 //
1960 // Example:
1961 //   input_dim_indices = {2, 3}
1962 //   input  shape = T[a, b, x, y, cd]
1963 //   output shape = T[ab, x, 1, y, c, d]
1964 //   return value = {1, 3}
1965 //
1966 // Precondition: input_dim_indices is sorted.
ReshapeLeavesDimensionsUnmodified(const HloInstruction * hlo,absl::Span<const int64> input_dim_indices)1967 absl::optional<std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
1968     const HloInstruction* hlo, absl::Span<const int64> input_dim_indices) {
1969   CHECK_EQ(HloOpcode::kReshape, hlo->opcode());
1970   CHECK(std::is_sorted(input_dim_indices.begin(), input_dim_indices.end()));
1971 
1972   std::vector<int64> output_dim_indices;
1973   std::vector<std::pair<int64, int64>> unmodified_dims =
1974       ShapeUtil::DimensionsUnmodifiedByReshape(hlo->operand(0)->shape(),
1975                                                hlo->shape());
1976   size_t i = 0;  // index to unmodified_dims
1977   for (int64 input_dim_index : input_dim_indices) {
1978     // Search unmodified_dims for input_dim_index. We can search from the last
1979     // matching position because input_dim_indices is guaranteed to be sorted.
1980     while (i < unmodified_dims.size() &&
1981            unmodified_dims[i].first < input_dim_index) {
1982       ++i;
1983     }
1984     if (i >= unmodified_dims.size() ||
1985         unmodified_dims[i].first != input_dim_index) {
1986       return absl::nullopt;
1987     }
1988     output_dim_indices.push_back(unmodified_dims[i].second);
1989   }
1990   return output_dim_indices;
1991 }
1992 
1993 // Returns true if the output of "instruction" is a permutation of the
1994 // elements of "operand". Precondition: "operand" is an operand of
1995 // "instruction".
OutputIsPermutationOfOperandElements(HloInstruction * instruction,HloInstruction * operand)1996 bool OutputIsPermutationOfOperandElements(HloInstruction* instruction,
1997                                           HloInstruction* operand) {
1998   DCHECK(!instruction->OperandIndices(operand).empty());
1999   switch (instruction->opcode()) {
2000     case HloOpcode::kReshape:
2001     case HloOpcode::kReverse:
2002     case HloOpcode::kTranspose:
2003       return true;
2004     case HloOpcode::kSort:
2005       return (!instruction->shape().IsTuple());
2006     default:
2007       return false;
2008   }
2009 }
2010 
2011 // Returns true if the output of "instruction" is a subset of the elements of
2012 // "operand". Precondition: "operand" is an operand of "instruction".
OutputIsSubsetOfOperandElements(HloInstruction * instruction,HloInstruction * operand)2013 bool OutputIsSubsetOfOperandElements(HloInstruction* instruction,
2014                                      HloInstruction* operand) {
2015   std::vector<int64> operand_indices = instruction->OperandIndices(operand);
2016   CHECK(!operand_indices.empty());
2017   if (operand_indices.size() != 1) {
2018     return false;
2019   }
2020   int64 operand_index = operand_indices[0];
2021   switch (instruction->opcode()) {
2022     case HloOpcode::kSlice:
2023       CHECK_EQ(0, operand_index);
2024       return true;
2025     case HloOpcode::kDynamicSlice:
2026       return operand_index == 0;
2027     default:
2028       return false;
2029   }
2030 }
2031 
2032 }  // namespace
2033 
HandleBroadcast(HloInstruction * broadcast)2034 Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
2035   HloInstruction* operand;
2036   CHECK(Match(broadcast, m::Broadcast(m::Op(&operand))));
2037   auto dims = broadcast->dimensions();
2038   // A degenerate broadcast of a reshape that does not change the number of
2039   // elements can be replaced by a reshape.
2040   if (std::is_sorted(dims.begin(), dims.end()) &&
2041       ShapeUtil::ElementsIn(broadcast->shape()) ==
2042           ShapeUtil::ElementsIn(operand->shape())) {
2043     VLOG(10) << "transform broadcast(X) -> reshape(X) where "
2044                 "n(broadcast(X)) == n(X)";
2045     return ReplaceWithNewInstruction(
2046         broadcast, HloInstruction::CreateReshape(broadcast->shape(), operand));
2047   }
2048 
2049   // A degenerate broadcast that has the same input and output rank can be
2050   // converted into a transpose.
2051   if (broadcast->shape().rank() == operand->shape().rank() &&
2052       ShapeUtil::ElementsIn(broadcast->shape()) ==
2053           ShapeUtil::ElementsIn(operand->shape())) {
2054     VLOG(10) << "transform broadcast(X) -> transpose(X) where "
2055                 "n(broadcast(X)) == n(X)";
2056     return ReplaceWithNewInstruction(
2057         broadcast,
2058         HloInstruction::CreateTranspose(broadcast->shape(), operand, dims));
2059   }
2060 
2061   // A broadcast of a reshape which merely inserts 1-sized dimensions can
2062   // elide its operand.
2063   {
2064     bool merely_inserts_or_deletes_1_sized_dimensions;
2065     std::vector<int64> inserted_indices, deleted_indices;
2066     std::tie(merely_inserts_or_deletes_1_sized_dimensions, deleted_indices,
2067              inserted_indices) =
2068         operand->ReshapeMerelyInsertsOrDeletes1SizedDimensions();
2069     if (merely_inserts_or_deletes_1_sized_dimensions &&
2070         deleted_indices.empty()) {
2071       std::reverse(inserted_indices.begin(), inserted_indices.end());
2072       for (auto inserted_index : inserted_indices) {
2073         dims.erase(dims.begin() + inserted_index);
2074       }
2075       return ReplaceWithNewInstruction(
2076           broadcast,
2077           HloInstruction::CreateBroadcast(broadcast->shape(),
2078                                           operand->mutable_operand(0), dims));
2079     }
2080   }
2081 
2082   // A Broadcast that feeds a unary element-wise operation can sink the
2083   // broadcast after the unary element-wise operation.
2084   TF_ASSIGN_OR_RETURN(
2085       bool sink_succeeded,
2086       TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(broadcast));
2087   changed_ |= sink_succeeded;
2088   if (sink_succeeded) {
2089     return Status::OK();
2090   }
2091 
2092   // A scalar broadcast feeding an instruction which only permutes (reshape,
2093   // transpose, sort, reverse) or selects a subset of operand elements (slice,
2094   // dynamic slice) can be replaced with a broadcast directly to the output
2095   // shape of the instruction.
2096   if (ShapeUtil::IsScalar(operand->shape())) {
2097     for (HloInstruction* user : broadcast->users()) {
2098       // Skip if the broadcast user has no uses itself.
2099       if (user->user_count() == 0 && user != computation_->root_instruction()) {
2100         continue;
2101       }
2102       if (OutputIsPermutationOfOperandElements(user, broadcast) ||
2103           OutputIsSubsetOfOperandElements(user, broadcast)) {
2104         VLOG(10) << "transform permuting/subset  of a scalar broadcast into "
2105                  << "a single broadcast";
2106         HloInstruction* new_broadcast = computation_->AddInstruction(
2107             HloInstruction::CreateBroadcast(user->shape(), operand, {}));
2108         // Use HloInstruction::ReplaceAllUsesWith instead of
2109         // HloComputation::ReplaceWithNewInstruction because we are replacing an
2110         // instruction other than the visited instruction.
2111         changed_ = true;
2112         return user->ReplaceAllUsesWith(new_broadcast);
2113       }
2114     }
2115     return Status::OK();
2116   }
2117 
2118   // broadcast(iota) -> iota.
2119   if (operand->opcode() == HloOpcode::kIota) {
2120     return ReplaceWithNewInstruction(
2121         broadcast,
2122         HloInstruction::CreateIota(
2123             broadcast->shape(),
2124             dims[Cast<HloIotaInstruction>(operand)->iota_dimension()]));
2125   }
2126 
2127   // Merge two consecutive broadcasts into a single one.
2128   if (operand->opcode() == HloOpcode::kBroadcast) {
2129     std::vector<int64> new_dimensions;
2130     for (auto dim : operand->dimensions()) {
2131       new_dimensions.push_back(dims[dim]);
2132     }
2133     return ReplaceWithNewInstruction(
2134         broadcast,
2135         HloInstruction::CreateBroadcast(
2136             broadcast->shape(), operand->mutable_operand(0), new_dimensions));
2137   }
2138   return Status::OK();
2139 }
2140 
2141 // A conversion to the same element type as the operand is a nop and can be
2142 // removed.  A conversion of a constant can be simplified by making a new
2143 // constant.
HandleConvert(HloInstruction * convert)2144 Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) {
2145   PrimitiveType src_type = convert->operand(0)->shape().element_type();
2146   PrimitiveType dest_type = convert->shape().element_type();
2147   if (src_type == dest_type) {
2148     return ReplaceInstruction(convert, convert->mutable_operand(0));
2149   }
2150   return Status::OK();
2151 }
2152 
2153 // Complex(Real(c), Imag(c)) -> c
HandleComplex(HloInstruction * complex)2154 Status AlgebraicSimplifierVisitor::HandleComplex(HloInstruction* complex) {
2155   HloInstruction *c0, *c1;
2156   if (Match(complex, m::Complex(m::Real(m::Op(&c0)), m::Imag(m::Op(&c1)))) &&
2157       c0 == c1) {
2158     return ReplaceInstruction(complex, c0);
2159   }
2160   return Status::OK();
2161 }
2162 
2163 // Real(Complex(r, i)) -> r
HandleReal(HloInstruction * real)2164 Status AlgebraicSimplifierVisitor::HandleReal(HloInstruction* real) {
2165   HloInstruction* op;
2166   if (Match(real, m::Real(m::Complex(m::Op(&op), m::Op())))) {
2167     return ReplaceInstruction(real, op);
2168   }
2169   return Status::OK();
2170 }
2171 
2172 // Imag(Complex(r, i)) -> i
HandleImag(HloInstruction * imag)2173 Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) {
2174   HloInstruction* op;
2175   if (Match(imag, m::Imag(m::Complex(m::Op(), m::Op(&op))))) {
2176     return ReplaceInstruction(imag, op);
2177   }
2178   return Status::OK();
2179 }
2180 
HandleIota(HloInstruction * instruction)2181 Status AlgebraicSimplifierVisitor::HandleIota(HloInstruction* instruction) {
2182   // iota -> zero if the iota dimension never produces an element other than
2183   // zero.
2184   auto* iota = Cast<HloIotaInstruction>(instruction);
2185   if (iota->shape().dimensions(iota->iota_dimension()) <= 1) {
2186     auto zero = computation_->AddInstruction(HloInstruction::CreateConstant(
2187         LiteralUtil::Zero(iota->shape().element_type()).Clone()));
2188     return ReplaceWithNewInstruction(
2189         iota, HloInstruction::CreateBroadcast(iota->shape(), zero, {}));
2190   }
2191   return Status::OK();
2192 }
2193 
HandlePad(HloInstruction * pad)2194 Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
2195   if (ShapeUtil::IsZeroElementArray(pad->operand(0)->shape())) {
2196     return ReplaceWithNewInstruction(
2197         pad, HloInstruction::CreateBroadcast(pad->shape(),
2198                                              pad->mutable_operand(1), {}));
2199   }
2200 
2201   // Interior padding on one sized dimensions have no effect. As a result it
2202   // makes other simplifications possible if there is no interior padding.
2203   if (HasInteriorPadding(pad->padding_config())) {
2204     PaddingConfig padding_config = pad->padding_config();
2205     bool cleared_interior_padding = false;
2206     for (int64 i = 0; i < pad->shape().rank(); ++i) {
2207       if (padding_config.dimensions(i).interior_padding() > 0 &&
2208           pad->operand(0)->shape().dimensions(i) == 1) {
2209         cleared_interior_padding = true;
2210         padding_config.mutable_dimensions(i)->set_interior_padding(0);
2211       }
2212     }
2213     if (cleared_interior_padding) {
2214       return ReplaceWithNewInstruction(
2215           pad,
2216           HloInstruction::CreatePad(pad->shape(), pad->mutable_operand(0),
2217                                     pad->mutable_operand(1), padding_config));
2218     }
2219   }
2220 
2221   // Eliminate nop pads (padding all zero), and replace a pad with negative
2222   // padding with a pad with non-negative padding followed by a slice.
2223   bool all_zero = true;
2224   bool has_negative = false;
2225   for (auto& padding_dimension : pad->padding_config().dimensions()) {
2226     if (padding_dimension.edge_padding_low() < 0 ||
2227         padding_dimension.edge_padding_high() < 0) {
2228       has_negative = true;
2229     }
2230     if (padding_dimension.edge_padding_low() != 0 ||
2231         padding_dimension.edge_padding_high() != 0) {
2232       all_zero = false;
2233     }
2234   }
2235 
2236   if (all_zero) {
2237     ReplaceInstructionIfSameShape(pad, pad->mutable_operand(0));
2238     return Status::OK();
2239   }
2240 
2241   if (has_negative) {
2242     // Pad has negative padding. Replace with a pad with the non-negative
2243     // padding followed by a slice which effectively performs the negative
2244     // padding.
2245     // TODO(b/34628603): Add support for negative padding in the backends, or
2246     // change kPad semantics to disallow negative padding and use slice
2247     // instead.
2248 
2249     // First construct the padding config with non-negative entries and the
2250     // compute the shape of this new pad instruction.
2251     PaddingConfig nonzero_padding = pad->padding_config();
2252     for (int i = 0; i < pad->padding_config().dimensions_size(); ++i) {
2253       PaddingConfig::PaddingConfigDimension* padding_dimension =
2254           nonzero_padding.mutable_dimensions(i);
2255       // Set negative padding to zero.
2256       if (padding_dimension->edge_padding_low() < 0) {
2257         padding_dimension->set_edge_padding_low(0);
2258       }
2259       if (padding_dimension->edge_padding_high() < 0) {
2260         padding_dimension->set_edge_padding_high(0);
2261       }
2262     }
2263 
2264     TF_ASSIGN_OR_RETURN(HloInstruction * nonzero_pad,
2265                         MakePadHlo(pad->mutable_operand(0),
2266                                    pad->mutable_operand(1), nonzero_padding));
2267     // Copy the layout from the original pad instructions. The new pad and the
2268     // slice instruction should all have the same layout.
2269     TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
2270         pad->shape(), nonzero_pad->mutable_shape()));
2271 
2272     // Second, construct the slice instruction to perform the negative padding.
2273     std::vector<int64> start_indices;
2274     std::vector<int64> end_indices;
2275     std::vector<int64> strides;
2276     for (int64 i = 0; i < pad->padding_config().dimensions_size(); ++i) {
2277       const PaddingConfig::PaddingConfigDimension& padding_dimension =
2278           pad->padding_config().dimensions(i);
2279       int64 start = 0;
2280       if (padding_dimension.edge_padding_low() < 0) {
2281         start = -1 * padding_dimension.edge_padding_low();
2282       }
2283       int64 end = nonzero_pad->shape().dimensions(i);
2284       if (padding_dimension.edge_padding_high() < 0) {
2285         end += padding_dimension.edge_padding_high();
2286       }
2287       start_indices.push_back(start);
2288       end_indices.push_back(end);
2289       strides.push_back(1);
2290     }
2291 
2292     TF_ASSIGN_OR_RETURN(
2293         HloInstruction * slice,
2294         MakeSliceHlo(nonzero_pad, start_indices, end_indices, strides));
2295 
2296     // Verify that the slice shape matches the pad shape.
2297     TF_RET_CHECK(ShapeUtil::Compatible(slice->shape(), pad->shape()));
2298 
2299     return ReplaceInstruction(pad, slice);
2300   }
2301 
2302   return Status::OK();
2303 }
2304 
HandlePower(HloInstruction * power)2305 Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
2306   VLOG(10) << "trying transform [pow(A, 0) => 1]: " << power->ToString();
2307   HloInstruction *lhs, *rhs;
2308   CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs))));
2309   if (IsAll(rhs, 0)) {
2310     auto one = HloInstruction::CreateConstant(
2311         LiteralUtil::One(power->shape().element_type()).Clone());
2312     std::unique_ptr<HloInstruction> ones;
2313     if (ShapeUtil::IsScalar(power->shape())) {
2314       ones = std::move(one);
2315     } else {
2316       ones = HloInstruction::CreateBroadcast(
2317           power->shape(), computation_->AddInstruction(std::move(one)), {});
2318     }
2319     return ReplaceWithNewInstruction(power, std::move(ones));
2320   }
2321 
2322   VLOG(10) << "trying transform [pow(A, 1) => A]: " << power->ToString();
2323   if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(power, lhs)) {
2324     return Status::OK();
2325   }
2326 
2327   // pow(exp(A),B) => exp(A*B)
2328   HloInstruction *a, *b;
2329   if (Match(power, m::Power(m::Exp(m::Op(&a)), m::Op(&b)))) {
2330     auto a_times_b = computation_->AddInstruction(HloInstruction::CreateBinary(
2331         power->shape(), HloOpcode::kMultiply, a, b));
2332     return ReplaceWithNewInstruction(
2333         power, HloInstruction::CreateUnary(power->shape(), HloOpcode::kExp,
2334                                            a_times_b));
2335   }
2336   VLOG(10) << "trying transform [pow(A, 2) => A*A]: " << power->ToString();
2337   if (IsAll(rhs, 2)) {
2338     return ReplaceWithNewInstruction(
2339         power, HloInstruction::CreateBinary(power->shape(),
2340                                             HloOpcode::kMultiply, lhs, lhs));
2341   }
2342 
2343   VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString();
2344   if (IsAll(rhs, -1)) {
2345     auto* one = computation_->AddInstruction(HloInstruction::CreateConstant(
2346         LiteralUtil::One(rhs->shape().element_type()).Clone()));
2347 
2348     // Explicitly broadcast scalar 1 to the output shape, to avoid implicit
2349     // broadcast in divide HLO as we are trying to eliminate implicit
2350     // broadcasting at HLO level.
2351     auto* broadcast_one = computation_->AddInstruction(
2352         HloInstruction::CreateBroadcast(power->shape(), one, {}));
2353     return ReplaceWithNewInstruction(
2354         power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kDivide,
2355                                             broadcast_one, lhs));
2356   }
2357 
2358   VLOG(10) << "trying transform [pow(pow(A, X), Y) => pow(A, X*Y)]: "
2359            << power->ToString();
2360 
2361   // Don't perform this optimization if either of the exponents is complex; this
2362   // identity is true only for real-valued exponents.  In addition, we cowardly
2363   // refuse to do this transformation if the two expontents have different
2364   // element types.
2365   if (lhs->opcode() == HloOpcode::kPower &&
2366       !ShapeUtil::ElementIsComplex(lhs->operand(1)->shape()) &&
2367       !ShapeUtil::ElementIsComplex(rhs->shape()) &&
2368       ShapeUtil::SameElementType(lhs->operand(1)->shape(), rhs->shape())) {
2369     auto exponent_product =
2370         computation_->AddInstruction(HloInstruction::CreateBinary(
2371             rhs->shape(), HloOpcode::kMultiply, lhs->mutable_operand(1), rhs));
2372     return ReplaceWithNewInstruction(
2373         power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kPower,
2374                                             lhs->mutable_operand(0),
2375                                             exponent_product));
2376   }
2377 
2378   return Status::OK();
2379 }
2380 
2381 StatusOr<bool>
TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(HloInstruction * broadcast)2382 AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
2383     HloInstruction* broadcast) {
2384   TF_RET_CHECK(broadcast->opcode() == HloOpcode::kBroadcast);
2385   bool changed = false;
2386   if (ShapeUtil::IsScalar(broadcast->shape())) {
2387     return false;
2388   }
2389   HloInstruction* operand = broadcast->mutable_operand(0);
2390   for (HloInstruction* user : broadcast->users()) {
2391     if (user->user_count() == 0 && user != computation_->root_instruction()) {
2392       continue;
2393     }
2394     // Do not move reshapes or broadcasts past copies since the shape the copy
2395     // will operate on will change.
2396     if (user->opcode() == HloOpcode::kCopy) {
2397       continue;
2398     }
2399     // Do not change the shape of fusion nodes in case there a multiple shapes
2400     // inside the fusion node already.
2401     if (user->opcode() == HloOpcode::kFusion) {
2402       continue;
2403     }
2404     if (!user->IsElementwise()) {
2405       continue;
2406     }
2407 
2408     // Find the unique non-scalar operand or continue if there isn't one.
2409     int64 scalar_broadcast_count = 0;
2410     int64 broadcast_use_count = 0;
2411     for (HloInstruction* user_operand : user->operands()) {
2412       if (user_operand->opcode() == HloOpcode::kBroadcast &&
2413           ShapeUtil::IsScalar(user_operand->operand(0)->shape())) {
2414         ++scalar_broadcast_count;
2415       } else if (broadcast == user_operand) {
2416         ++broadcast_use_count;
2417       }
2418     }
2419     if (scalar_broadcast_count + broadcast_use_count != user->operand_count()) {
2420       continue;
2421     }
2422     std::vector<HloInstruction*> new_operands;
2423     new_operands.reserve(user->operand_count());
2424 
2425     for (HloInstruction* user_operand : user->operands()) {
2426       if (user_operand->opcode() == HloOpcode::kBroadcast &&
2427           ShapeUtil::IsScalar(user_operand->operand(0)->shape())) {
2428         new_operands.push_back(
2429             computation_->AddInstruction(HloInstruction::CreateBroadcast(
2430                 ShapeUtil::ChangeElementType(
2431                     operand->shape(), user_operand->shape().element_type()),
2432                 user_operand->mutable_operand(0), {})));
2433       } else {
2434         CHECK_EQ(broadcast, user_operand);
2435         new_operands.push_back(operand);
2436       }
2437     }
2438     VLOG(4) << "Sinking broadcast after user:";
2439     VLOG(4) << "  old broadcast: " << broadcast->ToString();
2440     VLOG(4) << "  old user: " << user->ToString();
2441     HloInstruction* new_user =
2442         computation_->AddInstruction(user->CloneWithNewOperands(
2443             ShapeUtil::ChangeElementType(operand->shape(),
2444                                          user->shape().element_type()),
2445             new_operands));
2446     VLOG(4) << "  new user: " << new_user->ToString();
2447     HloInstruction* new_broadcast =
2448         computation_->AddInstruction(HloInstruction::CreateBroadcast(
2449             user->shape(), new_user, broadcast->dimensions()));
2450     VLOG(4) << "  new broadcast: " << new_broadcast->ToString();
2451     TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(new_broadcast));
2452     changed = true;
2453   }
2454   return changed;
2455 }
2456 
2457 namespace {
2458 template <typename T>
TryRemainderToAnd(HloInstruction * remainder,HloComputation * computation)2459 std::unique_ptr<HloInstruction> TryRemainderToAnd(HloInstruction* remainder,
2460                                                   HloComputation* computation) {
2461   HloInstruction *a, *b, *c;
2462   CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b))));
2463 
2464   if (ShapeUtil::ElementIsIntegral(remainder->shape()) &&
2465       !Match(b, m::ConstantEffectiveScalar(&c)) &&
2466       !Match(b, m::Broadcast(m::ConstantEffectiveScalar(&c)))) {
2467     return nullptr;
2468   }
2469 
2470   if (ShapeUtil::ElementIsSigned(remainder->shape())) {
2471     int64 b_value = c->literal().GetFirstElement<T>();
2472     if (b_value > 0 && IsPowerOfTwo(static_cast<uint64>(b_value))) {
2473       // Handle negative dividends by negating the result of the division.
2474       HloInstruction* zero_like_a = BroadcastZeros(
2475           computation, a->shape().element_type(), a->shape().dimensions());
2476 
2477       auto* dividend_is_negative =
2478           computation->AddInstruction(HloInstruction::CreateCompare(
2479               ShapeUtil::ChangeElementType(a->shape(), PRED), a, zero_like_a,
2480               ComparisonDirection::kLt));
2481 
2482       auto* negated_dividend = computation->AddInstruction(
2483           HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a));
2484 
2485       auto* abs_dividend =
2486           computation->AddInstruction(HloInstruction::CreateTernary(
2487               a->shape(), HloOpcode::kSelect, dividend_is_negative,
2488               negated_dividend, a));
2489 
2490       auto* mask_amount =
2491           computation->AddInstruction(HloInstruction::CreateConstant(
2492               LiteralUtil::CreateR0<T>(b_value - 1)));
2493       if (!ShapeUtil::IsScalar(b->shape())) {
2494         mask_amount = computation->AddInstruction(
2495             HloInstruction::CreateBroadcast(b->shape(), mask_amount, {}));
2496       }
2497 
2498       auto* quotient = computation->AddInstruction(HloInstruction::CreateBinary(
2499           remainder->shape(), HloOpcode::kAnd, abs_dividend, mask_amount));
2500 
2501       auto* neqated_quotient =
2502           computation->AddInstruction(HloInstruction::CreateUnary(
2503               quotient->shape(), HloOpcode::kNegate, quotient));
2504 
2505       return HloInstruction::CreateTernary(
2506           remainder->shape(), HloOpcode::kSelect, dividend_is_negative,
2507           neqated_quotient, quotient);
2508     }
2509   } else {
2510     uint64 b_value = c->literal().GetFirstElement<T>();
2511     if (IsPowerOfTwo(b_value)) {
2512       HloInstruction* mask_amount =
2513           computation->AddInstruction(HloInstruction::CreateConstant(
2514               LiteralUtil::CreateR0<T>(b_value - 1)));
2515       if (!ShapeUtil::IsScalar(b->shape())) {
2516         mask_amount = computation->AddInstruction(
2517             HloInstruction::CreateBroadcast(b->shape(), mask_amount, {}));
2518       }
2519       return HloInstruction::CreateBinary(remainder->shape(), HloOpcode::kAnd,
2520                                           a, mask_amount);
2521     }
2522   }
2523   return nullptr;
2524 }
2525 }  // namespace
2526 
HandleRemainder(HloInstruction * remainder)2527 Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) {
2528   HloInstruction *a, *b;
2529   CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b))));
2530 
2531   // A % B => A & (B - 1) if B is a power of 2.
2532   switch (remainder->shape().element_type()) {
2533     case S8:
2534       if (std::unique_ptr<HloInstruction> shift =
2535               TryRemainderToAnd<int8>(remainder, computation_)) {
2536         return ReplaceWithNewInstruction(remainder, std::move(shift));
2537       }
2538       break;
2539     case S16:
2540       if (std::unique_ptr<HloInstruction> shift =
2541               TryRemainderToAnd<int16>(remainder, computation_)) {
2542         return ReplaceWithNewInstruction(remainder, std::move(shift));
2543       }
2544       break;
2545     case S32:
2546       if (std::unique_ptr<HloInstruction> shift =
2547               TryRemainderToAnd<int32>(remainder, computation_)) {
2548         return ReplaceWithNewInstruction(remainder, std::move(shift));
2549       }
2550       break;
2551     case S64:
2552       if (std::unique_ptr<HloInstruction> shift =
2553               TryRemainderToAnd<int64>(remainder, computation_)) {
2554         return ReplaceWithNewInstruction(remainder, std::move(shift));
2555       }
2556       break;
2557     case U8:
2558       if (std::unique_ptr<HloInstruction> shift =
2559               TryRemainderToAnd<uint8>(remainder, computation_)) {
2560         return ReplaceWithNewInstruction(remainder, std::move(shift));
2561       }
2562       break;
2563     case U16:
2564       if (std::unique_ptr<HloInstruction> shift =
2565               TryRemainderToAnd<uint16>(remainder, computation_)) {
2566         return ReplaceWithNewInstruction(remainder, std::move(shift));
2567       }
2568       break;
2569     case U32:
2570       if (std::unique_ptr<HloInstruction> shift =
2571               TryRemainderToAnd<uint32>(remainder, computation_)) {
2572         return ReplaceWithNewInstruction(remainder, std::move(shift));
2573       }
2574       break;
2575     case U64:
2576       if (std::unique_ptr<HloInstruction> shift =
2577               TryRemainderToAnd<uint64>(remainder, computation_)) {
2578         return ReplaceWithNewInstruction(remainder, std::move(shift));
2579       }
2580       break;
2581     default:
2582       break;
2583   }
2584 
2585   return Status::OK();
2586 }
2587 
HandleReshape(HloInstruction * reshape)2588 Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
2589   auto operand = reshape->mutable_operand(0);
2590 
2591   // Reshape directly to empty constant if the shape contains zero-element
2592   // dimension.
2593   if (ShapeUtil::IsZeroElementArray(reshape->shape())) {
2594     // If the instruction doesn't have a layout, use a default layout for
2595     // the literal result.
2596     Shape reshaped_shape = reshape->shape();
2597     if (!LayoutUtil::HasLayout(reshaped_shape)) {
2598       LayoutUtil::SetToDefaultLayout(&reshaped_shape);
2599     }
2600     auto empty_constant = HloInstruction::CreateConstant(
2601         Literal::CreateFromShape(reshaped_shape));
2602 
2603     return ReplaceWithNewInstruction(reshape, std::move(empty_constant));
2604   }
2605 
2606   // Delete no-op reshapes, i.e. where shape = operand shape.
2607   if (SameShape(reshape, operand)) {
2608     VLOG(10) << "deleting no-op reshape";
2609     return ReplaceInstruction(reshape, operand);
2610   }
2611 
2612   // Merge reshapes.
2613   if (HloOpcode::kReshape == operand->opcode()) {
2614     return ReplaceWithNewInstruction(
2615         reshape, HloInstruction::CreateReshape(reshape->shape(),
2616                                                operand->mutable_operand(0)));
2617   }
2618 
2619   if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) {
2620     *operand->mutable_shape() = reshape->shape();
2621     return ReplaceInstruction(reshape, operand);
2622   }
2623 
2624   if (HloOpcode::kBroadcast == reshape->operand(0)->opcode()) {
2625     auto opt_dims = ReshapeLeavesDimensionsUnmodified(
2626         reshape, reshape->operand(0)->dimensions());
2627     if (opt_dims.has_value()) {
2628       return ReplaceWithNewInstruction(
2629           reshape,
2630           HloInstruction::CreateBroadcast(
2631               reshape->shape(), reshape->mutable_operand(0)->mutable_operand(0),
2632               *opt_dims));
2633     }
2634   }
2635 
2636   // reshape(iota) -> iota.
2637   if (operand->opcode() == HloOpcode::kIota) {
2638     auto* iota = Cast<HloIotaInstruction>(operand);
2639     auto opt_dims =
2640         ReshapeLeavesDimensionsUnmodified(reshape, {iota->iota_dimension()});
2641     if (opt_dims.has_value()) {
2642       CHECK_EQ(opt_dims->size(), 1);
2643       return ReplaceWithNewInstruction(
2644           reshape,
2645           HloInstruction::CreateIota(reshape->shape(), opt_dims->front()));
2646     }
2647   }
2648 
2649   // Make this a bitcast if possible.
2650   if (HloInstruction* bitcast_operand =
2651           BitcastingOperandOfReshapeOrCopyChain(reshape, options_)) {
2652     ReplaceWithBitcast(reshape, bitcast_operand);
2653   }
2654   return Status::OK();
2655 }
2656 
HandleReverse(HloInstruction * reverse)2657 Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse) {
2658   // When all the dimensions to reverse are trivial (i.e. the bound is 1),
2659   // there is nothing to be done.
2660   auto dim_is_one = [&](int64 i) -> bool {
2661     return reverse->shape().dimensions(i) == 1;
2662   };
2663   if (absl::c_all_of(reverse->dimensions(), dim_is_one)) {
2664     return ReplaceInstruction(reverse, reverse->mutable_operand(0));
2665   }
2666   return Status::OK();
2667 }
2668 
TrySimplifyScalarSlice(HloInstruction * slice)2669 StatusOr<bool> AlgebraicSimplifierVisitor::TrySimplifyScalarSlice(
2670     HloInstruction* slice) {
2671   // Only try to do this for effective scalars. We could do the same for slicing
2672   // out larger pieces of padding (replacing with a broadcast of the padding
2673   // value), but this is probably not worth it.
2674   if (!ShapeUtil::IsEffectiveScalar(slice->shape())) {
2675     return false;
2676   }
2677 
2678   if (slice->operand(0)->opcode() == HloOpcode::kPad) {
2679     VLOG(10) << "Trying to simplify scalar slice of pad";
2680     // Check there's no internal padding. Again, we could handle that too, since
2681     // everything is statically known, but it's not worth it.
2682     auto pad = Cast<HloPadInstruction>(slice->mutable_operand(0));
2683     auto padding_config = pad->padding_config();
2684     int64 rank = padding_config.dimensions_size();
2685     if (HasInteriorPadding(padding_config)) {
2686       VLOG(10) << "Not folding scalar slice of pad, pad has interior padding";
2687       return false;
2688     }
2689 
2690     // Check whether the scalar we're slicing out falls into the padding.
2691     bool in_padding = [&]() {
2692       for (int64 i = 0; i < rank; ++i) {
2693         int64 start = slice->slice_starts(i);
2694         int64 low = padding_config.dimensions(i).edge_padding_low();
2695         int64 data = pad->operand(0)->shape().dimensions(i);
2696         if (start < low || start >= low + data) {
2697           return true;
2698         }
2699       }
2700       return false;
2701     }();
2702 
2703     if (in_padding) {
2704       VLOG(10) << "Folding scalar slice of pad into padding value";
2705       TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
2706           slice, HloInstruction::CreateReshape(slice->shape(),
2707                                                pad->mutable_padding_value())));
2708       return true;
2709     } else {
2710       // We already know the output of the slice is scalar. If the padded
2711       // value is scalar, and it's not in the padding, then it's exactly the
2712       // output value.
2713       bool replaced =
2714           ReplaceInstructionIfSameShape(slice, pad->mutable_operand(0));
2715       if (replaced) {
2716         VLOG(10) << "Folding scalar slice of pad into padded value";
2717       } else {
2718         VLOG(10) << "Not folding scalar slice of pad into padded value as they "
2719                     "have different shapes.";
2720       }
2721       return replaced;
2722     }
2723   }
2724 
2725   if (slice->operand(0)->opcode() == HloOpcode::kConcatenate) {
2726     VLOG(10) << "Trying to simplify scalar slice of concat";
2727     // Only do this for R1, there's no chance of this being useful otherwise.
2728     if (slice->shape().rank() != 1) {
2729       VLOG(10) << "Not folding, slice is not rank 1";
2730       return false;
2731     }
2732     HloConcatenateInstruction* concat =
2733         Cast<HloConcatenateInstruction>(slice->mutable_operand(0));
2734     int64 operand_start = 0;
2735     int64 operand_num = 0;
2736     // Weird loop structure to avoid annoying off-by-one errors.
2737     while (true) {
2738       TF_RET_CHECK(operand_num < concat->operand_count());
2739       const HloInstruction* operand = concat->operand(operand_num);
2740       int64 next_operand_start = operand_start + operand->shape().dimensions(0);
2741       if (next_operand_start > slice->slice_starts(0)) {
2742         break;
2743       }
2744       operand_start = next_operand_start;
2745       operand_num++;
2746     }
2747 
2748     bool replaced = ReplaceInstructionIfSameShape(
2749         slice, concat->mutable_operand(operand_num));
2750     if (replaced) {
2751       VLOG(10) << "Folding scalar slice of concat into concat operand";
2752     } else {
2753       VLOG(10) << "Folding scalar slice of concat into slice of concat operand";
2754       TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
2755           slice, HloInstruction::CreateSlice(
2756                      slice->shape(), concat->mutable_operand(operand_num),
2757                      {slice->slice_starts(0) - operand_start},
2758                      {slice->slice_starts(0) - operand_start + 1},
2759                      slice->slice_strides())));
2760     }
2761     return true;
2762   }
2763 
2764   return false;
2765 }
2766 
TryToReorderSliceAndReshape(HloInstruction * slice)2767 StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape(
2768     HloInstruction* slice) {
2769   CHECK_EQ(slice->opcode(), HloOpcode::kSlice);
2770   if (!IsUnstridedSlice(slice)) {
2771     return false;
2772   }
2773   HloInstruction* reshape = slice->mutable_operand(0);
2774   if (reshape->opcode() != HloOpcode::kReshape) {
2775     return false;
2776   }
2777   HloInstruction* new_slice_operand = reshape->mutable_operand(0);
2778   int64 slice_rank = slice->shape().rank();
2779   std::vector<int64> sliced_dims;
2780   for (int64 i = 0; i < slice_rank; ++i) {
2781     if (slice->slice_starts(i) != 0 ||
2782         slice->slice_limits(i) != reshape->shape().dimensions(i)) {
2783       sliced_dims.push_back(i);
2784     }
2785   }
2786 
2787   if (sliced_dims.size() == 1 && sliced_dims[0] == 0 &&
2788       slice->slice_starts(0) == 0) {
2789     const Shape& new_slice_shape = new_slice_operand->shape();
2790     const int64 rank = new_slice_shape.rank();
2791     std::vector<int64> new_slice_starts(rank, 0);
2792     std::vector<int64> new_slice_stides(rank, 1);
2793     std::vector<int64> new_slice_limits(new_slice_shape.dimensions().begin(),
2794                                         new_slice_shape.dimensions().end());
2795     int64 slice_elements = ShapeUtil::ElementsIn(slice->shape());
2796     for (int64 i = rank - 1; i >= 0; --i) {
2797       if (slice_elements >= new_slice_limits[i]) {
2798         if (slice_elements % new_slice_limits[i] != 0) {
2799           return false;
2800         }
2801         slice_elements /= new_slice_limits[i];
2802       } else {
2803         new_slice_limits[i] = slice_elements;
2804         slice_elements = 1;
2805       }
2806     }
2807     HloInstruction* new_slice =
2808         computation_->AddInstruction(HloInstruction::CreateSlice(
2809             ShapeUtil::MakeShape(new_slice_shape.element_type(),
2810                                  new_slice_limits),
2811             new_slice_operand, new_slice_starts, new_slice_limits,
2812             new_slice_stides));
2813     TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
2814         slice, HloInstruction::CreateReshape(slice->shape(), new_slice)));
2815     return true;
2816   }
2817   return false;
2818 }
2819 
HandleSlice(HloInstruction * slice)2820 Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) {
2821   // Delete no-op slices, i.e. where shape = operand shape.
2822   if (ReplaceInstructionIfSameShape(slice, slice->mutable_operand(0))) {
2823     return Status::OK();
2824   }
2825 
2826   if (slice->operand(0)->opcode() == HloOpcode::kSlice &&
2827       IsUnstridedSlice(slice) && IsUnstridedSlice(slice->operand(0))) {
2828     HloInstruction* operand_slice = slice->mutable_operand(0);
2829     std::vector<int64> new_slice_starts = slice->slice_starts();
2830     std::vector<int64> new_slice_limits = slice->slice_limits();
2831     for (int64 i = 0; i < new_slice_starts.size(); ++i) {
2832       new_slice_starts[i] += operand_slice->slice_starts(i);
2833       new_slice_limits[i] += operand_slice->slice_starts(i);
2834     }
2835     return ReplaceWithNewInstruction(
2836         slice, HloInstruction::CreateSlice(
2837                    slice->shape(), operand_slice->mutable_operand(0),
2838                    new_slice_starts, new_slice_limits, slice->slice_strides()));
2839   }
2840 
2841   auto only_broadcast_dims_sliced = [&] {
2842     if (slice->operand(0)->opcode() != HloOpcode::kBroadcast) {
2843       return false;
2844     }
2845     for (int64 dim : slice->operand(0)->dimensions()) {
2846       if (slice->slice_starts(dim) != 0 || slice->slice_strides(dim) != 1 ||
2847           slice->slice_limits(dim) !=
2848               slice->operand(0)->shape().dimensions(dim)) {
2849         return false;
2850       }
2851     }
2852     return true;
2853   };
2854   if (only_broadcast_dims_sliced()) {
2855     return ReplaceWithNewInstruction(
2856         slice,
2857         HloInstruction::CreateBroadcast(
2858             slice->shape(), slice->mutable_operand(0)->mutable_operand(0),
2859             slice->mutable_operand(0)->dimensions()));
2860   }
2861 
2862   TF_ASSIGN_OR_RETURN(bool replaced, TrySimplifyScalarSlice(slice));
2863   if (replaced) {
2864     return Status::OK();
2865   }
2866 
2867   TF_ASSIGN_OR_RETURN(replaced, TryToReorderSliceAndReshape(slice));
2868   if (replaced) {
2869     return Status::OK();
2870   }
2871   return Status::OK();
2872 }
2873 
HandleDynamicSlice(HloInstruction * dynamic_slice)2874 Status AlgebraicSimplifierVisitor::HandleDynamicSlice(
2875     HloInstruction* dynamic_slice) {
2876   auto operand = dynamic_slice->mutable_operand(0);
2877   if (ShapeUtil::IsScalar(dynamic_slice->shape())) {
2878     return ReplaceInstruction(dynamic_slice, operand);
2879   }
2880   // DynamicSlice where operand has the same size as the output is simply equal
2881   // to operand.
2882   if (SameShape(operand, dynamic_slice)) {
2883     return ReplaceInstruction(dynamic_slice, operand);
2884   }
2885   return Status::OK();
2886 }
2887 
HandleDynamicUpdateSlice(HloInstruction * dynamic_update_slice)2888 Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice(
2889     HloInstruction* dynamic_update_slice) {
2890   auto update = dynamic_update_slice->mutable_operand(1);
2891 
2892   // DynamicUpdateSlice where operand and update have the same size is simply
2893   // equal to update.
2894   if (SameShape(dynamic_update_slice, update)) {
2895     return ReplaceInstruction(dynamic_update_slice, update);
2896   }
2897 
2898   // If any dimension of update is 0, elide the DynamicUpdateSlice.  This
2899   // optimization becomes invalid should we later prefer to warn about out of
2900   // bound indices.
2901   if (ShapeUtil::IsZeroElementArray(update->shape())) {
2902     return ReplaceInstruction(dynamic_update_slice,
2903                               dynamic_update_slice->mutable_operand(0));
2904   }
2905   return Status::OK();
2906 }
2907 
HandleReduce(HloInstruction * hlo)2908 Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
2909   HloReduceInstruction* reduce = Cast<HloReduceInstruction>(hlo);
2910   bool multi_output_reduce = reduce->shape().IsTuple();
2911 
2912   // For tuple reduce, we require all reduce shapes to be the same, up to the
2913   // element types, so we can just the first operand and the first result as a
2914   // representative.
2915   auto arg = reduce->inputs()[0];
2916   auto init_value = reduce->init_values()[0];
2917   const Shape& reduce_result_shape =
2918       multi_output_reduce ? reduce->shape().tuple_shapes(0) : reduce->shape();
2919 
2920   absl::Span<const int64> dimensions(reduce->dimensions());
2921   HloComputation* function = reduce->to_apply();
2922   if (ShapeUtil::IsZeroElementArray(arg->shape()) ||
2923       ShapeUtil::IsZeroElementArray(reduce_result_shape)) {
2924     if (multi_output_reduce) {
2925       std::vector<HloInstruction*> broadcast_inits;
2926       int64 inputs = reduce->input_count();
2927       for (int64 i = 0; i < inputs; ++i) {
2928         broadcast_inits.push_back(computation_->AddInstruction(
2929             HloInstruction::CreateBroadcast(reduce->shape().tuple_shapes(i),
2930                                             reduce->init_values()[i], {})));
2931       }
2932       return ReplaceWithNewInstruction(
2933           reduce, HloInstruction::CreateTuple(broadcast_inits));
2934     } else {
2935       return ReplaceWithNewInstruction(
2936           reduce,
2937           HloInstruction::CreateBroadcast(reduce_result_shape, init_value, {}));
2938     }
2939   }
2940 
2941   // If the reduction results in the same number of elements, then the only
2942   // possible side effect would be a reshape. Since the init_value is an
2943   // identity of the reduction function, we can therefore replace the reduce
2944   // with a simple reshape, ignoring the reduction function completely.
2945   if (ShapeUtil::ElementsIn(reduce_result_shape) ==
2946       ShapeUtil::ElementsIn(arg->shape())) {
2947     if (multi_output_reduce) {
2948       std::vector<HloInstruction*> reshaped_args;
2949       int64 inputs = reduce->input_count();
2950       for (int64 i = 0; i < inputs; ++i) {
2951         reshaped_args.push_back(
2952             computation_->AddInstruction(HloInstruction::CreateReshape(
2953                 reduce->shape().tuple_shapes(i), reduce->inputs()[i])));
2954       }
2955       return ReplaceWithNewInstruction(
2956           reduce, HloInstruction::CreateTuple(reshaped_args));
2957     } else {
2958       return ReplaceWithNewInstruction(
2959           reduce, HloInstruction::CreateReshape(reduce_result_shape, arg));
2960     }
2961   }
2962 
2963   // TODO(b/112040122): Most of those optimizations below can be done for
2964   // multi-output reduces.
2965   if (multi_output_reduce) {
2966     return Status::OK();
2967   }
2968 
2969   // A Transpose feeding a reduce can simply permute the reduction dimensions
2970   // field if the output of the reduce is a vector or scalar. Higher ranked
2971   // result may require a transpose of the output.
2972   if (reduce_result_shape.rank() <= 1 &&
2973       arg->opcode() == HloOpcode::kTranspose) {
2974     auto transpose_dimensions = arg->dimensions();
2975     std::vector<int64> new_reduce_dimensions;
2976     for (auto dim : dimensions) {
2977       new_reduce_dimensions.push_back(transpose_dimensions[dim]);
2978     }
2979     return ReplaceWithNewInstruction(
2980         reduce, HloInstruction::CreateReduce(
2981                     reduce_result_shape, arg->mutable_operand(0), init_value,
2982                     new_reduce_dimensions, function));
2983   }
2984 
2985   // If a reduce feeds a reduce with the same computation and initial value,
2986   // they can be combined into a single reduce.
2987   if (arg->opcode() == HloOpcode::kReduce &&
2988       init_value->Identical(*arg->operand(1)) &&
2989       *function == *arg->to_apply()) {
2990     // Create a new reduce with the combined reduction dimensions of both
2991     // reduces.
2992     std::vector<int64> arg_dims = arg->dimensions();
2993     absl::c_sort(arg_dims);
2994     std::vector<int64> reduce_dims = reduce->dimensions();
2995     absl::c_sort(reduce_dims);
2996     // Transform reduce_dims to the same rank as the operand of the operand.
2997     for (int64 arg_dim : arg_dims) {
2998       for (int64& dim : reduce_dims) {
2999         if (dim >= arg_dim) {
3000           ++dim;
3001         }
3002       }
3003     }
3004     std::vector<int64> new_dimensions;
3005     new_dimensions.reserve(arg->dimensions().size() +
3006                            reduce->dimensions().size());
3007     std::merge(arg_dims.begin(), arg_dims.end(), reduce_dims.begin(),
3008                reduce_dims.end(), std::back_inserter(new_dimensions));
3009     return ReplaceWithNewInstruction(
3010         reduce, HloInstruction::CreateReduce(
3011                     reduce_result_shape, arg->mutable_operand(0), init_value,
3012                     new_dimensions, function));
3013   }
3014 
3015   // A reshape that collapses multiple dimensions into a dimension being
3016   // reduced can just reduce all of those dimensions instead of doing a
3017   // collapsing reshape before a reduction.
3018   if (arg->opcode() == HloOpcode::kReshape) {
3019     std::vector<std::pair<int64, int64>> unmodified_dims =
3020         ShapeUtil::DimensionsUnmodifiedByReshape(arg->operand(0)->shape(),
3021                                                  arg->shape());
3022     std::vector<bool> arg_dim_in_output(arg->shape().rank(), true);
3023     std::vector<bool> arg_dim_unmodified(arg->shape().rank(), false);
3024     for (auto dim : dimensions) {
3025       arg_dim_in_output[dim] = false;
3026     }
3027     for (auto dim_pair : unmodified_dims) {
3028       arg_dim_unmodified[dim_pair.second] = true;
3029     }
3030     // The goal is to verify that all dimensions that are not removed in the
3031     // reduce are unmodified by the reshape. For example:
3032     // reduce(reshape([A,B*C], a[A,B,C]),[1]) = reduce(a[A, B, C], [1, 2])
3033     bool can_move_reshape_into_reduce = true;
3034     for (int64 i = 0; i < arg_dim_in_output.size(); ++i) {
3035       if (arg_dim_in_output[i] && !arg_dim_unmodified[i]) {
3036         can_move_reshape_into_reduce = false;
3037       }
3038     }
3039     if (can_move_reshape_into_reduce) {
3040       changed_ = true;
3041       absl::flat_hash_set<int64> dimensions_not_to_reduce;
3042       for (auto dim_pair : unmodified_dims) {
3043         if (arg_dim_in_output[dim_pair.second]) {
3044           dimensions_not_to_reduce.insert(dim_pair.first);
3045         }
3046       }
3047       std::vector<int64> new_reduce_dimensions;
3048       for (int64 i = 0; i < arg->operand(0)->shape().rank(); ++i) {
3049         if (!dimensions_not_to_reduce.contains(i)) {
3050           new_reduce_dimensions.push_back(i);
3051         }
3052       }
3053       return ReplaceWithNewInstruction(
3054           reduce, HloInstruction::CreateReduce(
3055                       reduce_result_shape, arg->mutable_operand(0), init_value,
3056                       new_reduce_dimensions, function));
3057     }
3058   }
3059   // Convert Reduce(concat({a,b,...})) to
3060   //  map(reduce(a),map(reduce(b),...,))
3061   //
3062   // This should make fusion easier or use less memory bandwidth in the unfused
3063   // case.
3064   if (arg->opcode() == HloOpcode::kConcatenate &&
3065       absl::c_linear_search(reduce->dimensions(),
3066                             arg->concatenate_dimension())) {
3067     HloInstruction* old_reduce = nullptr;
3068     for (HloInstruction* operand : arg->operands()) {
3069       HloInstruction* new_reduce = computation_->AddInstruction(
3070           HloInstruction::CreateReduce(reduce_result_shape, operand, init_value,
3071                                        reduce->dimensions(), function));
3072       if (old_reduce != nullptr) {
3073         new_reduce = computation_->AddInstruction(HloInstruction::CreateMap(
3074             reduce_result_shape, {old_reduce, new_reduce}, function));
3075       }
3076       old_reduce = new_reduce;
3077     }
3078     return ReplaceInstruction(reduce, old_reduce);
3079   }
3080   return Status::OK();
3081 }
3082 
HandleReduceWindow(HloInstruction * reduce_window)3083 Status AlgebraicSimplifierVisitor::HandleReduceWindow(
3084     HloInstruction* reduce_window) {
3085   if (ShapeUtil::IsZeroElementArray(reduce_window->operand(0)->shape())) {
3086     return ReplaceWithNewInstruction(
3087         reduce_window,
3088         HloInstruction::CreateBroadcast(reduce_window->shape(),
3089                                         reduce_window->mutable_operand(1), {}));
3090   }
3091   auto operand = reduce_window->mutable_operand(0);
3092   const Window& window = reduce_window->window();
3093   auto function = reduce_window->to_apply();
3094   if (ShapeUtil::IsScalar(operand->shape())) {
3095     TF_RET_CHECK(ShapeUtil::IsScalar(reduce_window->shape()));
3096     return ReplaceWithNewInstruction(
3097         reduce_window,
3098         HloInstruction::CreateMap(reduce_window->shape(),
3099                                   {reduce_window->mutable_operand(1), operand},
3100                                   function));
3101   }
3102 
3103   if (options_.enable_window_reduce_to_reduce_replacement()) {
3104     // A reduce window can be expressed as a reduce and a reshape if all
3105     // dimensions either have a window size of one or the entire dimension. If
3106     // there is no stride, dilation, or padding, this is as easy as checking the
3107     // size of the output shape and window dimension.
3108     //
3109     // The reshape is a bitcast since it adds one-sized dimensions. Often these
3110     // ones are immediately removed as well with another reshape. The
3111     // implementation of reduce tends to be slightly more efficient at reducing
3112     // entire dimensions compared to reduce window.
3113     auto effective_reduce_dims = [&] {
3114       if (window_util::HasStride(window) || window_util::HasDilation(window) ||
3115           window_util::HasPadding(window)) {
3116         return absl::InlinedVector<int64, 8>{};
3117       }
3118       absl::InlinedVector<int64, 8> reduce_dims;
3119       for (int64 i = 0; i < window.dimensions_size(); ++i) {
3120         if (window.dimensions(i).size() == 1) {
3121           continue;
3122         } else if (reduce_window->shape().dimensions(i) == 1) {
3123           reduce_dims.push_back(i);
3124         } else {
3125           return absl::InlinedVector<int64, 8>{};
3126         }
3127       }
3128       return reduce_dims;
3129     }();
3130 
3131     // If a reduce window can be expressed as a reduce, do so and reshape the
3132     // output.
3133     if (!effective_reduce_dims.empty()) {
3134       Shape reduce_shape = ShapeUtil::FilterDimensions(
3135           [&](int64 dim) {
3136             return !absl::c_linear_search(effective_reduce_dims, dim);
3137           },
3138           reduce_window->shape());
3139       HloInstruction* reduce =
3140           computation_->AddInstruction(HloInstruction::CreateReduce(
3141               /*shape=*/reduce_shape,
3142               /*operand=*/operand,
3143               /*init_value=*/reduce_window->mutable_operand(1),
3144               /*dimensions_to_reduce=*/effective_reduce_dims,
3145               /*reduce_computation=*/function));
3146       return ReplaceWithNewInstruction(
3147           reduce_window,
3148           HloInstruction::CreateReshape(reduce_window->shape(), reduce));
3149     }
3150   }
3151 
3152   // This optimization folds a pad op into reduce_window.
3153   HloInstruction* pad;
3154   const HloInstruction* convert = nullptr;
3155   if (operand->opcode() == HloOpcode::kPad) {
3156     pad = operand;
3157   } else if (operand->opcode() == HloOpcode::kConvert &&
3158              operand->operand(0)->opcode() == HloOpcode::kPad) {
3159     convert = operand;
3160     pad = operand->mutable_operand(0);
3161   } else {
3162     VLOG(10) << "Not folding pad into reduce-window as there is no pad.";
3163     return Status::OK();
3164   }
3165 
3166   // Bail on dilation.
3167   if (window_util::HasDilation(window)) {
3168     VLOG(10) << "Not folding pad into reduce-window as there is dilation.";
3169     return Status::OK();
3170   }
3171 
3172   VLOG(10) << "Considering folding Pad: " << pad->ToString()
3173            << "\ninto reduce-window: " << reduce_window->ToString()
3174            << (convert != nullptr
3175                    ? absl::StrCat("\nvia convert: ", convert->ToString())
3176                    : "");
3177 
3178   // Do not fold interior padding into ReduceWindow since the backends do not
3179   // support it.
3180   const PaddingConfig& pad_config = pad->padding_config();
3181   if (HasInteriorPadding(pad_config)) {
3182     VLOG(10) << "Not folding pad into reduce-window due to interior padding.";
3183     return Status::OK();
3184   }
3185 
3186   // If reduce_window already has padding, the pad value of the pad op and the
3187   // init value of reduce_window must match to allow folding the pad.
3188   const HloInstruction* pad_value = pad->operand(1);
3189   const HloInstruction* reduce_init_value = reduce_window->operand(1);
3190   if (pad_value != reduce_init_value) {
3191     auto literals_are_equivalent = [&] {
3192       auto& pad_literal = pad_value->literal();
3193       auto& reduce_init_literal = reduce_init_value->literal();
3194       if (pad_literal == reduce_init_literal) {
3195         return true;
3196       }
3197       auto converted_pad_literal =
3198           pad_literal.ConvertToShape(reduce_init_value->shape());
3199       if (!converted_pad_literal.ok()) {
3200         return false;
3201       }
3202       return converted_pad_literal.ValueOrDie() == reduce_init_literal;
3203     };
3204     // The pad value is usually a constant, so we handle that case and do not
3205     // try to get more fancy about proving equivalence in cases beyond that.
3206     if (pad_value->opcode() != HloOpcode::kConstant ||
3207         reduce_init_value->opcode() != HloOpcode::kConstant ||
3208         !literals_are_equivalent()) {
3209       VLOG(10) << "Not folding pad into reduce-window due to different pad "
3210                   "values.";
3211       return Status::OK();
3212     }
3213   }
3214 
3215   // If the pad puts a single non-identity value in each window that we're
3216   // reducing, then this is a broadcast.
3217   HloInstruction* pad_operand = pad->mutable_operand(0);
3218   auto is_effective_broadcast = [&] {
3219     if (window_util::HasStride(window)) {
3220       VLOG(10) << "Window has stride.";
3221       return false;
3222     }
3223     if (!window_util::HasSymmetricPadding(pad_config)) {
3224       VLOG(10) << "Window has uneven padding.";
3225       return false;
3226     }
3227     for (int64 i = 0; i < pad_config.dimensions_size(); ++i) {
3228       const auto& pad_dimension = pad_config.dimensions(i);
3229       if ((pad_dimension.edge_padding_low() != 0 ||
3230            pad_dimension.edge_padding_high() != 0) &&
3231           pad_operand->shape().dimensions(i) != 1) {
3232         VLOG(10) << "Found non-trivial dimension being padded: " << i;
3233         return false;
3234       }
3235     }
3236     VLOG(10) << "Found to be padding trivial dimensions only.";
3237 
3238     for (int64 i = 0; i < window.dimensions_size(); ++i) {
3239       const auto& pad_dimension = pad_config.dimensions(i);
3240       const WindowDimension& window_dimension = window.dimensions(i);
3241       bool dimension_has_padding = (pad_dimension.edge_padding_low() != 0 ||
3242                                     pad_dimension.edge_padding_high() != 0);
3243       if (dimension_has_padding &&
3244           window_dimension.size() < pad_dimension.edge_padding_low() + 1) {
3245         VLOG(10) << "Found window did not cover single unpadded element in "
3246                     "dimension: "
3247                  << i;
3248         return false;
3249       }
3250       if (pad_operand->shape().dimensions(i) != 1 &&
3251           window_dimension.size() != 1) {
3252         VLOG(10) << "Found window covers more than one element in non-trivial "
3253                     "dimension: "
3254                  << i;
3255         return false;
3256       }
3257     }
3258     VLOG(10) << "Found window covers a single unpadded element.";
3259     return true;
3260   };
3261 
3262   HloInstruction* new_reduce_window_operand;
3263   if (convert != nullptr) {
3264     new_reduce_window_operand =
3265         computation_->AddInstruction(HloInstruction::CreateConvert(
3266             ShapeUtil::ChangeElementType(pad_operand->shape(),
3267                                          convert->shape().element_type()),
3268             pad_operand));
3269   } else {
3270     new_reduce_window_operand = pad_operand;
3271   }
3272 
3273   if (is_effective_broadcast()) {
3274     VLOG(10) << "Replacing pad/reduce-window with broadcast.";
3275     auto fadd = [this](std::unique_ptr<HloInstruction> x) {
3276       return computation_->AddInstruction(std::move(x));
3277     };
3278     return ReplaceWithNewInstruction(
3279         reduce_window, HloInstruction::CreateBroadcastSequence(
3280                            /*output_shape=*/reduce_window->shape(),
3281                            /*operand=*/new_reduce_window_operand, fadd));
3282   }
3283 
3284   // Carry out the folding of the pad into reduce_window.
3285   VLOG(10) << "Folding pad into reduce-window.";
3286   Window new_window = window;
3287   const int64 rank = reduce_window->shape().rank();
3288   TF_RET_CHECK(pad_config.dimensions_size() == rank);
3289   TF_RET_CHECK(window.dimensions_size() == rank);
3290   for (int64 i = 0; i < rank; ++i) {
3291     const auto& pad_dim = pad_config.dimensions(i);
3292     auto& window_dim = *new_window.mutable_dimensions(i);
3293     window_dim.set_padding_low(window_dim.padding_low() +
3294                                pad_dim.edge_padding_low());
3295     window_dim.set_padding_high(window_dim.padding_high() +
3296                                 pad_dim.edge_padding_high());
3297   }
3298 
3299   return ReplaceWithNewInstruction(
3300       reduce_window, HloInstruction::CreateReduceWindow(
3301                          /*shape=*/reduce_window->shape(),
3302                          /*operand=*/new_reduce_window_operand,
3303                          /*init_value=*/reduce_window->mutable_operand(1),
3304                          /*window=*/new_window,
3305                          /*reduce_computation=*/function));
3306 }
3307 
HandleSelect(HloInstruction * select)3308 Status AlgebraicSimplifierVisitor::HandleSelect(HloInstruction* select) {
3309   // select(x, y, y) -> y.
3310   if (select->operand(1) == select->operand(2)) {
3311     return ReplaceInstruction(select, select->mutable_operand(1));
3312   }
3313   // select(true, x, y) -> x.
3314   if (IsAll(select->operand(0), true)) {
3315     return ReplaceInstruction(select, select->mutable_operand(1));
3316   }
3317   // select(false, x, y) -> y.
3318   if (IsAll(select->operand(0), false)) {
3319     return ReplaceInstruction(select, select->mutable_operand(2));
3320   }
3321   return Status::OK();
3322 }
3323 
HandleSort(HloInstruction * sort)3324 Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) {
3325   auto operand = sort->mutable_operand(0);
3326   int64 dimension_to_sort = sort->dimensions(0);
3327   if (ShapeUtil::IsZeroElementArray(operand->shape()) ||
3328       operand->shape().dimensions(dimension_to_sort) <= 1) {
3329     if (sort->operand_count() == 1) {
3330       return ReplaceInstruction(sort, operand);
3331     }
3332     // If it is key/value sort, the output of sort is a tuple.
3333     return ReplaceWithNewInstruction(
3334         sort, HloInstruction::CreateTuple(sort->operands()));
3335   }
3336   return Status::OK();
3337 }
3338 
3339 namespace {
OnlyPermutesMoreThanOneDegenerateDim(const Shape & shape,absl::Span<const int64> perm)3340 bool OnlyPermutesMoreThanOneDegenerateDim(const Shape& shape,
3341                                           absl::Span<const int64> perm) {
3342   std::vector<int64> new_permutation;
3343   int64 degenerate_count = 0;
3344   for (int64 i = 0; i < perm.size(); ++i) {
3345     if (shape.dimensions(i) != 1) {
3346       new_permutation.push_back(perm[i]);
3347     } else {
3348       ++degenerate_count;
3349     }
3350   }
3351   return degenerate_count > 1 && absl::c_is_sorted(new_permutation);
3352 }
3353 }  // namespace
3354 
HandleTranspose(HloInstruction * transpose)3355 Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) {
3356   auto operand = transpose->mutable_operand(0);
3357   if (std::is_sorted(transpose->dimensions().begin(),
3358                      transpose->dimensions().end())) {
3359     VLOG(10) << "deleting no-op transpose";
3360     return ReplaceInstruction(transpose, operand);
3361   }
3362 
3363   if (HloOpcode::kTranspose == operand->opcode()) {
3364     return ReplaceWithNewInstruction(
3365         transpose, HloInstruction::CreateTranspose(
3366                        transpose->shape(), operand->mutable_operand(0),
3367                        ComposePermutations(operand->dimensions(),
3368                                            transpose->dimensions())));
3369   }
3370 
3371   // Replace transpose with a reshape if more than one degenerate method is
3372   // permuted.
3373   if (OnlyPermutesMoreThanOneDegenerateDim(transpose->shape(),
3374                                            transpose->dimensions())) {
3375     return ReplaceWithNewInstruction(
3376         transpose, HloInstruction::CreateReshape(
3377                        transpose->shape(), transpose->mutable_operand(0)));
3378   }
3379 
3380   if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) {
3381     *operand->mutable_shape() = transpose->shape();
3382     return ReplaceInstruction(transpose, operand);
3383   }
3384 
3385   if (options_.is_layout_sensitive() && TransposeIsBitcast(transpose)) {
3386     ReplaceWithBitcast(transpose);
3387     return Status::OK();
3388   }
3389 
3390   return Status::OK();
3391 }
3392 
FoldConvInputPad(HloInstruction * convolution)3393 StatusOr<bool> AlgebraicSimplifierVisitor::FoldConvInputPad(
3394     HloInstruction* convolution) {
3395   auto* lhs = convolution->mutable_operand(0);
3396   auto* rhs = convolution->mutable_operand(1);
3397   const auto& window = convolution->window();
3398   const ConvolutionDimensionNumbers& dnums =
3399       convolution->convolution_dimension_numbers();
3400 
3401   if (lhs->opcode() != HloOpcode::kPad) {
3402     return false;
3403   }
3404 
3405   // Convolution's padding is always zero, so bail if the kPad is adding
3406   // something other than zero.
3407   if (!IsAll(lhs->operand(1), 0)) {
3408     return false;
3409   }
3410 
3411   const auto& padding = lhs->padding_config();
3412 
3413   // Can't pad batch or feature dims.
3414   for (int64 dim :
3415        {dnums.input_batch_dimension(), dnums.input_feature_dimension()}) {
3416     const auto& p = padding.dimensions(dim);
3417     if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
3418         p.interior_padding() != 0) {
3419       return false;
3420     }
3421   }
3422 
3423   // Compute the window which is the result of merging the kPad and the
3424   // convolution's existing window.
3425   Window new_window = window;
3426   for (int64 dim = 0; dim < dnums.input_spatial_dimensions_size(); ++dim) {
3427     auto& w = *new_window.mutable_dimensions(dim);
3428     const auto& p = padding.dimensions(dnums.input_spatial_dimensions(dim));
3429     // Edge padding composes with itself in the straightforward way, but
3430     // composing interior padding is nontrivial, and we cowardly refuse to
3431     // think about it. If we see interior padding in either the kPad or conv,
3432     // bail if there's any sort of padding in the other.
3433     if (p.interior_padding() != 0 &&
3434         (w.padding_low() != 0 || w.padding_high() != 0 ||
3435          w.base_dilation() != 1)) {
3436       return false;
3437     }
3438     if (w.base_dilation() != 1 &&
3439         (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
3440          p.interior_padding() != 0)) {
3441       return false;
3442     }
3443 
3444     w.set_padding_low(w.padding_low() + p.edge_padding_low());
3445     w.set_padding_high(w.padding_high() + p.edge_padding_high());
3446     if (p.interior_padding() != 0) {
3447       CHECK_EQ(w.base_dilation(), 1);
3448       w.set_base_dilation(1 + p.interior_padding());
3449     }
3450   }
3451 
3452   auto new_conv = convolution->CloneWithNewOperands(
3453       convolution->shape(), {lhs->mutable_operand(0), rhs});
3454   new_conv->set_window(new_window);
3455   TF_RETURN_IF_ERROR(
3456       ReplaceWithNewInstruction(convolution, std::move(new_conv)));
3457   return true;
3458 }
3459 
FoldConvFilterPad(HloInstruction * convolution)3460 StatusOr<bool> AlgebraicSimplifierVisitor::FoldConvFilterPad(
3461     HloInstruction* convolution) {
3462   auto* lhs = convolution->mutable_operand(0);
3463   auto* rhs = convolution->mutable_operand(1);
3464   const ConvolutionDimensionNumbers& dnums =
3465       convolution->convolution_dimension_numbers();
3466 
3467   if (rhs->opcode() != HloOpcode::kPad) {
3468     return false;
3469   }
3470 
3471   // Convolution's padding is always zero, so bail if the kPad is adding
3472   // something other than zero.
3473   if (!IsAll(rhs->operand(1), 0)) {
3474     return false;
3475   }
3476 
3477   const auto& padding = rhs->padding_config();
3478 
3479   // Can't pad or dilate feature dims.
3480   for (int64 dim : {dnums.kernel_input_feature_dimension(),
3481                     dnums.kernel_output_feature_dimension()}) {
3482     const auto& p = padding.dimensions(dim);
3483     if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
3484         p.interior_padding() != 0) {
3485       return false;
3486     }
3487   }
3488 
3489   // Compute the window which is the result of merging the kPad and the
3490   // convolution's existing window.
3491   Window new_window = convolution->window();
3492   for (int64 dim = 0; dim < dnums.kernel_spatial_dimensions_size(); ++dim) {
3493     auto& w = *new_window.mutable_dimensions(dim);
3494     const auto& p = padding.dimensions(dnums.kernel_spatial_dimensions(dim));
3495 
3496     // We can only do this transformation if p adds dilation to the filter --
3497     // edge padding on the filter is not supported in conv.
3498     if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0) {
3499       return false;
3500     }
3501 
3502     // Nothing to do if the kPad for this dim is entirely a nop.
3503     if (p.interior_padding() == 0) {
3504       continue;
3505     }
3506 
3507     // We cowardly refuse to think about how dilation composes with itself;
3508     // bail if both the kPad and conv have dilation on this dimension.
3509     if (w.window_dilation() > 1) {
3510       return false;
3511     }
3512     CHECK_EQ(w.window_dilation(), 1);
3513     w.set_window_dilation(1 + p.interior_padding());
3514     w.set_size(rhs->operand(0)->shape().dimensions(
3515         dnums.kernel_spatial_dimensions(dim)));
3516   }
3517 
3518   auto new_conv = convolution->CloneWithNewOperands(
3519       convolution->shape(), {lhs, rhs->mutable_operand(0)});
3520   new_conv->set_window(new_window);
3521   TF_RETURN_IF_ERROR(
3522       ReplaceWithNewInstruction(convolution, std::move(new_conv)));
3523   return true;
3524 }
3525 
SimplifyConvToDot(HloInstruction * convolution)3526 StatusOr<bool> AlgebraicSimplifierVisitor::SimplifyConvToDot(
3527     HloInstruction* convolution) {
3528   auto* lhs = convolution->mutable_operand(0);
3529   auto* rhs = convolution->mutable_operand(1);
3530   const auto& window = convolution->window();
3531   const ConvolutionDimensionNumbers& dnums =
3532       convolution->convolution_dimension_numbers();
3533 
3534   if (!options_.enable_conv_simplification()) {
3535     return false;
3536   }
3537 
3538   // TODO(b/31337498): For now, we cowardly refuse to do this optimization in
3539   // layout-insensitive mode, for fear of adding nontrivial reshapes.
3540   if (!options_.is_layout_sensitive()) {
3541     return false;
3542   }
3543 
3544   const Shape& input_shape = lhs->shape();
3545   const Shape& filter_shape = rhs->shape();
3546   const Shape& convolution_shape = convolution->shape();
3547   TF_RET_CHECK(LayoutUtil::HasLayout(input_shape));
3548   TF_RET_CHECK(LayoutUtil::HasLayout(filter_shape));
3549   TF_RET_CHECK(LayoutUtil::HasLayout(convolution_shape));
3550 
3551   // Require the spatial dimensions in the kernel to have a bound of one.
3552   for (int64 i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) {
3553     if (filter_shape.dimensions(dnums.kernel_spatial_dimensions(i)) != 1) {
3554       return false;
3555     }
3556   }
3557 
3558   // Stride ignores part of the output, which matrix multiplication does not do,
3559   // so require no stride. Padding and base (lhs) dilation both implicitly
3560   // extend the data, which matrix multiplication also does not do, so require
3561   // no padding and no base (lhs) dilation. Window (rhs) dilation has no effect
3562   // for a 1x1 window, so window dilation is no problem.
3563   if (window_util::HasStride(window) || window_util::HasPadding(window) ||
3564       window_util::HasBaseDilation(window)) {
3565     return false;
3566   }
3567 
3568   // Also, the shapes must align for a rowmajor matmul:
3569   // - the input and output have the same layout.
3570   // - for input/output, the channel dimension must be the most minor. Other
3571   //   spatial dims can be in any order.
3572   // - for filters, the input channel dimension must be more major than the
3573   //   output channel dimension. The width+height don't matter because
3574   //   they are 1.
3575   //
3576   // These constraints are harsh. If the channel dimension is the most major
3577   // and/or the layout of input/output feature dimensions are reversed, we can
3578   // still convert Conv into more efficient Matmul with operand transposition
3579   // (such as the transposition flags in cuBLAS SGEMM).
3580   if (!LayoutUtil::Equal(input_shape.layout(), convolution_shape.layout()) ||
3581       LayoutUtil::Minor(input_shape.layout(), 0) !=
3582           dnums.input_feature_dimension() ||
3583       LayoutUtil::Minor(convolution_shape.layout(), 0) !=
3584           dnums.output_feature_dimension() ||
3585       // The input feature dimension should come later in the minor-to-major
3586       // order.
3587       (PositionInContainer(LayoutUtil::MinorToMajor(filter_shape),
3588                            dnums.kernel_input_feature_dimension()) <
3589        PositionInContainer(LayoutUtil::MinorToMajor(filter_shape),
3590                            dnums.kernel_output_feature_dimension()))) {
3591     return false;
3592   }
3593 
3594   auto add_bitcast = [&](Shape shape, HloInstruction* operand) {
3595     std::vector<int64> dims(operand->shape().dimensions_size());
3596     std::iota(dims.begin(), dims.end(), 0);
3597     return computation_->AddInstruction(
3598         HloInstruction::CreateUnary(shape, HloOpcode::kBitcast, operand));
3599   };
3600 
3601   // Replace it with a dot, with bitcasts around it to get the right shape.
3602   const int64 input_channels =
3603       input_shape.dimensions(dnums.input_feature_dimension());
3604   const int64 output_channels =
3605       filter_shape.dimensions(dnums.kernel_output_feature_dimension());
3606 
3607   // Computes the product of the non-feature dimensions.
3608   int64 conv_width = 1;
3609   for (int i = 0; i < input_shape.dimensions_size(); ++i) {
3610     if (i != dnums.input_feature_dimension()) {
3611       conv_width *= input_shape.dimensions(i);
3612     }
3613   }
3614 
3615   // We already checked feature_dimension is most minor, so data in input_shape
3616   // and row-major {conv_width,input_channels} are bitwise identical.
3617   const Shape new_input_shape = ShapeUtil::MakeShapeWithDescendingLayout(
3618       input_shape.element_type(), {conv_width, input_channels});
3619   // We already checked input_feature_dimension is more major than
3620   // output_feature_dimension, so data in filter_shape and row-major
3621   // {input_channels,output_channels} are bitwise identical.
3622   const Shape new_filter_shape = ShapeUtil::MakeShapeWithDescendingLayout(
3623       filter_shape.element_type(), {input_channels, output_channels});
3624   const Shape dot_output_shape = ShapeUtil::MakeShapeWithDescendingLayout(
3625       convolution_shape.element_type(), {conv_width, output_channels});
3626 
3627   auto new_lhs = add_bitcast(new_input_shape, lhs);
3628   auto new_rhs = add_bitcast(new_filter_shape, rhs);
3629   DotDimensionNumbers dot_dimension_numbers;
3630   dot_dimension_numbers.add_lhs_contracting_dimensions(1);
3631   dot_dimension_numbers.add_rhs_contracting_dimensions(0);
3632   auto dot = computation_->AddInstruction(HloInstruction::CreateDot(
3633       dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers,
3634       convolution->precision_config()));
3635 
3636   TF_RETURN_IF_ERROR(
3637       ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot)));
3638   return true;
3639 }
3640 
HandleConvolution(HloInstruction * convolution)3641 Status AlgebraicSimplifierVisitor::HandleConvolution(
3642     HloInstruction* convolution) {
3643   // Zero-sized input or filter.
3644   if (ShapeUtil::IsZeroElementArray(convolution->operand(0)->shape()) ||
3645       ShapeUtil::IsZeroElementArray(convolution->operand(1)->shape())) {
3646     return ReplaceWithNewInstruction(
3647         convolution,
3648         HloInstruction::CreateBroadcast(
3649             convolution->shape(),
3650             computation_->AddInstruction(HloInstruction::CreateConstant(
3651                 LiteralUtil::Zero(convolution->shape().element_type()))),
3652             {}));
3653   }
3654 
3655   // Try to merge padding/dilation of the input with the convolution's window.
3656   TF_ASSIGN_OR_RETURN(bool folded_input_pad, FoldConvInputPad(convolution));
3657   if (folded_input_pad) {
3658     return Status::OK();
3659   }
3660 
3661   // Try to merge dilation of the filter with the convolution's window.
3662   TF_ASSIGN_OR_RETURN(bool folded_filter_pad, FoldConvFilterPad(convolution));
3663   if (folded_filter_pad) {
3664     return Status::OK();
3665   }
3666 
3667   // Try to replace the convolution with a kDot instruction.
3668   TF_ASSIGN_OR_RETURN(bool replaced_with_dot, SimplifyConvToDot(convolution));
3669   if (replaced_with_dot) {
3670     return Status::OK();
3671   }
3672 
3673   return Status::OK();
3674 }
3675 
TransformToClampIfSameShape(HloInstruction * root,HloInstruction * min,HloInstruction * min_operand,HloInstruction * operand,HloInstruction * max,HloInstruction * max_operand)3676 bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape(
3677     HloInstruction* root, HloInstruction* min, HloInstruction* min_operand,
3678     HloInstruction* operand, HloInstruction* max, HloInstruction* max_operand) {
3679   // Ensure shapes of min and max operand are equal to match current shape
3680   // inference.
3681   if (!SameShape(min_operand, max_operand)) {
3682     return false;
3683   }
3684 
3685   auto clamp = HloInstruction::CreateTernary(root->shape(), HloOpcode::kClamp,
3686                                              max_operand, operand, min_operand);
3687   TF_CHECK_OK(ReplaceWithNewInstruction(root, std::move(clamp)));
3688   return true;
3689 }
3690 
HandleMap(HloInstruction * map)3691 Status AlgebraicSimplifierVisitor::HandleMap(HloInstruction* map) {
3692   auto* map_computation = map->to_apply();
3693   auto* map_root = map_computation->root_instruction();
3694   if (map_root->opcode() == HloOpcode::kParameter) {
3695     ReplaceInstructionIfSameShape(
3696         map, map->mutable_operand(map_root->parameter_number()));
3697     return Status::OK();
3698   }
3699   if (map_root->opcode() == HloOpcode::kConstant) {
3700     if (!ShapeUtil::IsScalar(map_root->shape())) {
3701       return Status::OK();
3702     }
3703     auto clone = map_root->CloneWithNewOperands(map_root->shape(), {});
3704     if (ShapeUtil::IsScalar(map->shape())) {
3705       return ReplaceWithNewInstruction(map, std::move(clone));
3706     }
3707     return ReplaceWithNewInstruction(
3708         map,
3709         HloInstruction::CreateBroadcast(
3710             map->shape(), computation_->AddInstruction(std::move(clone)), {}));
3711   }
3712   // Inline the map if the map computation only contains an elementwise
3713   // operation that can accept arbitrary shapes.
3714   if (map_root->opcode() == HloOpcode::kFusion || !map_root->IsElementwise()) {
3715     return Status::OK();
3716   }
3717   std::vector<HloInstruction*> new_operands;
3718   for (auto* root_operand : map_root->operands()) {
3719     if (root_operand->opcode() != HloOpcode::kParameter) {
3720       return Status::OK();
3721     }
3722     new_operands.push_back(
3723         map->mutable_operand(root_operand->parameter_number()));
3724   }
3725   auto clone = map_root->CloneWithNewOperands(map->shape(), new_operands);
3726   return ReplaceWithNewInstruction(map, std::move(clone));
3727 }
3728 
Run(HloModule * module)3729 StatusOr<bool> AlgebraicSimplifier::Run(HloModule* module) {
3730   XLA_VLOG_LINES(2,
3731                  "AlgebraicSimplifier::Run(), before:\n" + module->ToString());
3732   bool changed = false;
3733   for (auto* comp : module->MakeNonfusionComputations()) {
3734     if (AlgebraicSimplifierVisitor::Run(comp, options_)) {
3735       changed = true;
3736     }
3737   }
3738   XLA_VLOG_LINES(2,
3739                  "AlgebraicSimplifier::Run(), after:\n" + module->ToString());
3740   return changed;
3741 }
3742 
3743 }  // namespace xla
3744