1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
17
18 #include <numeric>
19
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "mlir/IR/Attributes.h" // from @llvm-project
23 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
24 #include "mlir/IR/Diagnostics.h" // from @llvm-project
25 #include "mlir/IR/MLIRContext.h" // from @llvm-project
26 #include "mlir/IR/Matchers.h" // from @llvm-project
27 #include "mlir/IR/PatternMatch.h" // from @llvm-project
28 #include "mlir/IR/TypeRange.h" // from @llvm-project
29 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h"
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
35 #include "tensorflow/core/util/tensor_format.h"
36
37 namespace mlir {
38 namespace TF {
39 namespace {
40
41 // Returns 1D 64-bit dense elements attribute with the given values.
GetI64ElementsAttr(ArrayRef<int64_t> values,Builder * builder)42 static DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values,
43 Builder *builder) {
44 RankedTensorType ty = RankedTensorType::get(
45 {static_cast<int64_t>(values.size())}, builder->getIntegerType(64));
46 return DenseIntElementsAttr::get(ty, values);
47 }
48
49 // Returns a 1-d i64 elements attribute populated with numbers from start to
50 // end, excluding.
GetI64ElementsAttrForSeq(int start,int end,Builder * builder)51 static DenseIntElementsAttr GetI64ElementsAttrForSeq(int start, int end,
52 Builder *builder) {
53 int size = end - start;
54
55 SmallVector<int64_t, 4> vals;
56 vals.resize(size);
57 std::iota(vals.begin(), vals.end(), start);
58
59 TensorType ty = RankedTensorType::get({size}, builder->getIntegerType(64));
60 return DenseIntElementsAttr::get(ty, vals);
61 }
62
ConvertToAPFloat(double val,Type type)63 static APFloat ConvertToAPFloat(double val, Type type) {
64 if (type.getIntOrFloatBitWidth() == 32) {
65 return APFloat(static_cast<float>(val));
66 }
67
68 return APFloat(val);
69 }
70
71 // Returns int, float, or complex DenseElementsAttr with scalar shape with the
72 // given element type and the value.
73 template <typename T>
GetScalarOfType(Type ty,T raw_value)74 static DenseElementsAttr GetScalarOfType(Type ty, T raw_value) {
75 RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
76 if (auto float_ty = ty.dyn_cast_or_null<FloatType>()) {
77 FloatAttr attr = FloatAttr::get(float_ty, raw_value);
78 return DenseElementsAttr::get(scalar_ty, attr);
79 } else if (auto int_ty = ty.dyn_cast_or_null<IntegerType>()) {
80 IntegerAttr attr = IntegerAttr::get(int_ty, raw_value);
81 return DenseElementsAttr::get(scalar_ty, attr);
82 } else if (auto complex_ty = ty.dyn_cast_or_null<ComplexType>()) {
83 Type complex_element_ty = complex_ty.getElementType();
84 if (complex_element_ty.isF32()) {
85 return DenseElementsAttr::get(
86 scalar_ty, static_cast<std::complex<float>>(raw_value));
87 } else if (complex_element_ty.isF64()) {
88 return DenseElementsAttr::get(
89 scalar_ty, static_cast<std::complex<double>>(raw_value));
90 }
91 }
92 llvm_unreachable("unsupported type");
93 }
94
95 // Returns reduction indices to use while lowering tf.BiasAddGrad op to tf.Sum
96 // op.
GetBiasAddGradReductionIndices(int64_t rank,StringAttr data_format,Builder * builder)97 DenseIntElementsAttr GetBiasAddGradReductionIndices(int64_t rank,
98 StringAttr data_format,
99 Builder *builder) {
100 tensorflow::TensorFormat format;
101 if (!FormatFromString(data_format.getValue().str(), &format)) return {};
102
103 // Reduce along all dimensions except the feature dimension.
104 int64_t feature_dim = GetTensorFeatureDimIndex(rank, format);
105 llvm::SmallVector<int64_t, 4> dims_to_reduce(rank - 1);
106 std::iota(dims_to_reduce.begin(), dims_to_reduce.begin() + feature_dim, 0);
107 std::iota(dims_to_reduce.begin() + feature_dim, dims_to_reduce.end(),
108 feature_dim + 1);
109 return GetI64ElementsAttr(dims_to_reduce, builder);
110 }
111
112 #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_lower_tf.inc"
113
114 // Infers ExpandDims op output type for the given input type `ty` and dimension
115 // to expand at the given `axis`.
InferExpandDimsType(Type ty,int64_t axis,Builder * builder)116 Type InferExpandDimsType(Type ty, int64_t axis, Builder *builder) {
117 auto ranked_ty = ty.dyn_cast<RankedTensorType>();
118
119 // Unranked type.
120 if (!ranked_ty) return ty;
121
122 auto shape = llvm::to_vector<4>(ranked_ty.getShape());
123 if (axis < 0) axis += ranked_ty.getRank() + 1;
124
125 shape.insert(shape.begin() + axis, 1);
126 return RankedTensorType::get(shape, ranked_ty.getElementType());
127 }
128
129 // Converts individual Values to a tensor of rank 1. Each input Value has rank 1
130 // and size 1.
ValuesToRank1(PatternRewriter & rewriter,Location loc,Type dtype,ArrayRef<Value> vals)131 Value ValuesToRank1(PatternRewriter &rewriter, Location loc, Type dtype,
132 ArrayRef<Value> vals) {
133 int64_t length = vals.size();
134 auto type = RankedTensorType::get({length}, dtype);
135 auto axis = rewriter.create<ConstOp>(
136 loc, GetScalarOfType(rewriter.getIntegerType(64), 0));
137 return rewriter.create<ConcatV2Op>(loc, type, ValueRange(vals), axis);
138 }
139
140 // Lowers AddN op to a sequence of AddV2 ops to accumulate operands.
141 //
142 // Note that to improve the parallelism, AddN op uses tree-based reduction.
143 // For example, tf.AddN([0, 1, 2, 3, 4]) behaves as follows:
144 //
145 // 0 1 2 3 4
146 // | | | | |
147 // ------- ------- |
148 // | | |
149 // 5 6 |
150 // | | |
151 // ------------- |
152 // | |
153 // 7 |
154 // | |
155 // ----------------
156 // |
157 // 8
158 //
159 // Example:
160 //
161 // %result = "tf.AddN"(%0, %1, %2)
162 //
163 // is lowered to:
164 //
165 // %sum0 = "tf.AddV2"(%0, %1)
166 // %result = "tf.AddV2"(%sum0, %2)
167 //
168 // While
169 //
170 // %result = "tf.AddN"(%0, %1, %2, %3, %4)
171 //
172 // is lowered to:
173 //
174 // %sum0 = "tf.AddV2"(%0, %1)
175 // %sum1 = "tf.AddV2"(%2, %3)
176 // %sum2 = "tf.AddV2"(%sum0, %sum1)
177 // %result = "tf.AddV2"(%sum2, %4)
178 //
179 class LowerAddNOp : public RewritePattern {
180 public:
LowerAddNOp(MLIRContext * context)181 explicit LowerAddNOp(MLIRContext *context)
182 : RewritePattern(AddNOp::getOperationName(),
183 {AddV2Op::getOperationName()}, 1, context) {}
184
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const185 LogicalResult matchAndRewrite(Operation *op,
186 PatternRewriter &rewriter) const override {
187 auto addn_op = cast<AddNOp>(op);
188
189 // TODO(hinsu): Support variant with TensorList type. tf.AddV2 doesn't
190 // support variant type so variant types require special handling.
191 if (getElementTypeOrSelf(addn_op.getType()).isa<VariantType>())
192 return failure();
193 llvm::SmallVector<Value, 4> operands(addn_op.inputs().begin(),
194 addn_op.inputs().end());
195
196 int64_t n = operands.size();
197 // Keep doing tree-based reduction when there are more than one operand.
198 while (n > 1) {
199 for (int64_t i = 0; i < n; i += 2) {
200 // Add two adjacent operands if applicable.
201 operands[i / 2] =
202 (i + 1 < n) ? rewriter.create<AddV2Op>(addn_op.getLoc(),
203 operands[i], operands[i + 1])
204 : operands[i];
205 }
206 n = (n + 1) / 2;
207 }
208
209 rewriter.replaceOp(addn_op, operands[0]);
210 return success();
211 }
212 };
213
214 // Lowers DynamicStitch op with constant indices and with static input and
215 // output shapes using Reshape, UnPack and Pack op.
216 //
217 // %indices0 = "tf.Const"() {value = dense<4> : tensor<i32>}
218 // %indices1 = "tf.Const"() {value = dense<[[3, 2], [1, 0]]> :
219 // tensor<2x2xi32>} %0 = "tf.DynamicStitch"(%indices0, %indices1, %arg0,
220 // %arg1)
221 // : (tensor<i32>, tensor<2x2xi32>, tensor<2xf32>, tensor<2x2x2xf32>)
222 // -> tensor<5x2xf32>
223 //
224 // is lowered to
225 //
226 // %shape = "tf.Const"() {value = dense<[-1, 2]> : tensor<2xi64>}
227 // %inp0 = "tf.Reshape"(%arg0, %shape)
228 // : (tensor<2xf32>, tensor<2xi64>) -> tensor<1x2xf32>
229 // %inp1 = "tf.Reshape"(%arg1, %shape)
230 // : (tensor<2x2x2xf32>, tensor<2xi64>) -> tensor<4x2xf32>
231 // %items0 = "tf.Unpack"(%[[INP0]]) {axis = 0 : i64}
232 // : (tensor<1x2xf32>) -> tensor<2xf32>
233 // %items1:4 = "tf.Unpack"(%[[INP1]]) {axis = 0 : i64}
234 // : (tensor<4x2xf32>) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>,
235 // tensor<2xf32>)
236 // %axis = "tf.Const"() {value = dense<0> : tensor<i64>}
237 // %0 = "tf.Pack"(items1#3, items1#2, items1#1, items1#0, %items0, %axis)
238 // : (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>,
239 // tensor<2xf32>, tensor<i64>) -> tensor<5x2xf32>
240 //
241 template <typename OpT>
242 class LowerDynamicStitchOp : public RewritePattern {
243 public:
LowerDynamicStitchOp(MLIRContext * context)244 explicit LowerDynamicStitchOp(MLIRContext *context)
245 : RewritePattern(
246 OpT::getOperationName(),
247 {ConstOp::getOperationName(), ReshapeOp::getOperationName(),
248 UnpackOp::getOperationName(), PackOp::getOperationName()},
249 1, context) {}
250
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const251 LogicalResult matchAndRewrite(Operation *src_op,
252 PatternRewriter &rewriter) const override {
253 auto op = cast<OpT>(src_op);
254
255 // Static output type is used to compute intermediate values. Note that the
256 // output type doesn't have to be static but if input types and indices are
257 // constant, then the output type can be statically determined.
258 RankedTensorType out_ty =
259 op.getType().template dyn_cast<RankedTensorType>();
260 if (!out_ty || !out_ty.hasStaticShape()) return failure();
261
262 // Extract out all the constant indices' attributes and verify that data
263 // types are static.
264 SmallVector<DenseIntElementsAttr, 4> indices;
265 indices.reserve(op.N());
266 for (auto it : llvm::zip(op.indices(), op.data())) {
267 Value index = std::get<0>(it);
268 Value data = std::get<1>(it);
269
270 DenseIntElementsAttr index_attr;
271 if (!matchPattern(index, m_Constant(&index_attr))) return failure();
272 indices.push_back(index_attr);
273
274 RankedTensorType data_ty =
275 data.getType().template dyn_cast<RankedTensorType>();
276 if (!data_ty || !data_ty.hasStaticShape()) return failure();
277 }
278
279 // Compute type of each of the items and shape to use while reshaping inputs
280 // so that they can be unpacked to extract out individual items.
281 ArrayRef<int64_t> item_shape = out_ty.getShape().drop_front(1);
282 auto item_ty = RankedTensorType::get(item_shape, out_ty.getElementType());
283
284 SmallVector<int64_t, 4> packed_shape;
285 packed_shape.push_back(-1);
286 packed_shape.append(item_shape.begin(), item_shape.end());
287 Location loc = op.getLoc();
288 auto packed_shape_val = rewriter.create<ConstOp>(
289 loc, GetI64ElementsAttr(packed_shape, &rewriter));
290
291 // Prepare each of the output item by unpacking data and then putting it to
292 // the specified index.
293 SmallVector<Value, 8> values(out_ty.getDimSize(0));
294 for (auto it : llvm::zip(indices, op.data())) {
295 DenseIntElementsAttr index_attr = std::get<0>(it);
296 Value data = std::get<1>(it);
297
298 auto reshaped_data =
299 rewriter.create<ReshapeOp>(loc, data, packed_shape_val);
300 auto num_items = reshaped_data.getType()
301 .template cast<RankedTensorType>()
302 .getShape()[0];
303 auto items = rewriter.create<UnpackOp>(
304 loc, SmallVector<Type, 4>(num_items, item_ty), reshaped_data,
305 /*axis=*/0);
306 for (auto index_item : llvm::zip(index_attr, items.getResults())) {
307 int64_t output_index = std::get<0>(index_item).getSExtValue();
308 Value item = std::get<1>(index_item);
309 values[output_index] = item;
310 }
311 }
312
313 rewriter.replaceOpWithNewOp<PackOp>(op, op.getType(), values);
314 return success();
315 }
316 };
317
318 // This pass performs a manual conversion with FakeQuant, converting between
319 // floating point and quantized space. It is designed to reproduce TF's
320 // implementation, mirroring the previous XLA implementation.
321 //
322 // 1. Computing proper quantized bounds. This involves nudging the input bounds.
323 // 2. Converting the input bounds to quantized space, rounding values.
324 // 3. Convert back into floating point space.
325 class ConvertFakeQuantWithMinMaxVarsOp : public RewritePattern {
326 public:
ConvertFakeQuantWithMinMaxVarsOp(MLIRContext * context)327 explicit ConvertFakeQuantWithMinMaxVarsOp(MLIRContext *context)
328 : RewritePattern(
329 FakeQuantWithMinMaxVarsOp::getOperationName(),
330 {AddV2Op::getOperationName(), SubOp::getOperationName(),
331 ConstOp::getOperationName(), MulOp::getOperationName(),
332 FloorOp::getOperationName(), ClipByValueOp::getOperationName(),
333 DivOp::getOperationName(), RoundOp::getOperationName()},
334 1, context) {}
335
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const336 LogicalResult matchAndRewrite(Operation *src_op,
337 PatternRewriter &rewriter) const override {
338 auto op = cast<FakeQuantWithMinMaxVarsOp>(src_op);
339
340 auto input = op.inputs();
341 auto input_ty = input.getType().cast<ShapedType>();
342 auto element_ty = input_ty.getElementType();
343 auto scalar_ty = RankedTensorType::get({}, element_ty);
344
345 auto num_bits = op.num_bits();
346 auto narrow_range = op.narrow_range();
347 const double bits_min = narrow_range ? 1 : 0;
348 const double bits_max = (1 << num_bits) - 1;
349
350 auto float_min = op.min();
351 auto float_max = op.max();
352
353 auto float_diff = rewriter.create<SubOp>(op.getLoc(), float_max, float_min);
354
355 // Compute the range when quantized.
356 auto quant_min = rewriter.create<ConstOp>(
357 op.getLoc(), DenseElementsAttr::get(
358 scalar_ty, ConvertToAPFloat(bits_min, element_ty)));
359
360 auto quant_max = rewriter.create<ConstOp>(
361 op.getLoc(), DenseElementsAttr::get(
362 scalar_ty, ConvertToAPFloat(bits_max, element_ty)));
363
364 auto quant_diff = rewriter.create<ConstOp>(
365 op.getLoc(),
366 DenseElementsAttr::get(
367 scalar_ty, ConvertToAPFloat(bits_max - bits_min, element_ty)));
368
369 auto quant_to_float =
370 rewriter.create<DivOp>(op.getLoc(), float_diff, quant_diff);
371
372 auto float_to_quant =
373 rewriter.create<DivOp>(op.getLoc(), quant_diff, float_diff);
374
375 // During quantization, the quantized min/max values may not line up
376 // perfectly with the specified min/max. Nudge them into the right range.
377 auto min_scaled =
378 rewriter.create<DivOp>(op.getLoc(), float_min, quant_to_float);
379 auto min_scaled_sub =
380 rewriter.create<SubOp>(op.getLoc(), quant_min, min_scaled);
381
382 auto mid_rounded =
383 rewriter.create<RoundOp>(op.getLoc(), scalar_ty, min_scaled_sub);
384
385 auto nudged_zero_point_val = rewriter.create<ClipByValueOp>(
386 op.getLoc(), scalar_ty, mid_rounded, quant_min, quant_max);
387
388 auto quant_min_sub =
389 rewriter.create<SubOp>(op.getLoc(), quant_min, nudged_zero_point_val);
390 auto quant_max_sub =
391 rewriter.create<SubOp>(op.getLoc(), quant_max, nudged_zero_point_val);
392
393 auto nudged_float_min =
394 rewriter.create<MulOp>(op.getLoc(), quant_min_sub, quant_to_float);
395
396 auto nudged_float_max =
397 rewriter.create<MulOp>(op.getLoc(), quant_max_sub, quant_to_float);
398
399 // Now quantize the input value with the approximated min/max values.
400
401 // Move the input value into quantized space
402 Value quantized_input = rewriter.create<ClipByValueOp>(
403 op.getLoc(), input_ty, input, nudged_float_min, nudged_float_max);
404
405 quantized_input = rewriter.create<SubOp>(op.getLoc(), input_ty,
406 quantized_input, nudged_float_min);
407
408 quantized_input = rewriter.create<MulOp>(op.getLoc(), input_ty,
409 quantized_input, float_to_quant);
410
411 // Round the quantized input always to the positive direction.
412 auto half_val = rewriter.create<ConstOp>(
413 op.getLoc(),
414 DenseElementsAttr::get(scalar_ty, ConvertToAPFloat(0.5, element_ty)));
415
416 quantized_input = rewriter.create<AddV2Op>(op.getLoc(), input_ty,
417 quantized_input, half_val);
418
419 quantized_input = rewriter.create<FloorOp>(op.getLoc(), quantized_input);
420
421 // Convert back into floating point spae.
422 Value output = rewriter.create<MulOp>(op.getLoc(), input_ty,
423 quantized_input, quant_to_float);
424
425 output = rewriter.create<AddV2Op>(op.getLoc(), input_ty, output,
426 nudged_float_min);
427
428 rewriter.replaceOp(op, {output});
429 return success();
430 }
431 };
432
433 // Lowers InvertPermutation op to TensorScatterUpdate op.
434 //
435 // Example:
436 //
437 // %x = "tf.Const"() {value = dense<[3, 4, 0, 1, 2]> : tensor<5xi32>}
438 // "tf.InvertPermutation"(%x) : (tensor<5xi32>) -> tensor<5xi32>
439 //
440 // is lowered to
441 //
442 // %x = "tf.Const"() {value = dense<[3, 4, 0, 1, 2]> : tensor<5xi32>}
443 // %start = "tf.Const"() {value = dense<0> : tensor<i32>}
444 // %limit = "tf.Const"() {value = dense<5> : tensor<i32>}
445 // %delta = "tf.Const"() {value = dense<1> : tensor<i32>}
446 // %updates = "tf.Range"(%start, %limit, %delta) :
447 // (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<5xi32>
448 // %shape = "tf.Const"() {value = dense<[5, 1]> : tensor<2xi32>}
449 // %indices = "tf.Reshape"(%x, %shape) : (tensor<5xi32, tensor<2xi32) ->
450 // tensor<5x1xi32>
451 // "tf.TensorScatterUpdate"(%x, %indices, %updates) :
452 // (tensor<5xi32>, tensor<5x1xi32>, tensor<5xi32>) -> tensor<5xi32>
453 //
454 class LowerInvertPermutationOp : public RewritePattern {
455 public:
LowerInvertPermutationOp(MLIRContext * context)456 explicit LowerInvertPermutationOp(MLIRContext *context)
457 : RewritePattern(
458 InvertPermutationOp::getOperationName(),
459 {ConstOp::getOperationName(), RangeOp::getOperationName(),
460 ReshapeOp::getOperationName(),
461 TensorScatterUpdateOp::getOperationName()},
462 1, context) {}
463
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const464 LogicalResult matchAndRewrite(Operation *src_op,
465 PatternRewriter &rewriter) const override {
466 auto op = cast<InvertPermutationOp>(src_op);
467
468 Location loc = op.getLoc();
469 auto x_type = op.x().getType().dyn_cast<RankedTensorType>();
470 // x input must have static shape.
471 if (!x_type || !x_type.hasStaticShape()) {
472 return failure();
473 }
474 Type int_type = x_type.getElementType(); // Could be i32 or i64.
475
476 auto result_type = x_type;
477 auto start = rewriter.create<ConstOp>(loc, GetScalarOfType(int_type, 0));
478 Value limit = rewriter.create<ConstOp>(
479 loc, GetScalarOfType(int_type, x_type.getShape()[0]));
480 auto delta = rewriter.create<ConstOp>(loc, GetScalarOfType(int_type, 1));
481 // Construct a sequence of numbers [0, 1, ... len(x)-1].
482 auto updates =
483 rewriter.create<RangeOp>(loc, result_type, start, limit, delta);
484
485 auto shape_type = RankedTensorType::get({2}, rewriter.getIntegerType(32));
486 auto shape = rewriter.create<ConstOp>(
487 loc, DenseElementsAttr::get(
488 shape_type, {static_cast<int>(x_type.getDimSize(0)), 1}));
489 auto indices = rewriter.create<ReshapeOp>(loc, op.x(), shape);
490
491 rewriter.replaceOpWithNewOp<TensorScatterUpdateOp>(op, result_type, op.x(),
492 indices, updates);
493 return success();
494 }
495 };
496
497 // Approximates lgamma using Lanczos' approximation from
498 // "A Precision Approximation of the Gamma Function". SIAM Journal on Numerical
499 // Analysis series B. Vol. 1:
500 // lgamma(z + 1) = (log(2) + log(pi)) / 2 + (z + 1/2) * log(t(z)) - t(z) + A(z)
501 // t(z) = z + kLanczosGamma + 1/2
502 // A(z) = kBaseLanczosCoeff
503 // + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k))
504 //
505 // Coefficients for the Lanczos approximation of the gamma function. The
506 // coefficients are uniquely determined by the choice of g and n
507 // (kLanczosGamma and kLanczosCoefficients.size() + 1). The coefficients below
508 // correspond to [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were
509 // evaluated and [7, 9] seemed to be the least sensitive to the quality of the
510 // log function. In particular, [5, 7] is the only choice where -1.5e-5 <=
511 // lgamma(2) <= 1.5e-5 for a particularly inaccurate log function.
512 static constexpr double kLanczosGamma = 7; // aka g
513 static constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478;
514 static constexpr std::array<double, 8> kLanczosCoefficients = {
515 676.520368121885098567009190444019, -1259.13921672240287047156078755283,
516 771.3234287776530788486528258894, -176.61502916214059906584551354,
517 12.507343278686904814458936853, -0.13857109526572011689554707,
518 9.984369578019570859563e-6, 1.50563273514931155834e-7};
519
520 class LowerLgammaOp : public RewritePattern {
521 public:
LowerLgammaOp(MLIRContext * context)522 explicit LowerLgammaOp(MLIRContext *context)
523 : RewritePattern(LgammaOp::getOperationName(),
524 {
525 CastOp::getOperationName(),
526 ConstOp::getOperationName(),
527 NegOp::getOperationName(),
528 SubOp::getOperationName(),
529 SelectV2Op::getOperationName(),
530 LessOp::getOperationName(),
531 AddV2Op::getOperationName(),
532 DivOp::getOperationName(),
533 SubOp::getOperationName(),
534 LogOp::getOperationName(),
535 Log1pOp::getOperationName(),
536 IsInfOp::getOperationName(),
537 MulOp::getOperationName(),
538 FloorOp::getOperationName(),
539 AbsOp::getOperationName(),
540 GreaterOp::getOperationName(),
541 SinOp::getOperationName(),
542 IsFiniteOp::getOperationName(),
543 },
544 1, context) {}
545
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const546 LogicalResult matchAndRewrite(Operation *src_op,
547 PatternRewriter &rewriter) const override {
548 auto op = cast<LgammaOp>(src_op);
549
550 Location loc = op.getLoc();
551 Value input = op.x();
552 TensorType original_tensor_type = op.x().getType().cast<TensorType>();
553
554 // The approximation is not precise enough for float16. Do the computation
555 // in float32 for that case.
556 TensorType tensor_type = original_tensor_type;
557 FloatType float_type = tensor_type.getElementType().cast<FloatType>();
558 bool needs_cast = float_type.getWidth() < 32;
559 if (needs_cast) {
560 MLIRContext *context = rewriter.getContext();
561 float_type = FloatType::getF32(context);
562 if (original_tensor_type.hasRank()) {
563 tensor_type =
564 RankedTensorType::get(original_tensor_type.getShape(), float_type);
565 } else {
566 tensor_type = UnrankedTensorType::get(float_type);
567 }
568 input = rewriter.create<CastOp>(loc, tensor_type, input);
569 }
570
571 // Helper lambda function for creating a ConstOp for a tensor filled with
572 // the given constant float value.
573 auto create_const_op = [&rewriter, loc, tensor_type,
574 float_type](double value) {
575 return rewriter.create<ConstOp>(
576 loc, DenseElementsAttr::get(tensor_type,
577 FloatAttr::get(float_type, value)));
578 };
579
580 Value one_half = create_const_op(0.5);
581 Value one = create_const_op(1.0);
582 Value infinity = create_const_op(std::numeric_limits<double>::infinity());
583 Value pi = create_const_op(M_PI);
584 Value log_pi = create_const_op(std::log(M_PI));
585 Value log_sqrt_two_pi = create_const_op((std::log(2) + std::log(M_PI)) / 2);
586 Value lanczos_gamma_plus_one_half = create_const_op(kLanczosGamma + 0.5);
587 Value log_lanczos_gamma_plus_one_half =
588 create_const_op(std::log(kLanczosGamma + 0.5));
589 Value base_lanczos_coeff = create_const_op(kBaseLanczosCoeff);
590
591 Value minus_input = rewriter.create<NegOp>(loc, input);
592 Value input_minus_one = rewriter.create<SubOp>(loc, input, one);
593
594 // If the input is less than 0.5 use Euler's reflection formula:
595 // gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
596 Value need_to_reflect = rewriter.create<LessOp>(loc, input, one_half);
597 Type tensor_bool_type = need_to_reflect.getType();
598 Value z = rewriter.create<SelectV2Op>(loc, need_to_reflect, minus_input,
599 input_minus_one);
600
601 Value x = base_lanczos_coeff;
602 for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
603 Value lanczos_coefficient = create_const_op(kLanczosCoefficients[i]);
604 Value index = create_const_op(static_cast<double>(i));
605 Value z_plus_index = rewriter.create<AddV2Op>(loc, z, index);
606 Value z_plus_index_plus_one =
607 rewriter.create<AddV2Op>(loc, z_plus_index, one);
608 Value incr = rewriter.create<DivOp>(loc, lanczos_coefficient,
609 z_plus_index_plus_one);
610 x = rewriter.create<AddV2Op>(loc, x, incr);
611 }
612
613 // To improve accuracy on platforms with less-precise log implementations,
614 // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on
615 // the device.
616 // log(t) = log(kLanczosGamma + 0.5 + z)
617 // = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5))
618 Value t = rewriter.create<AddV2Op>(loc, lanczos_gamma_plus_one_half, z);
619 Value z_div_lanczos_gamma_plus_one_half =
620 rewriter.create<DivOp>(loc, z, lanczos_gamma_plus_one_half);
621 Value log1p_z_div_lanczos_gamma_plus_one_half =
622 rewriter.create<Log1pOp>(loc, z_div_lanczos_gamma_plus_one_half);
623 Value log_t =
624 rewriter.create<AddV2Op>(loc, log_lanczos_gamma_plus_one_half,
625 log1p_z_div_lanczos_gamma_plus_one_half);
626
627 // Compute the final result (modulo reflection). t(z) may be large, and we
628 // need to be careful not to overflow to infinity in the first term of
629 //
630 // (z + 1/2) * log(t(z)) - t(z).
631 //
632 // Therefore we compute this as
633 //
634 // (z + 1/2 - t(z) / log(t(z))) * log(t(z)).
635 //
636 // log_y = log_sqrt_two_pi + (z + one_half - t / log_t) * log_t + Log(x);
637 Value t_div_log_t = rewriter.create<DivOp>(loc, t, log_t);
638 Value one_half_minus_t_div_log_t =
639 rewriter.create<SubOp>(loc, one_half, t_div_log_t);
640 Value z_plus_one_half_minus_t_div_log_t =
641 rewriter.create<AddV2Op>(loc, z, one_half_minus_t_div_log_t);
642 Value z_plus_one_half_minus_t_div_log_t_mul_log_t =
643 rewriter.create<MulOp>(loc, z_plus_one_half_minus_t_div_log_t, log_t);
644 Value log_x = rewriter.create<LogOp>(loc, x);
645 Value log_y_rhs = rewriter.create<AddV2Op>(
646 loc, z_plus_one_half_minus_t_div_log_t_mul_log_t, log_x);
647 Value log_y = rewriter.create<AddV2Op>(loc, log_sqrt_two_pi, log_y_rhs);
648
649 // Compute the reflected value, used when x < 0.5:
650 //
651 // lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))).
652 //
653 // (The abs is because lgamma is the log of the absolute value of the gamma
654 // function.)
655 //
656 // We have to be careful when computing the final term above. gamma(x) goes
657 // to +/-inf at every integer x < 0, and this is controlled by the
658 // sin(pi * x) term. The slope is large, so precision is particularly
659 // important.
660 //
661 // Because abs(sin(pi * x)) has period 1, we can equivalently use
662 // abs(sin(pi * frac(x))), where frac(x) is the fractional part of x. This
663 // is more numerically accurate: It doesn't overflow to inf like pi * x can,
664 // and if x is an integer, it evaluates to 0 exactly, which is significant
665 // because we then take the log of this value, and log(0) is inf.
666 //
667 // We don't have a frac(x) primitive in XLA and computing it is tricky, but
668 // because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for
669 // our purposes to use abs(frac(x)) = abs(x) - floor(abs(x)).
670 //
671 // Furthermore, pi * abs(frac(x)) loses precision when abs(frac(x)) is close
672 // to 1. To remedy this, we can use the fact that sin(pi * x) in the domain
673 // [0, 1] is symmetric across the line Y=0.5.
674 Value abs_input = rewriter.create<AbsOp>(loc, input);
675 Value abs_input_floor = rewriter.create<FloorOp>(loc, abs_input);
676 Value abs_frac_input =
677 rewriter.create<SubOp>(loc, abs_input, abs_input_floor);
678
679 // Convert values of abs_frac_input > 0.5 to (1 - frac_input) to improve
680 // precision of pi * abs_frac_input for values of abs_frac_input close to 1.
681 Value one_minus_abs_frac_input =
682 rewriter.create<SubOp>(loc, one, abs_frac_input);
683 Value abs_frac_input_gt_one_half =
684 rewriter.create<GreaterOp>(loc, abs_frac_input, one_half);
685 Value reduced_frac_input =
686 rewriter.create<SelectV2Op>(loc, abs_frac_input_gt_one_half,
687 one_minus_abs_frac_input, abs_frac_input);
688 Value pi_mul_reduced_frac_input =
689 rewriter.create<MulOp>(loc, pi, reduced_frac_input);
690 Value sin_pi_mul_reduced_frac_input =
691 rewriter.create<SinOp>(loc, pi_mul_reduced_frac_input);
692 Value reflection_denom =
693 rewriter.create<LogOp>(loc, sin_pi_mul_reduced_frac_input);
694
695 // Avoid computing -inf - inf, which is nan. If reflection_denom is +/-inf,
696 // then it "wins" and the result is +/-inf.
697 Value is_finite =
698 rewriter.create<IsFiniteOp>(loc, tensor_bool_type, reflection_denom);
699 Value neg_reflection_denom = rewriter.create<NegOp>(loc, reflection_denom);
700 Value log_pi_minus_reflection_denom =
701 rewriter.create<SubOp>(loc, log_pi, reflection_denom);
702 Value reflection_if_finite =
703 rewriter.create<SubOp>(loc, log_pi_minus_reflection_denom, log_y);
704 Value reflection = rewriter.create<SelectV2Op>(
705 loc, is_finite, reflection_if_finite, neg_reflection_denom);
706
707 Value result =
708 rewriter.create<SelectV2Op>(loc, need_to_reflect, reflection, log_y);
709
710 // lgamma(+/-inf) = +inf.
711 Value is_inf = rewriter.create<IsInfOp>(loc, tensor_bool_type, input);
712 result = rewriter.create<SelectV2Op>(loc, is_inf, infinity, result);
713
714 if (needs_cast) {
715 result = rewriter.create<CastOp>(loc, original_tensor_type, result);
716 }
717
718 rewriter.replaceOp(op, result);
719 return success();
720 }
721 };
722
723 // Lowers Pack op to ConcatV2 op after changing shape of the inputs with
724 // ExpandDims op.
725 //
726 // Sample result with 2 inputs to pack:
727 //
728 // %axis = "tf.Const"() {value = dense<1> : tensor<i64>}
729 // %inp0 = "tf.ExpandDims"(%operand0, %axis): tensor<2xf32> -> tensor<2x1xf32>
730 // %inp1 = "tf.ExpandDims"(%operand1, %axis): tensor<2xf32> -> tensor<2x1xf32>
731 // %result = "tf.ConcatV2"(%operand0, %operand1, %axis) { N = 2 : i64 }:
732 //
733 class LowerPackOp : public RewritePattern {
734 public:
LowerPackOp(MLIRContext * context)735 explicit LowerPackOp(MLIRContext *context)
736 : RewritePattern(
737 PackOp::getOperationName(),
738 {ConstOp::getOperationName(), ConcatV2Op::getOperationName(),
739 ExpandDimsOp::getOperationName()},
740 1, context) {}
741
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const742 LogicalResult matchAndRewrite(Operation *src_op,
743 PatternRewriter &rewriter) const override {
744 auto op = cast<PackOp>(src_op);
745
746 Location loc = op.getLoc();
747 auto axis_value = rewriter.create<ConstOp>(
748 loc,
749 DenseElementsAttr::get(
750 RankedTensorType::get({}, rewriter.getIntegerType(64)), op.axis()));
751 int64_t axis = op.axis();
752
753 Type prev_input_ty, inferred_ty;
754 SmallVector<Value, 4> expanded_inputs;
755 expanded_inputs.reserve(op.N());
756 for (Value input : op.values()) {
757 // If input type is different than the previous input type, infer the
758 // output type. Otherwise, use the already inferred output type from the
759 // previous iteration.
760 Type input_ty = input.getType();
761 if (input_ty != prev_input_ty) {
762 inferred_ty = InferExpandDimsType(input_ty, axis, &rewriter);
763 prev_input_ty = input_ty;
764 }
765 expanded_inputs.push_back(
766 rewriter.create<ExpandDimsOp>(loc, inferred_ty, input, axis_value));
767 }
768
769 rewriter.replaceOpWithNewOp<ConcatV2Op>(op, op.getType(), expanded_inputs,
770 axis_value);
771 return success();
772 }
773 };
774
775 // Lowers SpaceToBatchND by reducing to reshape(transpose(reshape(pad(input)))).
776 //
777 // Before rewrite:
778 // output = SpaceToBatchND(input, block_shape, paddings)
779 // Let:
780 // [batch] + spatial_shape + remaining_shape = input.shape
781 // M = spatial_shape.rank
782 // After rewrite:
783 // padded = zero-pad input with paddings
784 // The spatial_shape component of input.shape pads with paddings[*, 0]
785 // before each dimension, and paddings[*, 1] after each dimension.
786 // reshaped = reshape padded to:
787 // [batch]
788 // + [padded.shape[1]/block_shape[0], block_shape[0], ...,
789 // padded.shape[M]/block_shape[M-1], block_shape[M-1]]
790 // + remaining_shape
791 // permuted = transpose reshaped to:
792 // block_shape
793 // + [batch]
794 // + [padded.shape[1]/block_shape[0], ..., padded.shape[M]/block_shape[M-1]]
795 // + remaining_shape
796 // result = reshape permuted to:
797 // [batch * product(block_shape)]
798 // + [padded.shape[1]/block_shape[0], ..., padded.shape[M]/block_shape[M-1]]
799 // + remaining_shape
800 class LowerSpaceToBatchNDOp : public RewritePattern {
801 public:
LowerSpaceToBatchNDOp(MLIRContext * context)802 explicit LowerSpaceToBatchNDOp(MLIRContext *context)
803 : RewritePattern(SpaceToBatchNDOp::getOperationName(),
804 {
805 CastOp::getOperationName(),
806 ConstOp::getOperationName(),
807 ConcatV2Op::getOperationName(),
808 AddV2Op::getOperationName(),
809 PadOp::getOperationName(),
810 SplitOp::getOperationName(),
811 UnpackOp::getOperationName(),
812 DivOp::getOperationName(),
813 MulOp::getOperationName(),
814 ReshapeOp::getOperationName(),
815 TransposeOp::getOperationName(),
816 },
817 1, context) {}
818
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const819 LogicalResult matchAndRewrite(Operation *src_op,
820 PatternRewriter &rewriter) const override {
821 auto op = cast<SpaceToBatchNDOp>(src_op);
822
823 Location loc = op.getLoc();
824 auto input_type = op.input().getType().cast<TensorType>();
825 if (!input_type.hasStaticShape()) {
826 return failure();
827 }
828 ArrayRef<int64_t> input_shape = input_type.getShape();
829 auto block_shape_type = op.block_shape().getType().cast<TensorType>();
830 if (!block_shape_type.hasStaticShape()) {
831 return failure();
832 }
833 auto paddings_type = op.paddings().getType().cast<ShapedType>();
834 if (!paddings_type.hasRank()) {
835 return failure();
836 }
837
838 int64_t input_rank = input_type.getRank();
839 int64_t block_rank = block_shape_type.getNumElements();
840 int64_t remaining_rank = input_rank - 1 - block_rank;
841 if (remaining_rank < 0) {
842 // TODO(b/157475606): Move this check to ::Verify
843 return failure();
844 }
845
846 auto block_shape_i64_type = RankedTensorType::get(
847 block_shape_type.getShape(), rewriter.getIntegerType(64));
848 auto block_shape_i64 =
849 rewriter.create<CastOp>(loc, block_shape_i64_type, op.block_shape());
850
851 auto paddings_i64_type = RankedTensorType::get(paddings_type.getShape(),
852 rewriter.getIntegerType(64));
853 auto paddings_i64 =
854 rewriter.create<CastOp>(loc, paddings_i64_type, op.paddings());
855
856 auto pad00 = rewriter.create<ConstOp>(
857 loc, DenseElementsAttr::get<int64_t>(
858 RankedTensorType::get({1, 2}, rewriter.getIntegerType(64)),
859 {0, 0}));
860 SmallVector<Value, 4> full_paddings_list{pad00, paddings_i64};
861 full_paddings_list.append(remaining_rank, pad00);
862 auto full_paddings_type =
863 RankedTensorType::get({input_rank, 2}, rewriter.getIntegerType(64));
864 auto zero_i64 = rewriter.create<ConstOp>(
865 loc, GetScalarOfType(rewriter.getIntegerType(64), 0));
866 // Extends paddings to all dimensions of input by adding 0s to non-block
867 // dimensions.
868 auto full_paddings = rewriter.create<ConcatV2Op>(
869 loc, full_paddings_type, full_paddings_list, zero_i64);
870
871 // Compute the result type here instead of using shape inference because the
872 // full_paddings won't be available as a constant for shape inference.
873 ElementsAttr block_shape;
874 ElementsAttr paddings;
875 llvm::SmallVector<int64_t, 4> block_shape_ints;
876 auto padded_shape = llvm::to_vector<4>(input_shape);
877 if (matchPattern(op.block_shape(), m_Constant(&block_shape)) &&
878 matchPattern(op.paddings(), m_Constant(&paddings))) {
879 for (uint64_t i = 0; i < block_rank; i++) {
880 int64_t paddings_sum =
881 paddings.getValue({i, 0}).cast<IntegerAttr>().getInt() +
882 paddings.getValue({i, 1}).cast<IntegerAttr>().getInt();
883 int64_t block_shape_i =
884 block_shape.getValue({i}).cast<IntegerAttr>().getInt();
885 padded_shape[i + 1] = (paddings_sum + input_shape[i + 1]);
886 block_shape_ints.push_back(block_shape_i);
887 }
888 } else {
889 for (int i = 0; i < block_rank; i++) {
890 padded_shape[i + 1] = ShapedType::kDynamicSize;
891 }
892 block_shape_ints.resize(block_shape_type.getNumElements(), -1);
893 }
894
895 auto padded_type =
896 RankedTensorType::get(padded_shape, rewriter.getF32Type());
897 // padded = pad(input, full_paddings)
898 auto padded =
899 rewriter.create<PadOp>(loc, padded_type, op.input(), full_paddings);
900
901 auto paddings_sum_type =
902 RankedTensorType::get({input_rank}, rewriter.getIntegerType(64));
903 // paddings_sum = paddings[*,0] + paddings[*,1]
904 auto paddings_split = rewriter.create<UnpackOp>(
905 loc, TypeRange({paddings_sum_type, paddings_sum_type}), full_paddings,
906 rewriter.getI64IntegerAttr(1));
907 auto paddings_sum = rewriter.create<AddV2Op>(
908 loc, paddings_split.getResult(0), paddings_split.getResult(1));
909
910 auto input_shape_tensor = rewriter.create<ConstOp>(
911 loc,
912 DenseElementsAttr::get(
913 RankedTensorType::get({input_rank}, rewriter.getIntegerType(64)),
914 input_shape));
915
916 // padded_shape_tensor is the shape of padded.
917 auto padded_shape_tensor =
918 rewriter.create<AddV2Op>(loc, paddings_sum, input_shape_tensor);
919
920 auto zero_i32 = rewriter.create<ConstOp>(
921 loc, GetScalarOfType(rewriter.getIntegerType(32), 0));
922 SmallVector<Type, 4> padded_shape_splits_types(
923 input_rank, RankedTensorType::get({1}, rewriter.getIntegerType(64)));
924 SmallVector<Value, 4> padded_shape_splits(
925 rewriter
926 .create<SplitOp>(loc, padded_shape_splits_types, zero_i32,
927 padded_shape_tensor)
928 .output());
929
930 SmallVector<Type, 4> block_shape_splits_types(
931 block_rank, RankedTensorType::get({1}, rewriter.getIntegerType(64)));
932 SmallVector<Value, 4> block_shape_splits(
933 rewriter
934 .create<SplitOp>(loc, block_shape_splits_types, zero_i32,
935 block_shape_i64)
936 .output());
937
938 SmallVector<int64_t, 4> outer_shape_ints;
939 SmallVector<Value, 4> outer_shape_vals;
940 for (int64_t i = 0; i < block_rank; ++i) {
941 // TODO(b/157475606): Insert tf.Assert that the following division has
942 // remainder 0.
943 outer_shape_vals.push_back(rewriter.create<DivOp>(
944 loc, padded_shape_splits[1 + i], block_shape_splits[i]));
945
946 auto padded_shape_i = padded_shape[1 + i];
947 auto block_shape_ints_i = block_shape_ints[i];
948
949 // Compute the outer_shape constant values to infer the reshape.
950 if (padded_shape_i == -1 || block_shape_ints_i == -1) {
951 outer_shape_ints.push_back(-1);
952 } else {
953 outer_shape_ints.push_back(padded_shape_i / block_shape_ints_i);
954 }
955 }
956
957 SmallVector<Value, 6> reshaped_shape_vals{padded_shape_splits[0]};
958 SmallVector<int64_t, 6> reshaped_shape_ints{padded_shape[0]};
959 for (int64_t i = 0; i < block_rank; ++i) {
960 reshaped_shape_vals.push_back(outer_shape_vals[i]);
961 reshaped_shape_vals.push_back(block_shape_splits[i]);
962
963 reshaped_shape_ints.push_back(outer_shape_ints[i]);
964 reshaped_shape_ints.push_back(block_shape_ints[i]);
965 }
966 for (int64_t i = 1 + block_rank; i < input_rank; ++i) {
967 reshaped_shape_vals.push_back(padded_shape_splits[i]);
968 reshaped_shape_ints.push_back(padded_shape[i]);
969 }
970 auto reshaped_shape = ValuesToRank1(
971 rewriter, loc, rewriter.getIntegerType(64), reshaped_shape_vals);
972
973 auto reshaped = rewriter.create<ReshapeOp>(
974 loc,
975 RankedTensorType::get(reshaped_shape_ints, input_type.getElementType()),
976 padded, reshaped_shape);
977
978 SmallVector<int64_t, 6> permutation_vals;
979 for (int64_t i = 0; i < block_rank; ++i) {
980 permutation_vals.push_back(2 + 2 * i);
981 }
982 permutation_vals.push_back(0);
983 for (int64_t i = 0; i < block_rank; ++i) {
984 permutation_vals.push_back(1 + 2 * i);
985 }
986 for (int64_t i = 1 + block_rank; i < input_rank; ++i) {
987 permutation_vals.push_back(block_rank + i);
988 }
989 auto permutation = rewriter.create<ConstOp>(
990 loc, GetI64ElementsAttr(permutation_vals, &rewriter));
991
992 auto permuted = rewriter.create<TransposeOp>(loc, reshaped, permutation);
993 auto output_batch = padded_shape_splits[0];
994 for (int64_t i = 0; i < block_rank; ++i) {
995 output_batch =
996 rewriter.create<MulOp>(loc, output_batch, block_shape_splits[i]);
997 }
998 SmallVector<Value, 4> output_shape_vals{output_batch};
999 for (int64_t i = 0; i < block_rank; ++i) {
1000 output_shape_vals.push_back(outer_shape_vals[i]);
1001 }
1002 for (int64_t i = 1 + block_rank; i < input_rank; ++i) {
1003 output_shape_vals.push_back(padded_shape_splits[i]);
1004 }
1005 auto output_shape = ValuesToRank1(
1006 rewriter, loc, rewriter.getIntegerType(64), output_shape_vals);
1007
1008 // Sometimes the result type is more specific than what the reshape builder
1009 // can infer.
1010 auto result_type = op.getResult().getType();
1011 rewriter.replaceOpWithNewOp<ReshapeOp>(op, result_type, permuted,
1012 output_shape);
1013
1014 return success();
1015 }
1016 };
1017
1018 class LowerBatchToSpaceND : public RewritePattern {
1019 public:
LowerBatchToSpaceND(MLIRContext * context)1020 explicit LowerBatchToSpaceND(MLIRContext *context)
1021 : RewritePattern(BatchToSpaceNDOp::getOperationName(),
1022 {
1023 ConstOp::getOperationName(),
1024 ReshapeOp::getOperationName(),
1025 SliceOp::getOperationName(),
1026 TransposeOp::getOperationName(),
1027 },
1028 1, context) {}
1029
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const1030 LogicalResult matchAndRewrite(Operation *src_op,
1031 PatternRewriter &rewriter) const override {
1032 auto op = cast<BatchToSpaceNDOp>(src_op);
1033 auto input = op.input();
1034 auto input_ty = input.getType().cast<ShapedType>();
1035 auto element_ty = input_ty.getElementType();
1036 if (!input_ty.hasStaticShape()) {
1037 return failure();
1038 }
1039
1040 const int input_rank = input_ty.getRank();
1041 auto input_shape = input_ty.getShape();
1042
1043 DenseIntElementsAttr block_shape;
1044 DenseIntElementsAttr crops;
1045 if (!matchPattern(op.block_shape(), m_Constant(&block_shape)) ||
1046 !matchPattern(op.crops(), m_Constant(&crops))) {
1047 return failure();
1048 }
1049
1050 auto block_shape_ty = block_shape.getType();
1051 if (!block_shape_ty.hasRank() || block_shape_ty.getRank() != 1) {
1052 return failure();
1053 }
1054
1055 const int block_rank = block_shape_ty.getShape().front();
1056 auto remainder_shape = input_shape.drop_front(1 + block_rank);
1057
1058 const int64_t batch_size = input_shape[0];
1059
1060 // Compute the product of the block_shape values.
1061 int64_t block_num_elems = 1;
1062
1063 for (auto val : block_shape.getIntValues()) {
1064 block_num_elems *= val.getSExtValue();
1065 }
1066
1067 if (block_num_elems <= 0) {
1068 op.emitOpError()
1069 << "The product of the block dimensions must be positive";
1070 return failure();
1071 }
1072
1073 // 1. Reshape `input` to `reshaped` of shape:
1074 // [block_shape[0], ..., block_shape[M-1],
1075 // batch / prod(block_shape),
1076 // input_shape[1], ..., input_shape[N-1]]
1077 std::vector<int64_t> reshaped_shape;
1078 for (auto val : block_shape) {
1079 reshaped_shape.push_back(val.getSExtValue());
1080 }
1081 reshaped_shape.resize(input_rank + block_rank);
1082
1083 reshaped_shape[block_rank] = batch_size / block_num_elems;
1084 std::copy(input_shape.begin() + 1, input_shape.end(),
1085 reshaped_shape.begin() + block_rank + 1);
1086
1087 auto reshaped = rewriter.create<TF::ReshapeOp>(
1088 op.getLoc(), RankedTensorType::get(reshaped_shape, element_ty), input,
1089 rewriter.create<ConstOp>(op.getLoc(),
1090 rewriter.getI64TensorAttr(reshaped_shape)));
1091
1092 // 2. Permute dimensions of `reshaped` to produce `permuted` of shape
1093 // [batch / prod(block_shape),
1094 //
1095 // input_shape[1], block_shape[0],
1096 // ...,
1097 // input_shape[M], block_shape[M-1],
1098 //
1099 // input_shape[M+1], ..., input_shape[N-1]]
1100 std::vector<int64_t> permutation(reshaped_shape.size());
1101 permutation[0] = block_rank;
1102 for (int i = 0; i < block_rank; ++i) {
1103 permutation[1 + 2 * i] = block_rank + 1 + i;
1104 permutation[1 + 2 * i + 1] = i;
1105 }
1106 std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(),
1107 1 + block_rank * 2);
1108
1109 std::vector<int64_t> transpose_shape(permutation.size());
1110 for (auto it : llvm::enumerate(permutation)) {
1111 transpose_shape[it.index()] = reshaped_shape[it.value()];
1112 }
1113
1114 auto permuted = rewriter.create<TF::TransposeOp>(
1115 op.getLoc(), RankedTensorType::get(transpose_shape, element_ty),
1116 reshaped,
1117 rewriter.create<ConstOp>(op.getLoc(),
1118 rewriter.getI64TensorAttr(permutation)));
1119
1120 // 3. Reshape `permuted` to produce `reshaped_permuted` of shape
1121 // [batch / prod(block_shape),
1122 //
1123 // input_shape[1] * block_shape[0],
1124 // ...,
1125 // input_shape[M] * block_shape[M-1],
1126 //
1127 // input_shape[M+1],
1128 // ...,
1129 // input_shape[N-1]]
1130 std::vector<int64_t> reshaped_permuted_shape(input_rank);
1131 auto block_shape_values = llvm::to_vector<4>(block_shape.getIntValues());
1132 reshaped_permuted_shape[0] = batch_size / block_num_elems;
1133 for (int i = 0; i < block_rank; ++i) {
1134 reshaped_permuted_shape[1 + i] =
1135 block_shape_values[i].getSExtValue() * input_shape[1 + i];
1136 }
1137 std::copy(remainder_shape.begin(), remainder_shape.end(),
1138 reshaped_permuted_shape.begin() + 1 + block_rank);
1139
1140 auto reshaped_permuted = rewriter.create<TF::ReshapeOp>(
1141 op.getLoc(), RankedTensorType::get(reshaped_permuted_shape, element_ty),
1142 permuted,
1143 rewriter.create<ConstOp>(
1144 op.getLoc(), rewriter.getI64TensorAttr(reshaped_permuted_shape)));
1145
1146 // 4. Crop the start and end of dimensions `[1, ..., M]` of
1147 // `reshaped_permuted` according to `crops` to produce the output of
1148 // shape:
1149 // [batch / prod(block_shape),
1150 //
1151 // input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1],
1152 // ...,
1153 // input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1],
1154 //
1155 // input_shape[M+1], ..., input_shape[N-1]]
1156 std::vector<int64_t> start_indices(input_rank, 0);
1157 std::vector<int64_t> slice_sizes = reshaped_permuted_shape;
1158 std::vector<int64_t> strides(input_rank, 1);
1159 auto crop_values = llvm::to_vector<4>(crops.getIntValues());
1160 for (int i = 0; i < block_rank; ++i) {
1161 int64_t crop_start = crop_values[i * 2].getSExtValue();
1162 int64_t crop_end = crop_values[i * 2 + 1].getSExtValue();
1163
1164 if (crop_start < 0 || crop_end < 0) {
1165 op.emitOpError() << "Crops must be non-negative";
1166 return failure();
1167 }
1168
1169 start_indices[i + 1] = crop_start;
1170 slice_sizes[i + 1] -= crop_start + crop_end;
1171
1172 if (slice_sizes[i + 1] < 0) {
1173 op.emitOpError() << "Cropped size must be non-negative: start: "
1174 << crop_start << " end: " << crop_end << " size "
1175 << reshaped_permuted_shape[1 + i];
1176 }
1177 }
1178
1179 rewriter.replaceOpWithNewOp<TF::SliceOp>(
1180 op, RankedTensorType::get(slice_sizes, element_ty), reshaped_permuted,
1181 rewriter.create<ConstOp>(op.getLoc(),
1182 rewriter.getI64TensorAttr(start_indices)),
1183 rewriter.create<ConstOp>(op.getLoc(),
1184 rewriter.getI64TensorAttr(slice_sizes)));
1185 return success();
1186 }
1187 };
1188
1189 // Lowers `SparseMatMulOp` to `MatMulOp`, ignoring the sparseness hints,
1190 // since we currently don't have an implementation that can use this
1191 // information. Adds appropriate casts where necessary to align element types
1192 // of operands and result for `MatMulOp`.
1193 class LowerSparseMatMulOp : public RewritePattern {
1194 public:
LowerSparseMatMulOp(MLIRContext * context)1195 explicit LowerSparseMatMulOp(MLIRContext *context)
1196 : RewritePattern(
1197 SparseMatMulOp::getOperationName(),
1198 {CastOp::getOperationName(), MatMulOp::getOperationName()}, 1,
1199 context) {}
1200
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const1201 LogicalResult matchAndRewrite(Operation *src_op,
1202 PatternRewriter &rewriter) const override {
1203 auto op = cast<SparseMatMulOp>(src_op);
1204
1205 // Result type must be f32 for applying the pattern (currently this is
1206 // required by the op anyway but this might change).
1207 if (!op.product().getType().cast<TensorType>().getElementType().isF32()) {
1208 return failure();
1209 }
1210 MLIRContext *context = rewriter.getContext();
1211 llvm::SmallVector<Value, 2> operands{op.a(), op.b()};
1212 for (Value &operand : operands) {
1213 TensorType tensor_type = operand.getType().cast<TensorType>();
1214 Type element_type = tensor_type.getElementType();
1215 if (element_type.isF32()) continue;
1216 // Element type can either be f32 or bf16 for `SparseMatMulOp` so it
1217 // must be bf16 here.
1218 assert(element_type.isBF16());
1219 Type tensor_type_f32;
1220 if (tensor_type.hasRank()) {
1221 tensor_type_f32 = RankedTensorType::get(tensor_type.getShape(),
1222 FloatType::getF32(context));
1223 } else {
1224 tensor_type_f32 = UnrankedTensorType::get(FloatType::getF32(context));
1225 }
1226 // Add cast to f32 to conform with element type of result.
1227 operand = rewriter.create<CastOp>(op.getLoc(), tensor_type_f32, operand);
1228 }
1229 Value result = rewriter.create<MatMulOp>(
1230 op.getLoc(), op.product().getType(), operands[0], operands[1],
1231 op.transpose_a(), op.transpose_b());
1232
1233 rewriter.replaceOp(op, {result});
1234 return success();
1235 }
1236 };
1237
1238 // Lowers _UnaryOpsComposition op as a series of original TensorFlow ops that
1239 // were fused together.
1240 class Lower_UnaryOpsComposition
1241 : public OpRewritePattern<_UnaryOpsCompositionOp> {
1242 public:
1243 using OpRewritePattern<_UnaryOpsCompositionOp>::OpRewritePattern;
1244
matchAndRewrite(_UnaryOpsCompositionOp op,PatternRewriter & rewriter) const1245 LogicalResult matchAndRewrite(_UnaryOpsCompositionOp op,
1246 PatternRewriter &rewriter) const override {
1247 Value result = op.x();
1248 for (StringRef op_name : op.op_names().getAsValueRange<StringAttr>()) {
1249 std::string full_name = "tf." + op_name.str();
1250 // All ops in the sequences have the same result type as the original
1251 // result type.
1252 OperationState state(op.getLoc(), full_name, /*operands=*/{result},
1253 /*types=*/{op.getType()}, /*attributes=*/{});
1254 Operation *op = rewriter.createOperation(state);
1255 result = op->getResult(0);
1256 }
1257 rewriter.replaceOp(op, {result});
1258 return success();
1259 }
1260 };
1261
1262 // Lowers ResizeNearestNeighbor to an indices computations with a gather along
1263 // the combined spatial dimensions. Generating the indices along the
1264 // width/height index could be used to gather along each of W and H dimension
1265 // of the input image array. To reduce to a single gather, these indices are
1266 // combined, so a single gather can be performed along the combined spatial
1267 // dimensions.
1268 //
1269 // Images must take the shape [b, h, w, c] and size is a rank-1 length-2 tensor
1270 // containing the height and width values for the output tensor. This lowering
1271 // should work with a dynamic images array.
1272 //
1273 // For example, a scaling with image shape [1, 3, 3, 1] to [2, 2] and unaligned
1274 // corners would generate a [0, 1] lookup along both the x and y direction.
1275 // Then when combined to form the 1-D spatial index the values would be
1276 // [0, 1, 3, 4] which would gather along the reshape image tensor of shape
1277 // [1, 9, 1], reshaped to the final [1, 3, 3, 1].
1278 class LowerResizeNearestNeighbor : public RewritePattern {
1279 public:
LowerResizeNearestNeighbor(MLIRContext * context)1280 explicit LowerResizeNearestNeighbor(MLIRContext *context)
1281 : RewritePattern(ResizeNearestNeighborOp::getOperationName(),
1282 {
1283 BroadcastToOp::getOperationName(),
1284 ConstOp::getOperationName(),
1285 DivOp::getOperationName(),
1286 PackOp::getOperationName(),
1287 RangeOp::getOperationName(),
1288 ReshapeOp::getOperationName(),
1289 ShapeOp::getOperationName(),
1290 SplitOp::getOperationName(),
1291 TransposeOp::getOperationName(),
1292 },
1293 1, context) {}
1294
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const1295 LogicalResult matchAndRewrite(Operation *src_op,
1296 PatternRewriter &rewriter) const override {
1297 auto op = cast<ResizeNearestNeighborOp>(src_op);
1298 auto loc = op.getLoc();
1299 auto result_ty = op.getType().cast<ShapedType>();
1300
1301 auto input = op.images();
1302 auto input_ty = input.getType().cast<ShapedType>();
1303 auto input_element_ty = input_ty.getElementType();
1304 auto out_size = op.size();
1305 auto out_size_ty = out_size.getType().cast<ShapedType>();
1306 auto out_size_element_ty = out_size_ty.getElementType();
1307
1308 // Input should be rank 4.
1309 if (!input_ty.hasRank() || input_ty.getRank() != 4) {
1310 return failure();
1311 }
1312
1313 // Check that out_size is rank-1, length-2. Otherwise the size is not legal.
1314 if (!out_size_ty.hasRank() || out_size_ty.getRank() != 1 ||
1315 out_size_ty.getShape()[0] != 2) {
1316 return failure();
1317 }
1318
1319 // Extract the output width / height dim size.
1320 int out_height_constant = -1;
1321 int out_width_constant = -1;
1322 DenseIntElementsAttr out_size_cst;
1323 if (matchPattern(out_size, m_Constant(&out_size_cst))) {
1324 llvm::SmallVector<int64_t, 2> cst_size;
1325 for (auto val : out_size_cst.getIntValues()) {
1326 cst_size.push_back(val.getSExtValue());
1327 }
1328
1329 out_height_constant = cst_size[0];
1330 out_width_constant = cst_size[1];
1331
1332 if (out_height_constant < 0 || out_width_constant < 0) return failure();
1333 }
1334
1335 int out_spatial_cst = out_height_constant < 0 || out_width_constant < 0
1336 ? -1
1337 : out_height_constant * out_width_constant;
1338
1339 // Input rank should be 4. Might be able to drop this requirement entirely
1340 // as its an input requirement.
1341 if (!input_ty.hasRank() || input_ty.getRank() != 4) {
1342 return failure();
1343 }
1344
1345 int batch_cst = input_ty.getShape()[0];
1346 int channels_cst = input_ty.getShape()[3];
1347
1348 int in_y_cst = input_ty.getShape()[1];
1349 int in_x_cst = input_ty.getShape()[2];
1350 int in_spatial_cst =
1351 in_y_cst < 0 || in_x_cst < 0 ? -1 : in_y_cst * in_x_cst;
1352
1353 // TODO(suderman): Add support for these optional parameters.
1354 if (op.align_corners() == true || op.half_pixel_centers() == true) {
1355 return failure();
1356 }
1357
1358 auto one =
1359 rewriter.create<ConstOp>(loc, GetScalarOfType(out_size_element_ty, 1));
1360
1361 // Extract the image shape.
1362 Value input_shape = rewriter.create<ShapeOp>(
1363 loc, RankedTensorType::get({4}, rewriter.getI64Type()), input);
1364 input_shape = rewriter.create<CastOp>(
1365 loc, RankedTensorType::get({4}, out_size_element_ty), input_shape);
1366
1367 auto scalar_dim_ty = RankedTensorType::get({}, out_size_element_ty);
1368 auto split_image_shape = rewriter.create<UnpackOp>(
1369 loc,
1370 TypeRange({scalar_dim_ty, scalar_dim_ty, scalar_dim_ty, scalar_dim_ty}),
1371 input_shape);
1372
1373 // Extract the separate components from the input shape.
1374 auto batch = split_image_shape.getResult(0);
1375 auto in_y = split_image_shape.getResult(1);
1376 auto in_x = split_image_shape.getResult(2);
1377 auto channels = split_image_shape.getResult(3);
1378
1379 auto in_count = rewriter.create<MulOp>(
1380 loc, RankedTensorType::get({}, out_size_element_ty), in_y, in_x);
1381
1382 // Unpack and separate the out width/height.
1383 auto split_out_size = rewriter.create<UnpackOp>(
1384 loc, TypeRange({scalar_dim_ty, scalar_dim_ty}), out_size);
1385
1386 auto out_y = split_out_size.getResult(0);
1387 auto out_x = split_out_size.getResult(1);
1388
1389 auto out_count = rewriter.create<MulOp>(
1390 loc, RankedTensorType::get({}, out_size_element_ty), out_y, out_x);
1391
1392 // Generate what the final output shape will look like.
1393 auto out_shape = rewriter.create<PackOp>(
1394 loc, RankedTensorType::get({4}, out_size_element_ty),
1395 ValueRange({batch, out_y, out_x, channels}));
1396
1397 // Compute the indices along the vertical dimension.
1398 auto in_y_f32 = rewriter.create<CastOp>(
1399 loc, RankedTensorType::get({}, rewriter.getF32Type()), in_y);
1400 auto out_w_f32 = rewriter.create<CastOp>(
1401 loc, RankedTensorType::get({}, rewriter.getF32Type()), out_y);
1402
1403 Value y_scale = rewriter.create<DivOp>(
1404 loc, RankedTensorType::get({}, rewriter.getF32Type()), in_y_f32,
1405 out_w_f32);
1406
1407 Value zero_f32 = rewriter.create<ConstOp>(
1408 loc, GetScalarOfType(rewriter.getF32Type(), 0.0));
1409 Value one_f32 = rewriter.create<ConstOp>(
1410 loc, GetScalarOfType(rewriter.getF32Type(), 1.0));
1411
1412 Value y_range = rewriter.create<RangeOp>(
1413 loc,
1414 RankedTensorType::get({out_height_constant}, rewriter.getF32Type()),
1415 zero_f32, out_w_f32, one_f32);
1416
1417 y_range = rewriter.create<MulOp>(
1418 loc,
1419 RankedTensorType::get({out_height_constant}, rewriter.getF32Type()),
1420 y_range, y_scale);
1421
1422 y_range = rewriter.create<CastOp>(
1423 loc, RankedTensorType::get({out_height_constant}, out_size_element_ty),
1424 y_range);
1425
1426 y_range = rewriter.create<ReshapeOp>(
1427 loc,
1428 RankedTensorType::get({out_height_constant, 1}, out_size_element_ty),
1429 y_range,
1430 rewriter.create<PackOp>(loc,
1431 RankedTensorType::get({2}, out_size_element_ty),
1432 ValueRange({out_y, one})));
1433
1434 Value y_indices = rewriter.create<MulOp>(
1435 loc,
1436 RankedTensorType::get({out_height_constant, 1}, out_size_element_ty),
1437 y_range, in_x);
1438
1439 // Compute the indices for the nearest neighbour lookup across the width
1440 // dim.
1441 auto in_x_f32 = rewriter.create<CastOp>(
1442 loc, RankedTensorType::get({}, rewriter.getF32Type()), in_x);
1443 auto out_h_f32 = rewriter.create<CastOp>(
1444 loc, RankedTensorType::get({}, rewriter.getF32Type()), out_x);
1445
1446 Value x_scale = rewriter.create<DivOp>(
1447 loc, RankedTensorType::get({}, rewriter.getF32Type()), in_x_f32,
1448 out_h_f32);
1449
1450 Value x_range = rewriter.create<RangeOp>(
1451 loc, RankedTensorType::get({out_width_constant}, rewriter.getF32Type()),
1452 zero_f32, out_h_f32, one_f32);
1453
1454 x_range = rewriter.create<MulOp>(
1455 loc, RankedTensorType::get({out_width_constant}, rewriter.getF32Type()),
1456 x_range, x_scale);
1457
1458 x_range = rewriter.create<CastOp>(
1459 loc, RankedTensorType::get({out_width_constant}, out_size_element_ty),
1460 x_range);
1461
1462 Value x_indices = rewriter.create<ReshapeOp>(
1463 loc,
1464 RankedTensorType::get({1, out_width_constant}, out_size_element_ty),
1465 x_range,
1466 rewriter.create<PackOp>(loc,
1467 RankedTensorType::get({2}, out_size_element_ty),
1468 ValueRange({one, out_x})));
1469
1470 // Generate the combined index array, reshape to be 1-D.
1471 Value indices = rewriter.create<AddV2Op>(
1472 loc,
1473 RankedTensorType::get({out_height_constant, out_width_constant},
1474 out_size_element_ty),
1475 y_indices, x_indices);
1476
1477 indices = rewriter.create<ReshapeOp>(
1478 loc, RankedTensorType::get({out_spatial_cst}, out_size_element_ty),
1479 indices,
1480 rewriter.create<ReshapeOp>(
1481 loc, RankedTensorType::get({1}, out_size_element_ty), out_count,
1482 rewriter.create<ConstOp>(loc, rewriter.getI64TensorAttr({1}))));
1483
1484 // Group the spatial indices and gather along that combined index.
1485 Value input_collapsed_spatial = rewriter.create<ReshapeOp>(
1486 loc,
1487 RankedTensorType::get({batch_cst, in_spatial_cst, channels_cst},
1488 input_element_ty),
1489 input,
1490 rewriter.create<PackOp>(loc,
1491 RankedTensorType::get({3}, out_size_element_ty),
1492 ValueRange({batch, in_count, channels})));
1493
1494 Value gathered_values = rewriter.create<GatherV2Op>(
1495 loc,
1496 RankedTensorType::get({batch_cst, out_spatial_cst, channels_cst},
1497 input_element_ty),
1498 input_collapsed_spatial, indices, /*axis=*/one);
1499
1500 gathered_values =
1501 rewriter.create<ReshapeOp>(loc, result_ty, gathered_values, out_shape);
1502
1503 rewriter.replaceOp(op, gathered_values);
1504 return success();
1505 }
1506 };
1507
1508 } // namespace
1509
PopulateLoweringTFPatterns(MLIRContext * context,OwningRewritePatternList * patterns)1510 void PopulateLoweringTFPatterns(MLIRContext *context,
1511 OwningRewritePatternList *patterns) {
1512 patterns->insert<LowerAddNOp, ConvertFakeQuantWithMinMaxVarsOp,
1513 LowerDynamicStitchOp<DynamicStitchOp>,
1514 LowerDynamicStitchOp<ParallelDynamicStitchOp>,
1515 LowerInvertPermutationOp, LowerLgammaOp, LowerPackOp,
1516 LowerBatchToSpaceND, LowerSpaceToBatchNDOp,
1517 LowerResizeNearestNeighbor, LowerSparseMatMulOp,
1518 Lower_UnaryOpsComposition>(context);
1519 populateWithGenerated(context, *patterns);
1520 }
1521
PopulateTFLoweringBeforeHLOPatterns(MLIRContext * context,OwningRewritePatternList * patterns)1522 void PopulateTFLoweringBeforeHLOPatterns(MLIRContext *context,
1523 OwningRewritePatternList *patterns) {
1524 // clang-format off
1525 patterns->insert<
1526 ConvertFakeQuantWithMinMaxVarsOp,
1527 LowerAddNOp,
1528 LowerBatchToSpaceND,
1529 LowerDynamicStitchOp<DynamicStitchOp>,
1530 LowerDynamicStitchOp<ParallelDynamicStitchOp>,
1531 LowerInvertPermutationOp,
1532 LowerPackOp,
1533 LowerResizeNearestNeighbor,
1534 LowerSpaceToBatchNDOp,
1535 LowerSparseMatMulOp,
1536 Lower_UnaryOpsComposition>(context);
1537 // clang-format on
1538
1539 // Populate the relevant generated patterns.
1540 // clang-format off
1541 patterns->insert<
1542 LowerBiasAddGradOp,
1543 LowerDivNoNanOp,
1544 LowerEmptyOp,
1545 LowerFakeQuantWithMinMaxArgs,
1546 LowerFillOp,
1547 LowerIsNanOp,
1548 LowerL2LossOp,
1549 LowerMulNoNanOp,
1550 LowerOnesLikeOp,
1551 LowerPadOp,
1552 LowerReciprocal,
1553 LowerRintOp,
1554 LowerRoundOpOnFloatTensor,
1555 LowerRoundOpOnIntTensor,
1556 LowerRsqrtGradOp,
1557 LowerScatterNdOp,
1558 LowerSizeOp,
1559 LowerSoftmaxCrossEntropyWithLogitsOp,
1560 LowerSparseSoftmaxCrossEntropyWithLogitsOp,
1561 LowerSqrtGradOp,
1562 LowerSquareOp,
1563 LowerSquaredDifferenceOpOnRealTensors,
1564 LowerSquaredDifferenceOpOneComplexTensors,
1565 LowerTanhGradOp,
1566 LowerXdivyOp,
1567 LowerXlog1pyOp,
1568 LowerXlogyOp,
1569 LowerZerosLikeOp>(context);
1570 // clang-format on
1571 }
1572
1573 } // namespace TF
1574 } // namespace mlir
1575