1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This transformation pass converts operations in TensorFlow dialect into
17 // operations that are legal in the TensorFlow Lite dialect. Operations that
18 // can be legalized to TensorFlow Lite dialect with simple replacements are part
19 // of this pass and other operations that may create extra ops should be part of
20 // the PrepareTF pass which should be run before this pass. That way any
21 // constant folding opportunities from the extra ops can be exploited by the
22 // constant folding support for the TensorFlow ops.
23
24 #include <climits>
25 #include <complex>
26 #include <cstdint>
27
28 #include "llvm/ADT/APInt.h"
29 #include "llvm/ADT/ArrayRef.h"
30 #include "llvm/ADT/StringSwitch.h"
31 #include "llvm/Support/Threading.h"
32 #include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
33 #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
34 #include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project
35 #include "mlir/IR/Attributes.h" // from @llvm-project
36 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
37 #include "mlir/IR/Diagnostics.h" // from @llvm-project
38 #include "mlir/IR/MLIRContext.h" // from @llvm-project
39 #include "mlir/IR/Operation.h" // from @llvm-project
40 #include "mlir/IR/PatternMatch.h" // from @llvm-project
41 #include "mlir/Pass/Pass.h" // from @llvm-project
42 #include "mlir/Support/LLVM.h" // from @llvm-project
43 #include "mlir/Support/LogicalResult.h" // from @llvm-project
44 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
45 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
46 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
47 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
48 #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
49 #include "tensorflow/compiler/mlir/lite/utils/constant_utils.h"
50 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
51 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
52 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
53 #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
54 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
55 #include "tensorflow/compiler/xla/status.h"
56 #include "tensorflow/compiler/xla/statusor.h"
57 #include "tensorflow/core/framework/tensor.pb.h"
58 #include "tensorflow/core/framework/tensor_shape.pb.h"
59 #include "tensorflow/core/framework/types.pb.h"
60 #include "tensorflow/core/lib/random/philox_random.h"
61 #include "tensorflow/core/lib/random/random_distributions.h"
62 #include "tensorflow/core/protobuf/error_codes.pb.h"
63
64 namespace mlir {
65 namespace TFL {
66
67 //===----------------------------------------------------------------------===//
68 // The actual LegalizeTF Pass.
69 namespace {
70
71 using xla::StatusOr;
72
73 constexpr char kUnidirectionalSequenceLstm[] = "tf.UnidirectionalSequenceLstm";
74 constexpr char kUnidirectionalSequenceRnn[] = "tf.UnidirectionalSequenceRnn";
75 constexpr char kTfLiteInputIndices[] = "_tflite_input_indices";
76
77 // Legalize operations in functions.
78 class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
getDependentDialects(DialectRegistry & registry) const79 void getDependentDialects(DialectRegistry& registry) const override {
80 registry.insert<quant::QuantizationDialect, TFL::TensorFlowLiteDialect>();
81 }
82
83 public:
84 LegalizeTF() = default;
LegalizeTF(const LegalizeTF &)85 LegalizeTF(const LegalizeTF&) {}
LegalizeTF(bool run_tfl_runtime_verification)86 explicit LegalizeTF(bool run_tfl_runtime_verification) {
87 run_tfl_runtime_verification_ = run_tfl_runtime_verification;
88 }
89
90 /// Performs the lowering to TFLite dialect.
91 void runOnFunction() override;
92
93 private:
94 Option<bool> run_tfl_runtime_verification_{
95 *this, "run-tfl-runtime-verification",
96 llvm::cl::desc("Allow tfl runtime verification."), llvm::cl::init(true)};
97 };
98
99 // Returns true if all tensor value in `values` has static shape and same shape.
HasSameStaticShapes(Operation * op)100 bool HasSameStaticShapes(Operation* op) {
101 auto values = op->getOperands();
102 int index = 0;
103 ArrayRef<int64_t> shape;
104 for (Value value : values) {
105 auto shaped_type = value.getType().dyn_cast<ShapedType>();
106 if (!shaped_type || !shaped_type.hasStaticShape()) {
107 return false;
108 }
109 if (index == 0) {
110 shape = shaped_type.getShape();
111 } else {
112 if (shape != shaped_type.getShape()) {
113 return false;
114 }
115 }
116 ++index;
117 }
118 return true;
119 }
120
121 // Util that casts 'val' to Int32 by adding a cast Op.
CreateCastToInt32(Value val,Location loc,PatternRewriter & rewriter)122 Value CreateCastToInt32(Value val, Location loc, PatternRewriter& rewriter) {
123 auto shape = val.getType().dyn_cast<RankedTensorType>().getShape();
124 IntegerType new_ele_type = rewriter.getIntegerType(32);
125 ShapedType new_type = RankedTensorType::get(shape, new_ele_type);
126 return rewriter.createOrFold<TF::CastOp>(loc, new_type, val,
127 rewriter.getBoolAttr(false));
128 }
129
130 #include "tensorflow/compiler/mlir/lite/transforms/generated_legalize_tf.inc"
131
132 #define DECL_CONVERT_OP(tf_op) \
133 struct ConvertTF##tf_op##Op : public RewritePattern { \
134 explicit ConvertTF##tf_op##Op(MLIRContext* context) \
135 : RewritePattern(TF::tf_op##Op::getOperationName(), 1, context) {} \
136 LogicalResult matchAndRewrite(Operation* op, \
137 PatternRewriter& rewriter) const override; \
138 }
139
140 // TODO(antiagainst): Define this pattern in a table-driven manner once variadic
141 // operands are properly supported in declarative rewrite rule specification.
142
143 DECL_CONVERT_OP(Assert);
144 DECL_CONVERT_OP(ConcatV2);
145 DECL_CONVERT_OP(MatMul);
146 DECL_CONVERT_OP(MatrixDiagV2);
147 DECL_CONVERT_OP(MatrixDiagV3);
148 DECL_CONVERT_OP(Pack);
149 DECL_CONVERT_OP(Split);
150 DECL_CONVERT_OP(SplitV);
151 DECL_CONVERT_OP(Unpack);
152 DECL_CONVERT_OP(RandomUniform);
153 DECL_CONVERT_OP(Conv3D);
154
155 #undef DECL_CONVERT_OP
156
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const157 LogicalResult ConvertTFRandomUniformOp::matchAndRewrite(
158 Operation* op, PatternRewriter& rewriter) const {
159 auto random_uniform_op = cast<TF::RandomUniformOp>(op);
160 if (random_uniform_op.seed() == 0 && random_uniform_op.seed2() == 0) {
161 return failure();
162 }
163 if (!random_uniform_op.dtype().isF32()) {
164 return failure();
165 }
166 typedef tensorflow::random::UniformDistribution<
167 tensorflow::random::PhiloxRandom, float>
168 Distribution;
169
170 tensorflow::random::PhiloxRandom generator(random_uniform_op.seed(),
171 random_uniform_op.seed2());
172 Distribution dist;
173 size_t num_elements = 0;
174 if (auto output_type =
175 random_uniform_op.output().getType().dyn_cast_or_null<ShapedType>()) {
176 if (auto ranked_output = output_type.dyn_cast_or_null<RankedTensorType>()) {
177 if (!ranked_output.hasRank() || ranked_output.getNumDynamicDims() != 0) {
178 return failure();
179 }
180 num_elements = output_type.getNumElements();
181 size_t offset = 0;
182 size_t num_samples = Distribution::kResultElementCount;
183 llvm::SmallVector<float, 32> data;
184 data.resize(num_elements);
185 while (offset < num_elements) {
186 const typename Distribution::ResultType samples = dist(&generator);
187 std::copy(&samples[0],
188 &samples[0] + std::min(num_samples, data.size() - offset),
189 &data[0] + offset);
190 offset += num_samples;
191 }
192 auto output_data = DenseFPElementsAttr::get(output_type, data);
193 rewriter.replaceOpWithNewOp<ConstantOp>(op, output_type, output_data);
194 return success();
195 }
196 }
197 return failure();
198 }
199
200 // Converts any IntegerAttr to an IntegerAttr of an i32 type.
201 // The value won't change in the new attribute, but if the value is out of
202 // the bound of i32, the function returns a failure.
ConvertToI32Attr(IntegerAttr attr,IntegerAttr * attr_i32)203 LogicalResult ConvertToI32Attr(IntegerAttr attr, IntegerAttr* attr_i32) {
204 if (attr.getType().isInteger(/*width=*/32)) {
205 *attr_i32 = attr;
206 return success();
207 }
208
209 int64_t value = attr.getInt();
210 if (value > std::numeric_limits<int>::max() ||
211 value < std::numeric_limits<int>::min()) {
212 return failure();
213 }
214
215 *attr_i32 = IntegerAttr::get(
216 IntegerType::get(attr.getContext(), /*width=*/32), value);
217 return success();
218 }
219
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const220 LogicalResult ConvertTFConcatV2Op::matchAndRewrite(
221 Operation* op, PatternRewriter& rewriter) const {
222 auto tf_concat_op = cast<TF::ConcatV2Op>(op);
223
224 auto values = tf_concat_op.values();
225 auto output_type = tf_concat_op.output().getType();
226 // Extract axis attribute from constant axis tensor
227 ElementsAttr axis;
228 if (!matchPattern(tf_concat_op.axis(), m_Constant(&axis))) return failure();
229 IntegerAttr axis_int = ExtractSingleElementAsInteger(axis);
230
231 // "axis" operand could be a i64 tensor. Resolve it here.
232 IntegerAttr axis_i32;
233 if (failed(ConvertToI32Attr(axis_int, &axis_i32))) return failure();
234
235 StringAttr fused_activation_function =
236 StringAttr::get(rewriter.getContext(), "NONE");
237 rewriter.replaceOpWithNewOp<ConcatenationOp>(
238 op, output_type, values, axis_i32, fused_activation_function);
239 return success();
240 }
241
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const242 LogicalResult ConvertTFMatMulOp::matchAndRewrite(
243 Operation* op, PatternRewriter& rewriter) const {
244 auto tf_matmul_op = cast<TF::MatMulOp>(op);
245 auto lhs = op->getOperand(0);
246 auto rhs = op->getOperand(1);
247 auto transpose = [&](Value input) -> std::pair<LogicalResult, Value> {
248 RankedTensorType type =
249 input.getType().dyn_cast_or_null<RankedTensorType>();
250 if (!type || type.getRank() != 2) return {failure(), nullptr};
251
252 auto permute_attr = DenseIntElementsAttr::get(
253 RankedTensorType::get({2}, rewriter.getI32Type()), {1, 0});
254 auto permute = rewriter.create<ConstantOp>(
255 op->getLoc(), permute_attr.getType(), permute_attr);
256 llvm::SmallVector<int64_t, 2> new_shape{type.getShape()[1],
257 type.getShape()[0]};
258 auto output = rewriter.create<TFL::TransposeOp>(
259 op->getLoc(), RankedTensorType::get(new_shape, type.getElementType()),
260 input, permute);
261 return {success(), output};
262 };
263
264 // TODO(jpienaar): Remove once handled via dailect conversion.
265 if (tf_matmul_op.transpose_a()) {
266 LogicalResult result = success();
267 std::tie(result, lhs) = transpose(lhs);
268 if (failed(result)) return failure();
269 }
270 if (!tf_matmul_op.transpose_b()) {
271 LogicalResult result = success();
272 std::tie(result, rhs) = transpose(rhs);
273 if (failed(result)) return failure();
274 }
275
276 Type output_type = tf_matmul_op.getResult().getType();
277 auto no_input = rewriter.create<ConstantOp>(
278 op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr());
279 auto fc_op = rewriter.create<FullyConnectedOp>(
280 op->getLoc(), ArrayRef<Type>{output_type}, lhs, rhs, no_input,
281 rewriter.getStringAttr("NONE"), rewriter.getStringAttr("DEFAULT"),
282 rewriter.getBoolAttr(false));
283 rewriter.replaceOp(op, {fc_op.getResult(0)});
284 return success();
285 }
286
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const287 LogicalResult ConvertTFPackOp::matchAndRewrite(
288 Operation* op, PatternRewriter& rewriter) const {
289 auto tf_pack_op = cast<TF::PackOp>(op);
290
291 SmallVector<Value, 4> values(tf_pack_op.values());
292 auto output_type = tf_pack_op.output().getType();
293 auto values_count = rewriter.getI32IntegerAttr(tf_pack_op.N());
294 // Axis can be negative.
295 auto axis = rewriter.getI32IntegerAttr(tf_pack_op.axis());
296
297 rewriter.replaceOpWithNewOp<PackOp>(op, output_type, values, values_count,
298 axis);
299 return success();
300 }
301
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const302 LogicalResult ConvertTFSplitOp::matchAndRewrite(
303 Operation* op, PatternRewriter& rewriter) const {
304 auto tf_split_op = cast<TF::SplitOp>(op);
305
306 // Number of splits cannot be negative.
307 auto num_split = rewriter.getI32IntegerAttr(tf_split_op.num_split());
308
309 rewriter.replaceOpWithNewOp<TFL::SplitOp>(op, tf_split_op.output().getTypes(),
310 tf_split_op.split_dim(),
311 tf_split_op.value(), num_split);
312 return success();
313 }
314
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const315 LogicalResult ConvertTFSplitVOp::matchAndRewrite(
316 Operation* op, PatternRewriter& rewriter) const {
317 auto tf_splitv_op = cast<TF::SplitVOp>(op);
318
319 // Number of splits cannot be negative.
320 auto num_split = rewriter.getI32IntegerAttr(tf_splitv_op.num_split());
321
322 rewriter.replaceOpWithNewOp<TFL::SplitVOp>(
323 op, tf_splitv_op.output().getTypes(), tf_splitv_op.value(),
324 tf_splitv_op.size_splits(), tf_splitv_op.split_dim(), num_split);
325 return success();
326 }
327
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const328 LogicalResult ConvertTFUnpackOp::matchAndRewrite(
329 Operation* op, PatternRewriter& rewriter) const {
330 auto tf_unpack_op = cast<TF::UnpackOp>(op);
331
332 auto input = tf_unpack_op.value();
333 auto num = rewriter.getI32IntegerAttr(tf_unpack_op.num());
334 // Axis can be negative.
335 auto axis = rewriter.getI32IntegerAttr(tf_unpack_op.axis());
336
337 rewriter.replaceOpWithNewOp<UnpackOp>(op, tf_unpack_op.output().getTypes(),
338 input, num, axis);
339 return success();
340 }
341
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const342 LogicalResult ConvertTFConv3DOp::matchAndRewrite(
343 Operation* op, PatternRewriter& rewriter) const {
344 if (!TFDataFormatIsNDHWC(op)) return failure();
345
346 auto tf_op = cast<TF::Conv3DOp>(op);
347
348 IntegerAttr stride_depth, stride_height, stride_width;
349 if (!TFIntListIs1XYZ1(op, "strides", &stride_depth, &stride_height,
350 &stride_width))
351 return failure();
352
353 IntegerAttr dilation_depth_factor, dilation_height_factor,
354 dilation_width_factor;
355 if (!TFIntListIs1XYZ1(op, "dilations", &dilation_depth_factor,
356 &dilation_height_factor, &dilation_width_factor)) {
357 // If the 'dilations' attribute is missing, we use the default value (1)
358 // for all dilation depth, height and width factor.
359 dilation_depth_factor = rewriter.getI32IntegerAttr(1);
360 dilation_height_factor = rewriter.getI32IntegerAttr(1);
361 dilation_width_factor = rewriter.getI32IntegerAttr(1);
362 }
363
364 StringAttr padding;
365 if (!TFPaddingIsSameOrValid(op, &padding)) return failure();
366
367 // TensorFlow Conv3D has no bias, optimization patterns will fuse Conv3D
368 // with other ops can fill the bias.
369 Value none = rewriter.create<mlir::ConstantOp>(
370 op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr());
371
372 rewriter.replaceOpWithNewOp<TFL::Conv3DOp>(
373 op, tf_op.getType(), tf_op.input(), tf_op.filter(),
374 /*bias=*/none, dilation_depth_factor, dilation_height_factor,
375 dilation_width_factor,
376 /*fused_activation_function=*/rewriter.getStringAttr("NONE"), padding,
377 stride_depth, stride_height, stride_width);
378
379 return success();
380 }
381
382 // MatrixDiagV3 is MatrixDiagV2 with an alignment attribute. This attribute
383 // only has effects when processing multiple diagonals. Since TFLite converts
384 // MatrixDiagV{2,3} to MatrixDiag, which only takes single-diagonal inputs, we
385 // can safely ignore this V3 attribute.
386 // We can't pass `rewriter` by reference because clang-tidy will want it to be
387 // constant (`const PatternRewriter& rewriter`). If we do that, we won't be able
388 // to call `rewriter::replaceOpWihNewOp`, which is not a const member function.
389 template <typename MatrixDiagV2OrV3Op>
ConvertTFMatrixDiagV2orV3(Operation * op,PatternRewriter * rewriter)390 bool ConvertTFMatrixDiagV2orV3(Operation* op, PatternRewriter* rewriter) {
391 auto tf_matrix_diag_v2_or_v3_op = cast<MatrixDiagV2OrV3Op>(op);
392
393 if (tf_matrix_diag_v2_or_v3_op.getNumOperands() != 5) return false;
394
395 auto input = tf_matrix_diag_v2_or_v3_op.diagonal();
396 auto output_type = tf_matrix_diag_v2_or_v3_op.output().getType();
397
398 // Extract k constant tensor and check value = 0.
399 ElementsAttr k;
400 if (!matchPattern(tf_matrix_diag_v2_or_v3_op.k(), m_Constant(&k)))
401 return false;
402 if (ExtractSingleElementAsInteger(k).getInt() != 0) return false;
403
404 // Extract num_rows constant tensor and check value = -1.
405 ElementsAttr num_rows;
406 if (!matchPattern(tf_matrix_diag_v2_or_v3_op.num_rows(),
407 m_Constant(&num_rows)))
408 return false;
409 if (ExtractSingleElementAsInteger(num_rows).getInt() != -1) return false;
410
411 // Extract num_cols constant tensor and check value = -1.
412 ElementsAttr num_cols;
413 if (!matchPattern(tf_matrix_diag_v2_or_v3_op.num_cols(),
414 m_Constant(&num_cols)))
415 return false;
416 if (ExtractSingleElementAsInteger(num_cols).getInt() != -1) return false;
417
418 // Verify padding_value is an integer tensor with all 0s.
419 ElementsAttr padding_value;
420 if (!matchPattern(tf_matrix_diag_v2_or_v3_op.padding_value(),
421 m_Constant(&padding_value)))
422 return false;
423 for (const auto& value : padding_value.getValues<APInt>()) {
424 if (value != 0) return false;
425 }
426
427 rewriter->replaceOpWithNewOp<MatrixDiagOp>(op, output_type, input);
428 return true;
429 }
430
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const431 LogicalResult ConvertTFMatrixDiagV2Op::matchAndRewrite(
432 Operation* op, PatternRewriter& rewriter) const {
433 if (ConvertTFMatrixDiagV2orV3<TF::MatrixDiagV2Op>(op, &rewriter))
434 return success();
435 return failure();
436 }
437
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const438 LogicalResult ConvertTFMatrixDiagV3Op::matchAndRewrite(
439 Operation* op, PatternRewriter& rewriter) const {
440 if (ConvertTFMatrixDiagV2orV3<TF::MatrixDiagV3Op>(op, &rewriter))
441 return success();
442 return failure();
443 }
444
445 // TF Lite doesn't support Assert, we just drop the assert from the graph.
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const446 LogicalResult ConvertTFAssertOp::matchAndRewrite(
447 Operation* op, PatternRewriter& rewriter) const {
448 rewriter.eraseOp(op);
449 return success();
450 }
451
452 // Legalize unidirectional sequence lstm.
453 struct LegalizeUnidirectionalSequenceLstm : public RewritePattern {
LegalizeUnidirectionalSequenceLstmmlir::TFL::__anon721dcf970111::LegalizeUnidirectionalSequenceLstm454 explicit LegalizeUnidirectionalSequenceLstm(MLIRContext* context)
455 : RewritePattern(kUnidirectionalSequenceLstm, 1, context) {}
456
matchAndRewritemlir::TFL::__anon721dcf970111::LegalizeUnidirectionalSequenceLstm457 LogicalResult matchAndRewrite(Operation* op,
458 PatternRewriter& rewriter) const override {
459 auto tflite_indices_attr =
460 op->getAttrOfType<ArrayAttr>(kTfLiteInputIndices);
461 if (!tflite_indices_attr) return failure();
462
463 SmallVector<int64_t, 20> tflite_indices;
464 for (auto index_attr : tflite_indices_attr.getValue()) {
465 IntegerAttr index = index_attr.cast<IntegerAttr>();
466 tflite_indices.push_back(index.getInt());
467 }
468
469 // Optional input placeholder.
470 Value none = rewriter.create<mlir::ConstantOp>(
471 op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr());
472
473 // Populate inputs.
474 // UnidirectionalSequenceLstm is expected to have 24 inputs.
475 SmallVector<Value, 24> inputs;
476 int count = 0;
477 int total_ophint_converted_inputs = tflite_indices.size();
478 for (int i = 0; i < 24; ++i) {
479 if (count < total_ophint_converted_inputs && tflite_indices[count] == i) {
480 // specified input.
481 inputs.push_back(op->getOperand(i));
482 count++;
483 } else {
484 // Non specified input.
485 inputs.push_back(none);
486 }
487 }
488
489 // Populate outputs.
490 // UnidirectionalSequenceLstm should only have 1 output, and that is the
491 // original ophint converted node's 3rd output.
492 SmallVector<Type, 4> result_types;
493 result_types.push_back(op->getOpResult(2).getType());
494
495 // Populate attributes.
496 SmallVector<NamedAttribute, 4> attributes;
497 // Activation will always be tanh.
498 attributes.push_back(rewriter.getNamedAttr("fused_activation_function",
499 rewriter.getStringAttr("TANH")));
500 // cell_clip.
501 attributes.push_back(
502 rewriter.getNamedAttr("cell_clip", rewriter.getF32FloatAttr(0.0)));
503 // proj_clip.
504 attributes.push_back(
505 rewriter.getNamedAttr("proj_clip", rewriter.getF32FloatAttr(0.0)));
506 // will always be time_majored.
507 attributes.push_back(
508 rewriter.getNamedAttr("time_major", rewriter.getBoolAttr(true)));
509
510 Value lstm_result = rewriter.create<TFL::UnidirectionalSequenceLSTMOp>(
511 op->getLoc(), result_types, inputs, attributes);
512
513 // Rewire the output.
514 rewriter.replaceOp(op, {nullptr, nullptr, lstm_result});
515 return success();
516 }
517 };
518
519 // Legalize unidirectional seqeucen rnn.
520 struct LegalizeUnidirectionalSequenceRnn : public RewritePattern {
LegalizeUnidirectionalSequenceRnnmlir::TFL::__anon721dcf970111::LegalizeUnidirectionalSequenceRnn521 explicit LegalizeUnidirectionalSequenceRnn(MLIRContext* context)
522 : RewritePattern(kUnidirectionalSequenceRnn, 1, context) {}
523
matchAndRewritemlir::TFL::__anon721dcf970111::LegalizeUnidirectionalSequenceRnn524 LogicalResult matchAndRewrite(Operation* op,
525 PatternRewriter& rewriter) const override {
526 auto tflite_indices_attr =
527 op->getAttrOfType<ArrayAttr>(kTfLiteInputIndices);
528 if (!tflite_indices_attr) return failure();
529
530 if (op->getNumOperands() != 5) {
531 op->emitError()
532 << "We're expecting 5 inputs for UnidirectionalSequenceRNN, only "
533 << op->getNumOperands() << " provided";
534 return failure();
535 }
536
537 if (op->getNumResults() != 2) {
538 op->emitError()
539 << "We're expecting 2 inputs for UnidirectionalSequenceRNN, only "
540 << op->getNumResults() << " found";
541 return failure();
542 }
543
544 // Populate inputs.
545 // UnidirectionalSequenceRnn is expected to have 5 inputs, and none of them
546 // are optional inputs.
547 SmallVector<Value, 5> inputs;
548 for (int i = 0; i < 5; ++i) {
549 inputs.push_back(op->getOperand(i));
550 }
551
552 // Populate outputs.
553 // UnidirectionalSequenceRnn should only have 1 output, and that is the
554 // original ophint converted node's 2nd output.
555 SmallVector<Type, 4> result_types;
556 result_types.push_back(op->getOpResult(1).getType());
557
558 // Populate attributes.
559 SmallVector<NamedAttribute, 2> attributes;
560 // Activation will always be tanh.
561 attributes.push_back(rewriter.getNamedAttr("fused_activation_function",
562 rewriter.getStringAttr("TANH")));
563
564 // will always be time_majored.
565 attributes.push_back(
566 rewriter.getNamedAttr("time_major", rewriter.getBoolAttr(true)));
567
568 Value rnn_result = rewriter.create<TFL::UnidirectionalSequenceRNNOp>(
569 op->getLoc(), result_types, inputs, attributes);
570
571 // Rewire the output.
572 rewriter.replaceOp(op, {nullptr, rnn_result});
573
574 return success();
575 }
576 };
577
578 // Put two TFL BroadcastTo ops in front of the given TF binary broadcast op to
579 // to make binary broadcast-able op conversion always successful and does not
580 // require flex delegate.
581 template <typename SourceOp>
582 class ApplyExplicitBroadcasting : public OpRewritePattern<SourceOp> {
583 public:
584 using OpRewritePattern<SourceOp>::OpRewritePattern;
585
matchAndRewrite(SourceOp src_op,PatternRewriter & rewriter) const586 LogicalResult matchAndRewrite(SourceOp src_op,
587 PatternRewriter& rewriter) const override {
588 Operation* op = static_cast<Operation*>(src_op);
589 auto lhs = op->getOperand(0);
590 auto rhs = op->getOperand(1);
591
592 // Should have static shapes to calculate the broadcasted shape.
593 if (!lhs.getType().cast<ShapedType>().hasStaticShape() ||
594 !rhs.getType().cast<ShapedType>().hasStaticShape()) {
595 return failure();
596 }
597
598 auto lhs_shape = lhs.getType().cast<ShapedType>().getShape();
599 auto rhs_shape = rhs.getType().cast<ShapedType>().getShape();
600
601 if (lhs_shape == rhs_shape) {
602 return failure();
603 }
604
605 // Calculate the broadcasted shape.
606 SmallVector<int64_t, 4> result_shape;
607 if (!OpTrait::util::getBroadcastedShape(lhs_shape, rhs_shape,
608 result_shape)) {
609 return failure();
610 }
611
612 RankedTensorType result_type = RankedTensorType::get(
613 result_shape, getElementTypeOrSelf(op->getResult(0).getType()));
614
615 // Create a const op, that stores the above broadcasted shape.
616 auto new_shape_attr = mlir::DenseIntElementsAttr::get(
617 RankedTensorType::get(result_shape.size(), rewriter.getIntegerType(64)),
618 result_shape);
619 auto new_shape = rewriter.create<TF::ConstOp>(op->getLoc(), new_shape_attr);
620
621 // Apply BroadcastTo ops to each input.
622 auto broadcast_type = RankedTensorType::get(
623 result_shape, getElementTypeOrSelf(lhs.getType()));
624
625 if (result_type.getShape() != lhs_shape) {
626 lhs = rewriter
627 .create<TF::BroadcastToOp>(op->getLoc(), broadcast_type, lhs,
628 new_shape)
629 .output();
630 }
631 if (result_type.getShape() != rhs_shape) {
632 rhs = rewriter
633 .create<TF::BroadcastToOp>(op->getLoc(), broadcast_type, rhs,
634 new_shape)
635 .output();
636 }
637
638 // Recreate an op with the above Broadcast op results.
639 rewriter.replaceOpWithNewOp<SourceOp>(op, result_type, lhs, rhs);
640 return success();
641 }
642 };
643
644 // This specialization is for TF SelectV2 op. SelectV2 op have three inputs and
645 // they should have broadcastable shapes.
646 template <>
647 class ApplyExplicitBroadcasting<TF::SelectV2Op>
648 : public OpRewritePattern<TF::SelectV2Op> {
649 public:
650 using OpRewritePattern<TF::SelectV2Op>::OpRewritePattern;
651
matchAndRewrite(TF::SelectV2Op src_op,PatternRewriter & rewriter) const652 LogicalResult matchAndRewrite(TF::SelectV2Op src_op,
653 PatternRewriter& rewriter) const override {
654 Operation* op = static_cast<Operation*>(src_op);
655 auto cond = op->getOperand(0);
656 auto lhs = op->getOperand(1);
657 auto rhs = op->getOperand(2);
658
659 // Should have static shapes to calculate the broadcasted shape.
660 if (!lhs.getType().cast<ShapedType>().hasStaticShape() ||
661 !rhs.getType().cast<ShapedType>().hasStaticShape() ||
662 !cond.getType().cast<ShapedType>().hasStaticShape()) {
663 return failure();
664 }
665
666 auto lhs_shape = lhs.getType().cast<ShapedType>().getShape();
667 auto rhs_shape = rhs.getType().cast<ShapedType>().getShape();
668 auto cond_shape = cond.getType().cast<ShapedType>().getShape();
669
670 if (lhs_shape == rhs_shape && cond_shape == lhs_shape) {
671 return failure();
672 }
673
674 // Calculate the broadcasted shape.
675 SmallVector<int64_t, 4> broadcasted_shape;
676 if (!OpTrait::util::getBroadcastedShape(lhs_shape, rhs_shape,
677 broadcasted_shape)) {
678 return failure();
679 }
680
681 SmallVector<int64_t, 4> result_shape;
682 if (!OpTrait::util::getBroadcastedShape(broadcasted_shape, cond_shape,
683 result_shape)) {
684 return failure();
685 }
686
687 // Create a const op, that stores the above broadcasted shape.
688 auto shape_type =
689 RankedTensorType::get(result_shape.size(), rewriter.getIntegerType(64));
690 auto new_shape_attr =
691 mlir::DenseIntElementsAttr::get(shape_type, result_shape);
692 auto new_shape = rewriter.create<TF::ConstOp>(op->getLoc(), new_shape_attr);
693
694 // Apply BroadcastTo ops to each input.
695 auto cond_result_type =
696 RankedTensorType::get(result_shape, rewriter.getIntegerType(1));
697 auto result_type = RankedTensorType::get(
698 result_shape, getElementTypeOrSelf(lhs.getType()));
699
700 if (result_shape != cond_shape) {
701 cond = rewriter
702 .create<TF::BroadcastToOp>(op->getLoc(), cond_result_type,
703 cond, new_shape)
704 .output();
705 }
706 if (result_shape != lhs_shape) {
707 lhs = rewriter
708 .create<TF::BroadcastToOp>(op->getLoc(), result_type, lhs,
709 new_shape)
710 .output();
711 }
712 if (result_shape != rhs_shape) {
713 rhs = rewriter
714 .create<TF::BroadcastToOp>(op->getLoc(), result_type, rhs,
715 new_shape)
716 .output();
717 }
718
719 // Recreate an op with the above Broadcast op results.
720 rewriter.replaceOpWithNewOp<TF::SelectV2Op>(op, result_type, cond, lhs,
721 rhs);
722 return success();
723 }
724 };
725
addPatterns(MLIRContext * context,OwningRewritePatternList & patterns)726 void addPatterns(MLIRContext* context, OwningRewritePatternList& patterns) {
727 // Add TF->TF lowering patterns.
728 TF::PopulateLoweringTFPatterns(context, &patterns);
729
730 // Add the generated patterns to the list.
731 populateWithGenerated(context, patterns);
732 patterns
733 .insert<ConvertTFConcatV2Op, ConvertTFMatMulOp, ConvertTFMatrixDiagV2Op,
734 ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFSplitOp,
735 ConvertTFSplitVOp, ConvertTFUnpackOp, ConvertTFAssertOp,
736 ConvertTFRandomUniformOp, ConvertTFConv3DOp>(context);
737
738 // Ophint python converter converted tf node pattern.
739 patterns.insert<LegalizeUnidirectionalSequenceLstm,
740 LegalizeUnidirectionalSequenceRnn>(context);
741 }
742
applyPatterns(FuncOp func,ConversionTarget & target,FrozenRewritePatternList & frozenPatterns)743 void applyPatterns(FuncOp func, ConversionTarget& target,
744 FrozenRewritePatternList& frozenPatterns) {
745 // Keep trying to convert.
746 // TODO(karimnosseir): This is similar to what apply greedy patterns does.
747 // Look if there is a function that tries until it converge.
748 // Currently unit-test doesn't do multiple tries, so we need this.
749 const int max_iterations = 15;
750 for (int i = 0; i < max_iterations; ++i) {
751 if (failed(applyPartialConversion(func, target, frozenPatterns))) {
752 return;
753 }
754 }
755 }
756
runOnFunction()757 void LegalizeTF::runOnFunction() {
758 auto* context = &getContext();
759 auto func = getFunction();
760
761 ConversionTarget target(*context);
762 // It is legal to have TF ops in the graph still which can be
763 // used later or in the case of SELECT were we allow TF ops in the final
764 // graph.
765 target.addLegalOp<mlir::ConstantOp>();
766 target.addLegalOp<ConstOp>();
767 if (run_tfl_runtime_verification_) {
768 target.addDynamicallyLegalDialect<TensorFlowLiteDialect>(
769 Optional<ConversionTarget::DynamicLegalityCallbackFn>(
770 [](Operation* op) {
771 auto tfl_op = dyn_cast_or_null<TflRuntimeVerifyOpInterface>(op);
772 if (!tfl_op) return false;
773 return succeeded(tfl_op.VerifyTflRuntimeConstraints(op));
774 }));
775 } else {
776 target.addLegalDialect<TensorFlowLiteDialect>();
777 }
778
779 // Ignore transient errors by registering an no-op handler.
780 // Applying legalization patterns will emit unwanted, transient errors when
781 // the replaced TFLite ops do not meet the sanity checks. In order to ignore
782 // the transient errors, the following lines override a diagnostic handler
783 // with an no-op handler only while this pass runs.
784 uint64_t current_thread_id = llvm::get_threadid();
785 ScopedDiagnosticHandler scoped_diag_handler(
786 context, [¤t_thread_id](Diagnostic&) -> LogicalResult {
787 // Consume only errors that are coming from the same thread in order not
788 // to ignore errors from other passes that are running. Things running
789 // in the pass manager can be multi-threaded.
790 return success(current_thread_id == llvm::get_threadid());
791 });
792
793 OwningRewritePatternList stage1Patterns;
794
795 addPatterns(context, stage1Patterns);
796
797 FrozenRewritePatternList stage1FrozenPatterns(std::move(stage1Patterns));
798 applyPatterns(func, target, stage1FrozenPatterns);
799
800 // Explict BroadcastTo addition for left-over broadcast-able ops.
801 // The following pattern matchings should be done after the other legalization
802 // rules in order not to add unnecessary BroadcastTo ops.
803 OwningRewritePatternList stage2Patterns;
804
805 addPatterns(context, stage2Patterns);
806
807 stage2Patterns.insert<ApplyExplicitBroadcasting<TF::LessEqualOp>,
808 ApplyExplicitBroadcasting<TF::GreaterEqualOp>,
809 ApplyExplicitBroadcasting<TF::NotEqualOp>,
810 ApplyExplicitBroadcasting<TF::GreaterOp>,
811 ApplyExplicitBroadcasting<TF::LessOp>,
812 ApplyExplicitBroadcasting<TF::EqualOp>,
813 ApplyExplicitBroadcasting<TF::AddOp>,
814 ApplyExplicitBroadcasting<TF::AddV2Op>,
815 ApplyExplicitBroadcasting<TF::MulOp>,
816 ApplyExplicitBroadcasting<TF::DivOp>,
817 ApplyExplicitBroadcasting<TF::RealDivOp>,
818 ApplyExplicitBroadcasting<TF::SubOp>,
819 ApplyExplicitBroadcasting<TF::FloorDivOp>,
820 ApplyExplicitBroadcasting<TF::FloorModOp>,
821 ApplyExplicitBroadcasting<TF::PowOp>,
822 ApplyExplicitBroadcasting<TF::MaximumOp>,
823 ApplyExplicitBroadcasting<TF::MinimumOp>,
824 ApplyExplicitBroadcasting<TF::SquaredDifferenceOp>,
825 ApplyExplicitBroadcasting<TF::SelectV2Op>>(context);
826
827 FrozenRewritePatternList stage2FrozenPatterns(std::move(stage2Patterns));
828 applyPatterns(func, target, stage2FrozenPatterns);
829 }
830
831 } // namespace
832
833 // Creates an instance of the TensorFlow Lite dialect LegalizeTF pass.
CreateLegalizeTFPass(bool run_tfl_runtime_verification)834 std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFPass(
835 bool run_tfl_runtime_verification) {
836 return std::make_unique<LegalizeTF>(run_tfl_runtime_verification);
837 }
838
839 static PassRegistration<LegalizeTF> pass(
840 "tfl-legalize-tf", "Legalize from TensorFlow to TensorFlow Lite dialect");
841
842 } // namespace TFL
843 } // namespace mlir
844