1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
17
18 #include <algorithm>
19 #include <cstddef>
20 #include <cstdint>
21 #include <iterator>
22 #include <numeric>
23
24 #include "third_party/eigen3/Eigen/Core"
25 #include "llvm/ADT/APFloat.h"
26 #include "llvm/ADT/APInt.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/SetVector.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/Support/FormatVariadic.h"
31 #include "llvm/Support/raw_ostream.h"
32 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
33 #include "mlir/IR/Attributes.h" // from @llvm-project
34 #include "mlir/IR/Builders.h" // from @llvm-project
35 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
36 #include "mlir/IR/Location.h" // from @llvm-project
37 #include "mlir/IR/Matchers.h" // from @llvm-project
38 #include "mlir/IR/OpImplementation.h" // from @llvm-project
39 #include "mlir/IR/PatternMatch.h" // from @llvm-project
40 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
41 #include "mlir/Support/LLVM.h" // from @llvm-project
42 #include "mlir/Support/LogicalResult.h" // from @llvm-project
43 #include "mlir/Transforms/FoldUtils.h" // from @llvm-project
44 #include "mlir/Transforms/InliningUtils.h" // from @llvm-project
45 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
46 #include "tensorflow/compiler/mlir/lite/ir/tfl_structs.cc.inc"
47 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
48 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
49
50 namespace mlir {
51 namespace TFL {
52
53 // Returns true when the given operand arguments have the same shape or
54 // broadcastable shape within the given rank. If any given shapes are
55 // non-static and maximum rank is within the given rank, this method returns
56 // true.
VerifyOperandsHaveSameShapesOrBroadcastableShape(Operation * op,ArrayRef<unsigned> indices,int max_bcast_rank)57 bool VerifyOperandsHaveSameShapesOrBroadcastableShape(
58 Operation *op, ArrayRef<unsigned> indices, int max_bcast_rank) {
59 if (indices.empty()) return true;
60
61 // First, it checks there are any inputs that has unknown rank.
62 bool has_unknown_shape_input = false;
63 bool has_same_shape = true;
64 bool reach_first_known_shape = false;
65 int64_t max_rank = -1;
66
67 ArrayRef<int64_t> pivot_shape;
68 SmallVector<int64_t, 4> current_shape;
69 SmallVector<int64_t, 4> result_shape;
70
71 for (unsigned index : indices) {
72 ShapedType shaped_type =
73 op->getOperand(index).getType().dyn_cast<ShapedType>();
74 if (!shaped_type || !shaped_type.hasRank()) {
75 // Marks that we have an unknown rank input.
76 has_unknown_shape_input = true;
77 continue;
78 }
79 max_rank = std::max(max_rank, shaped_type.getRank());
80 if (!shaped_type.hasStaticShape()) {
81 // Marks that we have an unknown shape input.
82 has_unknown_shape_input = true;
83 continue;
84 }
85
86 ArrayRef<int64_t> shape = shaped_type.getShape();
87 if (!reach_first_known_shape) {
88 pivot_shape = shape;
89 current_shape.assign(shape.begin(), shape.end());
90 reach_first_known_shape = true;
91 continue;
92 }
93
94 if (!pivot_shape.equals(shape)) {
95 has_same_shape = false;
96 }
97 // Checks if all the inputs are broadcastable since they have not all the
98 // same shapes.
99 if (!OpTrait::util::getBroadcastedShape(current_shape, shape,
100 result_shape)) {
101 return false;
102 }
103 current_shape = result_shape;
104 }
105
106 // It will treat the unknown shape inputs as acceptable inputs for model
107 // compatibility unless there is an known rank that is bigger than the allowed
108 // broadcast maximum rank.
109 if (has_unknown_shape_input) return max_rank <= max_bcast_rank;
110
111 // If all the shape is known and same, CPU kernels are able to handle inputs
112 // regardless of dimension size.
113 return has_same_shape || max_rank <= max_bcast_rank;
114 }
115
116 // Return true when the given element_type is QI8.
IsQI8Type(Type element_type)117 bool IsQI8Type(Type element_type) {
118 auto quantized_type = element_type.dyn_cast<QuantizedType>();
119 return quantized_type != nullptr &&
120 quantized_type.getStorageTypeIntegralWidth() == 8 &&
121 quantized_type.isSigned();
122 }
123
124 // Return true when the given element_type is QUI8.
IsQUI8Type(Type element_type)125 bool IsQUI8Type(Type element_type) {
126 auto quantized_type = element_type.dyn_cast<QuantizedType>();
127 return quantized_type != nullptr &&
128 quantized_type.getStorageTypeIntegralWidth() == 8 &&
129 !quantized_type.isSigned();
130 }
131
132 // Return true when the given element_type is QI16.
IsQI16Type(Type element_type)133 bool IsQI16Type(Type element_type) {
134 auto quantized_type = element_type.dyn_cast<QuantizedType>();
135 return quantized_type != nullptr &&
136 quantized_type.getStorageTypeIntegralWidth() == 16 &&
137 quantized_type.isSigned();
138 }
139
140 // Return true when the given element_type is I32.
IsI32Type(Type element_type)141 bool IsI32Type(Type element_type) {
142 return element_type.isInteger(32) && !element_type.isUnsignedInteger();
143 }
144
145 // Return true when the given element_type is I64.
IsI64Type(Type element_type)146 bool IsI64Type(Type element_type) {
147 return element_type.isInteger(64) && !element_type.isUnsignedInteger();
148 }
149
150 // Return true if the value is a splat tensor constant zero.
EqualsZero(Value value)151 bool EqualsZero(Value value) {
152 DenseElementsAttr constant;
153 if (!matchPattern(value, m_Constant(&constant)) || !constant.isSplat()) {
154 return false;
155 }
156
157 Type element_type = value.getType().cast<ShapedType>().getElementType();
158 if (element_type.isa<FloatType>()) {
159 return constant.getSplatValue<APFloat>().isZero();
160 } else {
161 return false;
162 }
163 }
164
165 // Replaces the bias operand with a "none" type value if the bias value is
166 // constant zero.
167 // `ConcreteOpType` must be an concrete MLIR op class that has an optional
168 // bias operand named 'bias'.
169 template <typename ConcreteOpType>
170 struct RemoveOptionalZeroBias : public OpRewritePattern<ConcreteOpType> {
171 using OpRewritePattern<ConcreteOpType>::OpRewritePattern;
172
matchAndRewritemlir::TFL::RemoveOptionalZeroBias173 LogicalResult matchAndRewrite(ConcreteOpType op,
174 PatternRewriter &rewriter) const override {
175 if (EqualsZero(op.bias())) {
176 auto none_value = rewriter.create<mlir::ConstantOp>(
177 rewriter.getUnknownLoc(), rewriter.getUnitAttr());
178 op.biasMutable().assign(none_value);
179 }
180
181 return success();
182 }
183 };
184
185 // Return true if the given Add operation has the CPU kernel supported shapes.
VerifyAddOpShapeConstraints(AddOp op)186 bool VerifyAddOpShapeConstraints(AddOp op) {
187 auto element_type = getElementTypeOrSelf(op.output().getType());
188
189 // Allows F32, QI8, QUI8 and I32 outputs when the operands have valid shapes,
190 // which are broadcastable shapes up to five dimension or have same shapes.
191 if (element_type.isF32() || IsQI8Type(element_type) ||
192 IsQUI8Type(element_type) || IsI32Type(element_type)) {
193 return VerifyOperandsHaveSameShapesOrBroadcastableShape(
194 /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
195 /*max_bcast_rank=*/4);
196 }
197
198 // Allows QI16 output when operands have the same shape.
199 if (IsQI16Type(element_type)) {
200 return succeeded(
201 mlir::verifyCompatibleShape(op.lhs().getType(), op.rhs().getType()));
202 }
203 return false;
204 }
205
206 // Return true if the given Sub operation has the CPU kernel supported shapes.
VerifySubOpShapeConstraints(SubOp op)207 bool VerifySubOpShapeConstraints(SubOp op) {
208 auto element_type = getElementTypeOrSelf(op.output().getType());
209
210 // Allows F32, QUI8, and QI16 outputs when the operands have valid shapes,
211 // which are broadcastable shapes up to five dimension or have same shapes.
212 if (element_type.isF32() || IsI32Type(element_type) ||
213 IsI64Type(element_type) || IsQUI8Type(element_type) ||
214 IsQI16Type(element_type)) {
215 return VerifyOperandsHaveSameShapesOrBroadcastableShape(
216 /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
217 /*max_bcast_rank=*/5);
218 }
219
220 // Allows QI8 output when the operands have valid shapes, which are
221 // broadcastable shapes up to four dimension or have same shapes.
222 if (IsQI8Type(element_type)) {
223 return VerifyOperandsHaveSameShapesOrBroadcastableShape(
224 /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
225 /*max_bcast_rank=*/4);
226 }
227 return false;
228 }
229
230 // Return true if the given Mul operation has the CPU kernel supported shapes.
VerifyMulOpShapeConstraints(MulOp op)231 bool VerifyMulOpShapeConstraints(MulOp op) {
232 auto element_type = getElementTypeOrSelf(op.output().getType());
233
234 // Allows QI8 and QUI8 inputs up to five dimension broadcasting unless the
235 // output type is not QI16. If the output type is Q16, allows only the same
236 // shape operands.
237 if (IsQI8Type(element_type) || IsQUI8Type(element_type)) {
238 if (IsQI16Type(getElementTypeOrSelf(op.lhs().getType()))) {
239 return succeeded(
240 mlir::verifyCompatibleShape(op.lhs().getType(), op.rhs().getType()));
241 }
242 return VerifyOperandsHaveSameShapesOrBroadcastableShape(
243 /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
244 /*max_bcast_rank=*/4);
245 }
246
247 // Allows I32, QI16 and F32 outputs when the operands have valid shapes, which
248 // are broadcastable shapes up to four dimension or have same shapes.
249 if (IsI32Type(element_type) || IsQI16Type(element_type) ||
250 element_type.isF32()) {
251 return VerifyOperandsHaveSameShapesOrBroadcastableShape(
252 /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
253 /*max_bcast_rank=*/4);
254 }
255 return false;
256 }
257
258 //===----------------------------------------------------------------------===//
259 // TensorFlowLiteDialect
260 //===----------------------------------------------------------------------===//
261
262 struct TensorFlowLiteInlinerInterface : public DialectInlinerInterface {
263 using DialectInlinerInterface::DialectInlinerInterface;
264
265 //===--------------------------------------------------------------------===//
266 // Analysis Hooks
267 //===--------------------------------------------------------------------===//
268
269 // Allow all call operations to be inlined.
isLegalToInlinemlir::TFL::TensorFlowLiteInlinerInterface270 bool isLegalToInline(Operation *call, Operation *callable,
271 bool wouldBeCloned) const final {
272 return true;
273 }
isLegalToInlinemlir::TFL::TensorFlowLiteInlinerInterface274 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
275 BlockAndValueMapping &) const final {
276 // No TFLite op restricts inlining today, revise as needed in the future.
277 return true;
278 }
isLegalToInlinemlir::TFL::TensorFlowLiteInlinerInterface279 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
280 BlockAndValueMapping &valueMapping) const final {
281 return isa<WhileOp>(dest->getParentOp());
282 }
283 };
284
285 struct TensorFlowLiteDialectFoldInterface : public DialectFoldInterface {
286 using DialectFoldInterface::DialectFoldInterface;
287
288 // Registered hook to check if the given region, which is attached to an
289 // operation that is *not* isolated from above (i.e. no internal regions
290 // reference values defined in an enclosing region), should be used when
291 // materializing constants.
292 // In the TFLite dialect we materialize inside a while regions as slightly
293 // more efficient computationally.
shouldMaterializeIntomlir::TFL::TensorFlowLiteDialectFoldInterface294 bool shouldMaterializeInto(Region *region) const final {
295 return isa<WhileOp>(region->getParentOp());
296 }
297 };
298
TensorFlowLiteDialect(mlir::MLIRContext * context)299 TensorFlowLiteDialect::TensorFlowLiteDialect(mlir::MLIRContext *context)
300 : Dialect(/*name=*/"tfl", context, TypeID::get<TensorFlowLiteDialect>()) {
301 addOperations<
302 #define GET_OP_LIST
303 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
304 >();
305 addInterfaces<TensorFlowLiteInlinerInterface,
306 TensorFlowLiteDialectFoldInterface>();
307 }
308
309 //===----------------------------------------------------------------------===//
310 // Common support logic
311 //===----------------------------------------------------------------------===//
312
313 namespace {
314
315 // Returns true if the dimensions in `a` is a suffix of the ones in `b`.
316 // For example, dimensions {2}, {1, 2}, and {3, 1, 2} are all suffixes to
317 // {5, 4, 3, 1, 2}, while {1}, {5, 4}, and {1, 3, 2} are all not.
IsTrailingDimensions(ArrayRef<int64_t> a,ArrayRef<int64_t> b)318 inline bool IsTrailingDimensions(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
319 if (a.size() > b.size()) return false;
320
321 return std::equal(a.rbegin(), a.rend(), b.rbegin());
322 }
323
324 // Returns true if it is a shaped type of f32 elements.
IsF32ShapedType(Type t)325 inline bool IsF32ShapedType(Type t) {
326 if (auto shaped_type = t.dyn_cast_or_null<ShapedType>()) {
327 return shaped_type.getElementType().isF32();
328 }
329 return false;
330 }
331
332 // Returns true if it is a shaped type of bf16 elements.
IsBF16ShapedType(Type t)333 inline bool IsBF16ShapedType(Type t) {
334 if (auto shaped_type = t.dyn_cast_or_null<ShapedType>()) {
335 return shaped_type.getElementType().isBF16();
336 }
337 return false;
338 }
339
340 // Returns new shape with rank 'new_dims' with padded ones on the
341 // left if needed.
GetPaddedShape(ArrayRef<int64_t> old_shape,int new_dims)342 inline std::vector<int64_t> GetPaddedShape(ArrayRef<int64_t> old_shape,
343 int new_dims) {
344 std::vector<int64_t> new_shape(new_dims, 1);
345 std::copy_backward(old_shape.begin(), old_shape.end(), new_shape.end());
346 return new_shape;
347 }
348
349 // Helper method that given and 'current_index' representing
350 // index in broadcasted tensor, get the index in the flat original tensor.
351 // 'shape' is the original shape with padding to match result shape.
GetElementIndex(const std::vector<int64_t> & shape,const std::vector<int64_t> & current_index)352 int64_t GetElementIndex(const std::vector<int64_t> &shape,
353 const std::vector<int64_t> ¤t_index) {
354 int64_t ind = 0;
355 int64_t mul = 1;
356 for (int i = shape.size() - 1; i >= 0; --i) {
357 ind += (current_index[i] % shape[i]) * mul;
358 mul *= shape[i];
359 }
360 return ind;
361 }
362
363 // Helper method that increment index represented in 'current_index_ptr'
364 // in the shape of 'result_shape'.
IncrementIndex(ArrayRef<int64_t> result_shape,std::vector<int64_t> * current_index_ptr)365 void IncrementIndex(ArrayRef<int64_t> result_shape,
366 std::vector<int64_t> *current_index_ptr) {
367 std::vector<int64_t> ¤t_index = *current_index_ptr;
368 for (int i = result_shape.size() - 1; i >= 0; --i) {
369 current_index[i]++;
370 if (current_index[i] == result_shape[i]) {
371 current_index[i] = 0;
372 } else {
373 break;
374 }
375 }
376 }
377
378 /// Performs const folding `calculate` with broadcast behavior on the two
379 /// attributes `operand1` and `operand2` and returns the result if possible.
380 /// This function assumes the both operands are verified to have value
381 /// attributes of broadcastable types.
382 template <class AttrElementT,
383 class ElementValueT = typename AttrElementT::ValueType,
384 class CalculationT =
385 llvm::function_ref<ElementValueT(ElementValueT, ElementValueT)>>
ConstFoldBinaryOpDenseDense(Type result_type,DenseElementsAttr lhs,DenseElementsAttr rhs,const CalculationT & calculate)386 Attribute ConstFoldBinaryOpDenseDense(Type result_type, DenseElementsAttr lhs,
387 DenseElementsAttr rhs,
388 const CalculationT &calculate) {
389 auto type = OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType())
390 .dyn_cast_or_null<ShapedType>();
391 if (!type) {
392 return {};
393 }
394
395 const bool rhs_is_splat = rhs.isSplat();
396 const bool lhs_is_splat = lhs.isSplat();
397
398 // If both of them are splat, compute and return.
399 if (lhs_is_splat && rhs_is_splat) {
400 auto element_result = AttrElementT::get(
401 type.getElementType(), calculate(lhs.getSplatValue<ElementValueT>(),
402 rhs.getSplatValue<ElementValueT>()));
403 if (!element_result) return {};
404
405 return DenseElementsAttr::get(type, element_result);
406 }
407
408 auto num_elements = type.getNumElements();
409
410 SmallVector<ElementValueT, 16> new_values;
411 new_values.reserve(num_elements);
412 const auto result_shape = type.getShape();
413 std::vector<int64_t> current_index(type.getRank(), 0);
414 // Create the new shape with ones padded to the left.
415 const std::vector<int64_t> lhs_new_shape =
416 GetPaddedShape(lhs.getType().getShape(), type.getRank());
417 const std::vector<int64_t> rhs_new_shape =
418 GetPaddedShape(rhs.getType().getShape(), type.getRank());
419
420 auto lhs_old_values = lhs.getValues<ElementValueT>();
421 auto rhs_old_values = rhs.getValues<ElementValueT>();
422
423 // Add each pair of the corresponding values in the dense elements
424 // attributes.
425 for (int64_t i = 0; i < num_elements; ++i) {
426 // current_index represents the index
427 // in the N-dimension tensor. GetElementIndex returns
428 // the index in the flat representation of the original tensor
429 // to use.
430 const int64_t lhs_index =
431 lhs_is_splat ? 0 : GetElementIndex(lhs_new_shape, current_index);
432 const int64_t rhs_index =
433 rhs_is_splat ? 0 : GetElementIndex(rhs_new_shape, current_index);
434
435 new_values.push_back(calculate(*(lhs_old_values.begin() + lhs_index),
436 *(rhs_old_values.begin() + rhs_index)));
437 IncrementIndex(result_shape, ¤t_index);
438 }
439 return DenseElementsAttr::get(type, ArrayRef<ElementValueT>(new_values));
440 }
441
442 /// Performs const folding `calculate` with broadcast behavior on the two
443 /// attributes `operand1` and `operand2` and returns the result if possible.
444 /// This function assumes the two operands are verified to have value
445 /// attributes of broadcastable types.
446 template <class AttrElementT,
447 class ElementValueT = typename AttrElementT::ValueType,
448 class CalculationT =
449 llvm::function_ref<ElementValueT(ElementValueT, ElementValueT)>>
ConstFoldBinaryOp(Type result_type,Attribute operand1,Attribute operand2,const CalculationT & calculate)450 Attribute ConstFoldBinaryOp(Type result_type, Attribute operand1,
451 Attribute operand2, const CalculationT &calculate) {
452 if (operand1.dyn_cast_or_null<DenseElementsAttr>() &&
453 operand2.dyn_cast_or_null<DenseElementsAttr>()) {
454 return ConstFoldBinaryOpDenseDense<AttrElementT, ElementValueT>(
455 result_type, operand1.cast<DenseElementsAttr>(),
456 operand2.cast<DenseElementsAttr>(), calculate);
457 }
458
459 // TODO: support other attribute kinds
460
461 return {};
462 }
463
464 /// Performs const folding with broadcast behavior on the two attributes in
465 /// `operands` and returns the result if possible.
466 /// Depending on the given `resultType`, either `floatCalculate` or
467 /// `intCalculate` is chosen to conduct the calculate.
ConstFoldBinaryOp(Type result_type,ArrayRef<Attribute> operands,llvm::function_ref<APFloat (APFloat,APFloat)> float_calculate,llvm::function_ref<APInt (APInt,APInt)> int_calculate)468 Attribute ConstFoldBinaryOp(
469 Type result_type, ArrayRef<Attribute> operands,
470 llvm::function_ref<APFloat(APFloat, APFloat)> float_calculate,
471 llvm::function_ref<APInt(APInt, APInt)> int_calculate) {
472 // Note: All types are wrapped in tensor types in TFlite. E.g., f32 is
473 // represented as tensor<f32>. So we are only handling tensor types here.
474 auto type = result_type.dyn_cast<ShapedType>();
475 if (!type) return {};
476
477 auto elemType = type.getElementType();
478
479 if (elemType.isa<FloatType>())
480 return ConstFoldBinaryOp<FloatAttr>(result_type, operands[0], operands[1],
481 float_calculate);
482
483 if (elemType.isSignlessInteger())
484 return ConstFoldBinaryOp<IntegerAttr>(result_type, operands[0], operands[1],
485 int_calculate);
486
487 return {};
488 }
489
490 /// Performs const folding a attributes `operand` and returns the result if
491 /// possible.
492 /// The function currently asserts that the `result_type` to be a f32 tensor
493 /// type.
494 /// TODO: Extend this function to handle integral tensor for ops like
495 /// "tfl.logical_not".
ConstFoldUnaryOp(Type result_type,Attribute operand,llvm::function_ref<APFloat (APFloat)> calculate)496 Attribute ConstFoldUnaryOp(Type result_type, Attribute operand,
497 llvm::function_ref<APFloat(APFloat)> calculate) {
498 assert(IsF32ShapedType(result_type) || IsBF16ShapedType(result_type));
499 auto result_shape_type = result_type.cast<ShapedType>();
500
501 if (!result_shape_type.hasStaticShape()) return {};
502
503 if (auto dense_elements = operand.dyn_cast_or_null<DenseElementsAttr>()) {
504 SmallVector<APFloat, 16> new_values;
505 const int num_elements = result_shape_type.getNumElements();
506 new_values.reserve(num_elements);
507
508 for (const APFloat &old_value : dense_elements.getValues<APFloat>()) {
509 new_values.push_back(calculate(old_value));
510 }
511
512 return DenseElementsAttr::get(result_shape_type, new_values);
513 }
514
515 return {};
516 }
517
buildComparisonBinOp(Builder * builder,OperationState & result,Value lhs,Value rhs)518 void buildComparisonBinOp(Builder *builder, OperationState &result, Value lhs,
519 Value rhs) {
520 auto result_type =
521 OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType());
522 if (!result_type)
523 emitError(result.location)
524 << "non-broadcastable operands: " << lhs.getType() << " and "
525 << rhs.getType();
526 result.addOperands({lhs, rhs});
527 // Comparison binary ops always return i1 tensor.
528 if (auto shaped_type = result_type.dyn_cast<RankedTensorType>()) {
529 auto result_shape = shaped_type.getShape();
530 result.types.push_back(
531 RankedTensorType::get(result_shape, builder->getI1Type()));
532 } else {
533 result.types.push_back(UnrankedTensorType::get(builder->getI1Type()));
534 }
535 }
536
buildFusedBroadcastableBinOp(Builder * builder,OperationState & result,Value lhs,Value rhs,StringAttr fused_activation_function)537 void buildFusedBroadcastableBinOp(Builder *builder, OperationState &result,
538 Value lhs, Value rhs,
539 StringAttr fused_activation_function) {
540 auto result_type =
541 OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType());
542
543 if (!result_type)
544 emitError(result.location)
545 << "non-broadcastable operands: " << lhs.getType() << " and "
546 << rhs.getType();
547
548 result.addOperands({lhs, rhs});
549 result.addAttribute("fused_activation_function", fused_activation_function);
550 result.types.push_back(result_type);
551 }
552
553 } // end anonymous namespace
554
555 //===----------------------------------------------------------------------===//
556 // AddOp
557 //===----------------------------------------------------------------------===//
558
fold(ArrayRef<Attribute> operands)559 OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
560 // TODO(b/142478136): Handle fused ops.
561 if (fused_activation_function() != "NONE") return {};
562 return ConstFoldBinaryOp(
563 getType(), operands, [](APFloat a, APFloat b) { return a + b; },
564 [](APInt a, APInt b) { return a + b; });
565 }
566
567 //===----------------------------------------------------------------------===//
568 // ConcatenationOp
569 //===----------------------------------------------------------------------===//
570 // TODO(ashwinm): Implement shape inference for Concatenation
571
572 namespace {
573
GetConcatenationOpAxis(ConcatenationOp op)574 int64_t GetConcatenationOpAxis(ConcatenationOp op) {
575 auto output_type = op.output().getType().cast<RankedTensorType>();
576 int32_t axis = op.axis();
577 if (axis < 0) axis += output_type.getRank();
578 return axis;
579 }
580
581 // Verify operand types and the result type:
582 //
583 // 1. Operand type ranks must be equal to the output type rank.
584 //
585 // 2. Operand dimension sizes (except dimension `axis`) must be equal to
586 // previously seen dimension sizes of the same dimension.
587 //
588 // 3. Sum of operand dimension sizes of the `axis` dimension must be equal to
589 // the dimension size of the `axis` dimension of output.
590 //
591 // Note: If an operand has unranked tensor type or has dynamic dimension size,
592 // those dimensions will be skipped.
VerifyConcatenationOpTypes(Operation * op,RankedTensorType output_type,ArrayRef<TensorType> operand_types,int64_t axis)593 LogicalResult VerifyConcatenationOpTypes(Operation *op,
594 RankedTensorType output_type,
595 ArrayRef<TensorType> operand_types,
596 int64_t axis) {
597 const int64_t output_rank = output_type.getRank();
598
599 constexpr int64_t kDynamicSize = -1;
600 SmallVector<int64_t, 4> result_dim_sizes_loc(output_rank, -1);
601 SmallVector<int64_t, 4> result_dim_sizes(output_type.getShape().begin(),
602 output_type.getShape().end());
603 result_dim_sizes[axis] = 0;
604
605 auto FormatLoc = [&result_dim_sizes_loc](int64_t dim) {
606 const int64_t loc = result_dim_sizes_loc[dim];
607 if (loc == -1) return std::string("output");
608 return llvm::formatv("operand #{0}", loc).str();
609 };
610
611 for (auto operand : llvm::enumerate(operand_types)) {
612 auto operand_type = operand.value().dyn_cast<RankedTensorType>();
613 if (!operand_type) {
614 result_dim_sizes[axis] = kDynamicSize;
615 continue;
616 }
617
618 const int64_t operand_rank = operand_type.getRank();
619 if (operand_rank != output_rank)
620 return op->emitOpError() << "rank of operand #" << operand.index()
621 << " must be equal to rank of output, expected "
622 << output_rank << ", got " << operand_rank;
623
624 for (int64_t dim = 0; dim < output_rank; ++dim) {
625 const int64_t operand_dim_size = operand_type.getDimSize(dim);
626 const int64_t result_dim_size = result_dim_sizes[dim];
627
628 if (dim == axis) {
629 if (RankedTensorType::isDynamic(operand_dim_size) ||
630 RankedTensorType::isDynamic(result_dim_size))
631 result_dim_sizes[axis] = kDynamicSize;
632 else
633 result_dim_sizes[axis] += operand_dim_size;
634 continue;
635 }
636
637 if (RankedTensorType::isDynamic(operand_dim_size)) continue;
638
639 if (RankedTensorType::isDynamic(result_dim_size)) {
640 result_dim_sizes[dim] = operand_dim_size;
641 result_dim_sizes_loc[dim] = operand.index();
642 continue;
643 }
644
645 if (result_dim_size != operand_dim_size)
646 return op->emitOpError()
647 << "dimension size of dimension #" << dim << " of operand #"
648 << operand.index() << " must be equal to "
649 << "dimension size of dimension #" << dim << " of "
650 << FormatLoc(dim) << ", expected " << result_dim_size << ", got "
651 << operand_dim_size;
652 }
653 }
654
655 const int64_t output_concated_dim_size = output_type.getDimSize(axis);
656 if (!RankedTensorType::isDynamic(output_concated_dim_size) &&
657 !RankedTensorType::isDynamic(result_dim_sizes[axis]) &&
658 result_dim_sizes[axis] != output_concated_dim_size)
659 return op->emitOpError()
660 << "dimension size of dimension #" << axis << " of output "
661 << "must be equal to the sum of dimension sizes of dimension #"
662 << axis << ", expected " << result_dim_sizes[axis] << ", got "
663 << output_concated_dim_size;
664
665 return success();
666 }
667
Verify(ConcatenationOp op)668 LogicalResult Verify(ConcatenationOp op) {
669 auto output_type = op.output().getType().dyn_cast<RankedTensorType>();
670
671 // If the output type is unranked, there is nothing else to be verified.
672 if (!output_type) return success();
673
674 const int64_t axis = GetConcatenationOpAxis(op);
675 if (axis < 0 || axis >= output_type.getRank())
676 return op.emitOpError("concatenation dimension must be in [-rank, rank)");
677
678 SmallVector<TensorType, 4> operand_types;
679 for (Value operand : op.values())
680 operand_types.push_back(operand.getType().cast<TensorType>());
681
682 return VerifyConcatenationOpTypes(op.getOperation(), output_type,
683 operand_types, axis);
684 }
685
686 // Returns true when all operands are instances of DenseElementsAttr and the
687 // output type has a static shape.
IsConcatenationOpConstFoldable(ConcatenationOp op,ArrayRef<Attribute> operands,RankedTensorType output_type,int64_t axis)688 bool IsConcatenationOpConstFoldable(ConcatenationOp op,
689 ArrayRef<Attribute> operands,
690 RankedTensorType output_type,
691 int64_t axis) {
692 if (operands.empty()) return false;
693 if (!output_type.hasStaticShape()) return false;
694 if (axis < 0) return false;
695
696 return llvm::all_of(operands, [](Attribute operand) {
697 return operand && operand.isa<DenseElementsAttr>();
698 });
699 }
700
ConstFoldConcatenateOpDense(ArrayRef<Attribute> operands,RankedTensorType output_type,int64_t axis)701 DenseElementsAttr ConstFoldConcatenateOpDense(ArrayRef<Attribute> operands,
702 RankedTensorType output_type,
703 int64_t axis) {
704 const auto outer_dims = output_type.getShape().take_front(axis);
705 const int64_t outer_size = std::accumulate(
706 outer_dims.begin(), outer_dims.end(), 1, std::multiplies<int64_t>());
707
708 const auto base_inner_dims = output_type.getShape().drop_front(axis + 1);
709 const int64_t base_inner_size =
710 std::accumulate(base_inner_dims.begin(), base_inner_dims.end(), 1,
711 std::multiplies<int64_t>());
712
713 // Splits each input operand into outer_size pieces and combines them in
714 // round-robin ordering.
715 std::vector<Attribute> out_attrs(output_type.getNumElements());
716 int64_t out = 0;
717 for (int64_t outer = 0; outer < outer_size; ++outer) {
718 for (auto op : operands) {
719 const int64_t dim_size =
720 op.getType().cast<RankedTensorType>().getDimSize(axis);
721 const int64_t inner_size = dim_size * base_inner_size;
722
723 auto input_attrs = op.cast<DenseElementsAttr>().getValues<Attribute>();
724 auto input_iter = input_attrs.begin() + outer * inner_size;
725 for (int64_t inner = 0; inner < inner_size; ++inner)
726 out_attrs[out++] = *input_iter++;
727 }
728 }
729
730 return DenseElementsAttr::get(output_type, out_attrs);
731 }
732
733 } // end anonymous namespace
734
fold(ArrayRef<Attribute> operands)735 OpFoldResult ConcatenationOp::fold(ArrayRef<Attribute> operands) {
736 if (fused_activation_function() == "NONE") {
737 if (auto output_type = output().getType().dyn_cast<RankedTensorType>()) {
738 const int64_t axis = GetConcatenationOpAxis(*this);
739 if (IsConcatenationOpConstFoldable(*this, operands, output_type, axis))
740 return ConstFoldConcatenateOpDense(operands, output_type, axis);
741 }
742 }
743
744 // Remove all empty values.
745 SmallVector<Value, 4> non_empty_values;
746 for (Value value : this->values()) {
747 const auto shaped_type = value.getType().cast<ShapedType>();
748 if (shaped_type.hasStaticShape() && shaped_type.getNumElements() == 0) {
749 continue;
750 }
751 non_empty_values.push_back(value);
752 }
753
754 // All are not empty, do nothing.
755 if (non_empty_values.size() == getNumOperands()) return nullptr;
756
757 // If only one input is non-empty, just return it as the result of folding.
758 if (non_empty_values.size() == 1) {
759 return non_empty_values[0];
760 }
761
762 // Otherwise, build a new concatenation op with non-empty values.
763 mlir::OpBuilder builder(getOperation());
764 auto new_concat = builder.create<TFL::ConcatenationOp>(
765 getLoc(), getType(), non_empty_values,
766 builder.getIntegerAttr(builder.getIntegerType(32), axis()),
767 builder.getStringAttr(fused_activation_function()));
768 return new_concat.getResult();
769 }
770
771 //===----------------------------------------------------------------------===//
772 // CustomOp
773 //===----------------------------------------------------------------------===//
774
Verify(CustomOp op)775 static LogicalResult Verify(CustomOp op) {
776 OpaqueElementsAttr opaque_attr =
777 op.custom_option().cast<OpaqueElementsAttr>();
778 if (!opaque_attr.getType().hasStaticShape())
779 return op.emitOpError("custom_option should have a static shape.");
780 const int attribute_size = opaque_attr.getValue().size();
781 if (attribute_size != opaque_attr.getType().cast<ShapedType>().getDimSize(0))
782 return op.emitOpError(
783 "custom_option should have the same length of content with shape.");
784 return success();
785 }
786
787 //===----------------------------------------------------------------------===//
788 // FullyConnectedOp
789 //===----------------------------------------------------------------------===//
790
Verify(FullyConnectedOp op)791 LogicalResult Verify(FullyConnectedOp op) {
792 ShapedType input_type = op.input().getType().cast<ShapedType>();
793 ShapedType filter_type = op.filter().getType().cast<ShapedType>();
794 if (filter_type.hasRank() && filter_type.getRank() != 2) {
795 return op.emitOpError("expect 2d filter, got ") << filter_type;
796 }
797
798 if (!input_type.hasStaticShape() || !filter_type.hasStaticShape()) {
799 return mlir::success();
800 }
801
802 // Input's element size must be multiple of parameter's z_in dimension.
803 const int z_in = filter_type.getDimSize(1);
804 const int num_input_elements = input_type.getNumElements();
805 if (num_input_elements % z_in != 0) {
806 return op.emitOpError(llvm::formatv(
807 "expect 'input' num_elements % {0} == 0, got input type ", z_in))
808 << input_type;
809 }
810
811 // TODO(jpienaar): Include more shape verification for SHUFFLED4x16INT8
812 // format.
813 if (op.weights_format() == "DEFAULT") {
814 ShapedType output_type =
815 (*op.output().begin()).getType().cast<ShapedType>();
816 if (!output_type.hasStaticShape()) {
817 return mlir::success();
818 }
819
820 const int num_output_elements = output_type.getNumElements();
821 const int z_out = filter_type.getDimSize(0);
822 if (num_output_elements % z_out != 0) {
823 return op.emitOpError(llvm::formatv(
824 "expect 'output' num_elements % {0} == 0, got ", z_out))
825 << output_type;
826 }
827
828 if (num_input_elements / z_in != num_output_elements / z_out) {
829 return op.emitOpError(
830 "num_input_elements / z_in != num_output_elements / z_out");
831 }
832 }
833
834 return mlir::success();
835 }
836
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)837 LogicalResult FullyConnectedOp::fold(ArrayRef<Attribute> operands,
838 SmallVectorImpl<OpFoldResult> &results) {
839 assert(operands.size() == 3);
840
841 // Folding not implemented with any activation function or any weight type
842 // besides the default.
843 if (fused_activation_function() != "NONE") return failure();
844 if (weights_format() != "DEFAULT") return failure();
845
846 // Bias tensor is optional.
847 const bool has_bias = !(!bias() || bias().getType().isa<NoneType>());
848
849 // Get the tensors.
850 DenseElementsAttr input_tensor, weights_tensor, bias_tensor;
851 if (!matchPattern(input(), m_Constant(&input_tensor)) ||
852 !matchPattern(filter(), m_Constant(&weights_tensor)) ||
853 (has_bias && !matchPattern(bias(), m_Constant(&bias_tensor)))) {
854 return failure();
855 }
856
857 // Get the tensor types.
858 const auto input_type = input_tensor.getType().cast<ShapedType>();
859 const auto weights_type = weights_tensor.getType().cast<ShapedType>();
860 const auto bias_type =
861 has_bias ? bias_tensor.getType().cast<ShapedType>() : ShapedType{};
862
863 const auto output_type = getType(0).cast<ShapedType>();
864
865 // Folding only implemented for float tensors.
866 if (!input_type.getElementType().isF32() ||
867 !weights_type.getElementType().isF32() ||
868 !output_type.getElementType().isF32() ||
869 (has_bias && !bias_type.getElementType().isF32())) {
870 return failure();
871 }
872
873 // Folding only implemented for static shapes
874 if (!input_type.hasStaticShape() || !weights_type.hasStaticShape() ||
875 (has_bias && !bias_type.hasStaticShape())) {
876 return failure();
877 }
878
879 // Folding only implemented for 1D input, 2D weights and 1D bias
880 if (input_type.getShape().size() != 1 ||
881 weights_type.getShape().size() != 2 ||
882 (has_bias && bias_type.getShape().size() != 1)) {
883 return failure();
884 }
885
886 // Get the sizes
887 const auto input_size = input_type.getNumElements();
888 const auto output_size = output_type.getNumElements();
889
890 // Get iterators to the tensors.
891 const auto input_values_it = input_tensor.getValues<float>().begin();
892 const auto weights_values_ptr = weights_tensor.getValues<float>().begin();
893 auto weights_row_it = weights_values_ptr;
894 // The 'else' case could be nullptr, but the types don't match.
895 auto bias_values_it =
896 has_bias ? bias_tensor.getValues<float>().begin() : input_values_it;
897
898 // Do the actual folding, one output at a time.
899 std::vector<float> result_values;
900 result_values.reserve(output_size);
901
902 for (int i = 0; i < output_size; ++i) {
903 // Dot product with Kahan/Neumaier summation to minimize numeric errors.
904 float sum = has_bias ? *bias_values_it : 0.0f;
905 float compensation = 0.0f;
906 for (int j = 0; j < input_size; ++j) {
907 const float addend = input_values_it[j] * weights_row_it[j];
908 const float new_sum = sum + addend;
909 // DO NOT enable -funsafe-math-optimizations here.
910 // There is a test detecting unsafe optimizations.
911 // Unsafe math optimizations can reorder float formulas, and set the
912 // compensation to constant 0. The formula must be evaluated as written
913 // for the algorithm to work.
914 // (Note: -ffast-math is a superset of -funsafe-math-optimizations.)
915 if (std::abs(sum) >= std::abs(addend)) {
916 compensation += (sum - new_sum) + addend;
917 } else {
918 compensation += (addend - new_sum) + sum;
919 }
920 sum = new_sum;
921 }
922 result_values.push_back(sum + compensation);
923 weights_row_it += input_size;
924 bias_values_it++;
925 }
926
927 // Set result tensor
928 const auto folded =
929 DenseElementsAttr::get(output_type, ArrayRef<float>(result_values));
930 results.assign({folded});
931
932 return success();
933 }
934
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)935 void FullyConnectedOp::getCanonicalizationPatterns(
936 OwningRewritePatternList &results, MLIRContext *context) {
937 results.insert<RemoveOptionalZeroBias<FullyConnectedOp>>(context);
938 }
939
940 //===----------------------------------------------------------------------===//
941 // Conv2DOp
942 //===----------------------------------------------------------------------===//
943
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)944 void Conv2DOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
945 MLIRContext *context) {
946 // TODO(b/180121750): Enable the pattern after the integration tests are
947 // fixed.
948 // results.insert<RemoveOptionalZeroBias<Conv2DOp>>(context);
949 }
950
951 //===----------------------------------------------------------------------===//
952 // DepthwiseConv2DO
953 //===----------------------------------------------------------------------===//
954
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)955 void DepthwiseConv2DOp::getCanonicalizationPatterns(
956 OwningRewritePatternList &results, MLIRContext *context) {
957 // TODO(b/180121750): Enable the pattern after the integration tests are
958 // fixed.
959 // results.insert<RemoveOptionalZeroBias<DepthwiseConv2DOp>>(context);
960 }
961
962 //===----------------------------------------------------------------------===//
963 // GatherOp
964 //===----------------------------------------------------------------------===//
965
BuildGatherOp(OpBuilder * builder,OperationState & result,Value params,Value indices,IntegerAttr axis)966 static void BuildGatherOp(OpBuilder *builder, OperationState &result,
967 Value params, Value indices, IntegerAttr axis) {
968 auto params_type = params.getType().cast<TensorType>();
969 auto indices_type = indices.getType().cast<TensorType>();
970
971 // If params/indices is unranked, then output is unranked.
972 if (!params_type.hasRank() || !indices_type.hasRank())
973 return TFL::GatherOp::build(
974 *builder, result, UnrankedTensorType::get(params_type.getElementType()),
975 params, indices, axis);
976
977 int64_t params_rank = params_type.getRank();
978 int64_t indices_rank = indices_type.getRank();
979
980 // params rank is guaranteed to be at least 1.
981 // Produces an output tensor with shape:
982 // params.shape[:axis] + indices.shape + params.shape[axis + 1:]
983 std::vector<int64_t> shape(params_type.getShape());
984 int64_t axis_i = axis.getInt();
985
986 // For neg axis values, we wrap around params, e.g. axis = -1 => params[:-1]
987 if (axis_i < 0) {
988 axis_i += params_rank;
989 }
990
991 // params must be at least rank axis + 1
992 if (params_rank < axis_i + 1) {
993 emitError(result.location, "params must be at least rank axis + 1");
994 }
995
996 if (indices_rank == 0) {
997 // Scalar indices (output is rank(params) - 1).
998 // Erase shape[axis]
999 shape.erase(shape.begin() + axis_i);
1000 } else if (indices_rank == 1) {
1001 // Vector indices (output is rank(params)).
1002 // Copy indices.shape into params.shape[axis]
1003 std::copy(std::begin(indices_type.getShape()),
1004 std::end(indices_type.getShape()), std::begin(shape) + axis_i);
1005 } else {
1006 // Higher rank indices (output is rank(params) + rank(indices) - 1).
1007 shape.resize(params_rank + indices_rank - 1);
1008 // Copy params.shape[axis + 1: ] into shape[axis + indices_rank:]
1009 std::copy(std::begin(params_type.getShape()) + axis_i + 1,
1010 std::end(params_type.getShape()),
1011 std::begin(shape) + axis_i + indices_rank);
1012
1013 // Copy indices.shape into params.shape[axis]
1014 std::copy(std::begin(indices_type.getShape()),
1015 std::end(indices_type.getShape()), std::begin(shape) + axis_i);
1016 }
1017
1018 TFL::GatherOp::build(
1019 *builder, result,
1020 RankedTensorType::get(shape, params_type.getElementType()), params,
1021 indices, axis);
1022 }
1023
1024 //===----------------------------------------------------------------------===//
1025 // ScatterNdOp
1026 //===----------------------------------------------------------------------===//
1027
Verify(ScatterNdOp op)1028 static LogicalResult Verify(ScatterNdOp op) {
1029 auto indices = op.indices();
1030 auto updates = op.updates();
1031 auto shape = op.shape();
1032 auto output = op.output();
1033
1034 auto updates_type = updates.getType().cast<ShapedType>();
1035 auto indices_type = indices.getType().cast<ShapedType>();
1036
1037 if (!indices_type.hasStaticShape() || !updates_type.hasStaticShape()) {
1038 return success();
1039 }
1040
1041 // Checks if the shape of `updates` is a tensor of shape
1042 // `indices.shape[:-1] + shape[indices.shape[-1]:]`, as described in
1043 // ScatterNd op description.
1044
1045 auto outer_dims = indices_type.getRank() - 1;
1046 auto outermost_dim = indices_type.getDimSize(outer_dims);
1047 // Checks whether the first `outer_dims` dimensions of `indices` and
1048 // `updates` are equal.
1049 for (auto i = 0; i < outer_dims; i++) {
1050 if (indices_type.getDimSize(i) != updates_type.getDimSize(i)) {
1051 return op.emitOpError()
1052 << "indices.Dims(" << i << ") == " << indices_type.getDimSize(i)
1053 << ", but updates.Dims(" << i
1054 << ") == " << updates_type.getDimSize(i);
1055 }
1056 }
1057
1058 auto output_type = output.getType().cast<ShapedType>();
1059 auto shape_type = shape.getType().cast<ShapedType>();
1060 if (shape_type.hasStaticShape()) {
1061 // Check the rank of `shape`.
1062 auto output_rank = outermost_dim + updates_type.getRank() - outer_dims;
1063 if (shape_type.getDimSize(0) != output_rank) {
1064 return op.emitOpError()
1065 << "shape must be a vector of length " << output_rank;
1066 }
1067 if (output_type.hasRank()) {
1068 if (output_type.getRank() != output_rank) {
1069 return op.emitOpError()
1070 << "output must have the same rank with the length of shape = "
1071 << output_rank;
1072 }
1073 }
1074 }
1075
1076 DenseIntElementsAttr shape_value;
1077 if (matchPattern(shape, m_Constant(&shape_value))) {
1078 for (const auto shape_elem : shape_value) {
1079 if (shape_elem.getSExtValue() <= 0) {
1080 return op.emitOpError("all elements of shape must be > 0");
1081 }
1082 }
1083
1084 // Checks whether the last `(shape_type.getDimSize(0) - outermost_dim)`
1085 // dimensions of `updates` and `shape` are equal.
1086 for (auto shape_it : llvm::enumerate(shape_value)) {
1087 int64_t i = shape_it.index();
1088 auto value = shape_it.value().getSExtValue();
1089 if (i >= outermost_dim) {
1090 auto corresponding_dim = i - outermost_dim + outer_dims;
1091 if (value != updates_type.getDimSize(corresponding_dim)) {
1092 return op.emitOpError()
1093 << "updates.Dims(" << i
1094 << ") == " << updates_type.getDimSize(corresponding_dim)
1095 << ", but shape[" << i << "] == " << value;
1096 }
1097 }
1098 }
1099
1100 // Checks if the output has the shape specified by `shape`.
1101 if (output_type.hasStaticShape()) {
1102 for (auto shape_it : llvm::enumerate(shape_value)) {
1103 int i = shape_it.index();
1104 auto value = shape_it.value().getSExtValue();
1105 if (output_type.getDimSize(i) != value) {
1106 return op.emitOpError()
1107 << "output shape [" << output_type.getShape()
1108 << "] must be equal to the value of shape " << shape_value;
1109 }
1110 }
1111 }
1112 }
1113 return success();
1114 }
1115
1116 //===----------------------------------------------------------------------===//
1117 // MulOp
1118 //===----------------------------------------------------------------------===//
1119
fold(ArrayRef<Attribute> operands)1120 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
1121 // TODO(b/142478136): Handle fused ops.
1122 if (fused_activation_function() != "NONE") return {};
1123
1124 // This function is performance critical for op fusion patterns, e.g.
1125 // FuseBinaryOpToPrecedingAffine and FuseMulOrDivWithConv2dOrDepthwiseConv2d.
1126 // So a few specializations are provided to evaluate the math operation
1127 // more efficiently.
1128
1129 // Specialization for f32 type.
1130 if (getType().cast<ShapedType>().getElementType().isF32()) {
1131 return ConstFoldBinaryOp<FloatAttr, float>(
1132 getType(), operands[0], operands[1],
1133 [](float a, float b) { return a * b; });
1134 }
1135
1136 // Specialization for bf16 type.
1137 if (getType().cast<ShapedType>().getElementType().isBF16()) {
1138 return ConstFoldBinaryOp<FloatAttr, Eigen::bfloat16>(
1139 getType(), operands[0], operands[1],
1140 [](Eigen::bfloat16 a, Eigen::bfloat16 b) { return a * b; });
1141 }
1142
1143 // Generic fallback with APFloat
1144 return ConstFoldBinaryOp(
1145 getType(), operands, [](APFloat a, APFloat b) { return a * b; },
1146 [](APInt a, APInt b) { return a * b; });
1147 }
1148
1149 //===----------------------------------------------------------------------===//
1150 // DivOp
1151 //===----------------------------------------------------------------------===//
1152
fold(ArrayRef<Attribute> operands)1153 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
1154 // TODO(b/142478136): Handle fused ops.
1155 if (fused_activation_function() != "NONE") return {};
1156 return ConstFoldBinaryOp(
1157 getType(), operands, [](APFloat a, APFloat b) { return a / b; },
1158 [](APInt a, APInt b) { return a.sdiv(b); });
1159 }
1160
1161 //===----------------------------------------------------------------------===//
1162 // PackOp
1163 //===----------------------------------------------------------------------===//
1164
1165 // TODO(b/133486129): Implement shape inference for pack
1166
Verify(PackOp op)1167 static LogicalResult Verify(PackOp op) {
1168 // TODO(antiagainst): Implement other checks as in
1169 // tensorflow/lite/kernels/pack.cc
1170
1171 if (op.getOperation()->getNumOperands() != op.values_count())
1172 return op.emitOpError("input count should match 'values_count' attribute");
1173
1174 Value operand0 = op.getOperand(0);
1175 auto input_type = operand0.getType().cast<ShapedType>();
1176
1177 // Check axis bounds.
1178 if (input_type.hasRank()) {
1179 int32_t axis_value = op.axis();
1180 if (axis_value < 0) axis_value += input_type.getRank() + 1;
1181 if (axis_value < 0 || axis_value >= input_type.getRank() + 1)
1182 return op.emitOpError()
1183 << "op attribute 'axis' should be in range [-rank - 1, rank + 1), "
1184 << "got rank = " << input_type.getRank()
1185 << ", and axis = " << op.axis();
1186 }
1187
1188 // Make sure all inputs have the same shape and element type.
1189 // TODO(b/135032063): Simplify once fixed.
1190 for (Type operand_type : op.getOperandTypes()) {
1191 if (failed(mlir::verifyCompatibleShape(input_type, operand_type)))
1192 return op.emitOpError("operands should be of the same type. got ")
1193 << input_type << ", " << operand_type;
1194 }
1195
1196 return success();
1197 }
1198
1199 //===----------------------------------------------------------------------===//
1200 // PReluOp
1201 //===----------------------------------------------------------------------===//
1202
Verify(PReluOp op)1203 static LogicalResult Verify(PReluOp op) {
1204 auto input_type = op.input().getType().cast<ShapedType>();
1205 auto alpha_type = op.alpha().getType().cast<ShapedType>();
1206 auto output_type = op.output().getType().cast<ShapedType>();
1207
1208 if (input_type.hasStaticShape() && alpha_type.hasStaticShape()) {
1209 if (input_type.getRank() != alpha_type.getRank() + 1) {
1210 return op.emitOpError("'alpha' should have one less rank than 'input'.");
1211 }
1212
1213 // Check if alpha is broadcastable
1214 for (int i = 0; i < alpha_type.getRank(); i++) {
1215 if (alpha_type.getDimSize(i) != input_type.getDimSize(i + 1) &&
1216 alpha_type.getDimSize(i) != 1) {
1217 return op.emitOpError(
1218 llvm::formatv("'alpha' is not broadcastable at dimension {0}.", i));
1219 }
1220 }
1221 }
1222
1223 if (input_type.hasStaticShape() && output_type.hasStaticShape()) {
1224 if (input_type.getRank() != output_type.getRank()) {
1225 return op.emitOpError("'input' and 'output' should have the same rank.");
1226 }
1227
1228 // Check if input and output shapes are same
1229 for (int i = 0; i < input_type.getRank(); i++) {
1230 if (input_type.getDimSize(i) != output_type.getDimSize(i)) {
1231 return op.emitOpError(
1232 "'input' and 'output' should have the same shape.");
1233 }
1234 }
1235 }
1236 return success();
1237 }
1238
1239 //===----------------------------------------------------------------------===//
1240 // ReshapeOp
1241 //===----------------------------------------------------------------------===//
1242
1243 namespace {
1244 // This pattern matches and merges a tfl.reshape under the following
1245 // condition:
1246 // * The input's defining op is another tfl.reshape.
1247 // TODO(antiagainst): This pattern probably should be moved to the peephole
1248 // category, after we have the infra for peephole passes.
1249 struct RemoveAdjacentReshape : public RewritePattern {
RemoveAdjacentReshapemlir::TFL::__anonc99a34740d11::RemoveAdjacentReshape1250 RemoveAdjacentReshape(MLIRContext *context)
1251 : RewritePattern(ReshapeOp::getOperationName(), 1, context) {}
1252
matchmlir::TFL::__anonc99a34740d11::RemoveAdjacentReshape1253 LogicalResult match(Operation *op) const override {
1254 auto thisOp = cast<ReshapeOp>(op);
1255 auto prevOp = thisOp.getOperand(0).getDefiningOp();
1256 return isa_and_nonnull<ReshapeOp>(prevOp) ? success() : failure();
1257 }
1258
rewritemlir::TFL::__anonc99a34740d11::RemoveAdjacentReshape1259 void rewrite(Operation *op, PatternRewriter &rewriter) const override {
1260 auto thisOp = cast<ReshapeOp>(op);
1261 auto prevOp = cast<ReshapeOp>(thisOp.getOperand(0).getDefiningOp());
1262
1263 // Replace
1264 // %1 = "tfl.reshape"(%0, %shape0)
1265 // %2 = "tfl.reshape"(%1, %shape1)
1266 // With
1267 // %2 = "tfl.reshape"(%0, %shape1)
1268 rewriter.replaceOpWithNewOp<ReshapeOp>(
1269 op, thisOp.getType(), prevOp.getOperand(0), thisOp.getOperand(1));
1270 }
1271 };
1272
1273 // The kernel expects an 1-D tensor for the shape operand if it presents. If all
1274 // the dimensions are '1's except the last dimension, it will be reshaped to a
1275 // 1-D tensor.
1276 // Note that this pattern doesn't check or change the content of the shape
1277 // tensor.
1278 struct ConvertShapeTo1D : public OpRewritePattern<ReshapeOp> {
1279 using OpRewritePattern<ReshapeOp>::OpRewritePattern;
1280
matchAndRewritemlir::TFL::__anonc99a34740d11::ConvertShapeTo1D1281 LogicalResult matchAndRewrite(ReshapeOp reshape,
1282 PatternRewriter &rewriter) const override {
1283 if (!reshape.shape().hasOneUse()) return failure();
1284
1285 DenseIntElementsAttr shape;
1286 if (!matchPattern(reshape.shape(), m_Constant(&shape))) {
1287 return failure();
1288 }
1289 // It is already a 1-D constant, no change.
1290 auto old_shape = shape.getType().getShape();
1291 if (old_shape.size() == 1) {
1292 return failure();
1293 }
1294 // Verify all the leading dimensions are length one, except the last one.
1295 for (auto it = ++old_shape.rbegin(); it != old_shape.rend(); ++it) {
1296 if (*it != 1) {
1297 reshape->emitError(
1298 "Non-vector shape input is used, might cause runtime error");
1299 return failure();
1300 }
1301 }
1302 auto new_shape = shape.reshape(RankedTensorType::get(
1303 {*old_shape.rbegin()}, shape.getType().getElementType()));
1304 rewriter.replaceOpWithNewOp<TFL::ConstOp>(reshape.shape().getDefiningOp(),
1305 new_shape);
1306 return success();
1307 }
1308 };
1309
1310 } // end anonymous namespace
1311
fold(ArrayRef<Attribute> operands)1312 OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
1313 // Remove identity reshape with both static result and input shape.
1314 auto result_type = getType().cast<ShapedType>();
1315 auto input_type = getOperand(0).getType().cast<ShapedType>();
1316 if (result_type.hasStaticShape() && result_type == input_type) {
1317 return getOperand(0);
1318 }
1319
1320 // Constant folding
1321 if (auto dense_elements = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {
1322 // If the result type isn't static, tries to derive the result type from
1323 // the #2 operand.
1324 if (!result_type.hasStaticShape()) {
1325 auto shape_elements = operands[1].dyn_cast_or_null<DenseElementsAttr>();
1326 if (!shape_elements) return nullptr;
1327
1328 SmallVector<int64_t, 4> shape_data;
1329 for (const auto &it : shape_elements.getValues<APInt>()) {
1330 shape_data.push_back(it.getSExtValue());
1331 }
1332 result_type =
1333 RankedTensorType::get(shape_data, input_type.getElementType());
1334 }
1335 return dense_elements.reshape(result_type);
1336 }
1337
1338 return nullptr;
1339 }
1340
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1341 void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1342 MLIRContext *context) {
1343 results.insert<RemoveAdjacentReshape, ConvertShapeTo1D>(context);
1344 }
1345
1346 using ReshapeErrorHandler =
1347 llvm::function_ref<LogicalResult(const llvm::Twine &)>;
1348
GetReshapeOutputType(Value input,Value shape,ReshapeErrorHandler error_handler,TensorType & output_ty)1349 LogicalResult GetReshapeOutputType(Value input, Value shape,
1350 ReshapeErrorHandler error_handler,
1351 TensorType &output_ty) {
1352 auto input_ty = input.getType().cast<TensorType>();
1353 auto element_ty = input_ty.getElementType();
1354 output_ty = UnrankedTensorType::get(element_ty);
1355
1356 auto shape_ty = shape.getType().dyn_cast<RankedTensorType>();
1357 if (!shape_ty) return success();
1358 if (shape_ty.getRank() != 1)
1359 return error_handler(llvm::formatv(
1360 "requires 'shape' to be rank 1, but got {0}", shape_ty.getRank()));
1361
1362 DenseIntElementsAttr shape_attr;
1363 if (!matchPattern(shape, m_Constant(&shape_attr))) {
1364 // If only shape of `shape` is known, return ranked but dynamic output
1365 // shape.
1366 if (shape_ty.hasStaticShape()) {
1367 llvm::SmallVector<int64_t, 8> dynamic_shape(shape_ty.getDimSize(0),
1368 ShapedType::kDynamicSize);
1369 output_ty = RankedTensorType::get(dynamic_shape, element_ty);
1370 }
1371 return success();
1372 }
1373
1374 // Detect if reshape output shape is folded.
1375 bool shape_ty_zero_dim = false;
1376 int unknown_index = -1;
1377 // The product of constant shape argument excluding unknown dimension.
1378 int64_t shape_ty_size = 1;
1379 llvm::SmallVector<int64_t, 8> output_ty_shape;
1380 output_ty_shape.reserve(shape_attr.getNumElements());
1381 for (const auto &dim : llvm::enumerate(shape_attr.getIntValues())) {
1382 const int64_t size = dim.value().getSExtValue();
1383 if (size == ShapedType::kDynamicSize) {
1384 if (unknown_index != -1)
1385 return error_handler(llvm::formatv(
1386 "requires 'shape' to have at most one dynamic dimension, but got "
1387 "multiple dynamic dimensions at indices {0} and {1}. You need to "
1388 "set up the unspecified size(s) to avoid this problem, for example,"
1389 "setting batch size in keras model or setting unspecified input "
1390 "size(s) with fixed ones.",
1391 unknown_index, dim.index()));
1392
1393 unknown_index = dim.index();
1394 } else if (size == 0) {
1395 shape_ty_zero_dim = true;
1396 } else if (size > 0) {
1397 shape_ty_size *= size;
1398 } else {
1399 return error_handler(
1400 llvm::formatv("requires 'shape' to have dimensions greater than -1, "
1401 "but got {0} at index {1}",
1402 size, dim.index()));
1403 }
1404 output_ty_shape.push_back(size);
1405 }
1406
1407 if (!input_ty.hasStaticShape()) {
1408 output_ty = RankedTensorType::get(output_ty_shape, element_ty);
1409 return success();
1410 }
1411
1412 // Compute the value of the unknown dimension.
1413 if (unknown_index != -1) {
1414 // Compute number of elements in tensor shape.
1415 int64_t input_ty_size = 1;
1416 bool input_ty_zero_dim = false;
1417 for (const auto &dim : input_ty.getShape()) {
1418 if (dim > 0 || !shape_ty_zero_dim) {
1419 input_ty_size *= dim;
1420 } else {
1421 input_ty_zero_dim = true;
1422 }
1423 }
1424
1425 const int64_t missing_dim = input_ty_size / shape_ty_size;
1426 if (!input_ty_zero_dim && shape_ty_size * missing_dim != input_ty_size)
1427 return error_handler(
1428 llvm::formatv("requires 'input' number of elements be a multiple of "
1429 "{0}, but got {1}",
1430 shape_ty_size, input_ty_size));
1431
1432 // Set the unknown dimension such that total number of elements remain
1433 // constant.
1434 output_ty_shape[unknown_index] = missing_dim;
1435 }
1436
1437 output_ty = RankedTensorType::get(output_ty_shape, element_ty);
1438
1439 return success();
1440 }
1441
Verify(ReshapeOp op)1442 static LogicalResult Verify(ReshapeOp op) {
1443 auto error_handler = [&op](const llvm::Twine &message) -> LogicalResult {
1444 return op.emitOpError() << message;
1445 };
1446 TensorType expected_ty;
1447 if (failed(GetReshapeOutputType(op.input(), op.shape(), error_handler,
1448 expected_ty)))
1449 return failure();
1450
1451 auto output_ty = op.getType().dyn_cast<RankedTensorType>();
1452 if (!output_ty) return success();
1453 auto input_ty = op.input().getType().cast<TensorType>();
1454 if (output_ty.hasStaticShape() && input_ty.hasStaticShape()) {
1455 const int64_t output_ty_size = output_ty.getNumElements();
1456 const int64_t input_ty_size = input_ty.getNumElements();
1457 if (input_ty_size != output_ty_size)
1458 return op.emitOpError() << "requires 'output' number of elements to "
1459 "match 'input' number of elements, but got "
1460 << output_ty_size << " and " << input_ty_size;
1461 }
1462
1463 if (!TF::AreCastCompatible({output_ty, expected_ty}))
1464 return op.emitOpError()
1465 << "requires 'output' type " << output_ty
1466 << " to be cast compatible with expected type " << expected_ty;
1467
1468 return success();
1469 }
1470
1471 //===----------------------------------------------------------------------===//
1472 // PackOp
1473 //===----------------------------------------------------------------------===//
1474
1475 // Remove redundant unpack pack op.
1476 // If a unpack op is followed by a pack op, we can remove the pack op, if the
1477 // unpack op is only consumed by the pack op, it will be removed as well.
1478 // An example illustration is:
1479 // Unpack [5, 8, 9], axis = 1
1480 // / \
1481 // value ... value [5, 9], 8 values in total
1482 // \ /
1483 // pack, axis = 1
1484 // |
1485 // value [5, 8, 9]
1486 //
1487 // This can actually be simplified into just:
1488 //
1489 // => Value [5, 8, 9]
1490 // TODO(b/133341698): Move to tablegen when variadic is supported.
1491 struct RemoveRedundantUnpackPack : public RewritePattern {
RemoveRedundantUnpackPackmlir::TFL::RemoveRedundantUnpackPack1492 explicit RemoveRedundantUnpackPack(MLIRContext *context)
1493 : RewritePattern(PackOp::getOperationName(), 2, context) {}
1494
matchAndRewritemlir::TFL::RemoveRedundantUnpackPack1495 LogicalResult matchAndRewrite(Operation *op,
1496 PatternRewriter &rewriter) const override {
1497 TFL::PackOp pack_op = cast<TFL::PackOp>(op);
1498 Operation *first_input = pack_op.getOperand(0).getDefiningOp();
1499 if (!first_input) return failure();
1500 auto input_unpack_op = dyn_cast_or_null<TFL::UnpackOp>(first_input);
1501 if (!input_unpack_op) return failure();
1502
1503 // The unpack & pack should have the same axis & num inputs/outputs.
1504 if (pack_op.axis() != input_unpack_op.axis() ||
1505 pack_op.values_count() != input_unpack_op.num())
1506 return failure();
1507
1508 const int total_pack_inputs = pack_op.getNumOperands();
1509 const int num_results = input_unpack_op.getNumResults();
1510 if (total_pack_inputs != num_results) return failure();
1511 for (auto input_output :
1512 llvm::zip(pack_op.getOperands(), input_unpack_op.getResults())) {
1513 Value pack_input = std::get<0>(input_output);
1514 Value unpack_output = std::get<1>(input_output);
1515 // Make sure the ordering is the same for the pack op & unpack op.
1516 if (pack_input != unpack_output) return failure();
1517 }
1518
1519 // Replace the pack's output to the unpack's input.
1520 rewriter.replaceOp(pack_op, input_unpack_op.getOperand());
1521 // At this point, we don't manually remove the redundant pack op & unpack op
1522 // (we cannot actually), but trust the PatterRewriter to garbage collect
1523 // these two ops.
1524 return success();
1525 }
1526 };
1527
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1528 void PackOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1529 MLIRContext *context) {
1530 results.insert<RemoveRedundantUnpackPack>(context);
1531 }
1532
1533 //===----------------------------------------------------------------------===//
1534 // SliceOp
1535 //===----------------------------------------------------------------------===//
1536
Verify(SliceOp op)1537 static LogicalResult Verify(SliceOp op) {
1538 auto input_type = op.input().getType().cast<ShapedType>();
1539 auto begin_type = op.begin().getType().cast<ShapedType>();
1540 auto size_type = op.size().getType().cast<ShapedType>();
1541 if (input_type.hasStaticShape() && begin_type.hasStaticShape() &&
1542 size_type.hasStaticShape()) {
1543 if (input_type.getRank() != begin_type.getNumElements()) {
1544 return op.emitError(
1545 "begin tensor elements size is not equal to input tensor rank");
1546 }
1547
1548 if (input_type.getRank() != size_type.getNumElements()) {
1549 return op.emitError(
1550 "size tensor elements size is not equal to input tensor rank");
1551 }
1552 }
1553
1554 DenseIntElementsAttr begin;
1555 if (matchPattern(op.begin(), m_Constant(&begin))) {
1556 int axis = 0;
1557 for (auto begin_i : llvm::enumerate(begin)) {
1558 if (begin_i.value().getSExtValue() < 0) {
1559 return op.emitError(
1560 llvm::formatv("begin[{0}] cannot be negative", axis));
1561 }
1562 axis++;
1563 }
1564 }
1565
1566 DenseIntElementsAttr size;
1567 if (matchPattern(op.size(), m_Constant(&size))) {
1568 int axis = 0;
1569 for (auto size_i : llvm::enumerate(size)) {
1570 if (size_i.value().getSExtValue() < -1) {
1571 return op.emitError(
1572 llvm::formatv("size[{0}] cannot be negative other than -1", axis));
1573 }
1574 axis++;
1575 }
1576 }
1577
1578 if (begin && size && input_type.hasStaticShape()) {
1579 for (uint64_t i = 0, end = begin.getNumElements(); i < end; i++) {
1580 int begin_i =
1581 begin.getValue({i}).cast<IntegerAttr>().getValue().getSExtValue();
1582 int size_i =
1583 size.getValue({i}).cast<IntegerAttr>().getValue().getSExtValue();
1584 int dim_i = input_type.getShape()[i];
1585 if (begin_i > dim_i) {
1586 return op.emitOpError(llvm::formatv(
1587 "begin[{0}] cannot exceed dimension length: {1}", i, dim_i));
1588 }
1589 if (size_i >= 0 && begin_i + size_i > dim_i) {
1590 return op.emitError(llvm::formatv(
1591 "begin[{0}] + size[{0}] cannot exceed dimension length: {1}", i,
1592 dim_i));
1593 }
1594 }
1595 }
1596
1597 return success();
1598 }
1599
NarrowDownInt64InputValuesForOp(Operation * input_op,RankedTensorType value_type,Location loc,OpBuilder * builder)1600 TFL::ConstOp NarrowDownInt64InputValuesForOp(Operation *input_op,
1601 RankedTensorType value_type,
1602 Location loc, OpBuilder *builder) {
1603 if (input_op == nullptr) return nullptr;
1604
1605 mlir::DenseIntElementsAttr attr;
1606 if (!matchPattern(input_op, m_Constant(&attr))) {
1607 return nullptr;
1608 }
1609
1610 auto value_shape_type = mlir::RankedTensorType::get(
1611 value_type.getShape(), builder->getIntegerType(32));
1612
1613 SmallVector<int32_t, 4> value_i32;
1614 value_i32.reserve(value_type.getRank());
1615 for (const auto &size : attr) {
1616 value_i32.push_back(static_cast<int32_t>(size.getSExtValue()));
1617 }
1618 auto new_value_i32_attr =
1619 mlir::DenseIntElementsAttr::get(value_shape_type, value_i32);
1620
1621 return builder->create<TFL::ConstOp>(loc, new_value_i32_attr);
1622 }
1623
1624 // This will cast down int64 values for TFL slice op.
1625 // This will require the begin & size are constants.
1626 struct CastDonwInt64BeginEndToInt32 : public OpRewritePattern<TFL::SliceOp> {
1627 using OpRewritePattern<TFL::SliceOp>::OpRewritePattern;
1628
matchAndRewritemlir::TFL::CastDonwInt64BeginEndToInt321629 LogicalResult matchAndRewrite(TFL::SliceOp slice_op,
1630 PatternRewriter &rewriter) const override {
1631 auto begin = slice_op.begin();
1632 auto size = slice_op.size();
1633 auto begin_type = begin.getType().dyn_cast_or_null<RankedTensorType>();
1634 auto size_type = size.getType().dyn_cast_or_null<RankedTensorType>();
1635 auto begin_op = begin.getDefiningOp();
1636 auto size_op = size.getDefiningOp();
1637
1638 if (begin_op == nullptr && size_op == nullptr) return failure();
1639
1640 if (begin_type == nullptr && size_type == nullptr) return failure();
1641
1642 // Handle begin.
1643 if (begin_op && begin_type && begin_type.getElementType().isInteger(64)) {
1644 auto new_begin = NarrowDownInt64InputValuesForOp(
1645 begin_op, begin_type, slice_op.getLoc(), &rewriter);
1646 if (new_begin != nullptr) {
1647 slice_op.setOperand(1, new_begin);
1648 }
1649 }
1650
1651 // Handle size.
1652 if (size_op && size_type && size_type.getElementType().isInteger(64)) {
1653 auto new_size = NarrowDownInt64InputValuesForOp(
1654 size_op, size_type, slice_op.getLoc(), &rewriter);
1655 if (new_size != nullptr) {
1656 slice_op.setOperand(2, new_size);
1657 }
1658 }
1659
1660 return success();
1661 }
1662 };
1663
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1664 void SliceOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1665 MLIRContext *context) {
1666 results.insert<CastDonwInt64BeginEndToInt32>(context);
1667 }
1668
1669 //===----------------------------------------------------------------------===//
1670 // SubOp
1671 //===----------------------------------------------------------------------===//
1672
fold(ArrayRef<Attribute> operands)1673 OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
1674 // TODO(b/142478136): Handle fused ops.
1675 if (fused_activation_function() != "NONE") return {};
1676 return ConstFoldBinaryOp(
1677 getType(), operands, [](APFloat a, APFloat b) { return a - b; },
1678 [](APInt a, APInt b) { return a - b; });
1679 }
1680
1681 //===----------------------------------------------------------------------===//
1682 // TopKOp
1683 //===----------------------------------------------------------------------===//
1684
BuildTopKOp(OpBuilder * builder,OperationState & result,Value input,Value k)1685 static void BuildTopKOp(OpBuilder *builder, OperationState &result, Value input,
1686 Value k) {
1687 // Output size is only known if k is constant value. A negative dimension is
1688 // considered dynamic so use -1 here if k is not a constant value.
1689 int const_k = -1;
1690 ElementsAttr cst;
1691 if (matchPattern(k, m_Constant(&cst)))
1692 // These casts should all be valid due to how Tensor constants are stored.
1693 // TODO(jpienaar): This should use a helper function.
1694 const_k = cst.getValue<IntegerAttr>({}).getValue().getSExtValue();
1695
1696 auto val_type = input.getType().cast<TensorType>();
1697 // If value is unranked, then so is results.
1698 if (!val_type.hasRank())
1699 return TFL::TopKV2Op::build(
1700 *builder, result, UnrankedTensorType::get(val_type.getElementType()),
1701 UnrankedTensorType::get(builder->getIntegerType(32)), input, k);
1702
1703 // Resultant shape is value.shape[:-1] + [k]
1704 std::vector<int64_t> shape(val_type.getShape());
1705 shape[shape.size() - 1] = const_k;
1706 TFL::TopKV2Op::build(
1707 *builder, result, RankedTensorType::get(shape, val_type.getElementType()),
1708 RankedTensorType::get(shape, builder->getIntegerType(32)), input, k);
1709 }
1710
1711 //===----------------------------------------------------------------------===//
1712 // FakeQuantOp
1713 //===----------------------------------------------------------------------===//
1714
1715 // Return true if the op has non-empty "minmax" attribute.
HasValidMinMaxAttribute(Operation * op)1716 static inline bool HasValidMinMaxAttribute(Operation *op) {
1717 auto minmax = op->getAttrOfType<ArrayAttr>("minmax");
1718 return minmax && minmax.getValue().size() == 2;
1719 }
1720
1721 namespace {
1722
1723 /// This pattern matches and remove a tfl.fake_quant if all the users of this op
1724 /// and itself have "minmax" attribute set.
1725 struct DropFakeQuant : public RewritePattern {
DropFakeQuantmlir::TFL::__anonc99a34741111::DropFakeQuant1726 explicit DropFakeQuant(MLIRContext *context)
1727 : RewritePattern(FakeQuantOp::getOperationName(), 1, context) {}
1728
matchmlir::TFL::__anonc99a34741111::DropFakeQuant1729 LogicalResult match(Operation *op) const override {
1730 // We only match the op with valid "minmax" attribute.
1731 if (!HasValidMinMaxAttribute(op)) return failure();
1732
1733 // If all the users of this op have valid "minmax" attributes, it is matched
1734 // and can be removed.
1735 auto fakeQuantOp = cast<FakeQuantOp>(op);
1736 for (auto *operand : fakeQuantOp.getResult().getUsers())
1737 if (!HasValidMinMaxAttribute(operand)) return failure();
1738
1739 return success();
1740 }
1741
rewritemlir::TFL::__anonc99a34741111::DropFakeQuant1742 void rewrite(Operation *op, PatternRewriter &rewriter) const override {
1743 // Replace the matched FakeQuantOp by its primary operand.
1744 rewriter.replaceOp(op, op->getOperand(0));
1745 }
1746 };
1747 } // end anonymous namespace
1748
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1749 void FakeQuantOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1750 MLIRContext *context) {
1751 results.insert<DropFakeQuant>(context);
1752 }
1753
1754 //===----------------------------------------------------------------------===//
1755 // UnpackOp
1756 //===----------------------------------------------------------------------===//
1757
1758 // TODO(b/133486129): Implement shape inference for unpack
1759
inferReturnTypes(MLIRContext * context,Optional<Location> loc,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1760 LogicalResult UnpackOp::inferReturnTypes(
1761 MLIRContext *context, Optional<Location> loc, ValueRange operands,
1762 DictionaryAttr attributes, RegionRange regions,
1763 SmallVectorImpl<Type> &inferredReturnTypes) {
1764 UnpackOpAdaptor op(operands, attributes);
1765 // TODO(jpienaar): Refactor verify
1766 if (failed(op.verify(loc.hasValue() ? *loc : UnknownLoc::get(context))))
1767 return failure();
1768
1769 if (operands.size() != 1) {
1770 return emitOptionalError(loc, "input count should be equal to 1");
1771 }
1772
1773 const int64_t num_value = op.num().getInt();
1774 auto input_type = operands[0].getType().dyn_cast<ShapedType>();
1775 if (!input_type || !input_type.hasRank()) {
1776 // If input is unranked, then so is output.
1777 inferredReturnTypes.assign(
1778 num_value, UnrankedTensorType::get(input_type.getElementType()));
1779 return success();
1780 }
1781
1782 if (input_type.hasStaticShape() && input_type.getNumElements() <= 0) {
1783 return emitOptionalError(
1784 loc, "number of elements in input should be larger than 0");
1785 }
1786
1787 const int64_t rank = input_type.getRank();
1788 if (rank <= 0) {
1789 return emitOptionalError(loc, "input should be of rank larger than 0");
1790 }
1791
1792 int64_t axis_value = op.axis().getInt();
1793 if (axis_value < 0) {
1794 axis_value += rank;
1795 }
1796 if (axis_value < 0 || axis_value >= rank) {
1797 return emitOptionalError(
1798 loc, "attribute 'axis' should be in range [-rank, rank), got axis = ",
1799 op.axis().getInt(), ", and rank = ", rank);
1800 }
1801
1802 if (!ShapedType::isDynamic(input_type.getDimSize(axis_value)) &&
1803 input_type.getDimSize(axis_value) != num_value) {
1804 return emitOptionalError(loc, "output count should match 'num' attribute");
1805 }
1806
1807 auto output_shape = llvm::to_vector<4>(input_type.getShape());
1808 output_shape.erase(output_shape.begin() + axis_value);
1809
1810 auto output_type =
1811 RankedTensorType::get(output_shape, input_type.getElementType());
1812 inferredReturnTypes.assign(num_value, output_type);
1813
1814 return success();
1815 }
1816
isCompatibleReturnTypes(ArrayRef<Type> lhs,ArrayRef<Type> rhs)1817 bool UnpackOp::isCompatibleReturnTypes(ArrayRef<Type> lhs, ArrayRef<Type> rhs) {
1818 if (lhs.size() != rhs.size()) return false;
1819 for (auto pair : llvm::zip(lhs, rhs)) {
1820 if (failed(
1821 mlir::verifyCompatibleShape(std::get<0>(pair), std::get<1>(pair))))
1822 return false;
1823 }
1824 return true;
1825 }
1826
1827 //===----------------------------------------------------------------------===//
1828 // SplitOp
1829 //===----------------------------------------------------------------------===//
1830
1831 // Extracts and returns the signed integer constant in a 0-rank integer tensor
1832 // or 1-element 1-rank integer tensor if 'value' is a constant.
ExtractConstantIntFromTensor(Value value)1833 static llvm::Optional<int64_t> ExtractConstantIntFromTensor(Value value) {
1834 ElementsAttr attr;
1835 if (!matchPattern(value, m_Constant(&attr))) return {};
1836 if (attr.getNumElements() != 1) return {};
1837 IntegerAttr int_attr = *attr.getValues<IntegerAttr>().begin();
1838 return int_attr.getValue().getSExtValue();
1839 }
1840
1841 // Returns a RankedTensorType which is similar to `input_type` but replaces the
1842 // dimension size of `dim` with `dim_size`. For example,
1843 // `SubstituteRankedTensorTypeDimSize(tensor<3x4xi32>, 1, 2)` returns
1844 // `tensor<3x2xi32>`.
SubstituteRankedTensorTypeDimSize(RankedTensorType input_type,int64_t dim,int64_t dim_size)1845 static RankedTensorType SubstituteRankedTensorTypeDimSize(
1846 RankedTensorType input_type, int64_t dim, int64_t dim_size) {
1847 auto shape = input_type.getShape().vec();
1848 shape[dim] = dim_size;
1849 return RankedTensorType::get(shape, input_type.getElementType());
1850 }
1851
1852 // Verifies the output tensor types of SplitOp or SplitVOp.
1853 template <typename ExpectedOutputTypeGetter>
VerifySplitOpOutputTypes(Operation * op,int64_t num_splits,ExpectedOutputTypeGetter get_expected_output_type)1854 static LogicalResult VerifySplitOpOutputTypes(
1855 Operation *op, int64_t num_splits,
1856 ExpectedOutputTypeGetter get_expected_output_type) {
1857 for (int64_t i = 0; i < num_splits; ++i) {
1858 auto expected_output_type = get_expected_output_type(i);
1859 Value output = op->getResult(i);
1860 if (failed(verifyCompatibleShape(output.getType(), expected_output_type)))
1861 return op->emitOpError()
1862 << "output #" << i << " should be " << expected_output_type
1863 << " instead got " << output.getType();
1864 }
1865 return success();
1866 }
1867
Verify(SplitOp op)1868 static LogicalResult Verify(SplitOp op) {
1869 int64_t num_splits = op.num_splits();
1870 if (op.getNumResults() != num_splits)
1871 return op.emitOpError("output count should match 'num_splits' attribute");
1872
1873 // If 'split_dim' is not a constant, there are no other checks.
1874 llvm::Optional<int64_t> split_dim_opt =
1875 ExtractConstantIntFromTensor(op.split_dim());
1876 if (!split_dim_opt) return success();
1877
1878 // If 'input' is not a ranked tensor, there are no other checks.
1879 auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
1880 if (!input_type) return success();
1881
1882 int64_t split_dim = split_dim_opt.getValue();
1883 const int64_t rank = input_type.getRank();
1884 if (split_dim < 0) split_dim += rank;
1885 if (split_dim < 0 || split_dim >= rank)
1886 return op.emitOpError("'split_dim' should be in [-rank, rank)");
1887
1888 // If the 'split_dim' dimension of the 'input' tensor has a dynamic size,
1889 // there are no other checks.
1890 const int64_t dim_size = input_type.getDimSize(split_dim);
1891 if (ShapedType::isDynamic(dim_size)) return success();
1892
1893 if (dim_size % num_splits != 0)
1894 return op.emitOpError("'num_splits' should evenly divide 'split_dim' axis");
1895
1896 // Verifies output tensor types.
1897 RankedTensorType expected_output_type = SubstituteRankedTensorTypeDimSize(
1898 input_type, split_dim, dim_size / num_splits);
1899 return VerifySplitOpOutputTypes(
1900 op.getOperation(), num_splits,
1901 [expected_output_type](int64_t) { return expected_output_type; });
1902 }
1903
Verify(SplitVOp op)1904 static LogicalResult Verify(SplitVOp op) {
1905 int64_t num_splits = op.num_splits();
1906 if (op.getNumResults() != num_splits)
1907 return op.emitOpError("output count should match 'num_splits' attribute");
1908
1909 // If 'split_dim' is not a constant, there are no other checks.
1910 llvm::Optional<int64_t> split_dim_opt =
1911 ExtractConstantIntFromTensor(op.split_dim());
1912 if (!split_dim_opt) return success();
1913
1914 // If 'input' is not a ranked tensor, there are no other checks.
1915 auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
1916 if (!input_type) return success();
1917
1918 int64_t split_dim = split_dim_opt.getValue();
1919 const int64_t rank = input_type.getRank();
1920 if (split_dim < 0) split_dim += rank;
1921 if (split_dim < 0 || split_dim >= rank)
1922 return op.emitOpError("'split_dim' should be in [-rank, rank)");
1923
1924 // If the 'split_dim' dimension of the 'input' tensor has a dynamic size,
1925 // there are no other checks.
1926 const int64_t dim_size = input_type.getDimSize(split_dim);
1927 if (ShapedType::isDynamic(dim_size)) return success();
1928
1929 // If 'size_splits' is not a constant, there are no other checks.
1930 ElementsAttr size_splits_attr;
1931 if (!matchPattern(op.size_splits(), m_Constant(&size_splits_attr)))
1932 return success();
1933
1934 if (size_splits_attr.getNumElements() != num_splits) {
1935 auto size_splits_type = op.size_splits().getType().cast<RankedTensorType>();
1936 RankedTensorType expected_size_splits_type =
1937 RankedTensorType::get({num_splits}, size_splits_type.getElementType());
1938 return op.emitOpError("'size_splits' should be ")
1939 << expected_size_splits_type;
1940 }
1941
1942 // Normalizes and verifies 'size_splits'.
1943 // Note: TensorFlow allows one -1 element in 'size_splits'. The -1 element
1944 // means the rest of the dimension size.
1945 llvm::SmallVector<int64_t, 4> size_splits;
1946 size_splits.reserve(num_splits);
1947
1948 int64_t negative_size_split_loc = -1;
1949 int64_t total_size_splits = 0;
1950
1951 for (int64_t i = 0; i < num_splits; ++i) {
1952 auto size_split_attr = size_splits_attr.getValue<IntegerAttr>(i);
1953 int64_t size_split = size_split_attr.getValue().getSExtValue();
1954 size_splits.push_back(size_split);
1955 if (size_split >= 0) {
1956 total_size_splits += size_split;
1957 continue;
1958 }
1959 if (size_split < -1)
1960 return op.emitOpError(
1961 "elements of 'size_splits' should be greater than or equal to -1");
1962 if (negative_size_split_loc != -1)
1963 return op.emitOpError("'size_splits' can only have one -1");
1964 negative_size_split_loc = i;
1965 }
1966
1967 if (negative_size_split_loc != -1) {
1968 if (total_size_splits > dim_size)
1969 return op.emitOpError(
1970 "sum of non-negative elements of 'size_splits' is greater than the "
1971 "dimension size of 'split_dim' axis");
1972 size_splits[negative_size_split_loc] = dim_size - total_size_splits;
1973 total_size_splits = dim_size;
1974 }
1975
1976 if (total_size_splits != dim_size)
1977 return op.emitOpError(
1978 "sum of 'size_splits' should match the dimension size of 'split_dim' "
1979 "axis");
1980
1981 // Verifies result tensor types.
1982 auto get_expected_output_type = [input_type, split_dim,
1983 &size_splits](int64_t i) {
1984 return SubstituteRankedTensorTypeDimSize(input_type, split_dim,
1985 size_splits[i]);
1986 };
1987 return VerifySplitOpOutputTypes(op.getOperation(), num_splits,
1988 get_expected_output_type);
1989 }
1990
1991 //===----------------------------------------------------------------------===//
1992 // MeanOp
1993 //===----------------------------------------------------------------------===//
1994
1995 // TODO(b/133854225): Implement shape inference to Mean
1996
1997 //===----------------------------------------------------------------------===//
1998 // LSTMOp
1999 //===----------------------------------------------------------------------===//
2000
Verify(LSTMOp op)2001 static LogicalResult Verify(LSTMOp op) {
2002 auto operands = op.GetStatefulOperands();
2003 if (operands.size() != 2 || operands[0] != 18 || operands[1] != 19) {
2004 return op.emitOpError("LSTMOp expected to have two stateful operands");
2005 }
2006
2007 const auto input_type = op.input().getType().cast<ShapedType>();
2008 // Since TFLite runtime generally supports dynamic shape/rank, if `input_type`
2009 // doesn't have static shape, we skip the shape check below.
2010 if (!input_type.hasStaticShape()) return success();
2011 // The input should be at least 2D tensor since it will go through fully
2012 // connected layer.
2013 if (!input_type.hasRank() || input_type.getRank() < 2)
2014 return op.emitOpError(
2015 "the first input operand should have more than 2 dimensions.");
2016
2017 const auto activation_state =
2018 op.input_activation_state().getType().cast<ShapedType>();
2019 const auto cell_state = op.input_cell_state().getType().cast<ShapedType>();
2020 const auto input_to_output_weights =
2021 op.input_to_output_weights().getType().cast<ShapedType>();
2022 const auto recurrent_to_output_weights =
2023 op.recurrent_to_output_weights().getType().cast<ShapedType>();
2024 if (activation_state.hasStaticShape() && cell_state.hasStaticShape() &&
2025 input_to_output_weights.hasStaticShape() &&
2026 recurrent_to_output_weights.hasStaticShape()) {
2027 const int n_input = input_type.getDimSize(input_type.getRank() - 1);
2028 const int n_cell = input_to_output_weights.getDimSize(0);
2029 const int n_output = recurrent_to_output_weights.getDimSize(1);
2030 const int output_state_size = activation_state.getNumElements();
2031 const int n_batch = input_type.getRank() == 2 ? input_type.getDimSize(0)
2032 : input_type.getDimSize(1);
2033 const int state_size = cell_state.getNumElements();
2034
2035 // Check if the dimension of the inputs matches.
2036 if ((output_state_size != n_batch * n_output) ||
2037 (state_size != n_batch * n_cell) ||
2038 (input_to_output_weights.getDimSize(1) != n_input) ||
2039 (recurrent_to_output_weights.getRank() != 2) ||
2040 (recurrent_to_output_weights.getDimSize(0) != n_cell) ||
2041 (input_to_output_weights.getRank() != 2)) {
2042 return op.emitOpError("inputs don't match with the dimensions.");
2043 }
2044
2045 const bool is_layer_norm_lstm =
2046 !op.forget_layer_norm_coefficients().getType().isa<NoneType>();
2047 if (is_layer_norm_lstm) {
2048 const auto forget_layer_norm_coefficients =
2049 op.forget_layer_norm_coefficients().getType().cast<ShapedType>();
2050 // If this lstm has layer normalization, this input value,
2051 // "forget_layer_norm_coefficients" should be a 1D tensor.
2052 if (!forget_layer_norm_coefficients.hasRank() ||
2053 forget_layer_norm_coefficients.getRank() != 1 ||
2054 forget_layer_norm_coefficients.getDimSize(0) != n_cell)
2055 return op.emitOpError(
2056 "coefficient inputs have more than 2 dimensions or "
2057 "don't match the dimension with input operand "
2058 "`input_to_output_weights`.");
2059 }
2060 }
2061
2062 return success();
2063 }
2064
2065 namespace {
2066
2067 // Replaces the optional bias operands with a "none" type value if the bias
2068 // values are constant zeros.
2069 struct RemoveLSTMOpZeroBias : public OpRewritePattern<LSTMOp> {
2070 using OpRewritePattern<LSTMOp>::OpRewritePattern;
2071
matchAndRewritemlir::TFL::__anonc99a34741411::RemoveLSTMOpZeroBias2072 LogicalResult matchAndRewrite(LSTMOp op,
2073 PatternRewriter &rewriter) const override {
2074 if (EqualsZero(op.input_gate_bias())) {
2075 auto none_value = rewriter.create<mlir::ConstantOp>(
2076 rewriter.getUnknownLoc(), rewriter.getUnitAttr());
2077 op.input_gate_biasMutable().assign(none_value);
2078 }
2079
2080 if (EqualsZero(op.projection_bias())) {
2081 auto none_value = rewriter.create<mlir::ConstantOp>(
2082 rewriter.getUnknownLoc(), rewriter.getUnitAttr());
2083 op.projection_biasMutable().assign(none_value);
2084 }
2085
2086 return success();
2087 }
2088 };
2089
2090 } // namespace
2091
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2092 void LSTMOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2093 MLIRContext *context) {
2094 results.insert<RemoveLSTMOpZeroBias>(context);
2095 }
2096
2097 //===----------------------------------------------------------------------===//
2098 // UnidirectionalSequenceLSTMOp
2099 //===----------------------------------------------------------------------===//
2100
Verify(UnidirectionalSequenceLSTMOp op)2101 static LogicalResult Verify(UnidirectionalSequenceLSTMOp op) {
2102 auto operands = op.GetStatefulOperands();
2103 if (operands.size() == 2 && operands[0] == 18 && operands[1] == 19) {
2104 return success();
2105 }
2106 return op.emitError(
2107 "UnidirectionalSequenceLSTMOp expected to have two stateful operands");
2108 }
2109
2110 //===----------------------------------------------------------------------===//
2111 // BidirectionalSequenceLSTMOp
2112 //===----------------------------------------------------------------------===//
2113
Verify(BidirectionalSequenceLSTMOp op)2114 static LogicalResult Verify(BidirectionalSequenceLSTMOp op) {
2115 auto operands = op.GetStatefulOperands();
2116 if (operands.size() == 4 && operands[0] == 35 && operands[1] == 36 &&
2117 operands[2] == 37 && operands[3] == 38) {
2118 return success();
2119 }
2120 return op.emitError(
2121 "BidirectionalSequenceLSTMOp expected to have four stateful operands");
2122 }
2123
2124 //===----------------------------------------------------------------------===//
2125 // UnidirectionalSequenceRNNOp
2126 //===----------------------------------------------------------------------===//
2127
Verify(UnidirectionalSequenceRNNOp op)2128 static LogicalResult Verify(UnidirectionalSequenceRNNOp op) {
2129 auto operands = op.GetStatefulOperands();
2130 if (operands.size() == 1 && operands[0] == 4) {
2131 return success();
2132 }
2133 return op.emitError(
2134 "UnidirectionalSequenceRNNOp expected to have one stateful operand");
2135 }
2136
2137 //===----------------------------------------------------------------------===//
2138 // SvdfOp
2139 //===----------------------------------------------------------------------===//
2140
Verify(SVDFOp op)2141 static LogicalResult Verify(SVDFOp op) {
2142 auto operands = op.GetStatefulOperands();
2143 if (operands.size() == 1 && operands[0] == 4) {
2144 return success();
2145 }
2146 return op.emitError("SvdfOp expected to have one stateful operand");
2147 }
2148
2149 //===----------------------------------------------------------------------===//
2150 // AbsOp
2151 //===----------------------------------------------------------------------===//
2152
fold(ArrayRef<Attribute> operands)2153 OpFoldResult AbsOp::fold(ArrayRef<Attribute> operands) {
2154 Type result_type = getType();
2155 // Only constant fold for tensor of f32 is implemented.
2156 if (!IsF32ShapedType(result_type)) return nullptr;
2157
2158 auto compute = [](APFloat value) -> APFloat { return llvm::abs(value); };
2159 return ConstFoldUnaryOp(result_type, operands[0], compute);
2160 }
2161
2162 //===----------------------------------------------------------------------===//
2163 // NegOp
2164 //===----------------------------------------------------------------------===//
2165
fold(ArrayRef<Attribute> operands)2166 OpFoldResult NegOp::fold(ArrayRef<Attribute> operands) {
2167 Type result_type = getType();
2168 // Only constant fold for tensor of f32 is implemented.
2169 if (!IsF32ShapedType(result_type)) return nullptr;
2170
2171 auto compute = [](APFloat value) -> APFloat { return llvm::neg(value); };
2172 return ConstFoldUnaryOp(result_type, operands[0], compute);
2173 }
2174
2175 //===----------------------------------------------------------------------===//
2176 // SinOp
2177 //===----------------------------------------------------------------------===//
2178
fold(ArrayRef<Attribute> operands)2179 OpFoldResult SinOp::fold(ArrayRef<Attribute> operands) {
2180 Type result_type = getType();
2181 // Only constant fold for tensor of f32 is implemented.
2182 if (!IsF32ShapedType(result_type)) return nullptr;
2183
2184 auto compute = [](APFloat value) -> APFloat {
2185 float f = value.convertToFloat();
2186 float result = std::sin(f);
2187 return APFloat(result);
2188 };
2189 return ConstFoldUnaryOp(result_type, operands[0], compute);
2190 }
2191
2192 //===----------------------------------------------------------------------===//
2193 // CosOp
2194 //===----------------------------------------------------------------------===//
2195
fold(ArrayRef<Attribute> operands)2196 OpFoldResult CosOp::fold(ArrayRef<Attribute> operands) {
2197 Type result_type = getType();
2198 // Only constant fold for tensor of f32 is implemented.
2199 if (!IsF32ShapedType(result_type)) return nullptr;
2200
2201 auto compute = [](APFloat value) -> APFloat {
2202 float f = value.convertToFloat();
2203 float result = std::cos(f);
2204 return APFloat(result);
2205 };
2206 return ConstFoldUnaryOp(result_type, operands[0], compute);
2207 }
2208
2209 //===----------------------------------------------------------------------===//
2210 // LogOp
2211 //===----------------------------------------------------------------------===//
2212
fold(ArrayRef<Attribute> operands)2213 OpFoldResult LogOp::fold(ArrayRef<Attribute> operands) {
2214 Type result_type = getType();
2215 // Only constant fold for tensor of f32 is implemented.
2216 if (!IsF32ShapedType(result_type)) return nullptr;
2217
2218 auto compute = [](APFloat value) -> APFloat {
2219 float f = value.convertToFloat();
2220 float result = std::log(f);
2221 return APFloat(result);
2222 };
2223 return ConstFoldUnaryOp(result_type, operands[0], compute);
2224 }
2225
2226 //===----------------------------------------------------------------------===//
2227 // SqrtOp
2228 //===----------------------------------------------------------------------===//
2229
fold(ArrayRef<Attribute> operands)2230 OpFoldResult SqrtOp::fold(ArrayRef<Attribute> operands) {
2231 Type result_type = getType();
2232 // Only constant fold for tensor of f32 is implemented.
2233 if (!IsF32ShapedType(result_type)) return nullptr;
2234
2235 auto compute = [](APFloat value) -> APFloat {
2236 float f = value.convertToFloat();
2237 float result = std::sqrt(f);
2238 return APFloat(result);
2239 };
2240 return ConstFoldUnaryOp(result_type, operands[0], compute);
2241 }
2242
2243 //===----------------------------------------------------------------------===//
2244 // RsqrtOp
2245 //===----------------------------------------------------------------------===//
2246
fold(ArrayRef<Attribute> operands)2247 OpFoldResult RsqrtOp::fold(ArrayRef<Attribute> operands) {
2248 Type result_type = getType();
2249 // Only constant fold for tensor of f32/bf16 is implemented.
2250 if (!IsF32ShapedType(result_type) && !IsBF16ShapedType(result_type))
2251 return nullptr;
2252
2253 auto compute = [](APFloat value) -> APFloat {
2254 bool loseInfo;
2255 const llvm::fltSemantics &original_float_semantics = value.getSemantics();
2256 value.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven,
2257 &loseInfo);
2258 float f = value.convertToFloat();
2259 APFloat result(1.f / std::sqrt(f));
2260 result.convert(original_float_semantics, APFloat::rmNearestTiesToEven,
2261 &loseInfo);
2262 return result;
2263 };
2264 return ConstFoldUnaryOp(result_type, operands[0], compute);
2265 }
2266
2267 //===----------------------------------------------------------------------===//
2268 // SquareOp
2269 //===----------------------------------------------------------------------===//
2270
fold(ArrayRef<Attribute> operands)2271 OpFoldResult SquareOp::fold(ArrayRef<Attribute> operands) {
2272 Type result_type = getType();
2273 // Only constant fold for tensor of f32 is implemented.
2274 if (!IsF32ShapedType(result_type)) return nullptr;
2275
2276 auto compute = [](APFloat value) -> APFloat { return value * value; };
2277 return ConstFoldUnaryOp(result_type, operands[0], compute);
2278 }
2279
2280 //===----------------------------------------------------------------------===//
2281 // RankOp
2282 //===----------------------------------------------------------------------===//
2283
fold(ArrayRef<Attribute> operands)2284 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
2285 assert(operands.size() == 1);
2286 auto result_type = getType().cast<ShapedType>();
2287 if (auto elements_attr = operands[0].dyn_cast_or_null<ElementsAttr>()) {
2288 auto rank = static_cast<int32_t>(elements_attr.getType().getRank());
2289 return DenseElementsAttr::get(result_type, {rank});
2290 }
2291
2292 // Also fold if `input` has a known rank.
2293 auto input_type = input().getType().cast<ShapedType>();
2294 // Do not fold if rank is zero because the TFLite converter doesn't
2295 // distinguish between unranked input and scalar input due to b/138865275.
2296 // TODO(b/138865275): Remove `input_type.getRank() != 0` in the following
2297 // predicate and fold the op when rank is zero.
2298 if (input_type.hasRank() && input_type.getRank() != 0) {
2299 auto rank = static_cast<int32_t>(input_type.getRank());
2300 return DenseElementsAttr::get(result_type, {rank});
2301 }
2302
2303 return nullptr;
2304 }
2305
2306 //===----------------------------------------------------------------------===//
2307 // ConstOp
2308 //===----------------------------------------------------------------------===//
2309
fold(ArrayRef<Attribute> operands)2310 OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
2311 assert(operands.empty() && "constant has no operands");
2312
2313 // Return the held attribute value.
2314 return value();
2315 }
2316
2317 //===----------------------------------------------------------------------===//
2318 // CastOp
2319 //===----------------------------------------------------------------------===//
2320
fold(ArrayRef<Attribute> operands)2321 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
2322 assert(operands.size() == 1);
2323 if (getElementTypeOrSelf(input()) == getElementTypeOrSelf(getType())) {
2324 return input();
2325 }
2326
2327 // For now, only supports cast between integer types.
2328 auto elements_attr = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
2329 if (!elements_attr) {
2330 return nullptr;
2331 }
2332
2333 auto result_element_type =
2334 getType().cast<ShapedType>().getElementType().dyn_cast<IntegerType>();
2335 auto operand_element_type = input()
2336 .getType()
2337 .cast<ShapedType>()
2338 .getElementType()
2339 .dyn_cast<IntegerType>();
2340 // Returns nullptr if either result/operand element type is not integer.
2341 if (!result_element_type || !operand_element_type) {
2342 return nullptr;
2343 }
2344
2345 const bool is_unsigned = operand_element_type.isUnsigned();
2346 const bool involves_bool = operand_element_type.getWidth() == 1 ||
2347 result_element_type.getWidth() == 1;
2348 const int output_bitwidth = result_element_type.getWidth();
2349 // The integer cast op is the same as C integer cast. Depends on the operand
2350 // type's signedness, we will determine whether or not sign extension is
2351 // needed.
2352 auto cast = [&](APInt value) {
2353 if (involves_bool) {
2354 // Handle boolean inputs or outputs explicitly as it doesn't have the same
2355 // behavior as extension or truncation.
2356 // true input should always be cast to 1 and not -1 as the sign extension
2357 // would do for signed outputs. Similarly, non-zero inputs should be cast
2358 // to true. Truncating even numbers to one bit will result in `false`.
2359 return APInt(result_element_type.getWidth(), value != 0);
2360 }
2361 return is_unsigned ? value.zextOrTrunc(output_bitwidth)
2362 : value.sextOrTrunc(output_bitwidth);
2363 };
2364
2365 return elements_attr.mapValues(result_element_type, cast);
2366 }
2367
2368 //===----------------------------------------------------------------------===//
2369 // SelectV2Op
2370 //===----------------------------------------------------------------------===//
2371
BuildSelectV2Op(Builder * builder,OperationState & result,Value cond,Value x,Value y)2372 static void BuildSelectV2Op(Builder *builder, OperationState &result,
2373 Value cond, Value x, Value y) {
2374 auto operand_type =
2375 OpTrait::util::getBroadcastedType(x.getType(), y.getType());
2376
2377 if (!operand_type)
2378 emitError(result.location) << "non-broadcastable operands: " << x.getType()
2379 << " and " << y.getType();
2380
2381 bool has_static_cond_shape = false;
2382 bool has_static_operand_shape = false;
2383 ArrayRef<int64_t> cond_shape;
2384 ArrayRef<int64_t> operand_shape;
2385
2386 if (auto shaped_type = cond.getType().dyn_cast<ShapedType>()) {
2387 if (shaped_type.hasStaticShape()) {
2388 has_static_cond_shape = true;
2389 cond_shape = shaped_type.getShape();
2390 }
2391 }
2392 if (auto shaped_type = operand_type.dyn_cast<ShapedType>()) {
2393 if (shaped_type.hasStaticShape()) {
2394 has_static_operand_shape = true;
2395 operand_shape = shaped_type.getShape();
2396 }
2397 }
2398
2399 SmallVector<int64_t, 4> broadcastedShape;
2400 if (has_static_cond_shape && has_static_operand_shape &&
2401 !OpTrait::util::getBroadcastedShape(cond_shape, operand_shape,
2402 broadcastedShape)) {
2403 emitError(result.location) << "non-broadcastable operands: " << operand_type
2404 << " and " << cond.getType();
2405 }
2406
2407 result.addOperands({cond, x, y});
2408
2409 auto elementType = x.getType().dyn_cast<ShapedType>().getElementType();
2410 if (has_static_cond_shape && has_static_operand_shape) {
2411 result.types.push_back(
2412 RankedTensorType::get(broadcastedShape, elementType));
2413 } else {
2414 result.types.push_back(UnrankedTensorType::get(elementType));
2415 }
2416 }
2417
2418 //===----------------------------------------------------------------------===//
2419 // RangeOp
2420 //===----------------------------------------------------------------------===//
2421
2422 namespace {
2423
2424 // Compute the length of a range (1-D) tensor given `start`, `limit`, `delta`.
2425 // Template parameter `FloatOrInt` must be standard C integer or floating-point
2426 // types.
2427 template <typename FloatOrInt>
GetLengthOfRange(FloatOrInt start,FloatOrInt limit,FloatOrInt delta)2428 int GetLengthOfRange(FloatOrInt start, FloatOrInt limit, FloatOrInt delta) {
2429 // Refer to the implementation in
2430 // tensorflow/lite/kernels/range.cc.
2431 return std::is_integral<FloatOrInt>::value
2432 ? ((std::abs(limit - start) + std::abs(delta) - 1) /
2433 std::abs(delta))
2434 : std::ceil(std::abs((limit - start) / delta));
2435 }
2436
2437 // Builds a constant range tensor of `result_elem_type` elements.
2438 // Template parameter `FloatOrIntAtrr` must be mlir::IntegerAttr or
2439 // mlir::FloatAttr.
2440 template <typename FloatOrIntAtrr>
BuildConstRangeTensor(Type result_elem_type,int num_elements,FloatOrIntAtrr start_attr,FloatOrIntAtrr delta_attr)2441 DenseElementsAttr BuildConstRangeTensor(Type result_elem_type, int num_elements,
2442 FloatOrIntAtrr start_attr,
2443 FloatOrIntAtrr delta_attr) {
2444 using ValueType = typename FloatOrIntAtrr::ValueType; // APInt or APFloat
2445 ValueType start = start_attr.getValue();
2446 ValueType delta = delta_attr.getValue();
2447
2448 SmallVector<ValueType, 16> new_values;
2449 new_values.reserve(num_elements);
2450 ValueType new_value = start;
2451 for (int i = 0; i < num_elements; ++i) {
2452 new_values.push_back(new_value);
2453 new_value = new_value + delta;
2454 }
2455 // Result is always a 1-D tensor.
2456 auto new_result_type =
2457 RankedTensorType::get({num_elements}, result_elem_type);
2458 return DenseElementsAttr::get(new_result_type, new_values);
2459 }
2460 } // namespace
2461
fold(ArrayRef<Attribute> operands)2462 OpFoldResult RangeOp::fold(ArrayRef<Attribute> operands) {
2463 assert(operands.size() == 3);
2464 auto start_tensor = operands[0].dyn_cast_or_null<ElementsAttr>();
2465 auto limit_tensor = operands[1].dyn_cast_or_null<ElementsAttr>();
2466 auto delta_tensor = operands[2].dyn_cast_or_null<ElementsAttr>();
2467 if (start_tensor && limit_tensor && delta_tensor) {
2468 // Operands should all be scalars
2469 assert(start_tensor.getType().getRank() == 0 &&
2470 limit_tensor.getType().getRank() == 0 &&
2471 delta_tensor.getType().getRank() == 0);
2472 Type elem_type = getType().cast<ShapedType>().getElementType();
2473 if (elem_type.isSignlessInteger()) {
2474 auto start_attr = start_tensor.getValue<IntegerAttr>({});
2475 auto limit_attr = limit_tensor.getValue<IntegerAttr>({});
2476 auto delta_attr = delta_tensor.getValue<IntegerAttr>({});
2477 const int num_elements = GetLengthOfRange(
2478 start_attr.getInt(), limit_attr.getInt(), delta_attr.getInt());
2479 return BuildConstRangeTensor(elem_type, num_elements, start_attr,
2480 delta_attr);
2481 } else if (elem_type.isa<FloatType>()) {
2482 auto start_attr = start_tensor.getValue<FloatAttr>({});
2483 auto limit_attr = limit_tensor.getValue<FloatAttr>({});
2484 auto delta_attr = delta_tensor.getValue<FloatAttr>({});
2485 const int num_elements = GetLengthOfRange(start_attr.getValueAsDouble(),
2486 limit_attr.getValueAsDouble(),
2487 delta_attr.getValueAsDouble());
2488 return BuildConstRangeTensor(elem_type, num_elements, start_attr,
2489 delta_attr);
2490 }
2491 }
2492
2493 return nullptr;
2494 }
2495
2496 //===----------------------------------------------------------------------===//
2497 // TransposeConvOp
2498 //===----------------------------------------------------------------------===//
2499
Verify(TransposeConvOp op)2500 static LogicalResult Verify(TransposeConvOp op) {
2501 ShapedType output_type = op.output().getType().cast<ShapedType>();
2502 ShapedType output_shape_type = op.output_shape().getType().cast<ShapedType>();
2503 if (output_type.hasRank() && output_shape_type.hasStaticShape()) {
2504 if (output_type.getRank() != output_shape_type.getDimSize(0)) {
2505 return op.emitOpError(llvm::formatv(
2506 "expect output type has rank = {0}, got output type {1}",
2507 output_shape_type.getDimSize(0), output_type));
2508 }
2509 }
2510
2511 DenseIntElementsAttr output_shape_elements;
2512 if (!matchPattern(op.output_shape(), m_Constant(&output_shape_elements))) {
2513 return success();
2514 }
2515
2516 llvm::SmallVector<int64_t, 4> output_shape;
2517 output_shape.reserve(output_shape_elements.getNumElements());
2518 for (auto dim : output_shape_elements.getValues<int>()) {
2519 output_shape.push_back(dim);
2520 }
2521
2522 auto expected_output_type =
2523 RankedTensorType::get(output_shape, output_type.getElementType());
2524 if (failed(mlir::verifyCompatibleShape(output_type, expected_output_type))) {
2525 return op.emitOpError(llvm::formatv("expect output type {0}, got {1}",
2526 expected_output_type, output_type));
2527 }
2528
2529 return success();
2530 }
2531
2532 //===----------------------------------------------------------------------===//
2533 // TransposeOp
2534 //===----------------------------------------------------------------------===//
2535
2536 namespace {
2537
2538 // Computes the permutation of a constant `input_tensor` according to `perm`.
2539 // The function recursively traverses the dimensions of the output tensor in
2540 // a row-major order and writes the value in the output tensor into
2541 // `new_values`.
ComputePermutation(ElementsAttr input_tensor,ArrayRef<int32_t> perm,ArrayRef<int64_t> output_shape,int num_dimensions,int output_axis,std::vector<uint64_t> * input_indices,std::vector<Attribute> * new_values)2542 void ComputePermutation(ElementsAttr input_tensor, ArrayRef<int32_t> perm,
2543 ArrayRef<int64_t> output_shape, int num_dimensions,
2544 int output_axis, std::vector<uint64_t> *input_indices,
2545 std::vector<Attribute> *new_values) {
2546 // Refer to the implementation of `Transpose` function in
2547 // tensorflow/lite/kernels/internal/reference/reference_ops.h
2548 assert(output_axis < num_dimensions);
2549 const int input_axis = perm[output_axis];
2550 for (int i = 0; i < output_shape[output_axis]; ++i) {
2551 // Update the input indices on `input_axis`.
2552 input_indices->at(input_axis) = i;
2553 // Write the value from `input_tensor` if it is the last axis or
2554 // recurse into the next axis.
2555 const bool is_last_axis = output_axis == num_dimensions - 1;
2556 if (is_last_axis) {
2557 new_values->push_back(input_tensor.getValue(*input_indices));
2558 } else {
2559 ComputePermutation(input_tensor, perm, output_shape, num_dimensions,
2560 output_axis + 1, input_indices, new_values);
2561 }
2562 }
2563 }
2564
2565 } // namespace
2566
fold(ArrayRef<Attribute> operands)2567 OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
2568 assert(operands.size() == 2);
2569 auto input_tensor = operands[0].dyn_cast_or_null<ElementsAttr>();
2570 auto perm_tensor = operands[1].dyn_cast_or_null<ElementsAttr>();
2571 if (!input_tensor || !perm_tensor) return nullptr;
2572
2573 // Do not try to fold elements attr of a quant type because
2574 // DenseElementsAttr does not support it.
2575 if (!getType().cast<ShapedType>().getElementType().isSignlessIntOrFloat())
2576 return nullptr;
2577
2578 assert(perm_tensor.getType().getRank() == 1);
2579 const int num_dimensions = input_tensor.getType().getRank();
2580 assert(perm_tensor.getType().getNumElements() == num_dimensions);
2581
2582 ArrayRef<int64_t> input_shape = input_tensor.getType().getShape();
2583 auto output_type = getType().cast<ShapedType>();
2584
2585 SmallVector<int32_t, 4> perm;
2586 SmallVector<int64_t, 4> output_shape;
2587 for (int i = 0; i < num_dimensions; ++i) {
2588 perm.push_back(
2589 perm_tensor.getValue<IntegerAttr>({static_cast<uint64_t>(i)}).getInt());
2590 output_shape.push_back(input_shape[perm[i]]);
2591
2592 // Check that the derived output shape matches the static shape.
2593 assert(!output_type.hasStaticShape() ||
2594 output_type.getShape()[i] == output_shape[i]);
2595 }
2596
2597 std::vector<Attribute> new_values;
2598 new_values.reserve(input_tensor.getType().getNumElements());
2599 std::vector<uint64_t> input_indices(num_dimensions);
2600 ComputePermutation(input_tensor, perm, output_shape, num_dimensions,
2601 /*output_axis=*/0, &input_indices, &new_values);
2602 auto result_type =
2603 RankedTensorType::get(output_shape, output_type.getElementType());
2604 return DenseElementsAttr::get(result_type, new_values);
2605 }
2606
Verify(TransposeOp op)2607 static LogicalResult Verify(TransposeOp op) {
2608 auto input_type = op.input().getType().cast<ShapedType>();
2609 auto perm_type = op.perm().getType().cast<ShapedType>();
2610 auto output_type = op.output().getType().cast<ShapedType>();
2611 if (input_type.hasStaticShape() && perm_type.hasStaticShape()) {
2612 if (perm_type.getNumElements() != input_type.getRank()) {
2613 return op.emitOpError(
2614 "perm tensor elements size is not equal to input tensor rank");
2615 }
2616 }
2617
2618 DenseIntElementsAttr perm;
2619 if (!matchPattern(op.perm(), m_Constant(&perm))) {
2620 return success();
2621 }
2622
2623 int index = 0;
2624 llvm::SmallVector<int64_t, 4> axes;
2625 for (const auto &axis_int : perm.getValues<APInt>()) {
2626 const int64_t axis = axis_int.getSExtValue();
2627 if (axis < 0 || (input_type.hasRank() && axis >= input_type.getRank())) {
2628 return op.emitOpError(
2629 llvm::formatv("perm[{0}] must be in [0, rank)", index));
2630 }
2631 if (std::count(axes.begin(), axes.end(), axis) > 0) {
2632 return op.emitOpError(
2633 llvm::formatv("perm[{0}] cannot have duplicated axis", index));
2634 }
2635 axes.push_back(axis);
2636 index++;
2637 }
2638
2639 if (input_type.hasStaticShape() && output_type.hasStaticShape()) {
2640 llvm::SmallVector<int64_t, 4> transposed_shape;
2641 for (int64_t axis : axes) {
2642 transposed_shape.push_back(input_type.getDimSize(axis));
2643 }
2644 auto expected_output_type =
2645 RankedTensorType::get(transposed_shape, input_type.getElementType());
2646 if (failed(
2647 mlir::verifyCompatibleShape(output_type, expected_output_type))) {
2648 return op.emitOpError(llvm::formatv("expect output type {0}, got {1}",
2649 expected_output_type, output_type));
2650 }
2651 }
2652
2653 return success();
2654 }
2655
2656 //===----------------------------------------------------------------------===//
2657 // WhileOp
2658 //===----------------------------------------------------------------------===//
2659
Verify(WhileOp op)2660 LogicalResult Verify(WhileOp op) {
2661 if (op.getNumOperands() != op.getNumResults())
2662 return op.emitOpError(llvm::formatv(
2663 "number of operands does not match number of results ({0} != {1})",
2664 op.getNumOperands(), op.getNumResults()));
2665 // TODO(jpienaar): Verify operand, result & block arguments types
2666 return success();
2667 }
2668
2669 namespace {
2670 // Canonicalize While op so that results and operands match and external values
2671 // are via implicit capture rather than via block args.
2672 struct WhileResultOperandsMatchAndImplicitCapture
2673 : public OpRewritePattern<WhileOp> {
2674 using OpRewritePattern<WhileOp>::OpRewritePattern;
2675
matchAndRewritemlir::TFL::__anonc99a34742011::WhileResultOperandsMatchAndImplicitCapture2676 LogicalResult matchAndRewrite(WhileOp while_op,
2677 PatternRewriter &rewriter) const override {
2678 // Replace values simply passed through the body with extern values
2679 // (in both body and condition regions as well as while result). The
2680 // block arguments of body and while match and so the corresponding cond
2681 // argument can be easily found.
2682 bool unchanged = true;
2683 auto &body_block = while_op.body().front();
2684 auto &cond_block = while_op.cond().front();
2685 auto &yield = *body_block.getTerminator();
2686 for (auto ba : body_block.getArguments()) {
2687 int arg_no = ba.getArgNumber();
2688 // Skip removing resources that are not read-only variables.
2689 if (getElementTypeOrSelf(ba.getType()).isa<TF::ResourceType>()) {
2690 bool has_read_only_variables = true;
2691 for (auto user : ba.getUsers()) {
2692 // Ternimator ops, for example, tfl::yield op, should be ignored since
2693 // the argument can be used for yielding as the `body` function result
2694 // and that does not give any meaningful points to the decision
2695 // whether the given arugment is a read-only variable or not.
2696 if (user->hasTrait<OpTrait::IsTerminator>()) continue;
2697 if (!llvm::isa<mlir::TF::ReadVariableOp>(user)) {
2698 has_read_only_variables = false;
2699 break;
2700 }
2701 }
2702 if (!has_read_only_variables) continue;
2703 }
2704 if (ba == yield.getOperand(arg_no)) {
2705 unchanged = false;
2706 auto value = while_op.getOperand(arg_no);
2707 ba.replaceAllUsesWith(value);
2708 cond_block.getArgument(arg_no).replaceAllUsesWith(value);
2709
2710 // This could be relaxed and casts inserted.
2711 if (while_op.getResult(arg_no).getType() == value.getType())
2712 while_op.getResult(arg_no).replaceAllUsesWith(value);
2713 }
2714 }
2715
2716 // The While ops operands and result types need to match
2717 SmallVector<Value, 4> new_operands;
2718 SmallVector<Value, 4> new_body_yield;
2719 SmallVector<bool, 4> removed_operand(while_op.getNumOperands(), false);
2720 llvm::SmallVector<Type, 4> types;
2721 new_operands.reserve(while_op.getNumOperands());
2722 new_body_yield.reserve(while_op.getNumOperands());
2723 types.reserve(while_op.getNumOperands());
2724
2725 // Remove block arguments not used in either cond or body. This leaves the
2726 // block arguments of body and cond matching still.
2727 int arg_index = 0;
2728 for (int while_index = 0, e = while_op.getNumOperands(); while_index < e;
2729 ++while_index) {
2730 auto value = while_op.getOperand(while_index);
2731 if (body_block.getArgument(arg_index).use_empty() &&
2732 cond_block.getArgument(arg_index).use_empty() &&
2733 // Note: since we are not erasing results, need to use while_index
2734 // to check if the corresponding result is unused.
2735 while_op.getResult(while_index).use_empty()) {
2736 unchanged = false;
2737 body_block.eraseArgument(arg_index);
2738 cond_block.eraseArgument(arg_index);
2739
2740 // Mark operand for removal.
2741 removed_operand[while_index] = true;
2742 } else {
2743 new_operands.push_back(value);
2744 new_body_yield.push_back(yield.getOperand(while_index));
2745 auto type = while_op.getResult(while_index).getType();
2746 types.push_back(type);
2747 ++arg_index;
2748 }
2749 }
2750
2751 // Done if no values removed from blocks and operands & results match.
2752 if (unchanged) return failure();
2753
2754 // Replace with new While with matching operands and results.
2755 Operation *op = while_op.getOperation();
2756 Operation *new_op = rewriter.insert(
2757 Operation::create(op->getLoc(), op->getName(), types, new_operands,
2758 op->getAttrs(), {}, /*numRegions=*/2));
2759
2760 for (int i = 0; i < 2; ++i) new_op->getRegion(i).takeBody(op->getRegion(i));
2761 int new_index = 0;
2762 for (int op_index = 0, e = op->getNumResults(); op_index < e; ++op_index) {
2763 if (removed_operand[op_index]) continue;
2764 op->getResult(op_index).replaceAllUsesWith(new_op->getResult(new_index));
2765 ++new_index;
2766 }
2767 rewriter.eraseOp(op);
2768
2769 Block &new_body_block = cast<WhileOp>(new_op).body().front();
2770 rewriter.setInsertionPointToEnd(&new_body_block);
2771 rewriter.replaceOpWithNewOp<YieldOp>(new_body_block.getTerminator(),
2772 new_body_yield);
2773
2774 return success();
2775 }
2776 };
2777
2778 } // namespace
2779
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2780 void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2781 MLIRContext *context) {
2782 results.insert<WhileResultOperandsMatchAndImplicitCapture>(context);
2783 }
2784
getLoopBody()2785 Region &WhileOp::getLoopBody() { return body(); }
2786
isDefinedOutsideOfLoop(Value value)2787 bool WhileOp::isDefinedOutsideOfLoop(Value value) {
2788 // TODO(jpienaar): This is to overly conservative and disables anything other
2789 // than constant hoisting initially.
2790 return false;
2791 }
2792
moveOutOfLoop(llvm::ArrayRef<mlir::Operation * > ops)2793 LogicalResult WhileOp::moveOutOfLoop(llvm::ArrayRef<mlir::Operation *> ops) {
2794 if (ops.empty()) return success();
2795
2796 // Move the hoisted value to just before the while.
2797 Operation *while_op = this->getOperation();
2798 for (auto op : ops) op->moveBefore(while_op);
2799
2800 return success();
2801 }
2802
2803 //===----------------------------------------------------------------------===//
2804 // TableGen'd op method definitions
2805 //===----------------------------------------------------------------------===//
2806
2807 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.cc.inc"
2808
2809 } // namespace TFL
2810 } // namespace mlir
2811
2812 #define GET_OP_CLASSES
2813 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
2814
2815 namespace mlir {
2816 namespace TFL {
2817
2818 #include "tensorflow/compiler/mlir/lite/runtime_verifiers.inc"
2819
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)2820 Operation *TensorFlowLiteDialect::materializeConstant(OpBuilder &builder,
2821 Attribute value,
2822 Type type, Location loc) {
2823 // If this is an opaque elements attribute or the result type doesn't match
2824 // the attribute type, then generate a tfl.pseudo_const.
2825 if (value.isa<OpaqueElementsAttr>() ||
2826 (value.isa<ElementsAttr>() && value.getType() != type))
2827 return builder.create<ConstOp>(loc, type, value.cast<ElementsAttr>());
2828 if (ConstantOp::isBuildableWith(value, type))
2829 return builder.create<ConstantOp>(loc, type, value);
2830 return nullptr;
2831 }
2832
2833 } // namespace TFL
2834 } // namespace mlir
2835