1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
17 
18 #include <numeric>
19 
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "mlir/IR/Attributes.h"  // from @llvm-project
23 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
24 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
25 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
26 #include "mlir/IR/Matchers.h"  // from @llvm-project
27 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
28 #include "mlir/IR/TypeRange.h"  // from @llvm-project
29 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h"
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
35 #include "tensorflow/core/util/tensor_format.h"
36 
37 namespace mlir {
38 namespace TF {
39 namespace {
40 
41 // Returns 1D 64-bit dense elements attribute with the given values.
GetI64ElementsAttr(ArrayRef<int64_t> values,Builder * builder)42 static DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values,
43                                                Builder *builder) {
44   RankedTensorType ty = RankedTensorType::get(
45       {static_cast<int64_t>(values.size())}, builder->getIntegerType(64));
46   return DenseIntElementsAttr::get(ty, values);
47 }
48 
49 // Returns a 1-d i64 elements attribute populated with numbers from start to
50 // end, excluding.
GetI64ElementsAttrForSeq(int start,int end,Builder * builder)51 static DenseIntElementsAttr GetI64ElementsAttrForSeq(int start, int end,
52                                                      Builder *builder) {
53   int size = end - start;
54 
55   SmallVector<int64_t, 4> vals;
56   vals.resize(size);
57   std::iota(vals.begin(), vals.end(), start);
58 
59   TensorType ty = RankedTensorType::get({size}, builder->getIntegerType(64));
60   return DenseIntElementsAttr::get(ty, vals);
61 }
62 
ConvertToAPFloat(double val,Type type)63 static APFloat ConvertToAPFloat(double val, Type type) {
64   if (type.getIntOrFloatBitWidth() == 32) {
65     return APFloat(static_cast<float>(val));
66   }
67 
68   return APFloat(val);
69 }
70 
71 // Returns int, float, or complex DenseElementsAttr with scalar shape with the
72 // given element type and the value.
73 template <typename T>
GetScalarOfType(Type ty,T raw_value)74 static DenseElementsAttr GetScalarOfType(Type ty, T raw_value) {
75   RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
76   if (auto float_ty = ty.dyn_cast_or_null<FloatType>()) {
77     FloatAttr attr = FloatAttr::get(float_ty, raw_value);
78     return DenseElementsAttr::get(scalar_ty, attr);
79   } else if (auto int_ty = ty.dyn_cast_or_null<IntegerType>()) {
80     IntegerAttr attr = IntegerAttr::get(int_ty, raw_value);
81     return DenseElementsAttr::get(scalar_ty, attr);
82   } else if (auto complex_ty = ty.dyn_cast_or_null<ComplexType>()) {
83     Type complex_element_ty = complex_ty.getElementType();
84     if (complex_element_ty.isF32()) {
85       return DenseElementsAttr::get(
86           scalar_ty, static_cast<std::complex<float>>(raw_value));
87     } else if (complex_element_ty.isF64()) {
88       return DenseElementsAttr::get(
89           scalar_ty, static_cast<std::complex<double>>(raw_value));
90     }
91   }
92   llvm_unreachable("unsupported type");
93 }
94 
95 // Returns reduction indices to use while lowering tf.BiasAddGrad op to tf.Sum
96 // op.
GetBiasAddGradReductionIndices(int64_t rank,StringAttr data_format,Builder * builder)97 DenseIntElementsAttr GetBiasAddGradReductionIndices(int64_t rank,
98                                                     StringAttr data_format,
99                                                     Builder *builder) {
100   tensorflow::TensorFormat format;
101   if (!FormatFromString(data_format.getValue().str(), &format)) return {};
102 
103   // Reduce along all dimensions except the feature dimension.
104   int64_t feature_dim = GetTensorFeatureDimIndex(rank, format);
105   llvm::SmallVector<int64_t, 4> dims_to_reduce(rank - 1);
106   std::iota(dims_to_reduce.begin(), dims_to_reduce.begin() + feature_dim, 0);
107   std::iota(dims_to_reduce.begin() + feature_dim, dims_to_reduce.end(),
108             feature_dim + 1);
109   return GetI64ElementsAttr(dims_to_reduce, builder);
110 }
111 
112 #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_lower_tf.inc"
113 
114 // Infers ExpandDims op output type for the given input type `ty` and dimension
115 // to expand at the given `axis`.
InferExpandDimsType(Type ty,int64_t axis,Builder * builder)116 Type InferExpandDimsType(Type ty, int64_t axis, Builder *builder) {
117   auto ranked_ty = ty.dyn_cast<RankedTensorType>();
118 
119   // Unranked type.
120   if (!ranked_ty) return ty;
121 
122   auto shape = llvm::to_vector<4>(ranked_ty.getShape());
123   if (axis < 0) axis += ranked_ty.getRank() + 1;
124 
125   shape.insert(shape.begin() + axis, 1);
126   return RankedTensorType::get(shape, ranked_ty.getElementType());
127 }
128 
129 // Converts individual Values to a tensor of rank 1. Each input Value has rank 1
130 // and size 1.
ValuesToRank1(PatternRewriter & rewriter,Location loc,Type dtype,ArrayRef<Value> vals)131 Value ValuesToRank1(PatternRewriter &rewriter, Location loc, Type dtype,
132                     ArrayRef<Value> vals) {
133   int64_t length = vals.size();
134   auto type = RankedTensorType::get({length}, dtype);
135   auto axis = rewriter.create<ConstOp>(
136       loc, GetScalarOfType(rewriter.getIntegerType(64), 0));
137   return rewriter.create<ConcatV2Op>(loc, type, ValueRange(vals), axis);
138 }
139 
140 // Lowers AddN op to a sequence of AddV2 ops to accumulate operands.
141 //
142 // Note that to improve the parallelism, AddN op uses tree-based reduction.
143 // For example, tf.AddN([0, 1, 2, 3, 4]) behaves as follows:
144 //
145 //                 0     1     2     3     4
146 //                 |     |     |     |     |
147 //                 -------     -------     |
148 //                    |           |        |
149 //                    5           6        |
150 //                    |           |        |
151 //                    -------------        |
152 //                          |              |
153 //                          7              |
154 //                          |              |
155 //                          ----------------
156 //                                 |
157 //                                 8
158 //
159 // Example:
160 //
161 //   %result = "tf.AddN"(%0, %1, %2)
162 //
163 // is lowered to:
164 //
165 //   %sum0 = "tf.AddV2"(%0, %1)
166 //   %result = "tf.AddV2"(%sum0, %2)
167 //
168 // While
169 //
170 //   %result = "tf.AddN"(%0, %1, %2, %3, %4)
171 //
172 // is lowered to:
173 //
174 //   %sum0 = "tf.AddV2"(%0, %1)
175 //   %sum1 = "tf.AddV2"(%2, %3)
176 //   %sum2 = "tf.AddV2"(%sum0, %sum1)
177 //   %result = "tf.AddV2"(%sum2, %4)
178 //
179 class LowerAddNOp : public RewritePattern {
180  public:
LowerAddNOp(MLIRContext * context)181   explicit LowerAddNOp(MLIRContext *context)
182       : RewritePattern(AddNOp::getOperationName(),
183                        {AddV2Op::getOperationName()}, 1, context) {}
184 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const185   LogicalResult matchAndRewrite(Operation *op,
186                                 PatternRewriter &rewriter) const override {
187     auto addn_op = cast<AddNOp>(op);
188 
189     // TODO(hinsu): Support variant with TensorList type. tf.AddV2 doesn't
190     // support variant type so variant types require special handling.
191     if (getElementTypeOrSelf(addn_op.getType()).isa<VariantType>())
192       return failure();
193     llvm::SmallVector<Value, 4> operands(addn_op.inputs().begin(),
194                                          addn_op.inputs().end());
195 
196     int64_t n = operands.size();
197     // Keep doing tree-based reduction when there are more than one operand.
198     while (n > 1) {
199       for (int64_t i = 0; i < n; i += 2) {
200         // Add two adjacent operands if applicable.
201         operands[i / 2] =
202             (i + 1 < n) ? rewriter.create<AddV2Op>(addn_op.getLoc(),
203                                                    operands[i], operands[i + 1])
204                         : operands[i];
205       }
206       n = (n + 1) / 2;
207     }
208 
209     rewriter.replaceOp(addn_op, operands[0]);
210     return success();
211   }
212 };
213 
214 // Lowers DynamicStitch op with constant indices and with static input and
215 // output shapes using Reshape, UnPack and Pack op.
216 //
217 //   %indices0 = "tf.Const"() {value = dense<4> : tensor<i32>}
218 //   %indices1 = "tf.Const"() {value = dense<[[3, 2], [1, 0]]> :
219 //   tensor<2x2xi32>} %0 = "tf.DynamicStitch"(%indices0, %indices1, %arg0,
220 //   %arg1)
221 //     : (tensor<i32>, tensor<2x2xi32>, tensor<2xf32>, tensor<2x2x2xf32>)
222 //     -> tensor<5x2xf32>
223 //
224 // is lowered to
225 //
226 //   %shape = "tf.Const"() {value = dense<[-1, 2]> : tensor<2xi64>}
227 //   %inp0 = "tf.Reshape"(%arg0, %shape)
228 //     : (tensor<2xf32>, tensor<2xi64>) -> tensor<1x2xf32>
229 //   %inp1 = "tf.Reshape"(%arg1, %shape)
230 //     : (tensor<2x2x2xf32>, tensor<2xi64>) -> tensor<4x2xf32>
231 //   %items0 = "tf.Unpack"(%[[INP0]]) {axis = 0 : i64}
232 //     : (tensor<1x2xf32>) -> tensor<2xf32>
233 //   %items1:4 = "tf.Unpack"(%[[INP1]]) {axis = 0 : i64}
234 //     : (tensor<4x2xf32>) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>,
235 //     tensor<2xf32>)
236 //   %axis = "tf.Const"() {value = dense<0> : tensor<i64>}
237 //   %0 = "tf.Pack"(items1#3, items1#2, items1#1, items1#0, %items0, %axis)
238 //     : (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>,
239 //        tensor<2xf32>, tensor<i64>) -> tensor<5x2xf32>
240 //
241 template <typename OpT>
242 class LowerDynamicStitchOp : public RewritePattern {
243  public:
LowerDynamicStitchOp(MLIRContext * context)244   explicit LowerDynamicStitchOp(MLIRContext *context)
245       : RewritePattern(
246             OpT::getOperationName(),
247             {ConstOp::getOperationName(), ReshapeOp::getOperationName(),
248              UnpackOp::getOperationName(), PackOp::getOperationName()},
249             1, context) {}
250 
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const251   LogicalResult matchAndRewrite(Operation *src_op,
252                                 PatternRewriter &rewriter) const override {
253     auto op = cast<OpT>(src_op);
254 
255     // Static output type is used to compute intermediate values. Note that the
256     // output type doesn't have to be static but if input types and indices are
257     // constant, then the output type can be statically determined.
258     RankedTensorType out_ty =
259         op.getType().template dyn_cast<RankedTensorType>();
260     if (!out_ty || !out_ty.hasStaticShape()) return failure();
261 
262     // Extract out all the constant indices' attributes and verify that data
263     // types are static.
264     SmallVector<DenseIntElementsAttr, 4> indices;
265     indices.reserve(op.N());
266     for (auto it : llvm::zip(op.indices(), op.data())) {
267       Value index = std::get<0>(it);
268       Value data = std::get<1>(it);
269 
270       DenseIntElementsAttr index_attr;
271       if (!matchPattern(index, m_Constant(&index_attr))) return failure();
272       indices.push_back(index_attr);
273 
274       RankedTensorType data_ty =
275           data.getType().template dyn_cast<RankedTensorType>();
276       if (!data_ty || !data_ty.hasStaticShape()) return failure();
277     }
278 
279     // Compute type of each of the items and shape to use while reshaping inputs
280     // so that they can be unpacked to extract out individual items.
281     ArrayRef<int64_t> item_shape = out_ty.getShape().drop_front(1);
282     auto item_ty = RankedTensorType::get(item_shape, out_ty.getElementType());
283 
284     SmallVector<int64_t, 4> packed_shape;
285     packed_shape.push_back(-1);
286     packed_shape.append(item_shape.begin(), item_shape.end());
287     Location loc = op.getLoc();
288     auto packed_shape_val = rewriter.create<ConstOp>(
289         loc, GetI64ElementsAttr(packed_shape, &rewriter));
290 
291     // Prepare each of the output item by unpacking data and then putting it to
292     // the specified index.
293     SmallVector<Value, 8> values(out_ty.getDimSize(0));
294     for (auto it : llvm::zip(indices, op.data())) {
295       DenseIntElementsAttr index_attr = std::get<0>(it);
296       Value data = std::get<1>(it);
297 
298       auto reshaped_data =
299           rewriter.create<ReshapeOp>(loc, data, packed_shape_val);
300       auto num_items = reshaped_data.getType()
301                            .template cast<RankedTensorType>()
302                            .getShape()[0];
303       auto items = rewriter.create<UnpackOp>(
304           loc, SmallVector<Type, 4>(num_items, item_ty), reshaped_data,
305           /*axis=*/0);
306       for (auto index_item : llvm::zip(index_attr, items.getResults())) {
307         int64_t output_index = std::get<0>(index_item).getSExtValue();
308         Value item = std::get<1>(index_item);
309         values[output_index] = item;
310       }
311     }
312 
313     rewriter.replaceOpWithNewOp<PackOp>(op, op.getType(), values);
314     return success();
315   }
316 };
317 
318 // This pass performs a manual conversion with FakeQuant, converting between
319 // floating point and quantized space. It is designed to reproduce TF's
320 // implementation, mirroring the previous XLA implementation.
321 //
322 // 1. Computing proper quantized bounds. This involves nudging the input bounds.
323 // 2. Converting the input bounds to quantized space, rounding values.
324 // 3. Convert back into floating point space.
325 class ConvertFakeQuantWithMinMaxVarsOp : public RewritePattern {
326  public:
ConvertFakeQuantWithMinMaxVarsOp(MLIRContext * context)327   explicit ConvertFakeQuantWithMinMaxVarsOp(MLIRContext *context)
328       : RewritePattern(
329             FakeQuantWithMinMaxVarsOp::getOperationName(),
330             {AddV2Op::getOperationName(), SubOp::getOperationName(),
331              ConstOp::getOperationName(), MulOp::getOperationName(),
332              FloorOp::getOperationName(), ClipByValueOp::getOperationName(),
333              DivOp::getOperationName(), RoundOp::getOperationName()},
334             1, context) {}
335 
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const336   LogicalResult matchAndRewrite(Operation *src_op,
337                                 PatternRewriter &rewriter) const override {
338     auto op = cast<FakeQuantWithMinMaxVarsOp>(src_op);
339 
340     auto input = op.inputs();
341     auto input_ty = input.getType().cast<ShapedType>();
342     auto element_ty = input_ty.getElementType();
343     auto scalar_ty = RankedTensorType::get({}, element_ty);
344 
345     auto num_bits = op.num_bits();
346     auto narrow_range = op.narrow_range();
347     const double bits_min = narrow_range ? 1 : 0;
348     const double bits_max = (1 << num_bits) - 1;
349 
350     auto float_min = op.min();
351     auto float_max = op.max();
352 
353     auto float_diff = rewriter.create<SubOp>(op.getLoc(), float_max, float_min);
354 
355     // Compute the range when quantized.
356     auto quant_min = rewriter.create<ConstOp>(
357         op.getLoc(), DenseElementsAttr::get(
358                          scalar_ty, ConvertToAPFloat(bits_min, element_ty)));
359 
360     auto quant_max = rewriter.create<ConstOp>(
361         op.getLoc(), DenseElementsAttr::get(
362                          scalar_ty, ConvertToAPFloat(bits_max, element_ty)));
363 
364     auto quant_diff = rewriter.create<ConstOp>(
365         op.getLoc(),
366         DenseElementsAttr::get(
367             scalar_ty, ConvertToAPFloat(bits_max - bits_min, element_ty)));
368 
369     auto quant_to_float =
370         rewriter.create<DivOp>(op.getLoc(), float_diff, quant_diff);
371 
372     auto float_to_quant =
373         rewriter.create<DivOp>(op.getLoc(), quant_diff, float_diff);
374 
375     // During quantization, the quantized min/max values may not line up
376     // perfectly with the specified min/max. Nudge them into the right range.
377     auto min_scaled =
378         rewriter.create<DivOp>(op.getLoc(), float_min, quant_to_float);
379     auto min_scaled_sub =
380         rewriter.create<SubOp>(op.getLoc(), quant_min, min_scaled);
381 
382     auto mid_rounded =
383         rewriter.create<RoundOp>(op.getLoc(), scalar_ty, min_scaled_sub);
384 
385     auto nudged_zero_point_val = rewriter.create<ClipByValueOp>(
386         op.getLoc(), scalar_ty, mid_rounded, quant_min, quant_max);
387 
388     auto quant_min_sub =
389         rewriter.create<SubOp>(op.getLoc(), quant_min, nudged_zero_point_val);
390     auto quant_max_sub =
391         rewriter.create<SubOp>(op.getLoc(), quant_max, nudged_zero_point_val);
392 
393     auto nudged_float_min =
394         rewriter.create<MulOp>(op.getLoc(), quant_min_sub, quant_to_float);
395 
396     auto nudged_float_max =
397         rewriter.create<MulOp>(op.getLoc(), quant_max_sub, quant_to_float);
398 
399     // Now quantize the input value with the approximated min/max values.
400 
401     // Move the input value into quantized space
402     Value quantized_input = rewriter.create<ClipByValueOp>(
403         op.getLoc(), input_ty, input, nudged_float_min, nudged_float_max);
404 
405     quantized_input = rewriter.create<SubOp>(op.getLoc(), input_ty,
406                                              quantized_input, nudged_float_min);
407 
408     quantized_input = rewriter.create<MulOp>(op.getLoc(), input_ty,
409                                              quantized_input, float_to_quant);
410 
411     // Round the quantized input always to the positive direction.
412     auto half_val = rewriter.create<ConstOp>(
413         op.getLoc(),
414         DenseElementsAttr::get(scalar_ty, ConvertToAPFloat(0.5, element_ty)));
415 
416     quantized_input = rewriter.create<AddV2Op>(op.getLoc(), input_ty,
417                                                quantized_input, half_val);
418 
419     quantized_input = rewriter.create<FloorOp>(op.getLoc(), quantized_input);
420 
421     // Convert back into floating point spae.
422     Value output = rewriter.create<MulOp>(op.getLoc(), input_ty,
423                                           quantized_input, quant_to_float);
424 
425     output = rewriter.create<AddV2Op>(op.getLoc(), input_ty, output,
426                                       nudged_float_min);
427 
428     rewriter.replaceOp(op, {output});
429     return success();
430   }
431 };
432 
433 // Lowers InvertPermutation op to TensorScatterUpdate op.
434 //
435 // Example:
436 //
437 //   %x = "tf.Const"() {value = dense<[3, 4, 0, 1, 2]> : tensor<5xi32>}
438 //   "tf.InvertPermutation"(%x) : (tensor<5xi32>) -> tensor<5xi32>
439 //
440 // is lowered to
441 //
442 //   %x = "tf.Const"() {value = dense<[3, 4, 0, 1, 2]> : tensor<5xi32>}
443 //   %start = "tf.Const"() {value = dense<0> : tensor<i32>}
444 //   %limit = "tf.Const"() {value = dense<5> : tensor<i32>}
445 //   %delta = "tf.Const"() {value = dense<1> : tensor<i32>}
446 //   %updates = "tf.Range"(%start, %limit, %delta) :
447 //     (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<5xi32>
448 //   %shape = "tf.Const"() {value = dense<[5, 1]> : tensor<2xi32>}
449 //   %indices = "tf.Reshape"(%x, %shape) : (tensor<5xi32, tensor<2xi32) ->
450 //     tensor<5x1xi32>
451 //   "tf.TensorScatterUpdate"(%x, %indices, %updates) :
452 //     (tensor<5xi32>, tensor<5x1xi32>, tensor<5xi32>) -> tensor<5xi32>
453 //
454 class LowerInvertPermutationOp : public RewritePattern {
455  public:
LowerInvertPermutationOp(MLIRContext * context)456   explicit LowerInvertPermutationOp(MLIRContext *context)
457       : RewritePattern(
458             InvertPermutationOp::getOperationName(),
459             {ConstOp::getOperationName(), RangeOp::getOperationName(),
460              ReshapeOp::getOperationName(),
461              TensorScatterUpdateOp::getOperationName()},
462             1, context) {}
463 
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const464   LogicalResult matchAndRewrite(Operation *src_op,
465                                 PatternRewriter &rewriter) const override {
466     auto op = cast<InvertPermutationOp>(src_op);
467 
468     Location loc = op.getLoc();
469     auto x_type = op.x().getType().dyn_cast<RankedTensorType>();
470     // x input must have static shape.
471     if (!x_type || !x_type.hasStaticShape()) {
472       return failure();
473     }
474     Type int_type = x_type.getElementType();  // Could be i32 or i64.
475 
476     auto result_type = x_type;
477     auto start = rewriter.create<ConstOp>(loc, GetScalarOfType(int_type, 0));
478     Value limit = rewriter.create<ConstOp>(
479         loc, GetScalarOfType(int_type, x_type.getShape()[0]));
480     auto delta = rewriter.create<ConstOp>(loc, GetScalarOfType(int_type, 1));
481     // Construct a sequence of numbers [0, 1, ... len(x)-1].
482     auto updates =
483         rewriter.create<RangeOp>(loc, result_type, start, limit, delta);
484 
485     auto shape_type = RankedTensorType::get({2}, rewriter.getIntegerType(32));
486     auto shape = rewriter.create<ConstOp>(
487         loc, DenseElementsAttr::get(
488                  shape_type, {static_cast<int>(x_type.getDimSize(0)), 1}));
489     auto indices = rewriter.create<ReshapeOp>(loc, op.x(), shape);
490 
491     rewriter.replaceOpWithNewOp<TensorScatterUpdateOp>(op, result_type, op.x(),
492                                                        indices, updates);
493     return success();
494   }
495 };
496 
497 // Approximates lgamma using Lanczos' approximation from
498 // "A Precision Approximation of the Gamma Function". SIAM Journal on Numerical
499 // Analysis series B. Vol. 1:
500 // lgamma(z + 1) = (log(2) + log(pi)) / 2 + (z + 1/2) * log(t(z)) - t(z) + A(z)
501 // t(z) = z + kLanczosGamma + 1/2
502 // A(z) = kBaseLanczosCoeff
503 //       + sigma(k = 1, n, kLanczosCoefficients[i] / (z +  k))
504 //
505 // Coefficients for the Lanczos approximation of the gamma function. The
506 // coefficients are uniquely determined by the choice of g and n
507 // (kLanczosGamma and kLanczosCoefficients.size() + 1). The coefficients below
508 // correspond to [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were
509 // evaluated and [7, 9] seemed to be the least sensitive to the quality of the
510 // log function. In particular, [5, 7] is the only choice where -1.5e-5 <=
511 // lgamma(2) <= 1.5e-5 for a particularly inaccurate log function.
512 static constexpr double kLanczosGamma = 7;  // aka g
513 static constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478;
514 static constexpr std::array<double, 8> kLanczosCoefficients = {
515     676.520368121885098567009190444019, -1259.13921672240287047156078755283,
516     771.3234287776530788486528258894,   -176.61502916214059906584551354,
517     12.507343278686904814458936853,     -0.13857109526572011689554707,
518     9.984369578019570859563e-6,         1.50563273514931155834e-7};
519 
520 class LowerLgammaOp : public RewritePattern {
521  public:
LowerLgammaOp(MLIRContext * context)522   explicit LowerLgammaOp(MLIRContext *context)
523       : RewritePattern(LgammaOp::getOperationName(),
524                        {
525                            CastOp::getOperationName(),
526                            ConstOp::getOperationName(),
527                            NegOp::getOperationName(),
528                            SubOp::getOperationName(),
529                            SelectV2Op::getOperationName(),
530                            LessOp::getOperationName(),
531                            AddV2Op::getOperationName(),
532                            DivOp::getOperationName(),
533                            SubOp::getOperationName(),
534                            LogOp::getOperationName(),
535                            Log1pOp::getOperationName(),
536                            IsInfOp::getOperationName(),
537                            MulOp::getOperationName(),
538                            FloorOp::getOperationName(),
539                            AbsOp::getOperationName(),
540                            GreaterOp::getOperationName(),
541                            SinOp::getOperationName(),
542                            IsFiniteOp::getOperationName(),
543                        },
544                        1, context) {}
545 
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const546   LogicalResult matchAndRewrite(Operation *src_op,
547                                 PatternRewriter &rewriter) const override {
548     auto op = cast<LgammaOp>(src_op);
549 
550     Location loc = op.getLoc();
551     Value input = op.x();
552     TensorType original_tensor_type = op.x().getType().cast<TensorType>();
553 
554     // The approximation is not precise enough for float16. Do the computation
555     // in float32 for that case.
556     TensorType tensor_type = original_tensor_type;
557     FloatType float_type = tensor_type.getElementType().cast<FloatType>();
558     bool needs_cast = float_type.getWidth() < 32;
559     if (needs_cast) {
560       MLIRContext *context = rewriter.getContext();
561       float_type = FloatType::getF32(context);
562       if (original_tensor_type.hasRank()) {
563         tensor_type =
564             RankedTensorType::get(original_tensor_type.getShape(), float_type);
565       } else {
566         tensor_type = UnrankedTensorType::get(float_type);
567       }
568       input = rewriter.create<CastOp>(loc, tensor_type, input);
569     }
570 
571     // Helper lambda function for creating a ConstOp for a tensor filled with
572     // the given constant float value.
573     auto create_const_op = [&rewriter, loc, tensor_type,
574                             float_type](double value) {
575       return rewriter.create<ConstOp>(
576           loc, DenseElementsAttr::get(tensor_type,
577                                       FloatAttr::get(float_type, value)));
578     };
579 
580     Value one_half = create_const_op(0.5);
581     Value one = create_const_op(1.0);
582     Value infinity = create_const_op(std::numeric_limits<double>::infinity());
583     Value pi = create_const_op(M_PI);
584     Value log_pi = create_const_op(std::log(M_PI));
585     Value log_sqrt_two_pi = create_const_op((std::log(2) + std::log(M_PI)) / 2);
586     Value lanczos_gamma_plus_one_half = create_const_op(kLanczosGamma + 0.5);
587     Value log_lanczos_gamma_plus_one_half =
588         create_const_op(std::log(kLanczosGamma + 0.5));
589     Value base_lanczos_coeff = create_const_op(kBaseLanczosCoeff);
590 
591     Value minus_input = rewriter.create<NegOp>(loc, input);
592     Value input_minus_one = rewriter.create<SubOp>(loc, input, one);
593 
594     // If the input is less than 0.5 use Euler's reflection formula:
595     // gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
596     Value need_to_reflect = rewriter.create<LessOp>(loc, input, one_half);
597     Type tensor_bool_type = need_to_reflect.getType();
598     Value z = rewriter.create<SelectV2Op>(loc, need_to_reflect, minus_input,
599                                           input_minus_one);
600 
601     Value x = base_lanczos_coeff;
602     for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
603       Value lanczos_coefficient = create_const_op(kLanczosCoefficients[i]);
604       Value index = create_const_op(static_cast<double>(i));
605       Value z_plus_index = rewriter.create<AddV2Op>(loc, z, index);
606       Value z_plus_index_plus_one =
607           rewriter.create<AddV2Op>(loc, z_plus_index, one);
608       Value incr = rewriter.create<DivOp>(loc, lanczos_coefficient,
609                                           z_plus_index_plus_one);
610       x = rewriter.create<AddV2Op>(loc, x, incr);
611     }
612 
613     // To improve accuracy on platforms with less-precise log implementations,
614     // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on
615     // the device.
616     // log(t) = log(kLanczosGamma + 0.5 + z)
617     //        = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5))
618     Value t = rewriter.create<AddV2Op>(loc, lanczos_gamma_plus_one_half, z);
619     Value z_div_lanczos_gamma_plus_one_half =
620         rewriter.create<DivOp>(loc, z, lanczos_gamma_plus_one_half);
621     Value log1p_z_div_lanczos_gamma_plus_one_half =
622         rewriter.create<Log1pOp>(loc, z_div_lanczos_gamma_plus_one_half);
623     Value log_t =
624         rewriter.create<AddV2Op>(loc, log_lanczos_gamma_plus_one_half,
625                                  log1p_z_div_lanczos_gamma_plus_one_half);
626 
627     // Compute the final result (modulo reflection).  t(z) may be large, and we
628     // need to be careful not to overflow to infinity in the first term of
629     //
630     //   (z + 1/2) * log(t(z)) - t(z).
631     //
632     // Therefore we compute this as
633     //
634     //   (z + 1/2 - t(z) / log(t(z))) * log(t(z)).
635     //
636     // log_y = log_sqrt_two_pi + (z + one_half - t / log_t) * log_t + Log(x);
637     Value t_div_log_t = rewriter.create<DivOp>(loc, t, log_t);
638     Value one_half_minus_t_div_log_t =
639         rewriter.create<SubOp>(loc, one_half, t_div_log_t);
640     Value z_plus_one_half_minus_t_div_log_t =
641         rewriter.create<AddV2Op>(loc, z, one_half_minus_t_div_log_t);
642     Value z_plus_one_half_minus_t_div_log_t_mul_log_t =
643         rewriter.create<MulOp>(loc, z_plus_one_half_minus_t_div_log_t, log_t);
644     Value log_x = rewriter.create<LogOp>(loc, x);
645     Value log_y_rhs = rewriter.create<AddV2Op>(
646         loc, z_plus_one_half_minus_t_div_log_t_mul_log_t, log_x);
647     Value log_y = rewriter.create<AddV2Op>(loc, log_sqrt_two_pi, log_y_rhs);
648 
649     // Compute the reflected value, used when x < 0.5:
650     //
651     //   lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))).
652     //
653     // (The abs is because lgamma is the log of the absolute value of the gamma
654     // function.)
655     //
656     // We have to be careful when computing the final term above. gamma(x) goes
657     // to +/-inf at every integer x < 0, and this is controlled by the
658     // sin(pi * x) term.  The slope is large, so precision is particularly
659     // important.
660     //
661     // Because abs(sin(pi * x)) has period 1, we can equivalently use
662     // abs(sin(pi * frac(x))), where frac(x) is the fractional part of x.  This
663     // is more numerically accurate: It doesn't overflow to inf like pi * x can,
664     // and if x is an integer, it evaluates to 0 exactly, which is significant
665     // because we then take the log of this value, and log(0) is inf.
666     //
667     // We don't have a frac(x) primitive in XLA and computing it is tricky, but
668     // because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for
669     // our purposes to use abs(frac(x)) = abs(x) - floor(abs(x)).
670     //
671     // Furthermore, pi * abs(frac(x)) loses precision when abs(frac(x)) is close
672     // to 1.  To remedy this, we can use the fact that sin(pi * x) in the domain
673     // [0, 1] is symmetric across the line Y=0.5.
674     Value abs_input = rewriter.create<AbsOp>(loc, input);
675     Value abs_input_floor = rewriter.create<FloorOp>(loc, abs_input);
676     Value abs_frac_input =
677         rewriter.create<SubOp>(loc, abs_input, abs_input_floor);
678 
679     // Convert values of abs_frac_input > 0.5 to (1 - frac_input) to improve
680     // precision of pi * abs_frac_input for values of abs_frac_input close to 1.
681     Value one_minus_abs_frac_input =
682         rewriter.create<SubOp>(loc, one, abs_frac_input);
683     Value abs_frac_input_gt_one_half =
684         rewriter.create<GreaterOp>(loc, abs_frac_input, one_half);
685     Value reduced_frac_input =
686         rewriter.create<SelectV2Op>(loc, abs_frac_input_gt_one_half,
687                                     one_minus_abs_frac_input, abs_frac_input);
688     Value pi_mul_reduced_frac_input =
689         rewriter.create<MulOp>(loc, pi, reduced_frac_input);
690     Value sin_pi_mul_reduced_frac_input =
691         rewriter.create<SinOp>(loc, pi_mul_reduced_frac_input);
692     Value reflection_denom =
693         rewriter.create<LogOp>(loc, sin_pi_mul_reduced_frac_input);
694 
695     // Avoid computing -inf - inf, which is nan.  If reflection_denom is +/-inf,
696     // then it "wins" and the result is +/-inf.
697     Value is_finite =
698         rewriter.create<IsFiniteOp>(loc, tensor_bool_type, reflection_denom);
699     Value neg_reflection_denom = rewriter.create<NegOp>(loc, reflection_denom);
700     Value log_pi_minus_reflection_denom =
701         rewriter.create<SubOp>(loc, log_pi, reflection_denom);
702     Value reflection_if_finite =
703         rewriter.create<SubOp>(loc, log_pi_minus_reflection_denom, log_y);
704     Value reflection = rewriter.create<SelectV2Op>(
705         loc, is_finite, reflection_if_finite, neg_reflection_denom);
706 
707     Value result =
708         rewriter.create<SelectV2Op>(loc, need_to_reflect, reflection, log_y);
709 
710     // lgamma(+/-inf) = +inf.
711     Value is_inf = rewriter.create<IsInfOp>(loc, tensor_bool_type, input);
712     result = rewriter.create<SelectV2Op>(loc, is_inf, infinity, result);
713 
714     if (needs_cast) {
715       result = rewriter.create<CastOp>(loc, original_tensor_type, result);
716     }
717 
718     rewriter.replaceOp(op, result);
719     return success();
720   }
721 };
722 
723 // Lowers Pack op to ConcatV2 op after changing shape of the inputs with
724 // ExpandDims op.
725 //
726 // Sample result with 2 inputs to pack:
727 //
728 //   %axis = "tf.Const"() {value = dense<1> : tensor<i64>}
729 //   %inp0 = "tf.ExpandDims"(%operand0, %axis): tensor<2xf32> -> tensor<2x1xf32>
730 //   %inp1 = "tf.ExpandDims"(%operand1, %axis): tensor<2xf32> -> tensor<2x1xf32>
731 //   %result = "tf.ConcatV2"(%operand0, %operand1, %axis) { N = 2 : i64 }:
732 //
733 class LowerPackOp : public RewritePattern {
734  public:
LowerPackOp(MLIRContext * context)735   explicit LowerPackOp(MLIRContext *context)
736       : RewritePattern(
737             PackOp::getOperationName(),
738             {ConstOp::getOperationName(), ConcatV2Op::getOperationName(),
739              ExpandDimsOp::getOperationName()},
740             1, context) {}
741 
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const742   LogicalResult matchAndRewrite(Operation *src_op,
743                                 PatternRewriter &rewriter) const override {
744     auto op = cast<PackOp>(src_op);
745 
746     Location loc = op.getLoc();
747     auto axis_value = rewriter.create<ConstOp>(
748         loc,
749         DenseElementsAttr::get(
750             RankedTensorType::get({}, rewriter.getIntegerType(64)), op.axis()));
751     int64_t axis = op.axis();
752 
753     Type prev_input_ty, inferred_ty;
754     SmallVector<Value, 4> expanded_inputs;
755     expanded_inputs.reserve(op.N());
756     for (Value input : op.values()) {
757       // If input type is different than the previous input type, infer the
758       // output type. Otherwise, use the already inferred output type from the
759       // previous iteration.
760       Type input_ty = input.getType();
761       if (input_ty != prev_input_ty) {
762         inferred_ty = InferExpandDimsType(input_ty, axis, &rewriter);
763         prev_input_ty = input_ty;
764       }
765       expanded_inputs.push_back(
766           rewriter.create<ExpandDimsOp>(loc, inferred_ty, input, axis_value));
767     }
768 
769     rewriter.replaceOpWithNewOp<ConcatV2Op>(op, op.getType(), expanded_inputs,
770                                             axis_value);
771     return success();
772   }
773 };
774 
775 // Lowers SpaceToBatchND by reducing to reshape(transpose(reshape(pad(input)))).
776 //
777 // Before rewrite:
778 //   output = SpaceToBatchND(input, block_shape, paddings)
779 // Let:
780 //   [batch] + spatial_shape + remaining_shape = input.shape
781 //   M = spatial_shape.rank
782 // After rewrite:
783 //   padded = zero-pad input with paddings
784 //     The spatial_shape component of input.shape pads with paddings[*, 0]
785 //     before each dimension, and paddings[*, 1] after each dimension.
786 //   reshaped = reshape padded to:
787 //     [batch]
788 //     + [padded.shape[1]/block_shape[0], block_shape[0], ...,
789 //        padded.shape[M]/block_shape[M-1], block_shape[M-1]]
790 //     + remaining_shape
791 //   permuted = transpose reshaped to:
792 //     block_shape
793 //     + [batch]
794 //     + [padded.shape[1]/block_shape[0], ..., padded.shape[M]/block_shape[M-1]]
795 //     + remaining_shape
796 //   result = reshape permuted to:
797 //     [batch * product(block_shape)]
798 //     + [padded.shape[1]/block_shape[0], ..., padded.shape[M]/block_shape[M-1]]
799 //     + remaining_shape
800 class LowerSpaceToBatchNDOp : public RewritePattern {
801  public:
LowerSpaceToBatchNDOp(MLIRContext * context)802   explicit LowerSpaceToBatchNDOp(MLIRContext *context)
803       : RewritePattern(SpaceToBatchNDOp::getOperationName(),
804                        {
805                            CastOp::getOperationName(),
806                            ConstOp::getOperationName(),
807                            ConcatV2Op::getOperationName(),
808                            AddV2Op::getOperationName(),
809                            PadOp::getOperationName(),
810                            SplitOp::getOperationName(),
811                            UnpackOp::getOperationName(),
812                            DivOp::getOperationName(),
813                            MulOp::getOperationName(),
814                            ReshapeOp::getOperationName(),
815                            TransposeOp::getOperationName(),
816                        },
817                        1, context) {}
818 
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const819   LogicalResult matchAndRewrite(Operation *src_op,
820                                 PatternRewriter &rewriter) const override {
821     auto op = cast<SpaceToBatchNDOp>(src_op);
822 
823     Location loc = op.getLoc();
824     auto input_type = op.input().getType().cast<TensorType>();
825     if (!input_type.hasStaticShape()) {
826       return failure();
827     }
828     ArrayRef<int64_t> input_shape = input_type.getShape();
829     auto block_shape_type = op.block_shape().getType().cast<TensorType>();
830     if (!block_shape_type.hasStaticShape()) {
831       return failure();
832     }
833     auto paddings_type = op.paddings().getType().cast<ShapedType>();
834     if (!paddings_type.hasRank()) {
835       return failure();
836     }
837 
838     int64_t input_rank = input_type.getRank();
839     int64_t block_rank = block_shape_type.getNumElements();
840     int64_t remaining_rank = input_rank - 1 - block_rank;
841     if (remaining_rank < 0) {
842       // TODO(b/157475606): Move this check to ::Verify
843       return failure();
844     }
845 
846     auto block_shape_i64_type = RankedTensorType::get(
847         block_shape_type.getShape(), rewriter.getIntegerType(64));
848     auto block_shape_i64 =
849         rewriter.create<CastOp>(loc, block_shape_i64_type, op.block_shape());
850 
851     auto paddings_i64_type = RankedTensorType::get(paddings_type.getShape(),
852                                                    rewriter.getIntegerType(64));
853     auto paddings_i64 =
854         rewriter.create<CastOp>(loc, paddings_i64_type, op.paddings());
855 
856     auto pad00 = rewriter.create<ConstOp>(
857         loc, DenseElementsAttr::get<int64_t>(
858                  RankedTensorType::get({1, 2}, rewriter.getIntegerType(64)),
859                  {0, 0}));
860     SmallVector<Value, 4> full_paddings_list{pad00, paddings_i64};
861     full_paddings_list.append(remaining_rank, pad00);
862     auto full_paddings_type =
863         RankedTensorType::get({input_rank, 2}, rewriter.getIntegerType(64));
864     auto zero_i64 = rewriter.create<ConstOp>(
865         loc, GetScalarOfType(rewriter.getIntegerType(64), 0));
866     // Extends paddings to all dimensions of input by adding 0s to non-block
867     // dimensions.
868     auto full_paddings = rewriter.create<ConcatV2Op>(
869         loc, full_paddings_type, full_paddings_list, zero_i64);
870 
871     // Compute the result type here instead of using shape inference because the
872     // full_paddings won't be available as a constant for shape inference.
873     ElementsAttr block_shape;
874     ElementsAttr paddings;
875     llvm::SmallVector<int64_t, 4> block_shape_ints;
876     auto padded_shape = llvm::to_vector<4>(input_shape);
877     if (matchPattern(op.block_shape(), m_Constant(&block_shape)) &&
878         matchPattern(op.paddings(), m_Constant(&paddings))) {
879       for (uint64_t i = 0; i < block_rank; i++) {
880         int64_t paddings_sum =
881             paddings.getValue({i, 0}).cast<IntegerAttr>().getInt() +
882             paddings.getValue({i, 1}).cast<IntegerAttr>().getInt();
883         int64_t block_shape_i =
884             block_shape.getValue({i}).cast<IntegerAttr>().getInt();
885         padded_shape[i + 1] = (paddings_sum + input_shape[i + 1]);
886         block_shape_ints.push_back(block_shape_i);
887       }
888     } else {
889       for (int i = 0; i < block_rank; i++) {
890         padded_shape[i + 1] = ShapedType::kDynamicSize;
891       }
892       block_shape_ints.resize(block_shape_type.getNumElements(), -1);
893     }
894 
895     auto padded_type =
896         RankedTensorType::get(padded_shape, rewriter.getF32Type());
897     // padded = pad(input, full_paddings)
898     auto padded =
899         rewriter.create<PadOp>(loc, padded_type, op.input(), full_paddings);
900 
901     auto paddings_sum_type =
902         RankedTensorType::get({input_rank}, rewriter.getIntegerType(64));
903     // paddings_sum = paddings[*,0] + paddings[*,1]
904     auto paddings_split = rewriter.create<UnpackOp>(
905         loc, TypeRange({paddings_sum_type, paddings_sum_type}), full_paddings,
906         rewriter.getI64IntegerAttr(1));
907     auto paddings_sum = rewriter.create<AddV2Op>(
908         loc, paddings_split.getResult(0), paddings_split.getResult(1));
909 
910     auto input_shape_tensor = rewriter.create<ConstOp>(
911         loc,
912         DenseElementsAttr::get(
913             RankedTensorType::get({input_rank}, rewriter.getIntegerType(64)),
914             input_shape));
915 
916     // padded_shape_tensor is the shape of padded.
917     auto padded_shape_tensor =
918         rewriter.create<AddV2Op>(loc, paddings_sum, input_shape_tensor);
919 
920     auto zero_i32 = rewriter.create<ConstOp>(
921         loc, GetScalarOfType(rewriter.getIntegerType(32), 0));
922     SmallVector<Type, 4> padded_shape_splits_types(
923         input_rank, RankedTensorType::get({1}, rewriter.getIntegerType(64)));
924     SmallVector<Value, 4> padded_shape_splits(
925         rewriter
926             .create<SplitOp>(loc, padded_shape_splits_types, zero_i32,
927                              padded_shape_tensor)
928             .output());
929 
930     SmallVector<Type, 4> block_shape_splits_types(
931         block_rank, RankedTensorType::get({1}, rewriter.getIntegerType(64)));
932     SmallVector<Value, 4> block_shape_splits(
933         rewriter
934             .create<SplitOp>(loc, block_shape_splits_types, zero_i32,
935                              block_shape_i64)
936             .output());
937 
938     SmallVector<int64_t, 4> outer_shape_ints;
939     SmallVector<Value, 4> outer_shape_vals;
940     for (int64_t i = 0; i < block_rank; ++i) {
941       // TODO(b/157475606): Insert tf.Assert that the following division has
942       // remainder 0.
943       outer_shape_vals.push_back(rewriter.create<DivOp>(
944           loc, padded_shape_splits[1 + i], block_shape_splits[i]));
945 
946       auto padded_shape_i = padded_shape[1 + i];
947       auto block_shape_ints_i = block_shape_ints[i];
948 
949       // Compute the outer_shape constant values to infer the reshape.
950       if (padded_shape_i == -1 || block_shape_ints_i == -1) {
951         outer_shape_ints.push_back(-1);
952       } else {
953         outer_shape_ints.push_back(padded_shape_i / block_shape_ints_i);
954       }
955     }
956 
957     SmallVector<Value, 6> reshaped_shape_vals{padded_shape_splits[0]};
958     SmallVector<int64_t, 6> reshaped_shape_ints{padded_shape[0]};
959     for (int64_t i = 0; i < block_rank; ++i) {
960       reshaped_shape_vals.push_back(outer_shape_vals[i]);
961       reshaped_shape_vals.push_back(block_shape_splits[i]);
962 
963       reshaped_shape_ints.push_back(outer_shape_ints[i]);
964       reshaped_shape_ints.push_back(block_shape_ints[i]);
965     }
966     for (int64_t i = 1 + block_rank; i < input_rank; ++i) {
967       reshaped_shape_vals.push_back(padded_shape_splits[i]);
968       reshaped_shape_ints.push_back(padded_shape[i]);
969     }
970     auto reshaped_shape = ValuesToRank1(
971         rewriter, loc, rewriter.getIntegerType(64), reshaped_shape_vals);
972 
973     auto reshaped = rewriter.create<ReshapeOp>(
974         loc,
975         RankedTensorType::get(reshaped_shape_ints, input_type.getElementType()),
976         padded, reshaped_shape);
977 
978     SmallVector<int64_t, 6> permutation_vals;
979     for (int64_t i = 0; i < block_rank; ++i) {
980       permutation_vals.push_back(2 + 2 * i);
981     }
982     permutation_vals.push_back(0);
983     for (int64_t i = 0; i < block_rank; ++i) {
984       permutation_vals.push_back(1 + 2 * i);
985     }
986     for (int64_t i = 1 + block_rank; i < input_rank; ++i) {
987       permutation_vals.push_back(block_rank + i);
988     }
989     auto permutation = rewriter.create<ConstOp>(
990         loc, GetI64ElementsAttr(permutation_vals, &rewriter));
991 
992     auto permuted = rewriter.create<TransposeOp>(loc, reshaped, permutation);
993     auto output_batch = padded_shape_splits[0];
994     for (int64_t i = 0; i < block_rank; ++i) {
995       output_batch =
996           rewriter.create<MulOp>(loc, output_batch, block_shape_splits[i]);
997     }
998     SmallVector<Value, 4> output_shape_vals{output_batch};
999     for (int64_t i = 0; i < block_rank; ++i) {
1000       output_shape_vals.push_back(outer_shape_vals[i]);
1001     }
1002     for (int64_t i = 1 + block_rank; i < input_rank; ++i) {
1003       output_shape_vals.push_back(padded_shape_splits[i]);
1004     }
1005     auto output_shape = ValuesToRank1(
1006         rewriter, loc, rewriter.getIntegerType(64), output_shape_vals);
1007 
1008     // Sometimes the result type is more specific than what the reshape builder
1009     // can infer.
1010     auto result_type = op.getResult().getType();
1011     rewriter.replaceOpWithNewOp<ReshapeOp>(op, result_type, permuted,
1012                                            output_shape);
1013 
1014     return success();
1015   }
1016 };
1017 
1018 class LowerBatchToSpaceND : public RewritePattern {
1019  public:
LowerBatchToSpaceND(MLIRContext * context)1020   explicit LowerBatchToSpaceND(MLIRContext *context)
1021       : RewritePattern(BatchToSpaceNDOp::getOperationName(),
1022                        {
1023                            ConstOp::getOperationName(),
1024                            ReshapeOp::getOperationName(),
1025                            SliceOp::getOperationName(),
1026                            TransposeOp::getOperationName(),
1027                        },
1028                        1, context) {}
1029 
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const1030   LogicalResult matchAndRewrite(Operation *src_op,
1031                                 PatternRewriter &rewriter) const override {
1032     auto op = cast<BatchToSpaceNDOp>(src_op);
1033     auto input = op.input();
1034     auto input_ty = input.getType().cast<ShapedType>();
1035     auto element_ty = input_ty.getElementType();
1036     if (!input_ty.hasStaticShape()) {
1037       return failure();
1038     }
1039 
1040     const int input_rank = input_ty.getRank();
1041     auto input_shape = input_ty.getShape();
1042 
1043     DenseIntElementsAttr block_shape;
1044     DenseIntElementsAttr crops;
1045     if (!matchPattern(op.block_shape(), m_Constant(&block_shape)) ||
1046         !matchPattern(op.crops(), m_Constant(&crops))) {
1047       return failure();
1048     }
1049 
1050     auto block_shape_ty = block_shape.getType();
1051     if (!block_shape_ty.hasRank() || block_shape_ty.getRank() != 1) {
1052       return failure();
1053     }
1054 
1055     const int block_rank = block_shape_ty.getShape().front();
1056     auto remainder_shape = input_shape.drop_front(1 + block_rank);
1057 
1058     const int64_t batch_size = input_shape[0];
1059 
1060     // Compute the product of the block_shape values.
1061     int64_t block_num_elems = 1;
1062 
1063     for (auto val : block_shape.getIntValues()) {
1064       block_num_elems *= val.getSExtValue();
1065     }
1066 
1067     if (block_num_elems <= 0) {
1068       op.emitOpError()
1069           << "The product of the block dimensions must be positive";
1070       return failure();
1071     }
1072 
1073     // 1. Reshape `input` to `reshaped` of shape:
1074     //      [block_shape[0], ..., block_shape[M-1],
1075     //       batch / prod(block_shape),
1076     //       input_shape[1], ..., input_shape[N-1]]
1077     std::vector<int64_t> reshaped_shape;
1078     for (auto val : block_shape) {
1079       reshaped_shape.push_back(val.getSExtValue());
1080     }
1081     reshaped_shape.resize(input_rank + block_rank);
1082 
1083     reshaped_shape[block_rank] = batch_size / block_num_elems;
1084     std::copy(input_shape.begin() + 1, input_shape.end(),
1085               reshaped_shape.begin() + block_rank + 1);
1086 
1087     auto reshaped = rewriter.create<TF::ReshapeOp>(
1088         op.getLoc(), RankedTensorType::get(reshaped_shape, element_ty), input,
1089         rewriter.create<ConstOp>(op.getLoc(),
1090                                  rewriter.getI64TensorAttr(reshaped_shape)));
1091 
1092     // 2. Permute dimensions of `reshaped` to produce `permuted` of shape
1093     //      [batch / prod(block_shape),
1094     //
1095     //       input_shape[1], block_shape[0],
1096     //       ...,
1097     //       input_shape[M], block_shape[M-1],
1098     //
1099     //       input_shape[M+1], ..., input_shape[N-1]]
1100     std::vector<int64_t> permutation(reshaped_shape.size());
1101     permutation[0] = block_rank;
1102     for (int i = 0; i < block_rank; ++i) {
1103       permutation[1 + 2 * i] = block_rank + 1 + i;
1104       permutation[1 + 2 * i + 1] = i;
1105     }
1106     std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(),
1107               1 + block_rank * 2);
1108 
1109     std::vector<int64_t> transpose_shape(permutation.size());
1110     for (auto it : llvm::enumerate(permutation)) {
1111       transpose_shape[it.index()] = reshaped_shape[it.value()];
1112     }
1113 
1114     auto permuted = rewriter.create<TF::TransposeOp>(
1115         op.getLoc(), RankedTensorType::get(transpose_shape, element_ty),
1116         reshaped,
1117         rewriter.create<ConstOp>(op.getLoc(),
1118                                  rewriter.getI64TensorAttr(permutation)));
1119 
1120     // 3. Reshape `permuted` to produce `reshaped_permuted` of shape
1121     //      [batch / prod(block_shape),
1122     //
1123     //       input_shape[1] * block_shape[0],
1124     //       ...,
1125     //       input_shape[M] * block_shape[M-1],
1126     //
1127     //       input_shape[M+1],
1128     //       ...,
1129     //       input_shape[N-1]]
1130     std::vector<int64_t> reshaped_permuted_shape(input_rank);
1131     auto block_shape_values = llvm::to_vector<4>(block_shape.getIntValues());
1132     reshaped_permuted_shape[0] = batch_size / block_num_elems;
1133     for (int i = 0; i < block_rank; ++i) {
1134       reshaped_permuted_shape[1 + i] =
1135           block_shape_values[i].getSExtValue() * input_shape[1 + i];
1136     }
1137     std::copy(remainder_shape.begin(), remainder_shape.end(),
1138               reshaped_permuted_shape.begin() + 1 + block_rank);
1139 
1140     auto reshaped_permuted = rewriter.create<TF::ReshapeOp>(
1141         op.getLoc(), RankedTensorType::get(reshaped_permuted_shape, element_ty),
1142         permuted,
1143         rewriter.create<ConstOp>(
1144             op.getLoc(), rewriter.getI64TensorAttr(reshaped_permuted_shape)));
1145 
1146     // 4. Crop the start and end of dimensions `[1, ..., M]` of
1147     //    `reshaped_permuted` according to `crops` to produce the output of
1148     //    shape:
1149     //      [batch / prod(block_shape),
1150     //
1151     //       input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1],
1152     //       ...,
1153     //       input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1],
1154     //
1155     //       input_shape[M+1], ..., input_shape[N-1]]
1156     std::vector<int64_t> start_indices(input_rank, 0);
1157     std::vector<int64_t> slice_sizes = reshaped_permuted_shape;
1158     std::vector<int64_t> strides(input_rank, 1);
1159     auto crop_values = llvm::to_vector<4>(crops.getIntValues());
1160     for (int i = 0; i < block_rank; ++i) {
1161       int64_t crop_start = crop_values[i * 2].getSExtValue();
1162       int64_t crop_end = crop_values[i * 2 + 1].getSExtValue();
1163 
1164       if (crop_start < 0 || crop_end < 0) {
1165         op.emitOpError() << "Crops must be non-negative";
1166         return failure();
1167       }
1168 
1169       start_indices[i + 1] = crop_start;
1170       slice_sizes[i + 1] -= crop_start + crop_end;
1171 
1172       if (slice_sizes[i + 1] < 0) {
1173         op.emitOpError() << "Cropped size must be non-negative: start: "
1174                          << crop_start << " end: " << crop_end << " size "
1175                          << reshaped_permuted_shape[1 + i];
1176       }
1177     }
1178 
1179     rewriter.replaceOpWithNewOp<TF::SliceOp>(
1180         op, RankedTensorType::get(slice_sizes, element_ty), reshaped_permuted,
1181         rewriter.create<ConstOp>(op.getLoc(),
1182                                  rewriter.getI64TensorAttr(start_indices)),
1183         rewriter.create<ConstOp>(op.getLoc(),
1184                                  rewriter.getI64TensorAttr(slice_sizes)));
1185     return success();
1186   }
1187 };
1188 
1189 // Lowers `SparseMatMulOp` to `MatMulOp`, ignoring the sparseness hints,
1190 // since we currently don't have an implementation that can use this
1191 // information. Adds appropriate casts where necessary to align element types
1192 // of operands and result for `MatMulOp`.
1193 class LowerSparseMatMulOp : public RewritePattern {
1194  public:
LowerSparseMatMulOp(MLIRContext * context)1195   explicit LowerSparseMatMulOp(MLIRContext *context)
1196       : RewritePattern(
1197             SparseMatMulOp::getOperationName(),
1198             {CastOp::getOperationName(), MatMulOp::getOperationName()}, 1,
1199             context) {}
1200 
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const1201   LogicalResult matchAndRewrite(Operation *src_op,
1202                                 PatternRewriter &rewriter) const override {
1203     auto op = cast<SparseMatMulOp>(src_op);
1204 
1205     // Result type must be f32 for applying the pattern (currently this is
1206     // required by the op anyway but this might change).
1207     if (!op.product().getType().cast<TensorType>().getElementType().isF32()) {
1208       return failure();
1209     }
1210     MLIRContext *context = rewriter.getContext();
1211     llvm::SmallVector<Value, 2> operands{op.a(), op.b()};
1212     for (Value &operand : operands) {
1213       TensorType tensor_type = operand.getType().cast<TensorType>();
1214       Type element_type = tensor_type.getElementType();
1215       if (element_type.isF32()) continue;
1216       // Element type can either be f32 or bf16 for `SparseMatMulOp` so it
1217       // must be bf16 here.
1218       assert(element_type.isBF16());
1219       Type tensor_type_f32;
1220       if (tensor_type.hasRank()) {
1221         tensor_type_f32 = RankedTensorType::get(tensor_type.getShape(),
1222                                                 FloatType::getF32(context));
1223       } else {
1224         tensor_type_f32 = UnrankedTensorType::get(FloatType::getF32(context));
1225       }
1226       // Add cast to f32 to conform with element type of result.
1227       operand = rewriter.create<CastOp>(op.getLoc(), tensor_type_f32, operand);
1228     }
1229     Value result = rewriter.create<MatMulOp>(
1230         op.getLoc(), op.product().getType(), operands[0], operands[1],
1231         op.transpose_a(), op.transpose_b());
1232 
1233     rewriter.replaceOp(op, {result});
1234     return success();
1235   }
1236 };
1237 
1238 // Lowers _UnaryOpsComposition op as a series of original TensorFlow ops that
1239 // were fused together.
1240 class Lower_UnaryOpsComposition
1241     : public OpRewritePattern<_UnaryOpsCompositionOp> {
1242  public:
1243   using OpRewritePattern<_UnaryOpsCompositionOp>::OpRewritePattern;
1244 
matchAndRewrite(_UnaryOpsCompositionOp op,PatternRewriter & rewriter) const1245   LogicalResult matchAndRewrite(_UnaryOpsCompositionOp op,
1246                                 PatternRewriter &rewriter) const override {
1247     Value result = op.x();
1248     for (StringRef op_name : op.op_names().getAsValueRange<StringAttr>()) {
1249       std::string full_name = "tf." + op_name.str();
1250       // All ops in the sequences have the same result type as the original
1251       // result type.
1252       OperationState state(op.getLoc(), full_name, /*operands=*/{result},
1253                            /*types=*/{op.getType()}, /*attributes=*/{});
1254       Operation *op = rewriter.createOperation(state);
1255       result = op->getResult(0);
1256     }
1257     rewriter.replaceOp(op, {result});
1258     return success();
1259   }
1260 };
1261 
1262 // Lowers ResizeNearestNeighbor to an indices computations with a gather along
1263 // the combined spatial dimensions. Generating the indices along the
1264 // width/height index could be used to gather along each of W and H dimension
1265 // of the input image array. To reduce to a single gather, these indices are
1266 // combined, so a single gather can be performed along the combined spatial
1267 // dimensions.
1268 //
1269 // Images must take the shape [b, h, w, c] and size is a rank-1 length-2 tensor
1270 // containing the height and width values for the output tensor. This lowering
1271 // should work with a dynamic images array.
1272 //
1273 // For example, a scaling with image shape [1, 3, 3, 1] to [2, 2] and unaligned
1274 // corners would generate a [0, 1] lookup along both the x and y direction.
1275 // Then when combined to form the 1-D spatial index the values would be
1276 // [0, 1, 3, 4] which would gather along the reshape image tensor of shape
1277 // [1, 9, 1], reshaped to the final [1, 3, 3, 1].
1278 class LowerResizeNearestNeighbor : public RewritePattern {
1279  public:
LowerResizeNearestNeighbor(MLIRContext * context)1280   explicit LowerResizeNearestNeighbor(MLIRContext *context)
1281       : RewritePattern(ResizeNearestNeighborOp::getOperationName(),
1282                        {
1283                            BroadcastToOp::getOperationName(),
1284                            ConstOp::getOperationName(),
1285                            DivOp::getOperationName(),
1286                            PackOp::getOperationName(),
1287                            RangeOp::getOperationName(),
1288                            ReshapeOp::getOperationName(),
1289                            ShapeOp::getOperationName(),
1290                            SplitOp::getOperationName(),
1291                            TransposeOp::getOperationName(),
1292                        },
1293                        1, context) {}
1294 
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const1295   LogicalResult matchAndRewrite(Operation *src_op,
1296                                 PatternRewriter &rewriter) const override {
1297     auto op = cast<ResizeNearestNeighborOp>(src_op);
1298     auto loc = op.getLoc();
1299     auto result_ty = op.getType().cast<ShapedType>();
1300 
1301     auto input = op.images();
1302     auto input_ty = input.getType().cast<ShapedType>();
1303     auto input_element_ty = input_ty.getElementType();
1304     auto out_size = op.size();
1305     auto out_size_ty = out_size.getType().cast<ShapedType>();
1306     auto out_size_element_ty = out_size_ty.getElementType();
1307 
1308     // Input should be rank 4.
1309     if (!input_ty.hasRank() || input_ty.getRank() != 4) {
1310       return failure();
1311     }
1312 
1313     // Check that out_size is rank-1, length-2. Otherwise the size is not legal.
1314     if (!out_size_ty.hasRank() || out_size_ty.getRank() != 1 ||
1315         out_size_ty.getShape()[0] != 2) {
1316       return failure();
1317     }
1318 
1319     // Extract the output width / height dim size.
1320     int out_height_constant = -1;
1321     int out_width_constant = -1;
1322     DenseIntElementsAttr out_size_cst;
1323     if (matchPattern(out_size, m_Constant(&out_size_cst))) {
1324       llvm::SmallVector<int64_t, 2> cst_size;
1325       for (auto val : out_size_cst.getIntValues()) {
1326         cst_size.push_back(val.getSExtValue());
1327       }
1328 
1329       out_height_constant = cst_size[0];
1330       out_width_constant = cst_size[1];
1331 
1332       if (out_height_constant < 0 || out_width_constant < 0) return failure();
1333     }
1334 
1335     int out_spatial_cst = out_height_constant < 0 || out_width_constant < 0
1336                               ? -1
1337                               : out_height_constant * out_width_constant;
1338 
1339     // Input rank should be 4. Might be able to drop this requirement entirely
1340     // as its an input requirement.
1341     if (!input_ty.hasRank() || input_ty.getRank() != 4) {
1342       return failure();
1343     }
1344 
1345     int batch_cst = input_ty.getShape()[0];
1346     int channels_cst = input_ty.getShape()[3];
1347 
1348     int in_y_cst = input_ty.getShape()[1];
1349     int in_x_cst = input_ty.getShape()[2];
1350     int in_spatial_cst =
1351         in_y_cst < 0 || in_x_cst < 0 ? -1 : in_y_cst * in_x_cst;
1352 
1353     // TODO(suderman): Add support for these optional parameters.
1354     if (op.align_corners() == true || op.half_pixel_centers() == true) {
1355       return failure();
1356     }
1357 
1358     auto one =
1359         rewriter.create<ConstOp>(loc, GetScalarOfType(out_size_element_ty, 1));
1360 
1361     // Extract the image shape.
1362     Value input_shape = rewriter.create<ShapeOp>(
1363         loc, RankedTensorType::get({4}, rewriter.getI64Type()), input);
1364     input_shape = rewriter.create<CastOp>(
1365         loc, RankedTensorType::get({4}, out_size_element_ty), input_shape);
1366 
1367     auto scalar_dim_ty = RankedTensorType::get({}, out_size_element_ty);
1368     auto split_image_shape = rewriter.create<UnpackOp>(
1369         loc,
1370         TypeRange({scalar_dim_ty, scalar_dim_ty, scalar_dim_ty, scalar_dim_ty}),
1371         input_shape);
1372 
1373     // Extract the separate components from the input shape.
1374     auto batch = split_image_shape.getResult(0);
1375     auto in_y = split_image_shape.getResult(1);
1376     auto in_x = split_image_shape.getResult(2);
1377     auto channels = split_image_shape.getResult(3);
1378 
1379     auto in_count = rewriter.create<MulOp>(
1380         loc, RankedTensorType::get({}, out_size_element_ty), in_y, in_x);
1381 
1382     // Unpack and separate the out width/height.
1383     auto split_out_size = rewriter.create<UnpackOp>(
1384         loc, TypeRange({scalar_dim_ty, scalar_dim_ty}), out_size);
1385 
1386     auto out_y = split_out_size.getResult(0);
1387     auto out_x = split_out_size.getResult(1);
1388 
1389     auto out_count = rewriter.create<MulOp>(
1390         loc, RankedTensorType::get({}, out_size_element_ty), out_y, out_x);
1391 
1392     // Generate what the final output shape will look like.
1393     auto out_shape = rewriter.create<PackOp>(
1394         loc, RankedTensorType::get({4}, out_size_element_ty),
1395         ValueRange({batch, out_y, out_x, channels}));
1396 
1397     // Compute the indices along the vertical dimension.
1398     auto in_y_f32 = rewriter.create<CastOp>(
1399         loc, RankedTensorType::get({}, rewriter.getF32Type()), in_y);
1400     auto out_w_f32 = rewriter.create<CastOp>(
1401         loc, RankedTensorType::get({}, rewriter.getF32Type()), out_y);
1402 
1403     Value y_scale = rewriter.create<DivOp>(
1404         loc, RankedTensorType::get({}, rewriter.getF32Type()), in_y_f32,
1405         out_w_f32);
1406 
1407     Value zero_f32 = rewriter.create<ConstOp>(
1408         loc, GetScalarOfType(rewriter.getF32Type(), 0.0));
1409     Value one_f32 = rewriter.create<ConstOp>(
1410         loc, GetScalarOfType(rewriter.getF32Type(), 1.0));
1411 
1412     Value y_range = rewriter.create<RangeOp>(
1413         loc,
1414         RankedTensorType::get({out_height_constant}, rewriter.getF32Type()),
1415         zero_f32, out_w_f32, one_f32);
1416 
1417     y_range = rewriter.create<MulOp>(
1418         loc,
1419         RankedTensorType::get({out_height_constant}, rewriter.getF32Type()),
1420         y_range, y_scale);
1421 
1422     y_range = rewriter.create<CastOp>(
1423         loc, RankedTensorType::get({out_height_constant}, out_size_element_ty),
1424         y_range);
1425 
1426     y_range = rewriter.create<ReshapeOp>(
1427         loc,
1428         RankedTensorType::get({out_height_constant, 1}, out_size_element_ty),
1429         y_range,
1430         rewriter.create<PackOp>(loc,
1431                                 RankedTensorType::get({2}, out_size_element_ty),
1432                                 ValueRange({out_y, one})));
1433 
1434     Value y_indices = rewriter.create<MulOp>(
1435         loc,
1436         RankedTensorType::get({out_height_constant, 1}, out_size_element_ty),
1437         y_range, in_x);
1438 
1439     // Compute the indices for the nearest neighbour lookup across the width
1440     // dim.
1441     auto in_x_f32 = rewriter.create<CastOp>(
1442         loc, RankedTensorType::get({}, rewriter.getF32Type()), in_x);
1443     auto out_h_f32 = rewriter.create<CastOp>(
1444         loc, RankedTensorType::get({}, rewriter.getF32Type()), out_x);
1445 
1446     Value x_scale = rewriter.create<DivOp>(
1447         loc, RankedTensorType::get({}, rewriter.getF32Type()), in_x_f32,
1448         out_h_f32);
1449 
1450     Value x_range = rewriter.create<RangeOp>(
1451         loc, RankedTensorType::get({out_width_constant}, rewriter.getF32Type()),
1452         zero_f32, out_h_f32, one_f32);
1453 
1454     x_range = rewriter.create<MulOp>(
1455         loc, RankedTensorType::get({out_width_constant}, rewriter.getF32Type()),
1456         x_range, x_scale);
1457 
1458     x_range = rewriter.create<CastOp>(
1459         loc, RankedTensorType::get({out_width_constant}, out_size_element_ty),
1460         x_range);
1461 
1462     Value x_indices = rewriter.create<ReshapeOp>(
1463         loc,
1464         RankedTensorType::get({1, out_width_constant}, out_size_element_ty),
1465         x_range,
1466         rewriter.create<PackOp>(loc,
1467                                 RankedTensorType::get({2}, out_size_element_ty),
1468                                 ValueRange({one, out_x})));
1469 
1470     // Generate the combined index array, reshape to be 1-D.
1471     Value indices = rewriter.create<AddV2Op>(
1472         loc,
1473         RankedTensorType::get({out_height_constant, out_width_constant},
1474                               out_size_element_ty),
1475         y_indices, x_indices);
1476 
1477     indices = rewriter.create<ReshapeOp>(
1478         loc, RankedTensorType::get({out_spatial_cst}, out_size_element_ty),
1479         indices,
1480         rewriter.create<ReshapeOp>(
1481             loc, RankedTensorType::get({1}, out_size_element_ty), out_count,
1482             rewriter.create<ConstOp>(loc, rewriter.getI64TensorAttr({1}))));
1483 
1484     // Group the spatial indices and gather along that combined index.
1485     Value input_collapsed_spatial = rewriter.create<ReshapeOp>(
1486         loc,
1487         RankedTensorType::get({batch_cst, in_spatial_cst, channels_cst},
1488                               input_element_ty),
1489         input,
1490         rewriter.create<PackOp>(loc,
1491                                 RankedTensorType::get({3}, out_size_element_ty),
1492                                 ValueRange({batch, in_count, channels})));
1493 
1494     Value gathered_values = rewriter.create<GatherV2Op>(
1495         loc,
1496         RankedTensorType::get({batch_cst, out_spatial_cst, channels_cst},
1497                               input_element_ty),
1498         input_collapsed_spatial, indices, /*axis=*/one);
1499 
1500     gathered_values =
1501         rewriter.create<ReshapeOp>(loc, result_ty, gathered_values, out_shape);
1502 
1503     rewriter.replaceOp(op, gathered_values);
1504     return success();
1505   }
1506 };
1507 
1508 }  // namespace
1509 
PopulateLoweringTFPatterns(MLIRContext * context,OwningRewritePatternList * patterns)1510 void PopulateLoweringTFPatterns(MLIRContext *context,
1511                                 OwningRewritePatternList *patterns) {
1512   patterns->insert<LowerAddNOp, ConvertFakeQuantWithMinMaxVarsOp,
1513                    LowerDynamicStitchOp<DynamicStitchOp>,
1514                    LowerDynamicStitchOp<ParallelDynamicStitchOp>,
1515                    LowerInvertPermutationOp, LowerLgammaOp, LowerPackOp,
1516                    LowerBatchToSpaceND, LowerSpaceToBatchNDOp,
1517                    LowerResizeNearestNeighbor, LowerSparseMatMulOp,
1518                    Lower_UnaryOpsComposition>(context);
1519   populateWithGenerated(context, *patterns);
1520 }
1521 
PopulateTFLoweringBeforeHLOPatterns(MLIRContext * context,OwningRewritePatternList * patterns)1522 void PopulateTFLoweringBeforeHLOPatterns(MLIRContext *context,
1523                                          OwningRewritePatternList *patterns) {
1524   // clang-format off
1525   patterns->insert<
1526       ConvertFakeQuantWithMinMaxVarsOp,
1527       LowerAddNOp,
1528       LowerBatchToSpaceND,
1529       LowerDynamicStitchOp<DynamicStitchOp>,
1530       LowerDynamicStitchOp<ParallelDynamicStitchOp>,
1531       LowerInvertPermutationOp,
1532       LowerPackOp,
1533       LowerResizeNearestNeighbor,
1534       LowerSpaceToBatchNDOp,
1535       LowerSparseMatMulOp,
1536       Lower_UnaryOpsComposition>(context);
1537   // clang-format on
1538 
1539   // Populate the relevant generated patterns.
1540   // clang-format off
1541   patterns->insert<
1542       LowerBiasAddGradOp,
1543       LowerDivNoNanOp,
1544       LowerEmptyOp,
1545       LowerFakeQuantWithMinMaxArgs,
1546       LowerFillOp,
1547       LowerIsNanOp,
1548       LowerL2LossOp,
1549       LowerMulNoNanOp,
1550       LowerOnesLikeOp,
1551       LowerPadOp,
1552       LowerReciprocal,
1553       LowerRintOp,
1554       LowerRoundOpOnFloatTensor,
1555       LowerRoundOpOnIntTensor,
1556       LowerRsqrtGradOp,
1557       LowerScatterNdOp,
1558       LowerSizeOp,
1559       LowerSoftmaxCrossEntropyWithLogitsOp,
1560       LowerSparseSoftmaxCrossEntropyWithLogitsOp,
1561       LowerSqrtGradOp,
1562       LowerSquareOp,
1563       LowerSquaredDifferenceOpOnRealTensors,
1564       LowerSquaredDifferenceOpOneComplexTensors,
1565       LowerTanhGradOp,
1566       LowerXdivyOp,
1567       LowerXlog1pyOp,
1568       LowerXlogyOp,
1569       LowerZerosLikeOp>(context);
1570   // clang-format on
1571 }
1572 
1573 }  // namespace TF
1574 }  // namespace mlir
1575