1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <cstdint>
21 #include <iterator>
22 #include <numeric>
23 
24 #include "third_party/eigen3/Eigen/Core"
25 #include "llvm/ADT/APFloat.h"
26 #include "llvm/ADT/APInt.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/SetVector.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/Support/FormatVariadic.h"
31 #include "llvm/Support/raw_ostream.h"
32 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
33 #include "mlir/IR/Attributes.h"  // from @llvm-project
34 #include "mlir/IR/Builders.h"  // from @llvm-project
35 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
36 #include "mlir/IR/Location.h"  // from @llvm-project
37 #include "mlir/IR/Matchers.h"  // from @llvm-project
38 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
39 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
40 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
41 #include "mlir/Support/LLVM.h"  // from @llvm-project
42 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
43 #include "mlir/Transforms/FoldUtils.h"  // from @llvm-project
44 #include "mlir/Transforms/InliningUtils.h"  // from @llvm-project
45 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
46 #include "tensorflow/compiler/mlir/lite/ir/tfl_structs.cc.inc"
47 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
48 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
49 
50 namespace mlir {
51 namespace TFL {
52 
53 // Returns true when the given operand arguments have the same shape or
54 // broadcastable shape within the given rank. If any given shapes are
55 // non-static and maximum rank is within the given rank, this method returns
56 // true.
VerifyOperandsHaveSameShapesOrBroadcastableShape(Operation * op,ArrayRef<unsigned> indices,int max_bcast_rank)57 bool VerifyOperandsHaveSameShapesOrBroadcastableShape(
58     Operation *op, ArrayRef<unsigned> indices, int max_bcast_rank) {
59   if (indices.empty()) return true;
60 
61   // First, it checks there are any inputs that has unknown rank.
62   bool has_unknown_shape_input = false;
63   bool has_same_shape = true;
64   bool reach_first_known_shape = false;
65   int64_t max_rank = -1;
66 
67   ArrayRef<int64_t> pivot_shape;
68   SmallVector<int64_t, 4> current_shape;
69   SmallVector<int64_t, 4> result_shape;
70 
71   for (unsigned index : indices) {
72     ShapedType shaped_type =
73         op->getOperand(index).getType().dyn_cast<ShapedType>();
74     if (!shaped_type || !shaped_type.hasRank()) {
75       // Marks that we have an unknown rank input.
76       has_unknown_shape_input = true;
77       continue;
78     }
79     max_rank = std::max(max_rank, shaped_type.getRank());
80     if (!shaped_type.hasStaticShape()) {
81       // Marks that we have an unknown shape input.
82       has_unknown_shape_input = true;
83       continue;
84     }
85 
86     ArrayRef<int64_t> shape = shaped_type.getShape();
87     if (!reach_first_known_shape) {
88       pivot_shape = shape;
89       current_shape.assign(shape.begin(), shape.end());
90       reach_first_known_shape = true;
91       continue;
92     }
93 
94     if (!pivot_shape.equals(shape)) {
95       has_same_shape = false;
96     }
97     //  Checks if all the inputs are broadcastable since they have not all the
98     //  same shapes.
99     if (!OpTrait::util::getBroadcastedShape(current_shape, shape,
100                                             result_shape)) {
101       return false;
102     }
103     current_shape = result_shape;
104   }
105 
106   // It will treat the unknown shape inputs as acceptable inputs for model
107   // compatibility unless there is an known rank that is bigger than the allowed
108   // broadcast maximum rank.
109   if (has_unknown_shape_input) return max_rank <= max_bcast_rank;
110 
111   // If all the shape is known and same, CPU kernels are able to handle inputs
112   // regardless of dimension size.
113   return has_same_shape || max_rank <= max_bcast_rank;
114 }
115 
116 // Return true when the given element_type is QI8.
IsQI8Type(Type element_type)117 bool IsQI8Type(Type element_type) {
118   auto quantized_type = element_type.dyn_cast<QuantizedType>();
119   return quantized_type != nullptr &&
120          quantized_type.getStorageTypeIntegralWidth() == 8 &&
121          quantized_type.isSigned();
122 }
123 
124 // Return true when the given element_type is QUI8.
IsQUI8Type(Type element_type)125 bool IsQUI8Type(Type element_type) {
126   auto quantized_type = element_type.dyn_cast<QuantizedType>();
127   return quantized_type != nullptr &&
128          quantized_type.getStorageTypeIntegralWidth() == 8 &&
129          !quantized_type.isSigned();
130 }
131 
132 // Return true when the given element_type is QI16.
IsQI16Type(Type element_type)133 bool IsQI16Type(Type element_type) {
134   auto quantized_type = element_type.dyn_cast<QuantizedType>();
135   return quantized_type != nullptr &&
136          quantized_type.getStorageTypeIntegralWidth() == 16 &&
137          quantized_type.isSigned();
138 }
139 
140 // Return true when the given element_type is I32.
IsI32Type(Type element_type)141 bool IsI32Type(Type element_type) {
142   return element_type.isInteger(32) && !element_type.isUnsignedInteger();
143 }
144 
145 // Return true when the given element_type is I64.
IsI64Type(Type element_type)146 bool IsI64Type(Type element_type) {
147   return element_type.isInteger(64) && !element_type.isUnsignedInteger();
148 }
149 
150 // Return true if the value is a splat tensor constant zero.
EqualsZero(Value value)151 bool EqualsZero(Value value) {
152   DenseElementsAttr constant;
153   if (!matchPattern(value, m_Constant(&constant)) || !constant.isSplat()) {
154     return false;
155   }
156 
157   Type element_type = value.getType().cast<ShapedType>().getElementType();
158   if (element_type.isa<FloatType>()) {
159     return constant.getSplatValue<APFloat>().isZero();
160   } else {
161     return false;
162   }
163 }
164 
165 // Replaces the bias operand with a "none" type value if the bias value is
166 // constant zero.
167 // `ConcreteOpType` must be an concrete MLIR op class that has an optional
168 // bias operand named 'bias'.
169 template <typename ConcreteOpType>
170 struct RemoveOptionalZeroBias : public OpRewritePattern<ConcreteOpType> {
171   using OpRewritePattern<ConcreteOpType>::OpRewritePattern;
172 
matchAndRewritemlir::TFL::RemoveOptionalZeroBias173   LogicalResult matchAndRewrite(ConcreteOpType op,
174                                 PatternRewriter &rewriter) const override {
175     if (EqualsZero(op.bias())) {
176       auto none_value = rewriter.create<mlir::ConstantOp>(
177           rewriter.getUnknownLoc(), rewriter.getUnitAttr());
178       op.biasMutable().assign(none_value);
179     }
180 
181     return success();
182   }
183 };
184 
185 // Return true if the given Add operation has the CPU kernel supported shapes.
VerifyAddOpShapeConstraints(AddOp op)186 bool VerifyAddOpShapeConstraints(AddOp op) {
187   auto element_type = getElementTypeOrSelf(op.output().getType());
188 
189   // Allows F32, QI8, QUI8 and I32 outputs when the operands have valid shapes,
190   // which are broadcastable shapes up to five dimension or have same shapes.
191   if (element_type.isF32() || IsQI8Type(element_type) ||
192       IsQUI8Type(element_type) || IsI32Type(element_type)) {
193     return VerifyOperandsHaveSameShapesOrBroadcastableShape(
194         /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
195         /*max_bcast_rank=*/4);
196   }
197 
198   // Allows QI16 output when operands have the same shape.
199   if (IsQI16Type(element_type)) {
200     return succeeded(
201         mlir::verifyCompatibleShape(op.lhs().getType(), op.rhs().getType()));
202   }
203   return false;
204 }
205 
206 // Return true if the given Sub operation has the CPU kernel supported shapes.
VerifySubOpShapeConstraints(SubOp op)207 bool VerifySubOpShapeConstraints(SubOp op) {
208   auto element_type = getElementTypeOrSelf(op.output().getType());
209 
210   // Allows F32, QUI8, and QI16 outputs when the operands have valid shapes,
211   // which are broadcastable shapes up to five dimension or have same shapes.
212   if (element_type.isF32() || IsI32Type(element_type) ||
213       IsI64Type(element_type) || IsQUI8Type(element_type) ||
214       IsQI16Type(element_type)) {
215     return VerifyOperandsHaveSameShapesOrBroadcastableShape(
216         /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
217         /*max_bcast_rank=*/5);
218   }
219 
220   // Allows QI8 output when the operands have valid shapes, which are
221   // broadcastable shapes up to four dimension or have same shapes.
222   if (IsQI8Type(element_type)) {
223     return VerifyOperandsHaveSameShapesOrBroadcastableShape(
224         /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
225         /*max_bcast_rank=*/4);
226   }
227   return false;
228 }
229 
230 // Return true if the given Mul operation has the CPU kernel supported shapes.
VerifyMulOpShapeConstraints(MulOp op)231 bool VerifyMulOpShapeConstraints(MulOp op) {
232   auto element_type = getElementTypeOrSelf(op.output().getType());
233 
234   // Allows QI8 and QUI8 inputs up to five dimension broadcasting unless the
235   // output type is not QI16. If the output type is Q16, allows only the same
236   // shape operands.
237   if (IsQI8Type(element_type) || IsQUI8Type(element_type)) {
238     if (IsQI16Type(getElementTypeOrSelf(op.lhs().getType()))) {
239       return succeeded(
240           mlir::verifyCompatibleShape(op.lhs().getType(), op.rhs().getType()));
241     }
242     return VerifyOperandsHaveSameShapesOrBroadcastableShape(
243         /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
244         /*max_bcast_rank=*/4);
245   }
246 
247   // Allows I32, QI16 and F32 outputs when the operands have valid shapes, which
248   // are broadcastable shapes up to four dimension or have same shapes.
249   if (IsI32Type(element_type) || IsQI16Type(element_type) ||
250       element_type.isF32()) {
251     return VerifyOperandsHaveSameShapesOrBroadcastableShape(
252         /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
253         /*max_bcast_rank=*/4);
254   }
255   return false;
256 }
257 
258 //===----------------------------------------------------------------------===//
259 // TensorFlowLiteDialect
260 //===----------------------------------------------------------------------===//
261 
262 struct TensorFlowLiteInlinerInterface : public DialectInlinerInterface {
263   using DialectInlinerInterface::DialectInlinerInterface;
264 
265   //===--------------------------------------------------------------------===//
266   // Analysis Hooks
267   //===--------------------------------------------------------------------===//
268 
269   // Allow all call operations to be inlined.
isLegalToInlinemlir::TFL::TensorFlowLiteInlinerInterface270   bool isLegalToInline(Operation *call, Operation *callable,
271                        bool wouldBeCloned) const final {
272     return true;
273   }
isLegalToInlinemlir::TFL::TensorFlowLiteInlinerInterface274   bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
275                        BlockAndValueMapping &) const final {
276     // No TFLite op restricts inlining today, revise as needed in the future.
277     return true;
278   }
isLegalToInlinemlir::TFL::TensorFlowLiteInlinerInterface279   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
280                        BlockAndValueMapping &valueMapping) const final {
281     return isa<WhileOp>(dest->getParentOp());
282   }
283 };
284 
285 struct TensorFlowLiteDialectFoldInterface : public DialectFoldInterface {
286   using DialectFoldInterface::DialectFoldInterface;
287 
288   // Registered hook to check if the given region, which is attached to an
289   // operation that is *not* isolated from above (i.e. no internal regions
290   // reference values defined in an enclosing region), should be used when
291   // materializing constants.
292   // In the TFLite dialect we materialize inside a while regions as slightly
293   // more efficient computationally.
shouldMaterializeIntomlir::TFL::TensorFlowLiteDialectFoldInterface294   bool shouldMaterializeInto(Region *region) const final {
295     return isa<WhileOp>(region->getParentOp());
296   }
297 };
298 
TensorFlowLiteDialect(mlir::MLIRContext * context)299 TensorFlowLiteDialect::TensorFlowLiteDialect(mlir::MLIRContext *context)
300     : Dialect(/*name=*/"tfl", context, TypeID::get<TensorFlowLiteDialect>()) {
301   addOperations<
302 #define GET_OP_LIST
303 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
304       >();
305   addInterfaces<TensorFlowLiteInlinerInterface,
306                 TensorFlowLiteDialectFoldInterface>();
307 }
308 
309 //===----------------------------------------------------------------------===//
310 // Common support logic
311 //===----------------------------------------------------------------------===//
312 
313 namespace {
314 
315 // Returns true if the dimensions in `a` is a suffix of the ones in `b`.
316 // For example, dimensions {2}, {1, 2}, and {3, 1, 2} are all suffixes to
317 // {5, 4, 3, 1, 2}, while {1}, {5, 4}, and {1, 3, 2} are all not.
IsTrailingDimensions(ArrayRef<int64_t> a,ArrayRef<int64_t> b)318 inline bool IsTrailingDimensions(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
319   if (a.size() > b.size()) return false;
320 
321   return std::equal(a.rbegin(), a.rend(), b.rbegin());
322 }
323 
324 // Returns true if it is a shaped type of f32 elements.
IsF32ShapedType(Type t)325 inline bool IsF32ShapedType(Type t) {
326   if (auto shaped_type = t.dyn_cast_or_null<ShapedType>()) {
327     return shaped_type.getElementType().isF32();
328   }
329   return false;
330 }
331 
332 // Returns true if it is a shaped type of bf16 elements.
IsBF16ShapedType(Type t)333 inline bool IsBF16ShapedType(Type t) {
334   if (auto shaped_type = t.dyn_cast_or_null<ShapedType>()) {
335     return shaped_type.getElementType().isBF16();
336   }
337   return false;
338 }
339 
340 // Returns new shape with rank 'new_dims' with padded ones on the
341 // left if needed.
GetPaddedShape(ArrayRef<int64_t> old_shape,int new_dims)342 inline std::vector<int64_t> GetPaddedShape(ArrayRef<int64_t> old_shape,
343                                            int new_dims) {
344   std::vector<int64_t> new_shape(new_dims, 1);
345   std::copy_backward(old_shape.begin(), old_shape.end(), new_shape.end());
346   return new_shape;
347 }
348 
349 // Helper method that given and 'current_index' representing
350 // index in broadcasted tensor, get the index in the flat original tensor.
351 // 'shape' is the original shape with padding to match result shape.
GetElementIndex(const std::vector<int64_t> & shape,const std::vector<int64_t> & current_index)352 int64_t GetElementIndex(const std::vector<int64_t> &shape,
353                         const std::vector<int64_t> &current_index) {
354   int64_t ind = 0;
355   int64_t mul = 1;
356   for (int i = shape.size() - 1; i >= 0; --i) {
357     ind += (current_index[i] % shape[i]) * mul;
358     mul *= shape[i];
359   }
360   return ind;
361 }
362 
363 // Helper method that increment index represented in 'current_index_ptr'
364 // in the shape of 'result_shape'.
IncrementIndex(ArrayRef<int64_t> result_shape,std::vector<int64_t> * current_index_ptr)365 void IncrementIndex(ArrayRef<int64_t> result_shape,
366                     std::vector<int64_t> *current_index_ptr) {
367   std::vector<int64_t> &current_index = *current_index_ptr;
368   for (int i = result_shape.size() - 1; i >= 0; --i) {
369     current_index[i]++;
370     if (current_index[i] == result_shape[i]) {
371       current_index[i] = 0;
372     } else {
373       break;
374     }
375   }
376 }
377 
378 /// Performs const folding `calculate` with broadcast behavior on the two
379 /// attributes `operand1` and `operand2` and returns the result if possible.
380 /// This function assumes the both operands are verified to have value
381 /// attributes of broadcastable types.
382 template <class AttrElementT,
383           class ElementValueT = typename AttrElementT::ValueType,
384           class CalculationT =
385               llvm::function_ref<ElementValueT(ElementValueT, ElementValueT)>>
ConstFoldBinaryOpDenseDense(Type result_type,DenseElementsAttr lhs,DenseElementsAttr rhs,const CalculationT & calculate)386 Attribute ConstFoldBinaryOpDenseDense(Type result_type, DenseElementsAttr lhs,
387                                       DenseElementsAttr rhs,
388                                       const CalculationT &calculate) {
389   auto type = OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType())
390                   .dyn_cast_or_null<ShapedType>();
391   if (!type) {
392     return {};
393   }
394 
395   const bool rhs_is_splat = rhs.isSplat();
396   const bool lhs_is_splat = lhs.isSplat();
397 
398   // If both of them are splat, compute and return.
399   if (lhs_is_splat && rhs_is_splat) {
400     auto element_result = AttrElementT::get(
401         type.getElementType(), calculate(lhs.getSplatValue<ElementValueT>(),
402                                          rhs.getSplatValue<ElementValueT>()));
403     if (!element_result) return {};
404 
405     return DenseElementsAttr::get(type, element_result);
406   }
407 
408   auto num_elements = type.getNumElements();
409 
410   SmallVector<ElementValueT, 16> new_values;
411   new_values.reserve(num_elements);
412   const auto result_shape = type.getShape();
413   std::vector<int64_t> current_index(type.getRank(), 0);
414   // Create the new shape with ones padded to the left.
415   const std::vector<int64_t> lhs_new_shape =
416       GetPaddedShape(lhs.getType().getShape(), type.getRank());
417   const std::vector<int64_t> rhs_new_shape =
418       GetPaddedShape(rhs.getType().getShape(), type.getRank());
419 
420   auto lhs_old_values = lhs.getValues<ElementValueT>();
421   auto rhs_old_values = rhs.getValues<ElementValueT>();
422 
423   // Add each pair of the corresponding values in the dense elements
424   // attributes.
425   for (int64_t i = 0; i < num_elements; ++i) {
426     // current_index represents the index
427     // in the N-dimension tensor. GetElementIndex returns
428     // the index in the flat representation of the original tensor
429     // to use.
430     const int64_t lhs_index =
431         lhs_is_splat ? 0 : GetElementIndex(lhs_new_shape, current_index);
432     const int64_t rhs_index =
433         rhs_is_splat ? 0 : GetElementIndex(rhs_new_shape, current_index);
434 
435     new_values.push_back(calculate(*(lhs_old_values.begin() + lhs_index),
436                                    *(rhs_old_values.begin() + rhs_index)));
437     IncrementIndex(result_shape, &current_index);
438   }
439   return DenseElementsAttr::get(type, ArrayRef<ElementValueT>(new_values));
440 }
441 
442 /// Performs const folding `calculate` with broadcast behavior on the two
443 /// attributes `operand1` and `operand2` and returns the result if possible.
444 /// This function assumes the two operands are verified to have value
445 /// attributes of broadcastable types.
446 template <class AttrElementT,
447           class ElementValueT = typename AttrElementT::ValueType,
448           class CalculationT =
449               llvm::function_ref<ElementValueT(ElementValueT, ElementValueT)>>
ConstFoldBinaryOp(Type result_type,Attribute operand1,Attribute operand2,const CalculationT & calculate)450 Attribute ConstFoldBinaryOp(Type result_type, Attribute operand1,
451                             Attribute operand2, const CalculationT &calculate) {
452   if (operand1.dyn_cast_or_null<DenseElementsAttr>() &&
453       operand2.dyn_cast_or_null<DenseElementsAttr>()) {
454     return ConstFoldBinaryOpDenseDense<AttrElementT, ElementValueT>(
455         result_type, operand1.cast<DenseElementsAttr>(),
456         operand2.cast<DenseElementsAttr>(), calculate);
457   }
458 
459   // TODO: support other attribute kinds
460 
461   return {};
462 }
463 
464 /// Performs const folding with broadcast behavior on the two attributes in
465 /// `operands` and returns the result if possible.
466 /// Depending on the given `resultType`, either `floatCalculate` or
467 /// `intCalculate` is chosen to conduct the calculate.
ConstFoldBinaryOp(Type result_type,ArrayRef<Attribute> operands,llvm::function_ref<APFloat (APFloat,APFloat)> float_calculate,llvm::function_ref<APInt (APInt,APInt)> int_calculate)468 Attribute ConstFoldBinaryOp(
469     Type result_type, ArrayRef<Attribute> operands,
470     llvm::function_ref<APFloat(APFloat, APFloat)> float_calculate,
471     llvm::function_ref<APInt(APInt, APInt)> int_calculate) {
472   // Note: All types are wrapped in tensor types in TFlite. E.g., f32 is
473   // represented as tensor<f32>. So we are only handling tensor types here.
474   auto type = result_type.dyn_cast<ShapedType>();
475   if (!type) return {};
476 
477   auto elemType = type.getElementType();
478 
479   if (elemType.isa<FloatType>())
480     return ConstFoldBinaryOp<FloatAttr>(result_type, operands[0], operands[1],
481                                         float_calculate);
482 
483   if (elemType.isSignlessInteger())
484     return ConstFoldBinaryOp<IntegerAttr>(result_type, operands[0], operands[1],
485                                           int_calculate);
486 
487   return {};
488 }
489 
490 /// Performs const folding a attributes `operand` and returns the result if
491 /// possible.
492 /// The function currently asserts that the `result_type` to be a f32 tensor
493 /// type.
494 /// TODO: Extend this function to handle integral tensor for ops like
495 /// "tfl.logical_not".
ConstFoldUnaryOp(Type result_type,Attribute operand,llvm::function_ref<APFloat (APFloat)> calculate)496 Attribute ConstFoldUnaryOp(Type result_type, Attribute operand,
497                            llvm::function_ref<APFloat(APFloat)> calculate) {
498   assert(IsF32ShapedType(result_type) || IsBF16ShapedType(result_type));
499   auto result_shape_type = result_type.cast<ShapedType>();
500 
501   if (!result_shape_type.hasStaticShape()) return {};
502 
503   if (auto dense_elements = operand.dyn_cast_or_null<DenseElementsAttr>()) {
504     SmallVector<APFloat, 16> new_values;
505     const int num_elements = result_shape_type.getNumElements();
506     new_values.reserve(num_elements);
507 
508     for (const APFloat &old_value : dense_elements.getValues<APFloat>()) {
509       new_values.push_back(calculate(old_value));
510     }
511 
512     return DenseElementsAttr::get(result_shape_type, new_values);
513   }
514 
515   return {};
516 }
517 
buildComparisonBinOp(Builder * builder,OperationState & result,Value lhs,Value rhs)518 void buildComparisonBinOp(Builder *builder, OperationState &result, Value lhs,
519                           Value rhs) {
520   auto result_type =
521       OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType());
522   if (!result_type)
523     emitError(result.location)
524         << "non-broadcastable operands: " << lhs.getType() << " and "
525         << rhs.getType();
526   result.addOperands({lhs, rhs});
527   // Comparison binary ops always return i1 tensor.
528   if (auto shaped_type = result_type.dyn_cast<RankedTensorType>()) {
529     auto result_shape = shaped_type.getShape();
530     result.types.push_back(
531         RankedTensorType::get(result_shape, builder->getI1Type()));
532   } else {
533     result.types.push_back(UnrankedTensorType::get(builder->getI1Type()));
534   }
535 }
536 
buildFusedBroadcastableBinOp(Builder * builder,OperationState & result,Value lhs,Value rhs,StringAttr fused_activation_function)537 void buildFusedBroadcastableBinOp(Builder *builder, OperationState &result,
538                                   Value lhs, Value rhs,
539                                   StringAttr fused_activation_function) {
540   auto result_type =
541       OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType());
542 
543   if (!result_type)
544     emitError(result.location)
545         << "non-broadcastable operands: " << lhs.getType() << " and "
546         << rhs.getType();
547 
548   result.addOperands({lhs, rhs});
549   result.addAttribute("fused_activation_function", fused_activation_function);
550   result.types.push_back(result_type);
551 }
552 
553 }  // end anonymous namespace
554 
555 //===----------------------------------------------------------------------===//
556 // AddOp
557 //===----------------------------------------------------------------------===//
558 
fold(ArrayRef<Attribute> operands)559 OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
560   // TODO(b/142478136): Handle fused ops.
561   if (fused_activation_function() != "NONE") return {};
562   return ConstFoldBinaryOp(
563       getType(), operands, [](APFloat a, APFloat b) { return a + b; },
564       [](APInt a, APInt b) { return a + b; });
565 }
566 
567 //===----------------------------------------------------------------------===//
568 // ConcatenationOp
569 //===----------------------------------------------------------------------===//
570 // TODO(ashwinm): Implement shape inference for Concatenation
571 
572 namespace {
573 
GetConcatenationOpAxis(ConcatenationOp op)574 int64_t GetConcatenationOpAxis(ConcatenationOp op) {
575   auto output_type = op.output().getType().cast<RankedTensorType>();
576   int32_t axis = op.axis();
577   if (axis < 0) axis += output_type.getRank();
578   return axis;
579 }
580 
581 // Verify operand types and the result type:
582 //
583 // 1. Operand type ranks must be equal to the output type rank.
584 //
585 // 2. Operand dimension sizes (except dimension `axis`) must be equal to
586 //    previously seen dimension sizes of the same dimension.
587 //
588 // 3. Sum of operand dimension sizes of the `axis` dimension must be equal to
589 //    the dimension size of the `axis` dimension of output.
590 //
591 // Note: If an operand has unranked tensor type or has dynamic dimension size,
592 // those dimensions will be skipped.
VerifyConcatenationOpTypes(Operation * op,RankedTensorType output_type,ArrayRef<TensorType> operand_types,int64_t axis)593 LogicalResult VerifyConcatenationOpTypes(Operation *op,
594                                          RankedTensorType output_type,
595                                          ArrayRef<TensorType> operand_types,
596                                          int64_t axis) {
597   const int64_t output_rank = output_type.getRank();
598 
599   constexpr int64_t kDynamicSize = -1;
600   SmallVector<int64_t, 4> result_dim_sizes_loc(output_rank, -1);
601   SmallVector<int64_t, 4> result_dim_sizes(output_type.getShape().begin(),
602                                            output_type.getShape().end());
603   result_dim_sizes[axis] = 0;
604 
605   auto FormatLoc = [&result_dim_sizes_loc](int64_t dim) {
606     const int64_t loc = result_dim_sizes_loc[dim];
607     if (loc == -1) return std::string("output");
608     return llvm::formatv("operand #{0}", loc).str();
609   };
610 
611   for (auto operand : llvm::enumerate(operand_types)) {
612     auto operand_type = operand.value().dyn_cast<RankedTensorType>();
613     if (!operand_type) {
614       result_dim_sizes[axis] = kDynamicSize;
615       continue;
616     }
617 
618     const int64_t operand_rank = operand_type.getRank();
619     if (operand_rank != output_rank)
620       return op->emitOpError() << "rank of operand #" << operand.index()
621                                << " must be equal to rank of output, expected "
622                                << output_rank << ", got " << operand_rank;
623 
624     for (int64_t dim = 0; dim < output_rank; ++dim) {
625       const int64_t operand_dim_size = operand_type.getDimSize(dim);
626       const int64_t result_dim_size = result_dim_sizes[dim];
627 
628       if (dim == axis) {
629         if (RankedTensorType::isDynamic(operand_dim_size) ||
630             RankedTensorType::isDynamic(result_dim_size))
631           result_dim_sizes[axis] = kDynamicSize;
632         else
633           result_dim_sizes[axis] += operand_dim_size;
634         continue;
635       }
636 
637       if (RankedTensorType::isDynamic(operand_dim_size)) continue;
638 
639       if (RankedTensorType::isDynamic(result_dim_size)) {
640         result_dim_sizes[dim] = operand_dim_size;
641         result_dim_sizes_loc[dim] = operand.index();
642         continue;
643       }
644 
645       if (result_dim_size != operand_dim_size)
646         return op->emitOpError()
647                << "dimension size of dimension #" << dim << " of operand #"
648                << operand.index() << " must be equal to "
649                << "dimension size of dimension #" << dim << " of "
650                << FormatLoc(dim) << ", expected " << result_dim_size << ", got "
651                << operand_dim_size;
652     }
653   }
654 
655   const int64_t output_concated_dim_size = output_type.getDimSize(axis);
656   if (!RankedTensorType::isDynamic(output_concated_dim_size) &&
657       !RankedTensorType::isDynamic(result_dim_sizes[axis]) &&
658       result_dim_sizes[axis] != output_concated_dim_size)
659     return op->emitOpError()
660            << "dimension size of dimension #" << axis << " of output "
661            << "must be equal to the sum of dimension sizes of dimension #"
662            << axis << ", expected " << result_dim_sizes[axis] << ", got "
663            << output_concated_dim_size;
664 
665   return success();
666 }
667 
Verify(ConcatenationOp op)668 LogicalResult Verify(ConcatenationOp op) {
669   auto output_type = op.output().getType().dyn_cast<RankedTensorType>();
670 
671   // If the output type is unranked, there is nothing else to be verified.
672   if (!output_type) return success();
673 
674   const int64_t axis = GetConcatenationOpAxis(op);
675   if (axis < 0 || axis >= output_type.getRank())
676     return op.emitOpError("concatenation dimension must be in [-rank, rank)");
677 
678   SmallVector<TensorType, 4> operand_types;
679   for (Value operand : op.values())
680     operand_types.push_back(operand.getType().cast<TensorType>());
681 
682   return VerifyConcatenationOpTypes(op.getOperation(), output_type,
683                                     operand_types, axis);
684 }
685 
686 // Returns true when all operands are instances of DenseElementsAttr and the
687 // output type has a static shape.
IsConcatenationOpConstFoldable(ConcatenationOp op,ArrayRef<Attribute> operands,RankedTensorType output_type,int64_t axis)688 bool IsConcatenationOpConstFoldable(ConcatenationOp op,
689                                     ArrayRef<Attribute> operands,
690                                     RankedTensorType output_type,
691                                     int64_t axis) {
692   if (operands.empty()) return false;
693   if (!output_type.hasStaticShape()) return false;
694   if (axis < 0) return false;
695 
696   return llvm::all_of(operands, [](Attribute operand) {
697     return operand && operand.isa<DenseElementsAttr>();
698   });
699 }
700 
ConstFoldConcatenateOpDense(ArrayRef<Attribute> operands,RankedTensorType output_type,int64_t axis)701 DenseElementsAttr ConstFoldConcatenateOpDense(ArrayRef<Attribute> operands,
702                                               RankedTensorType output_type,
703                                               int64_t axis) {
704   const auto outer_dims = output_type.getShape().take_front(axis);
705   const int64_t outer_size = std::accumulate(
706       outer_dims.begin(), outer_dims.end(), 1, std::multiplies<int64_t>());
707 
708   const auto base_inner_dims = output_type.getShape().drop_front(axis + 1);
709   const int64_t base_inner_size =
710       std::accumulate(base_inner_dims.begin(), base_inner_dims.end(), 1,
711                       std::multiplies<int64_t>());
712 
713   // Splits each input operand into outer_size pieces and combines them in
714   // round-robin ordering.
715   std::vector<Attribute> out_attrs(output_type.getNumElements());
716   int64_t out = 0;
717   for (int64_t outer = 0; outer < outer_size; ++outer) {
718     for (auto op : operands) {
719       const int64_t dim_size =
720           op.getType().cast<RankedTensorType>().getDimSize(axis);
721       const int64_t inner_size = dim_size * base_inner_size;
722 
723       auto input_attrs = op.cast<DenseElementsAttr>().getValues<Attribute>();
724       auto input_iter = input_attrs.begin() + outer * inner_size;
725       for (int64_t inner = 0; inner < inner_size; ++inner)
726         out_attrs[out++] = *input_iter++;
727     }
728   }
729 
730   return DenseElementsAttr::get(output_type, out_attrs);
731 }
732 
733 }  // end anonymous namespace
734 
fold(ArrayRef<Attribute> operands)735 OpFoldResult ConcatenationOp::fold(ArrayRef<Attribute> operands) {
736   if (fused_activation_function() == "NONE") {
737     if (auto output_type = output().getType().dyn_cast<RankedTensorType>()) {
738       const int64_t axis = GetConcatenationOpAxis(*this);
739       if (IsConcatenationOpConstFoldable(*this, operands, output_type, axis))
740         return ConstFoldConcatenateOpDense(operands, output_type, axis);
741     }
742   }
743 
744   // Remove all empty values.
745   SmallVector<Value, 4> non_empty_values;
746   for (Value value : this->values()) {
747     const auto shaped_type = value.getType().cast<ShapedType>();
748     if (shaped_type.hasStaticShape() && shaped_type.getNumElements() == 0) {
749       continue;
750     }
751     non_empty_values.push_back(value);
752   }
753 
754   // All are not empty, do nothing.
755   if (non_empty_values.size() == getNumOperands()) return nullptr;
756 
757   // If only one input is non-empty, just return it as the result of folding.
758   if (non_empty_values.size() == 1) {
759     return non_empty_values[0];
760   }
761 
762   // Otherwise, build a new concatenation op with non-empty values.
763   mlir::OpBuilder builder(getOperation());
764   auto new_concat = builder.create<TFL::ConcatenationOp>(
765       getLoc(), getType(), non_empty_values,
766       builder.getIntegerAttr(builder.getIntegerType(32), axis()),
767       builder.getStringAttr(fused_activation_function()));
768   return new_concat.getResult();
769 }
770 
771 //===----------------------------------------------------------------------===//
772 // CustomOp
773 //===----------------------------------------------------------------------===//
774 
Verify(CustomOp op)775 static LogicalResult Verify(CustomOp op) {
776   OpaqueElementsAttr opaque_attr =
777       op.custom_option().cast<OpaqueElementsAttr>();
778   if (!opaque_attr.getType().hasStaticShape())
779     return op.emitOpError("custom_option should have a static shape.");
780   const int attribute_size = opaque_attr.getValue().size();
781   if (attribute_size != opaque_attr.getType().cast<ShapedType>().getDimSize(0))
782     return op.emitOpError(
783         "custom_option should have the same length of content with shape.");
784   return success();
785 }
786 
787 //===----------------------------------------------------------------------===//
788 // FullyConnectedOp
789 //===----------------------------------------------------------------------===//
790 
Verify(FullyConnectedOp op)791 LogicalResult Verify(FullyConnectedOp op) {
792   ShapedType input_type = op.input().getType().cast<ShapedType>();
793   ShapedType filter_type = op.filter().getType().cast<ShapedType>();
794   if (filter_type.hasRank() && filter_type.getRank() != 2) {
795     return op.emitOpError("expect 2d filter, got ") << filter_type;
796   }
797 
798   if (!input_type.hasStaticShape() || !filter_type.hasStaticShape()) {
799     return mlir::success();
800   }
801 
802   // Input's element size must be multiple of parameter's z_in dimension.
803   const int z_in = filter_type.getDimSize(1);
804   const int num_input_elements = input_type.getNumElements();
805   if (num_input_elements % z_in != 0) {
806     return op.emitOpError(llvm::formatv(
807                "expect 'input' num_elements % {0} == 0, got input type ", z_in))
808            << input_type;
809   }
810 
811   // TODO(jpienaar): Include more shape verification for SHUFFLED4x16INT8
812   // format.
813   if (op.weights_format() == "DEFAULT") {
814     ShapedType output_type =
815         (*op.output().begin()).getType().cast<ShapedType>();
816     if (!output_type.hasStaticShape()) {
817       return mlir::success();
818     }
819 
820     const int num_output_elements = output_type.getNumElements();
821     const int z_out = filter_type.getDimSize(0);
822     if (num_output_elements % z_out != 0) {
823       return op.emitOpError(llvm::formatv(
824                  "expect 'output' num_elements % {0} == 0, got ", z_out))
825              << output_type;
826     }
827 
828     if (num_input_elements / z_in != num_output_elements / z_out) {
829       return op.emitOpError(
830           "num_input_elements / z_in != num_output_elements / z_out");
831     }
832   }
833 
834   return mlir::success();
835 }
836 
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)837 LogicalResult FullyConnectedOp::fold(ArrayRef<Attribute> operands,
838                                      SmallVectorImpl<OpFoldResult> &results) {
839   assert(operands.size() == 3);
840 
841   // Folding not implemented with any activation function or any weight type
842   // besides the default.
843   if (fused_activation_function() != "NONE") return failure();
844   if (weights_format() != "DEFAULT") return failure();
845 
846   // Bias tensor is optional.
847   const bool has_bias = !(!bias() || bias().getType().isa<NoneType>());
848 
849   // Get the tensors.
850   DenseElementsAttr input_tensor, weights_tensor, bias_tensor;
851   if (!matchPattern(input(), m_Constant(&input_tensor)) ||
852       !matchPattern(filter(), m_Constant(&weights_tensor)) ||
853       (has_bias && !matchPattern(bias(), m_Constant(&bias_tensor)))) {
854     return failure();
855   }
856 
857   // Get the tensor types.
858   const auto input_type = input_tensor.getType().cast<ShapedType>();
859   const auto weights_type = weights_tensor.getType().cast<ShapedType>();
860   const auto bias_type =
861       has_bias ? bias_tensor.getType().cast<ShapedType>() : ShapedType{};
862 
863   const auto output_type = getType(0).cast<ShapedType>();
864 
865   // Folding only implemented for float tensors.
866   if (!input_type.getElementType().isF32() ||
867       !weights_type.getElementType().isF32() ||
868       !output_type.getElementType().isF32() ||
869       (has_bias && !bias_type.getElementType().isF32())) {
870     return failure();
871   }
872 
873   // Folding only implemented for static shapes
874   if (!input_type.hasStaticShape() || !weights_type.hasStaticShape() ||
875       (has_bias && !bias_type.hasStaticShape())) {
876     return failure();
877   }
878 
879   // Folding only implemented for 1D input, 2D weights and 1D bias
880   if (input_type.getShape().size() != 1 ||
881       weights_type.getShape().size() != 2 ||
882       (has_bias && bias_type.getShape().size() != 1)) {
883     return failure();
884   }
885 
886   // Get the sizes
887   const auto input_size = input_type.getNumElements();
888   const auto output_size = output_type.getNumElements();
889 
890   // Get iterators to the tensors.
891   const auto input_values_it = input_tensor.getValues<float>().begin();
892   const auto weights_values_ptr = weights_tensor.getValues<float>().begin();
893   auto weights_row_it = weights_values_ptr;
894   // The 'else' case could be nullptr, but the types don't match.
895   auto bias_values_it =
896       has_bias ? bias_tensor.getValues<float>().begin() : input_values_it;
897 
898   // Do the actual folding, one output at a time.
899   std::vector<float> result_values;
900   result_values.reserve(output_size);
901 
902   for (int i = 0; i < output_size; ++i) {
903     // Dot product with Kahan/Neumaier summation to minimize numeric errors.
904     float sum = has_bias ? *bias_values_it : 0.0f;
905     float compensation = 0.0f;
906     for (int j = 0; j < input_size; ++j) {
907       const float addend = input_values_it[j] * weights_row_it[j];
908       const float new_sum = sum + addend;
909       // DO NOT enable -funsafe-math-optimizations here.
910       // There is a test detecting unsafe optimizations.
911       // Unsafe math optimizations can reorder float formulas, and set the
912       // compensation to constant 0. The formula must be evaluated as written
913       // for the algorithm to work.
914       // (Note: -ffast-math is a superset of -funsafe-math-optimizations.)
915       if (std::abs(sum) >= std::abs(addend)) {
916         compensation += (sum - new_sum) + addend;
917       } else {
918         compensation += (addend - new_sum) + sum;
919       }
920       sum = new_sum;
921     }
922     result_values.push_back(sum + compensation);
923     weights_row_it += input_size;
924     bias_values_it++;
925   }
926 
927   // Set result tensor
928   const auto folded =
929       DenseElementsAttr::get(output_type, ArrayRef<float>(result_values));
930   results.assign({folded});
931 
932   return success();
933 }
934 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)935 void FullyConnectedOp::getCanonicalizationPatterns(
936     OwningRewritePatternList &results, MLIRContext *context) {
937   results.insert<RemoveOptionalZeroBias<FullyConnectedOp>>(context);
938 }
939 
940 //===----------------------------------------------------------------------===//
941 // Conv2DOp
942 //===----------------------------------------------------------------------===//
943 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)944 void Conv2DOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
945                                            MLIRContext *context) {
946   // TODO(b/180121750): Enable the pattern after the integration tests are
947   // fixed.
948   // results.insert<RemoveOptionalZeroBias<Conv2DOp>>(context);
949 }
950 
951 //===----------------------------------------------------------------------===//
952 // DepthwiseConv2DO
953 //===----------------------------------------------------------------------===//
954 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)955 void DepthwiseConv2DOp::getCanonicalizationPatterns(
956     OwningRewritePatternList &results, MLIRContext *context) {
957   // TODO(b/180121750): Enable the pattern after the integration tests are
958   // fixed.
959   // results.insert<RemoveOptionalZeroBias<DepthwiseConv2DOp>>(context);
960 }
961 
962 //===----------------------------------------------------------------------===//
963 // GatherOp
964 //===----------------------------------------------------------------------===//
965 
BuildGatherOp(OpBuilder * builder,OperationState & result,Value params,Value indices,IntegerAttr axis)966 static void BuildGatherOp(OpBuilder *builder, OperationState &result,
967                           Value params, Value indices, IntegerAttr axis) {
968   auto params_type = params.getType().cast<TensorType>();
969   auto indices_type = indices.getType().cast<TensorType>();
970 
971   // If params/indices is unranked, then output is unranked.
972   if (!params_type.hasRank() || !indices_type.hasRank())
973     return TFL::GatherOp::build(
974         *builder, result, UnrankedTensorType::get(params_type.getElementType()),
975         params, indices, axis);
976 
977   int64_t params_rank = params_type.getRank();
978   int64_t indices_rank = indices_type.getRank();
979 
980   // params rank is guaranteed to be at least 1.
981   // Produces an output tensor with shape:
982   // params.shape[:axis] + indices.shape + params.shape[axis + 1:]
983   std::vector<int64_t> shape(params_type.getShape());
984   int64_t axis_i = axis.getInt();
985 
986   // For neg axis values, we wrap around params, e.g. axis = -1 => params[:-1]
987   if (axis_i < 0) {
988     axis_i += params_rank;
989   }
990 
991   // params must be at least rank axis + 1
992   if (params_rank < axis_i + 1) {
993     emitError(result.location, "params must be at least rank axis + 1");
994   }
995 
996   if (indices_rank == 0) {
997     // Scalar indices (output is rank(params) - 1).
998     // Erase shape[axis]
999     shape.erase(shape.begin() + axis_i);
1000   } else if (indices_rank == 1) {
1001     // Vector indices (output is rank(params)).
1002     // Copy indices.shape into params.shape[axis]
1003     std::copy(std::begin(indices_type.getShape()),
1004               std::end(indices_type.getShape()), std::begin(shape) + axis_i);
1005   } else {
1006     // Higher rank indices (output is rank(params) + rank(indices) - 1).
1007     shape.resize(params_rank + indices_rank - 1);
1008     // Copy params.shape[axis + 1: ] into shape[axis + indices_rank:]
1009     std::copy(std::begin(params_type.getShape()) + axis_i + 1,
1010               std::end(params_type.getShape()),
1011               std::begin(shape) + axis_i + indices_rank);
1012 
1013     // Copy indices.shape into params.shape[axis]
1014     std::copy(std::begin(indices_type.getShape()),
1015               std::end(indices_type.getShape()), std::begin(shape) + axis_i);
1016   }
1017 
1018   TFL::GatherOp::build(
1019       *builder, result,
1020       RankedTensorType::get(shape, params_type.getElementType()), params,
1021       indices, axis);
1022 }
1023 
1024 //===----------------------------------------------------------------------===//
1025 // ScatterNdOp
1026 //===----------------------------------------------------------------------===//
1027 
Verify(ScatterNdOp op)1028 static LogicalResult Verify(ScatterNdOp op) {
1029   auto indices = op.indices();
1030   auto updates = op.updates();
1031   auto shape = op.shape();
1032   auto output = op.output();
1033 
1034   auto updates_type = updates.getType().cast<ShapedType>();
1035   auto indices_type = indices.getType().cast<ShapedType>();
1036 
1037   if (!indices_type.hasStaticShape() || !updates_type.hasStaticShape()) {
1038     return success();
1039   }
1040 
1041   // Checks if the shape of `updates` is a tensor of shape
1042   // `indices.shape[:-1] + shape[indices.shape[-1]:]`, as described in
1043   // ScatterNd op description.
1044 
1045   auto outer_dims = indices_type.getRank() - 1;
1046   auto outermost_dim = indices_type.getDimSize(outer_dims);
1047   // Checks whether the first `outer_dims` dimensions of `indices` and
1048   // `updates` are equal.
1049   for (auto i = 0; i < outer_dims; i++) {
1050     if (indices_type.getDimSize(i) != updates_type.getDimSize(i)) {
1051       return op.emitOpError()
1052              << "indices.Dims(" << i << ") == " << indices_type.getDimSize(i)
1053              << ", but updates.Dims(" << i
1054              << ") == " << updates_type.getDimSize(i);
1055     }
1056   }
1057 
1058   auto output_type = output.getType().cast<ShapedType>();
1059   auto shape_type = shape.getType().cast<ShapedType>();
1060   if (shape_type.hasStaticShape()) {
1061     // Check the rank of `shape`.
1062     auto output_rank = outermost_dim + updates_type.getRank() - outer_dims;
1063     if (shape_type.getDimSize(0) != output_rank) {
1064       return op.emitOpError()
1065              << "shape must be a vector of length " << output_rank;
1066     }
1067     if (output_type.hasRank()) {
1068       if (output_type.getRank() != output_rank) {
1069         return op.emitOpError()
1070                << "output must have the same rank with the length of shape = "
1071                << output_rank;
1072       }
1073     }
1074   }
1075 
1076   DenseIntElementsAttr shape_value;
1077   if (matchPattern(shape, m_Constant(&shape_value))) {
1078     for (const auto shape_elem : shape_value) {
1079       if (shape_elem.getSExtValue() <= 0) {
1080         return op.emitOpError("all elements of shape must be > 0");
1081       }
1082     }
1083 
1084     // Checks whether the last `(shape_type.getDimSize(0) - outermost_dim)`
1085     // dimensions of `updates` and `shape` are equal.
1086     for (auto shape_it : llvm::enumerate(shape_value)) {
1087       int64_t i = shape_it.index();
1088       auto value = shape_it.value().getSExtValue();
1089       if (i >= outermost_dim) {
1090         auto corresponding_dim = i - outermost_dim + outer_dims;
1091         if (value != updates_type.getDimSize(corresponding_dim)) {
1092           return op.emitOpError()
1093                  << "updates.Dims(" << i
1094                  << ") == " << updates_type.getDimSize(corresponding_dim)
1095                  << ", but shape[" << i << "] == " << value;
1096         }
1097       }
1098     }
1099 
1100     // Checks if the output has the shape specified by `shape`.
1101     if (output_type.hasStaticShape()) {
1102       for (auto shape_it : llvm::enumerate(shape_value)) {
1103         int i = shape_it.index();
1104         auto value = shape_it.value().getSExtValue();
1105         if (output_type.getDimSize(i) != value) {
1106           return op.emitOpError()
1107                  << "output shape [" << output_type.getShape()
1108                  << "] must be equal to the value of shape " << shape_value;
1109         }
1110       }
1111     }
1112   }
1113   return success();
1114 }
1115 
1116 //===----------------------------------------------------------------------===//
1117 // MulOp
1118 //===----------------------------------------------------------------------===//
1119 
fold(ArrayRef<Attribute> operands)1120 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
1121   // TODO(b/142478136): Handle fused ops.
1122   if (fused_activation_function() != "NONE") return {};
1123 
1124   // This function is performance critical for op fusion patterns, e.g.
1125   // FuseBinaryOpToPrecedingAffine and FuseMulOrDivWithConv2dOrDepthwiseConv2d.
1126   // So a few specializations are provided to evaluate the math operation
1127   // more efficiently.
1128 
1129   // Specialization for f32 type.
1130   if (getType().cast<ShapedType>().getElementType().isF32()) {
1131     return ConstFoldBinaryOp<FloatAttr, float>(
1132         getType(), operands[0], operands[1],
1133         [](float a, float b) { return a * b; });
1134   }
1135 
1136   // Specialization for bf16 type.
1137   if (getType().cast<ShapedType>().getElementType().isBF16()) {
1138     return ConstFoldBinaryOp<FloatAttr, Eigen::bfloat16>(
1139         getType(), operands[0], operands[1],
1140         [](Eigen::bfloat16 a, Eigen::bfloat16 b) { return a * b; });
1141   }
1142 
1143   // Generic fallback with APFloat
1144   return ConstFoldBinaryOp(
1145       getType(), operands, [](APFloat a, APFloat b) { return a * b; },
1146       [](APInt a, APInt b) { return a * b; });
1147 }
1148 
1149 //===----------------------------------------------------------------------===//
1150 // DivOp
1151 //===----------------------------------------------------------------------===//
1152 
fold(ArrayRef<Attribute> operands)1153 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
1154   // TODO(b/142478136): Handle fused ops.
1155   if (fused_activation_function() != "NONE") return {};
1156   return ConstFoldBinaryOp(
1157       getType(), operands, [](APFloat a, APFloat b) { return a / b; },
1158       [](APInt a, APInt b) { return a.sdiv(b); });
1159 }
1160 
1161 //===----------------------------------------------------------------------===//
1162 // PackOp
1163 //===----------------------------------------------------------------------===//
1164 
1165 // TODO(b/133486129): Implement shape inference for pack
1166 
Verify(PackOp op)1167 static LogicalResult Verify(PackOp op) {
1168   // TODO(antiagainst): Implement other checks as in
1169   // tensorflow/lite/kernels/pack.cc
1170 
1171   if (op.getOperation()->getNumOperands() != op.values_count())
1172     return op.emitOpError("input count should match 'values_count' attribute");
1173 
1174   Value operand0 = op.getOperand(0);
1175   auto input_type = operand0.getType().cast<ShapedType>();
1176 
1177   // Check axis bounds.
1178   if (input_type.hasRank()) {
1179     int32_t axis_value = op.axis();
1180     if (axis_value < 0) axis_value += input_type.getRank() + 1;
1181     if (axis_value < 0 || axis_value >= input_type.getRank() + 1)
1182       return op.emitOpError()
1183              << "op attribute 'axis' should be in range [-rank - 1, rank + 1), "
1184              << "got rank = " << input_type.getRank()
1185              << ", and axis = " << op.axis();
1186   }
1187 
1188   // Make sure all inputs have the same shape and element type.
1189   // TODO(b/135032063): Simplify once fixed.
1190   for (Type operand_type : op.getOperandTypes()) {
1191     if (failed(mlir::verifyCompatibleShape(input_type, operand_type)))
1192       return op.emitOpError("operands should be of the same type. got ")
1193              << input_type << ", " << operand_type;
1194   }
1195 
1196   return success();
1197 }
1198 
1199 //===----------------------------------------------------------------------===//
1200 // PReluOp
1201 //===----------------------------------------------------------------------===//
1202 
Verify(PReluOp op)1203 static LogicalResult Verify(PReluOp op) {
1204   auto input_type = op.input().getType().cast<ShapedType>();
1205   auto alpha_type = op.alpha().getType().cast<ShapedType>();
1206   auto output_type = op.output().getType().cast<ShapedType>();
1207 
1208   if (input_type.hasStaticShape() && alpha_type.hasStaticShape()) {
1209     if (input_type.getRank() != alpha_type.getRank() + 1) {
1210       return op.emitOpError("'alpha' should have one less rank than 'input'.");
1211     }
1212 
1213     // Check if alpha is broadcastable
1214     for (int i = 0; i < alpha_type.getRank(); i++) {
1215       if (alpha_type.getDimSize(i) != input_type.getDimSize(i + 1) &&
1216           alpha_type.getDimSize(i) != 1) {
1217         return op.emitOpError(
1218             llvm::formatv("'alpha' is not broadcastable at dimension {0}.", i));
1219       }
1220     }
1221   }
1222 
1223   if (input_type.hasStaticShape() && output_type.hasStaticShape()) {
1224     if (input_type.getRank() != output_type.getRank()) {
1225       return op.emitOpError("'input' and 'output' should have the same rank.");
1226     }
1227 
1228     // Check if input and output shapes are same
1229     for (int i = 0; i < input_type.getRank(); i++) {
1230       if (input_type.getDimSize(i) != output_type.getDimSize(i)) {
1231         return op.emitOpError(
1232             "'input' and 'output' should have the same shape.");
1233       }
1234     }
1235   }
1236   return success();
1237 }
1238 
1239 //===----------------------------------------------------------------------===//
1240 // ReshapeOp
1241 //===----------------------------------------------------------------------===//
1242 
1243 namespace {
1244 // This pattern matches and merges a tfl.reshape under the following
1245 // condition:
1246 // * The input's defining op is another tfl.reshape.
1247 // TODO(antiagainst): This pattern probably should be moved to the peephole
1248 // category, after we have the infra for peephole passes.
1249 struct RemoveAdjacentReshape : public RewritePattern {
RemoveAdjacentReshapemlir::TFL::__anonc99a34740d11::RemoveAdjacentReshape1250   RemoveAdjacentReshape(MLIRContext *context)
1251       : RewritePattern(ReshapeOp::getOperationName(), 1, context) {}
1252 
matchmlir::TFL::__anonc99a34740d11::RemoveAdjacentReshape1253   LogicalResult match(Operation *op) const override {
1254     auto thisOp = cast<ReshapeOp>(op);
1255     auto prevOp = thisOp.getOperand(0).getDefiningOp();
1256     return isa_and_nonnull<ReshapeOp>(prevOp) ? success() : failure();
1257   }
1258 
rewritemlir::TFL::__anonc99a34740d11::RemoveAdjacentReshape1259   void rewrite(Operation *op, PatternRewriter &rewriter) const override {
1260     auto thisOp = cast<ReshapeOp>(op);
1261     auto prevOp = cast<ReshapeOp>(thisOp.getOperand(0).getDefiningOp());
1262 
1263     // Replace
1264     //   %1 = "tfl.reshape"(%0, %shape0)
1265     //   %2 = "tfl.reshape"(%1, %shape1)
1266     // With
1267     //   %2 = "tfl.reshape"(%0, %shape1)
1268     rewriter.replaceOpWithNewOp<ReshapeOp>(
1269         op, thisOp.getType(), prevOp.getOperand(0), thisOp.getOperand(1));
1270   }
1271 };
1272 
1273 // The kernel expects an 1-D tensor for the shape operand if it presents. If all
1274 // the dimensions are '1's except the last dimension, it will be reshaped to a
1275 // 1-D tensor.
1276 // Note that this pattern doesn't check or change the content of the shape
1277 // tensor.
1278 struct ConvertShapeTo1D : public OpRewritePattern<ReshapeOp> {
1279   using OpRewritePattern<ReshapeOp>::OpRewritePattern;
1280 
matchAndRewritemlir::TFL::__anonc99a34740d11::ConvertShapeTo1D1281   LogicalResult matchAndRewrite(ReshapeOp reshape,
1282                                 PatternRewriter &rewriter) const override {
1283     if (!reshape.shape().hasOneUse()) return failure();
1284 
1285     DenseIntElementsAttr shape;
1286     if (!matchPattern(reshape.shape(), m_Constant(&shape))) {
1287       return failure();
1288     }
1289     // It is already a 1-D constant, no change.
1290     auto old_shape = shape.getType().getShape();
1291     if (old_shape.size() == 1) {
1292       return failure();
1293     }
1294     // Verify all the leading dimensions are length one, except the last one.
1295     for (auto it = ++old_shape.rbegin(); it != old_shape.rend(); ++it) {
1296       if (*it != 1) {
1297         reshape->emitError(
1298             "Non-vector shape input is used, might cause runtime error");
1299         return failure();
1300       }
1301     }
1302     auto new_shape = shape.reshape(RankedTensorType::get(
1303         {*old_shape.rbegin()}, shape.getType().getElementType()));
1304     rewriter.replaceOpWithNewOp<TFL::ConstOp>(reshape.shape().getDefiningOp(),
1305                                               new_shape);
1306     return success();
1307   }
1308 };
1309 
1310 }  // end anonymous namespace
1311 
fold(ArrayRef<Attribute> operands)1312 OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
1313   // Remove identity reshape with both static result and input shape.
1314   auto result_type = getType().cast<ShapedType>();
1315   auto input_type = getOperand(0).getType().cast<ShapedType>();
1316   if (result_type.hasStaticShape() && result_type == input_type) {
1317     return getOperand(0);
1318   }
1319 
1320   // Constant folding
1321   if (auto dense_elements = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {
1322     // If the result type isn't static, tries to derive the result type from
1323     // the #2 operand.
1324     if (!result_type.hasStaticShape()) {
1325       auto shape_elements = operands[1].dyn_cast_or_null<DenseElementsAttr>();
1326       if (!shape_elements) return nullptr;
1327 
1328       SmallVector<int64_t, 4> shape_data;
1329       for (const auto &it : shape_elements.getValues<APInt>()) {
1330         shape_data.push_back(it.getSExtValue());
1331       }
1332       result_type =
1333           RankedTensorType::get(shape_data, input_type.getElementType());
1334     }
1335     return dense_elements.reshape(result_type);
1336   }
1337 
1338   return nullptr;
1339 }
1340 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1341 void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1342                                             MLIRContext *context) {
1343   results.insert<RemoveAdjacentReshape, ConvertShapeTo1D>(context);
1344 }
1345 
1346 using ReshapeErrorHandler =
1347     llvm::function_ref<LogicalResult(const llvm::Twine &)>;
1348 
GetReshapeOutputType(Value input,Value shape,ReshapeErrorHandler error_handler,TensorType & output_ty)1349 LogicalResult GetReshapeOutputType(Value input, Value shape,
1350                                    ReshapeErrorHandler error_handler,
1351                                    TensorType &output_ty) {
1352   auto input_ty = input.getType().cast<TensorType>();
1353   auto element_ty = input_ty.getElementType();
1354   output_ty = UnrankedTensorType::get(element_ty);
1355 
1356   auto shape_ty = shape.getType().dyn_cast<RankedTensorType>();
1357   if (!shape_ty) return success();
1358   if (shape_ty.getRank() != 1)
1359     return error_handler(llvm::formatv(
1360         "requires 'shape' to be rank 1, but got {0}", shape_ty.getRank()));
1361 
1362   DenseIntElementsAttr shape_attr;
1363   if (!matchPattern(shape, m_Constant(&shape_attr))) {
1364     // If only shape of `shape` is known, return ranked but dynamic output
1365     // shape.
1366     if (shape_ty.hasStaticShape()) {
1367       llvm::SmallVector<int64_t, 8> dynamic_shape(shape_ty.getDimSize(0),
1368                                                   ShapedType::kDynamicSize);
1369       output_ty = RankedTensorType::get(dynamic_shape, element_ty);
1370     }
1371     return success();
1372   }
1373 
1374   // Detect if reshape output shape is folded.
1375   bool shape_ty_zero_dim = false;
1376   int unknown_index = -1;
1377   // The product of constant shape argument excluding unknown dimension.
1378   int64_t shape_ty_size = 1;
1379   llvm::SmallVector<int64_t, 8> output_ty_shape;
1380   output_ty_shape.reserve(shape_attr.getNumElements());
1381   for (const auto &dim : llvm::enumerate(shape_attr.getIntValues())) {
1382     const int64_t size = dim.value().getSExtValue();
1383     if (size == ShapedType::kDynamicSize) {
1384       if (unknown_index != -1)
1385         return error_handler(llvm::formatv(
1386             "requires 'shape' to have at most one dynamic dimension, but got "
1387             "multiple dynamic dimensions at indices {0} and {1}. You need to "
1388             "set up the unspecified size(s) to avoid this problem, for example,"
1389             "setting batch size in keras model or setting unspecified input "
1390             "size(s) with fixed ones.",
1391             unknown_index, dim.index()));
1392 
1393       unknown_index = dim.index();
1394     } else if (size == 0) {
1395       shape_ty_zero_dim = true;
1396     } else if (size > 0) {
1397       shape_ty_size *= size;
1398     } else {
1399       return error_handler(
1400           llvm::formatv("requires 'shape' to have dimensions greater than -1, "
1401                         "but got {0} at index {1}",
1402                         size, dim.index()));
1403     }
1404     output_ty_shape.push_back(size);
1405   }
1406 
1407   if (!input_ty.hasStaticShape()) {
1408     output_ty = RankedTensorType::get(output_ty_shape, element_ty);
1409     return success();
1410   }
1411 
1412   // Compute the value of the unknown dimension.
1413   if (unknown_index != -1) {
1414     // Compute number of elements in tensor shape.
1415     int64_t input_ty_size = 1;
1416     bool input_ty_zero_dim = false;
1417     for (const auto &dim : input_ty.getShape()) {
1418       if (dim > 0 || !shape_ty_zero_dim) {
1419         input_ty_size *= dim;
1420       } else {
1421         input_ty_zero_dim = true;
1422       }
1423     }
1424 
1425     const int64_t missing_dim = input_ty_size / shape_ty_size;
1426     if (!input_ty_zero_dim && shape_ty_size * missing_dim != input_ty_size)
1427       return error_handler(
1428           llvm::formatv("requires 'input' number of elements be a multiple of "
1429                         "{0}, but got {1}",
1430                         shape_ty_size, input_ty_size));
1431 
1432     // Set the unknown dimension such that total number of elements remain
1433     // constant.
1434     output_ty_shape[unknown_index] = missing_dim;
1435   }
1436 
1437   output_ty = RankedTensorType::get(output_ty_shape, element_ty);
1438 
1439   return success();
1440 }
1441 
Verify(ReshapeOp op)1442 static LogicalResult Verify(ReshapeOp op) {
1443   auto error_handler = [&op](const llvm::Twine &message) -> LogicalResult {
1444     return op.emitOpError() << message;
1445   };
1446   TensorType expected_ty;
1447   if (failed(GetReshapeOutputType(op.input(), op.shape(), error_handler,
1448                                   expected_ty)))
1449     return failure();
1450 
1451   auto output_ty = op.getType().dyn_cast<RankedTensorType>();
1452   if (!output_ty) return success();
1453   auto input_ty = op.input().getType().cast<TensorType>();
1454   if (output_ty.hasStaticShape() && input_ty.hasStaticShape()) {
1455     const int64_t output_ty_size = output_ty.getNumElements();
1456     const int64_t input_ty_size = input_ty.getNumElements();
1457     if (input_ty_size != output_ty_size)
1458       return op.emitOpError() << "requires 'output' number of elements to "
1459                                  "match 'input' number of elements, but got "
1460                               << output_ty_size << " and " << input_ty_size;
1461   }
1462 
1463   if (!TF::AreCastCompatible({output_ty, expected_ty}))
1464     return op.emitOpError()
1465            << "requires 'output' type " << output_ty
1466            << " to be cast compatible with expected type " << expected_ty;
1467 
1468   return success();
1469 }
1470 
1471 //===----------------------------------------------------------------------===//
1472 // PackOp
1473 //===----------------------------------------------------------------------===//
1474 
1475 // Remove redundant unpack pack op.
1476 // If a unpack op is followed by a pack op, we can remove the pack op, if the
1477 // unpack op is only consumed by the pack op, it will be removed as well.
1478 // An example illustration is:
1479 //                  Unpack [5, 8, 9], axis = 1
1480 //                /       \
1481 //            value  ...  value [5, 9], 8 values in total
1482 //              \           /
1483 //                 pack,   axis = 1
1484 //                   |
1485 //               value   [5, 8, 9]
1486 //
1487 //   This can actually be simplified into just:
1488 //
1489 //           =>   Value [5, 8, 9]
1490 // TODO(b/133341698): Move to tablegen when variadic is supported.
1491 struct RemoveRedundantUnpackPack : public RewritePattern {
RemoveRedundantUnpackPackmlir::TFL::RemoveRedundantUnpackPack1492   explicit RemoveRedundantUnpackPack(MLIRContext *context)
1493       : RewritePattern(PackOp::getOperationName(), 2, context) {}
1494 
matchAndRewritemlir::TFL::RemoveRedundantUnpackPack1495   LogicalResult matchAndRewrite(Operation *op,
1496                                 PatternRewriter &rewriter) const override {
1497     TFL::PackOp pack_op = cast<TFL::PackOp>(op);
1498     Operation *first_input = pack_op.getOperand(0).getDefiningOp();
1499     if (!first_input) return failure();
1500     auto input_unpack_op = dyn_cast_or_null<TFL::UnpackOp>(first_input);
1501     if (!input_unpack_op) return failure();
1502 
1503     // The unpack & pack should have the same axis & num inputs/outputs.
1504     if (pack_op.axis() != input_unpack_op.axis() ||
1505         pack_op.values_count() != input_unpack_op.num())
1506       return failure();
1507 
1508     const int total_pack_inputs = pack_op.getNumOperands();
1509     const int num_results = input_unpack_op.getNumResults();
1510     if (total_pack_inputs != num_results) return failure();
1511     for (auto input_output :
1512          llvm::zip(pack_op.getOperands(), input_unpack_op.getResults())) {
1513       Value pack_input = std::get<0>(input_output);
1514       Value unpack_output = std::get<1>(input_output);
1515       // Make sure the ordering is the same for the pack op & unpack op.
1516       if (pack_input != unpack_output) return failure();
1517     }
1518 
1519     // Replace the pack's output to the unpack's input.
1520     rewriter.replaceOp(pack_op, input_unpack_op.getOperand());
1521     // At this point, we don't manually remove the redundant pack op & unpack op
1522     // (we cannot actually), but trust the PatterRewriter to garbage collect
1523     // these two ops.
1524     return success();
1525   }
1526 };
1527 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1528 void PackOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1529                                          MLIRContext *context) {
1530   results.insert<RemoveRedundantUnpackPack>(context);
1531 }
1532 
1533 //===----------------------------------------------------------------------===//
1534 // SliceOp
1535 //===----------------------------------------------------------------------===//
1536 
Verify(SliceOp op)1537 static LogicalResult Verify(SliceOp op) {
1538   auto input_type = op.input().getType().cast<ShapedType>();
1539   auto begin_type = op.begin().getType().cast<ShapedType>();
1540   auto size_type = op.size().getType().cast<ShapedType>();
1541   if (input_type.hasStaticShape() && begin_type.hasStaticShape() &&
1542       size_type.hasStaticShape()) {
1543     if (input_type.getRank() != begin_type.getNumElements()) {
1544       return op.emitError(
1545           "begin tensor elements size is not equal to input tensor rank");
1546     }
1547 
1548     if (input_type.getRank() != size_type.getNumElements()) {
1549       return op.emitError(
1550           "size tensor elements size is not equal to input tensor rank");
1551     }
1552   }
1553 
1554   DenseIntElementsAttr begin;
1555   if (matchPattern(op.begin(), m_Constant(&begin))) {
1556     int axis = 0;
1557     for (auto begin_i : llvm::enumerate(begin)) {
1558       if (begin_i.value().getSExtValue() < 0) {
1559         return op.emitError(
1560             llvm::formatv("begin[{0}] cannot be negative", axis));
1561       }
1562       axis++;
1563     }
1564   }
1565 
1566   DenseIntElementsAttr size;
1567   if (matchPattern(op.size(), m_Constant(&size))) {
1568     int axis = 0;
1569     for (auto size_i : llvm::enumerate(size)) {
1570       if (size_i.value().getSExtValue() < -1) {
1571         return op.emitError(
1572             llvm::formatv("size[{0}] cannot be negative other than -1", axis));
1573       }
1574       axis++;
1575     }
1576   }
1577 
1578   if (begin && size && input_type.hasStaticShape()) {
1579     for (uint64_t i = 0, end = begin.getNumElements(); i < end; i++) {
1580       int begin_i =
1581           begin.getValue({i}).cast<IntegerAttr>().getValue().getSExtValue();
1582       int size_i =
1583           size.getValue({i}).cast<IntegerAttr>().getValue().getSExtValue();
1584       int dim_i = input_type.getShape()[i];
1585       if (begin_i > dim_i) {
1586         return op.emitOpError(llvm::formatv(
1587             "begin[{0}] cannot exceed dimension length: {1}", i, dim_i));
1588       }
1589       if (size_i >= 0 && begin_i + size_i > dim_i) {
1590         return op.emitError(llvm::formatv(
1591             "begin[{0}] + size[{0}] cannot exceed dimension length: {1}", i,
1592             dim_i));
1593       }
1594     }
1595   }
1596 
1597   return success();
1598 }
1599 
NarrowDownInt64InputValuesForOp(Operation * input_op,RankedTensorType value_type,Location loc,OpBuilder * builder)1600 TFL::ConstOp NarrowDownInt64InputValuesForOp(Operation *input_op,
1601                                              RankedTensorType value_type,
1602                                              Location loc, OpBuilder *builder) {
1603   if (input_op == nullptr) return nullptr;
1604 
1605   mlir::DenseIntElementsAttr attr;
1606   if (!matchPattern(input_op, m_Constant(&attr))) {
1607     return nullptr;
1608   }
1609 
1610   auto value_shape_type = mlir::RankedTensorType::get(
1611       value_type.getShape(), builder->getIntegerType(32));
1612 
1613   SmallVector<int32_t, 4> value_i32;
1614   value_i32.reserve(value_type.getRank());
1615   for (const auto &size : attr) {
1616     value_i32.push_back(static_cast<int32_t>(size.getSExtValue()));
1617   }
1618   auto new_value_i32_attr =
1619       mlir::DenseIntElementsAttr::get(value_shape_type, value_i32);
1620 
1621   return builder->create<TFL::ConstOp>(loc, new_value_i32_attr);
1622 }
1623 
1624 // This will cast down int64 values for TFL slice op.
1625 // This will require the begin & size are constants.
1626 struct CastDonwInt64BeginEndToInt32 : public OpRewritePattern<TFL::SliceOp> {
1627   using OpRewritePattern<TFL::SliceOp>::OpRewritePattern;
1628 
matchAndRewritemlir::TFL::CastDonwInt64BeginEndToInt321629   LogicalResult matchAndRewrite(TFL::SliceOp slice_op,
1630                                 PatternRewriter &rewriter) const override {
1631     auto begin = slice_op.begin();
1632     auto size = slice_op.size();
1633     auto begin_type = begin.getType().dyn_cast_or_null<RankedTensorType>();
1634     auto size_type = size.getType().dyn_cast_or_null<RankedTensorType>();
1635     auto begin_op = begin.getDefiningOp();
1636     auto size_op = size.getDefiningOp();
1637 
1638     if (begin_op == nullptr && size_op == nullptr) return failure();
1639 
1640     if (begin_type == nullptr && size_type == nullptr) return failure();
1641 
1642     // Handle begin.
1643     if (begin_op && begin_type && begin_type.getElementType().isInteger(64)) {
1644       auto new_begin = NarrowDownInt64InputValuesForOp(
1645           begin_op, begin_type, slice_op.getLoc(), &rewriter);
1646       if (new_begin != nullptr) {
1647         slice_op.setOperand(1, new_begin);
1648       }
1649     }
1650 
1651     // Handle size.
1652     if (size_op && size_type && size_type.getElementType().isInteger(64)) {
1653       auto new_size = NarrowDownInt64InputValuesForOp(
1654           size_op, size_type, slice_op.getLoc(), &rewriter);
1655       if (new_size != nullptr) {
1656         slice_op.setOperand(2, new_size);
1657       }
1658     }
1659 
1660     return success();
1661   }
1662 };
1663 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1664 void SliceOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1665                                           MLIRContext *context) {
1666   results.insert<CastDonwInt64BeginEndToInt32>(context);
1667 }
1668 
1669 //===----------------------------------------------------------------------===//
1670 // SubOp
1671 //===----------------------------------------------------------------------===//
1672 
fold(ArrayRef<Attribute> operands)1673 OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
1674   // TODO(b/142478136): Handle fused ops.
1675   if (fused_activation_function() != "NONE") return {};
1676   return ConstFoldBinaryOp(
1677       getType(), operands, [](APFloat a, APFloat b) { return a - b; },
1678       [](APInt a, APInt b) { return a - b; });
1679 }
1680 
1681 //===----------------------------------------------------------------------===//
1682 // TopKOp
1683 //===----------------------------------------------------------------------===//
1684 
BuildTopKOp(OpBuilder * builder,OperationState & result,Value input,Value k)1685 static void BuildTopKOp(OpBuilder *builder, OperationState &result, Value input,
1686                         Value k) {
1687   // Output size is only known if k is constant value. A negative dimension is
1688   // considered dynamic so use -1 here if k is not a constant value.
1689   int const_k = -1;
1690   ElementsAttr cst;
1691   if (matchPattern(k, m_Constant(&cst)))
1692     // These casts should all be valid due to how Tensor constants are stored.
1693     // TODO(jpienaar): This should use a helper function.
1694     const_k = cst.getValue<IntegerAttr>({}).getValue().getSExtValue();
1695 
1696   auto val_type = input.getType().cast<TensorType>();
1697   // If value is unranked, then so is results.
1698   if (!val_type.hasRank())
1699     return TFL::TopKV2Op::build(
1700         *builder, result, UnrankedTensorType::get(val_type.getElementType()),
1701         UnrankedTensorType::get(builder->getIntegerType(32)), input, k);
1702 
1703   // Resultant shape is value.shape[:-1] + [k]
1704   std::vector<int64_t> shape(val_type.getShape());
1705   shape[shape.size() - 1] = const_k;
1706   TFL::TopKV2Op::build(
1707       *builder, result, RankedTensorType::get(shape, val_type.getElementType()),
1708       RankedTensorType::get(shape, builder->getIntegerType(32)), input, k);
1709 }
1710 
1711 //===----------------------------------------------------------------------===//
1712 // FakeQuantOp
1713 //===----------------------------------------------------------------------===//
1714 
1715 // Return true if the op has non-empty "minmax" attribute.
HasValidMinMaxAttribute(Operation * op)1716 static inline bool HasValidMinMaxAttribute(Operation *op) {
1717   auto minmax = op->getAttrOfType<ArrayAttr>("minmax");
1718   return minmax && minmax.getValue().size() == 2;
1719 }
1720 
1721 namespace {
1722 
1723 /// This pattern matches and remove a tfl.fake_quant if all the users of this op
1724 /// and itself have "minmax" attribute set.
1725 struct DropFakeQuant : public RewritePattern {
DropFakeQuantmlir::TFL::__anonc99a34741111::DropFakeQuant1726   explicit DropFakeQuant(MLIRContext *context)
1727       : RewritePattern(FakeQuantOp::getOperationName(), 1, context) {}
1728 
matchmlir::TFL::__anonc99a34741111::DropFakeQuant1729   LogicalResult match(Operation *op) const override {
1730     // We only match the op with valid "minmax" attribute.
1731     if (!HasValidMinMaxAttribute(op)) return failure();
1732 
1733     // If all the users of this op have valid "minmax" attributes, it is matched
1734     // and can be removed.
1735     auto fakeQuantOp = cast<FakeQuantOp>(op);
1736     for (auto *operand : fakeQuantOp.getResult().getUsers())
1737       if (!HasValidMinMaxAttribute(operand)) return failure();
1738 
1739     return success();
1740   }
1741 
rewritemlir::TFL::__anonc99a34741111::DropFakeQuant1742   void rewrite(Operation *op, PatternRewriter &rewriter) const override {
1743     // Replace the matched FakeQuantOp by its primary operand.
1744     rewriter.replaceOp(op, op->getOperand(0));
1745   }
1746 };
1747 }  // end anonymous namespace
1748 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1749 void FakeQuantOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1750                                               MLIRContext *context) {
1751   results.insert<DropFakeQuant>(context);
1752 }
1753 
1754 //===----------------------------------------------------------------------===//
1755 // UnpackOp
1756 //===----------------------------------------------------------------------===//
1757 
1758 // TODO(b/133486129): Implement shape inference for unpack
1759 
inferReturnTypes(MLIRContext * context,Optional<Location> loc,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1760 LogicalResult UnpackOp::inferReturnTypes(
1761     MLIRContext *context, Optional<Location> loc, ValueRange operands,
1762     DictionaryAttr attributes, RegionRange regions,
1763     SmallVectorImpl<Type> &inferredReturnTypes) {
1764   UnpackOpAdaptor op(operands, attributes);
1765   // TODO(jpienaar): Refactor verify
1766   if (failed(op.verify(loc.hasValue() ? *loc : UnknownLoc::get(context))))
1767     return failure();
1768 
1769   if (operands.size() != 1) {
1770     return emitOptionalError(loc, "input count should be equal to 1");
1771   }
1772 
1773   const int64_t num_value = op.num().getInt();
1774   auto input_type = operands[0].getType().dyn_cast<ShapedType>();
1775   if (!input_type || !input_type.hasRank()) {
1776     // If input is unranked, then so is output.
1777     inferredReturnTypes.assign(
1778         num_value, UnrankedTensorType::get(input_type.getElementType()));
1779     return success();
1780   }
1781 
1782   if (input_type.hasStaticShape() && input_type.getNumElements() <= 0) {
1783     return emitOptionalError(
1784         loc, "number of elements in input should be larger than 0");
1785   }
1786 
1787   const int64_t rank = input_type.getRank();
1788   if (rank <= 0) {
1789     return emitOptionalError(loc, "input should be of rank larger than 0");
1790   }
1791 
1792   int64_t axis_value = op.axis().getInt();
1793   if (axis_value < 0) {
1794     axis_value += rank;
1795   }
1796   if (axis_value < 0 || axis_value >= rank) {
1797     return emitOptionalError(
1798         loc, "attribute 'axis' should be in range [-rank, rank), got axis = ",
1799         op.axis().getInt(), ", and rank = ", rank);
1800   }
1801 
1802   if (!ShapedType::isDynamic(input_type.getDimSize(axis_value)) &&
1803       input_type.getDimSize(axis_value) != num_value) {
1804     return emitOptionalError(loc, "output count should match 'num' attribute");
1805   }
1806 
1807   auto output_shape = llvm::to_vector<4>(input_type.getShape());
1808   output_shape.erase(output_shape.begin() + axis_value);
1809 
1810   auto output_type =
1811       RankedTensorType::get(output_shape, input_type.getElementType());
1812   inferredReturnTypes.assign(num_value, output_type);
1813 
1814   return success();
1815 }
1816 
isCompatibleReturnTypes(ArrayRef<Type> lhs,ArrayRef<Type> rhs)1817 bool UnpackOp::isCompatibleReturnTypes(ArrayRef<Type> lhs, ArrayRef<Type> rhs) {
1818   if (lhs.size() != rhs.size()) return false;
1819   for (auto pair : llvm::zip(lhs, rhs)) {
1820     if (failed(
1821             mlir::verifyCompatibleShape(std::get<0>(pair), std::get<1>(pair))))
1822       return false;
1823   }
1824   return true;
1825 }
1826 
1827 //===----------------------------------------------------------------------===//
1828 // SplitOp
1829 //===----------------------------------------------------------------------===//
1830 
1831 // Extracts and returns the signed integer constant in a 0-rank integer tensor
1832 // or 1-element 1-rank integer tensor if 'value' is a constant.
ExtractConstantIntFromTensor(Value value)1833 static llvm::Optional<int64_t> ExtractConstantIntFromTensor(Value value) {
1834   ElementsAttr attr;
1835   if (!matchPattern(value, m_Constant(&attr))) return {};
1836   if (attr.getNumElements() != 1) return {};
1837   IntegerAttr int_attr = *attr.getValues<IntegerAttr>().begin();
1838   return int_attr.getValue().getSExtValue();
1839 }
1840 
1841 // Returns a RankedTensorType which is similar to `input_type` but replaces the
1842 // dimension size of `dim` with `dim_size`.  For example,
1843 // `SubstituteRankedTensorTypeDimSize(tensor<3x4xi32>, 1, 2)` returns
1844 // `tensor<3x2xi32>`.
SubstituteRankedTensorTypeDimSize(RankedTensorType input_type,int64_t dim,int64_t dim_size)1845 static RankedTensorType SubstituteRankedTensorTypeDimSize(
1846     RankedTensorType input_type, int64_t dim, int64_t dim_size) {
1847   auto shape = input_type.getShape().vec();
1848   shape[dim] = dim_size;
1849   return RankedTensorType::get(shape, input_type.getElementType());
1850 }
1851 
1852 // Verifies the output tensor types of SplitOp or SplitVOp.
1853 template <typename ExpectedOutputTypeGetter>
VerifySplitOpOutputTypes(Operation * op,int64_t num_splits,ExpectedOutputTypeGetter get_expected_output_type)1854 static LogicalResult VerifySplitOpOutputTypes(
1855     Operation *op, int64_t num_splits,
1856     ExpectedOutputTypeGetter get_expected_output_type) {
1857   for (int64_t i = 0; i < num_splits; ++i) {
1858     auto expected_output_type = get_expected_output_type(i);
1859     Value output = op->getResult(i);
1860     if (failed(verifyCompatibleShape(output.getType(), expected_output_type)))
1861       return op->emitOpError()
1862              << "output #" << i << " should be " << expected_output_type
1863              << " instead got " << output.getType();
1864   }
1865   return success();
1866 }
1867 
Verify(SplitOp op)1868 static LogicalResult Verify(SplitOp op) {
1869   int64_t num_splits = op.num_splits();
1870   if (op.getNumResults() != num_splits)
1871     return op.emitOpError("output count should match 'num_splits' attribute");
1872 
1873   // If 'split_dim' is not a constant, there are no other checks.
1874   llvm::Optional<int64_t> split_dim_opt =
1875       ExtractConstantIntFromTensor(op.split_dim());
1876   if (!split_dim_opt) return success();
1877 
1878   // If 'input' is not a ranked tensor, there are no other checks.
1879   auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
1880   if (!input_type) return success();
1881 
1882   int64_t split_dim = split_dim_opt.getValue();
1883   const int64_t rank = input_type.getRank();
1884   if (split_dim < 0) split_dim += rank;
1885   if (split_dim < 0 || split_dim >= rank)
1886     return op.emitOpError("'split_dim' should be in [-rank, rank)");
1887 
1888   // If the 'split_dim' dimension of the 'input' tensor has a dynamic size,
1889   // there are no other checks.
1890   const int64_t dim_size = input_type.getDimSize(split_dim);
1891   if (ShapedType::isDynamic(dim_size)) return success();
1892 
1893   if (dim_size % num_splits != 0)
1894     return op.emitOpError("'num_splits' should evenly divide 'split_dim' axis");
1895 
1896   // Verifies output tensor types.
1897   RankedTensorType expected_output_type = SubstituteRankedTensorTypeDimSize(
1898       input_type, split_dim, dim_size / num_splits);
1899   return VerifySplitOpOutputTypes(
1900       op.getOperation(), num_splits,
1901       [expected_output_type](int64_t) { return expected_output_type; });
1902 }
1903 
Verify(SplitVOp op)1904 static LogicalResult Verify(SplitVOp op) {
1905   int64_t num_splits = op.num_splits();
1906   if (op.getNumResults() != num_splits)
1907     return op.emitOpError("output count should match 'num_splits' attribute");
1908 
1909   // If 'split_dim' is not a constant, there are no other checks.
1910   llvm::Optional<int64_t> split_dim_opt =
1911       ExtractConstantIntFromTensor(op.split_dim());
1912   if (!split_dim_opt) return success();
1913 
1914   // If 'input' is not a ranked tensor, there are no other checks.
1915   auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
1916   if (!input_type) return success();
1917 
1918   int64_t split_dim = split_dim_opt.getValue();
1919   const int64_t rank = input_type.getRank();
1920   if (split_dim < 0) split_dim += rank;
1921   if (split_dim < 0 || split_dim >= rank)
1922     return op.emitOpError("'split_dim' should be in [-rank, rank)");
1923 
1924   // If the 'split_dim' dimension of the 'input' tensor has a dynamic size,
1925   // there are no other checks.
1926   const int64_t dim_size = input_type.getDimSize(split_dim);
1927   if (ShapedType::isDynamic(dim_size)) return success();
1928 
1929   // If 'size_splits' is not a constant, there are no other checks.
1930   ElementsAttr size_splits_attr;
1931   if (!matchPattern(op.size_splits(), m_Constant(&size_splits_attr)))
1932     return success();
1933 
1934   if (size_splits_attr.getNumElements() != num_splits) {
1935     auto size_splits_type = op.size_splits().getType().cast<RankedTensorType>();
1936     RankedTensorType expected_size_splits_type =
1937         RankedTensorType::get({num_splits}, size_splits_type.getElementType());
1938     return op.emitOpError("'size_splits' should be ")
1939            << expected_size_splits_type;
1940   }
1941 
1942   // Normalizes and verifies 'size_splits'.
1943   // Note: TensorFlow allows one -1 element in 'size_splits'.  The -1 element
1944   // means the rest of the dimension size.
1945   llvm::SmallVector<int64_t, 4> size_splits;
1946   size_splits.reserve(num_splits);
1947 
1948   int64_t negative_size_split_loc = -1;
1949   int64_t total_size_splits = 0;
1950 
1951   for (int64_t i = 0; i < num_splits; ++i) {
1952     auto size_split_attr = size_splits_attr.getValue<IntegerAttr>(i);
1953     int64_t size_split = size_split_attr.getValue().getSExtValue();
1954     size_splits.push_back(size_split);
1955     if (size_split >= 0) {
1956       total_size_splits += size_split;
1957       continue;
1958     }
1959     if (size_split < -1)
1960       return op.emitOpError(
1961           "elements of 'size_splits' should be greater than or equal to -1");
1962     if (negative_size_split_loc != -1)
1963       return op.emitOpError("'size_splits' can only have one -1");
1964     negative_size_split_loc = i;
1965   }
1966 
1967   if (negative_size_split_loc != -1) {
1968     if (total_size_splits > dim_size)
1969       return op.emitOpError(
1970           "sum of non-negative elements of 'size_splits' is greater than the "
1971           "dimension size of 'split_dim' axis");
1972     size_splits[negative_size_split_loc] = dim_size - total_size_splits;
1973     total_size_splits = dim_size;
1974   }
1975 
1976   if (total_size_splits != dim_size)
1977     return op.emitOpError(
1978         "sum of 'size_splits' should match the dimension size of 'split_dim' "
1979         "axis");
1980 
1981   // Verifies result tensor types.
1982   auto get_expected_output_type = [input_type, split_dim,
1983                                    &size_splits](int64_t i) {
1984     return SubstituteRankedTensorTypeDimSize(input_type, split_dim,
1985                                              size_splits[i]);
1986   };
1987   return VerifySplitOpOutputTypes(op.getOperation(), num_splits,
1988                                   get_expected_output_type);
1989 }
1990 
1991 //===----------------------------------------------------------------------===//
1992 // MeanOp
1993 //===----------------------------------------------------------------------===//
1994 
1995 // TODO(b/133854225): Implement shape inference to Mean
1996 
1997 //===----------------------------------------------------------------------===//
1998 // LSTMOp
1999 //===----------------------------------------------------------------------===//
2000 
Verify(LSTMOp op)2001 static LogicalResult Verify(LSTMOp op) {
2002   auto operands = op.GetStatefulOperands();
2003   if (operands.size() != 2 || operands[0] != 18 || operands[1] != 19) {
2004     return op.emitOpError("LSTMOp expected to have two stateful operands");
2005   }
2006 
2007   const auto input_type = op.input().getType().cast<ShapedType>();
2008   // Since TFLite runtime generally supports dynamic shape/rank, if `input_type`
2009   // doesn't have static shape, we skip the shape check below.
2010   if (!input_type.hasStaticShape()) return success();
2011   // The input should be at least 2D tensor since it will go through fully
2012   // connected layer.
2013   if (!input_type.hasRank() || input_type.getRank() < 2)
2014     return op.emitOpError(
2015         "the first input operand should have more than 2 dimensions.");
2016 
2017   const auto activation_state =
2018       op.input_activation_state().getType().cast<ShapedType>();
2019   const auto cell_state = op.input_cell_state().getType().cast<ShapedType>();
2020   const auto input_to_output_weights =
2021       op.input_to_output_weights().getType().cast<ShapedType>();
2022   const auto recurrent_to_output_weights =
2023       op.recurrent_to_output_weights().getType().cast<ShapedType>();
2024   if (activation_state.hasStaticShape() && cell_state.hasStaticShape() &&
2025       input_to_output_weights.hasStaticShape() &&
2026       recurrent_to_output_weights.hasStaticShape()) {
2027     const int n_input = input_type.getDimSize(input_type.getRank() - 1);
2028     const int n_cell = input_to_output_weights.getDimSize(0);
2029     const int n_output = recurrent_to_output_weights.getDimSize(1);
2030     const int output_state_size = activation_state.getNumElements();
2031     const int n_batch = input_type.getRank() == 2 ? input_type.getDimSize(0)
2032                                                   : input_type.getDimSize(1);
2033     const int state_size = cell_state.getNumElements();
2034 
2035     // Check if the dimension of the inputs matches.
2036     if ((output_state_size != n_batch * n_output) ||
2037         (state_size != n_batch * n_cell) ||
2038         (input_to_output_weights.getDimSize(1) != n_input) ||
2039         (recurrent_to_output_weights.getRank() != 2) ||
2040         (recurrent_to_output_weights.getDimSize(0) != n_cell) ||
2041         (input_to_output_weights.getRank() != 2)) {
2042       return op.emitOpError("inputs don't match with the dimensions.");
2043     }
2044 
2045     const bool is_layer_norm_lstm =
2046         !op.forget_layer_norm_coefficients().getType().isa<NoneType>();
2047     if (is_layer_norm_lstm) {
2048       const auto forget_layer_norm_coefficients =
2049           op.forget_layer_norm_coefficients().getType().cast<ShapedType>();
2050       // If this lstm has layer normalization, this input value,
2051       // "forget_layer_norm_coefficients" should be a 1D tensor.
2052       if (!forget_layer_norm_coefficients.hasRank() ||
2053           forget_layer_norm_coefficients.getRank() != 1 ||
2054           forget_layer_norm_coefficients.getDimSize(0) != n_cell)
2055         return op.emitOpError(
2056             "coefficient inputs have more than 2 dimensions or "
2057             "don't match the dimension with input operand "
2058             "`input_to_output_weights`.");
2059     }
2060   }
2061 
2062   return success();
2063 }
2064 
2065 namespace {
2066 
2067 // Replaces the optional bias operands with a "none" type value if the bias
2068 // values are constant zeros.
2069 struct RemoveLSTMOpZeroBias : public OpRewritePattern<LSTMOp> {
2070   using OpRewritePattern<LSTMOp>::OpRewritePattern;
2071 
matchAndRewritemlir::TFL::__anonc99a34741411::RemoveLSTMOpZeroBias2072   LogicalResult matchAndRewrite(LSTMOp op,
2073                                 PatternRewriter &rewriter) const override {
2074     if (EqualsZero(op.input_gate_bias())) {
2075       auto none_value = rewriter.create<mlir::ConstantOp>(
2076           rewriter.getUnknownLoc(), rewriter.getUnitAttr());
2077       op.input_gate_biasMutable().assign(none_value);
2078     }
2079 
2080     if (EqualsZero(op.projection_bias())) {
2081       auto none_value = rewriter.create<mlir::ConstantOp>(
2082           rewriter.getUnknownLoc(), rewriter.getUnitAttr());
2083       op.projection_biasMutable().assign(none_value);
2084     }
2085 
2086     return success();
2087   }
2088 };
2089 
2090 }  // namespace
2091 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2092 void LSTMOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2093                                          MLIRContext *context) {
2094   results.insert<RemoveLSTMOpZeroBias>(context);
2095 }
2096 
2097 //===----------------------------------------------------------------------===//
2098 // UnidirectionalSequenceLSTMOp
2099 //===----------------------------------------------------------------------===//
2100 
Verify(UnidirectionalSequenceLSTMOp op)2101 static LogicalResult Verify(UnidirectionalSequenceLSTMOp op) {
2102   auto operands = op.GetStatefulOperands();
2103   if (operands.size() == 2 && operands[0] == 18 && operands[1] == 19) {
2104     return success();
2105   }
2106   return op.emitError(
2107       "UnidirectionalSequenceLSTMOp expected to have two stateful operands");
2108 }
2109 
2110 //===----------------------------------------------------------------------===//
2111 // BidirectionalSequenceLSTMOp
2112 //===----------------------------------------------------------------------===//
2113 
Verify(BidirectionalSequenceLSTMOp op)2114 static LogicalResult Verify(BidirectionalSequenceLSTMOp op) {
2115   auto operands = op.GetStatefulOperands();
2116   if (operands.size() == 4 && operands[0] == 35 && operands[1] == 36 &&
2117       operands[2] == 37 && operands[3] == 38) {
2118     return success();
2119   }
2120   return op.emitError(
2121       "BidirectionalSequenceLSTMOp expected to have four stateful operands");
2122 }
2123 
2124 //===----------------------------------------------------------------------===//
2125 // UnidirectionalSequenceRNNOp
2126 //===----------------------------------------------------------------------===//
2127 
Verify(UnidirectionalSequenceRNNOp op)2128 static LogicalResult Verify(UnidirectionalSequenceRNNOp op) {
2129   auto operands = op.GetStatefulOperands();
2130   if (operands.size() == 1 && operands[0] == 4) {
2131     return success();
2132   }
2133   return op.emitError(
2134       "UnidirectionalSequenceRNNOp expected to have one stateful operand");
2135 }
2136 
2137 //===----------------------------------------------------------------------===//
2138 // SvdfOp
2139 //===----------------------------------------------------------------------===//
2140 
Verify(SVDFOp op)2141 static LogicalResult Verify(SVDFOp op) {
2142   auto operands = op.GetStatefulOperands();
2143   if (operands.size() == 1 && operands[0] == 4) {
2144     return success();
2145   }
2146   return op.emitError("SvdfOp expected to have one stateful operand");
2147 }
2148 
2149 //===----------------------------------------------------------------------===//
2150 // AbsOp
2151 //===----------------------------------------------------------------------===//
2152 
fold(ArrayRef<Attribute> operands)2153 OpFoldResult AbsOp::fold(ArrayRef<Attribute> operands) {
2154   Type result_type = getType();
2155   // Only constant fold for tensor of f32 is implemented.
2156   if (!IsF32ShapedType(result_type)) return nullptr;
2157 
2158   auto compute = [](APFloat value) -> APFloat { return llvm::abs(value); };
2159   return ConstFoldUnaryOp(result_type, operands[0], compute);
2160 }
2161 
2162 //===----------------------------------------------------------------------===//
2163 // NegOp
2164 //===----------------------------------------------------------------------===//
2165 
fold(ArrayRef<Attribute> operands)2166 OpFoldResult NegOp::fold(ArrayRef<Attribute> operands) {
2167   Type result_type = getType();
2168   // Only constant fold for tensor of f32 is implemented.
2169   if (!IsF32ShapedType(result_type)) return nullptr;
2170 
2171   auto compute = [](APFloat value) -> APFloat { return llvm::neg(value); };
2172   return ConstFoldUnaryOp(result_type, operands[0], compute);
2173 }
2174 
2175 //===----------------------------------------------------------------------===//
2176 // SinOp
2177 //===----------------------------------------------------------------------===//
2178 
fold(ArrayRef<Attribute> operands)2179 OpFoldResult SinOp::fold(ArrayRef<Attribute> operands) {
2180   Type result_type = getType();
2181   // Only constant fold for tensor of f32 is implemented.
2182   if (!IsF32ShapedType(result_type)) return nullptr;
2183 
2184   auto compute = [](APFloat value) -> APFloat {
2185     float f = value.convertToFloat();
2186     float result = std::sin(f);
2187     return APFloat(result);
2188   };
2189   return ConstFoldUnaryOp(result_type, operands[0], compute);
2190 }
2191 
2192 //===----------------------------------------------------------------------===//
2193 // CosOp
2194 //===----------------------------------------------------------------------===//
2195 
fold(ArrayRef<Attribute> operands)2196 OpFoldResult CosOp::fold(ArrayRef<Attribute> operands) {
2197   Type result_type = getType();
2198   // Only constant fold for tensor of f32 is implemented.
2199   if (!IsF32ShapedType(result_type)) return nullptr;
2200 
2201   auto compute = [](APFloat value) -> APFloat {
2202     float f = value.convertToFloat();
2203     float result = std::cos(f);
2204     return APFloat(result);
2205   };
2206   return ConstFoldUnaryOp(result_type, operands[0], compute);
2207 }
2208 
2209 //===----------------------------------------------------------------------===//
2210 // LogOp
2211 //===----------------------------------------------------------------------===//
2212 
fold(ArrayRef<Attribute> operands)2213 OpFoldResult LogOp::fold(ArrayRef<Attribute> operands) {
2214   Type result_type = getType();
2215   // Only constant fold for tensor of f32 is implemented.
2216   if (!IsF32ShapedType(result_type)) return nullptr;
2217 
2218   auto compute = [](APFloat value) -> APFloat {
2219     float f = value.convertToFloat();
2220     float result = std::log(f);
2221     return APFloat(result);
2222   };
2223   return ConstFoldUnaryOp(result_type, operands[0], compute);
2224 }
2225 
2226 //===----------------------------------------------------------------------===//
2227 // SqrtOp
2228 //===----------------------------------------------------------------------===//
2229 
fold(ArrayRef<Attribute> operands)2230 OpFoldResult SqrtOp::fold(ArrayRef<Attribute> operands) {
2231   Type result_type = getType();
2232   // Only constant fold for tensor of f32 is implemented.
2233   if (!IsF32ShapedType(result_type)) return nullptr;
2234 
2235   auto compute = [](APFloat value) -> APFloat {
2236     float f = value.convertToFloat();
2237     float result = std::sqrt(f);
2238     return APFloat(result);
2239   };
2240   return ConstFoldUnaryOp(result_type, operands[0], compute);
2241 }
2242 
2243 //===----------------------------------------------------------------------===//
2244 // RsqrtOp
2245 //===----------------------------------------------------------------------===//
2246 
fold(ArrayRef<Attribute> operands)2247 OpFoldResult RsqrtOp::fold(ArrayRef<Attribute> operands) {
2248   Type result_type = getType();
2249   // Only constant fold for tensor of f32/bf16 is implemented.
2250   if (!IsF32ShapedType(result_type) && !IsBF16ShapedType(result_type))
2251     return nullptr;
2252 
2253   auto compute = [](APFloat value) -> APFloat {
2254     bool loseInfo;
2255     const llvm::fltSemantics &original_float_semantics = value.getSemantics();
2256     value.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven,
2257                   &loseInfo);
2258     float f = value.convertToFloat();
2259     APFloat result(1.f / std::sqrt(f));
2260     result.convert(original_float_semantics, APFloat::rmNearestTiesToEven,
2261                    &loseInfo);
2262     return result;
2263   };
2264   return ConstFoldUnaryOp(result_type, operands[0], compute);
2265 }
2266 
2267 //===----------------------------------------------------------------------===//
2268 // SquareOp
2269 //===----------------------------------------------------------------------===//
2270 
fold(ArrayRef<Attribute> operands)2271 OpFoldResult SquareOp::fold(ArrayRef<Attribute> operands) {
2272   Type result_type = getType();
2273   // Only constant fold for tensor of f32 is implemented.
2274   if (!IsF32ShapedType(result_type)) return nullptr;
2275 
2276   auto compute = [](APFloat value) -> APFloat { return value * value; };
2277   return ConstFoldUnaryOp(result_type, operands[0], compute);
2278 }
2279 
2280 //===----------------------------------------------------------------------===//
2281 // RankOp
2282 //===----------------------------------------------------------------------===//
2283 
fold(ArrayRef<Attribute> operands)2284 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
2285   assert(operands.size() == 1);
2286   auto result_type = getType().cast<ShapedType>();
2287   if (auto elements_attr = operands[0].dyn_cast_or_null<ElementsAttr>()) {
2288     auto rank = static_cast<int32_t>(elements_attr.getType().getRank());
2289     return DenseElementsAttr::get(result_type, {rank});
2290   }
2291 
2292   // Also fold if `input` has a known rank.
2293   auto input_type = input().getType().cast<ShapedType>();
2294   // Do not fold if rank is zero because the TFLite converter doesn't
2295   // distinguish between unranked input and scalar input due to b/138865275.
2296   // TODO(b/138865275): Remove `input_type.getRank() != 0` in the following
2297   // predicate and fold the op when rank is zero.
2298   if (input_type.hasRank() && input_type.getRank() != 0) {
2299     auto rank = static_cast<int32_t>(input_type.getRank());
2300     return DenseElementsAttr::get(result_type, {rank});
2301   }
2302 
2303   return nullptr;
2304 }
2305 
2306 //===----------------------------------------------------------------------===//
2307 // ConstOp
2308 //===----------------------------------------------------------------------===//
2309 
fold(ArrayRef<Attribute> operands)2310 OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
2311   assert(operands.empty() && "constant has no operands");
2312 
2313   // Return the held attribute value.
2314   return value();
2315 }
2316 
2317 //===----------------------------------------------------------------------===//
2318 // CastOp
2319 //===----------------------------------------------------------------------===//
2320 
fold(ArrayRef<Attribute> operands)2321 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
2322   assert(operands.size() == 1);
2323   if (getElementTypeOrSelf(input()) == getElementTypeOrSelf(getType())) {
2324     return input();
2325   }
2326 
2327   // For now, only supports cast between integer types.
2328   auto elements_attr = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
2329   if (!elements_attr) {
2330     return nullptr;
2331   }
2332 
2333   auto result_element_type =
2334       getType().cast<ShapedType>().getElementType().dyn_cast<IntegerType>();
2335   auto operand_element_type = input()
2336                                   .getType()
2337                                   .cast<ShapedType>()
2338                                   .getElementType()
2339                                   .dyn_cast<IntegerType>();
2340   // Returns nullptr if either result/operand element type is not integer.
2341   if (!result_element_type || !operand_element_type) {
2342     return nullptr;
2343   }
2344 
2345   const bool is_unsigned = operand_element_type.isUnsigned();
2346   const bool involves_bool = operand_element_type.getWidth() == 1 ||
2347                              result_element_type.getWidth() == 1;
2348   const int output_bitwidth = result_element_type.getWidth();
2349   // The integer cast op is the same as C integer cast. Depends on the operand
2350   // type's signedness, we will determine whether or not sign extension is
2351   // needed.
2352   auto cast = [&](APInt value) {
2353     if (involves_bool) {
2354       // Handle boolean inputs or outputs explicitly as it doesn't have the same
2355       // behavior as extension or truncation.
2356       // true input should always be cast to 1 and not -1 as the sign extension
2357       // would do for signed outputs. Similarly, non-zero inputs should be cast
2358       // to true. Truncating even numbers to one bit will result in `false`.
2359       return APInt(result_element_type.getWidth(), value != 0);
2360     }
2361     return is_unsigned ? value.zextOrTrunc(output_bitwidth)
2362                        : value.sextOrTrunc(output_bitwidth);
2363   };
2364 
2365   return elements_attr.mapValues(result_element_type, cast);
2366 }
2367 
2368 //===----------------------------------------------------------------------===//
2369 // SelectV2Op
2370 //===----------------------------------------------------------------------===//
2371 
BuildSelectV2Op(Builder * builder,OperationState & result,Value cond,Value x,Value y)2372 static void BuildSelectV2Op(Builder *builder, OperationState &result,
2373                             Value cond, Value x, Value y) {
2374   auto operand_type =
2375       OpTrait::util::getBroadcastedType(x.getType(), y.getType());
2376 
2377   if (!operand_type)
2378     emitError(result.location) << "non-broadcastable operands: " << x.getType()
2379                                << " and " << y.getType();
2380 
2381   bool has_static_cond_shape = false;
2382   bool has_static_operand_shape = false;
2383   ArrayRef<int64_t> cond_shape;
2384   ArrayRef<int64_t> operand_shape;
2385 
2386   if (auto shaped_type = cond.getType().dyn_cast<ShapedType>()) {
2387     if (shaped_type.hasStaticShape()) {
2388       has_static_cond_shape = true;
2389       cond_shape = shaped_type.getShape();
2390     }
2391   }
2392   if (auto shaped_type = operand_type.dyn_cast<ShapedType>()) {
2393     if (shaped_type.hasStaticShape()) {
2394       has_static_operand_shape = true;
2395       operand_shape = shaped_type.getShape();
2396     }
2397   }
2398 
2399   SmallVector<int64_t, 4> broadcastedShape;
2400   if (has_static_cond_shape && has_static_operand_shape &&
2401       !OpTrait::util::getBroadcastedShape(cond_shape, operand_shape,
2402                                           broadcastedShape)) {
2403     emitError(result.location) << "non-broadcastable operands: " << operand_type
2404                                << " and " << cond.getType();
2405   }
2406 
2407   result.addOperands({cond, x, y});
2408 
2409   auto elementType = x.getType().dyn_cast<ShapedType>().getElementType();
2410   if (has_static_cond_shape && has_static_operand_shape) {
2411     result.types.push_back(
2412         RankedTensorType::get(broadcastedShape, elementType));
2413   } else {
2414     result.types.push_back(UnrankedTensorType::get(elementType));
2415   }
2416 }
2417 
2418 //===----------------------------------------------------------------------===//
2419 // RangeOp
2420 //===----------------------------------------------------------------------===//
2421 
2422 namespace {
2423 
2424 // Compute the length of a range (1-D) tensor given `start`, `limit`, `delta`.
2425 // Template parameter `FloatOrInt` must be standard C integer or floating-point
2426 // types.
2427 template <typename FloatOrInt>
GetLengthOfRange(FloatOrInt start,FloatOrInt limit,FloatOrInt delta)2428 int GetLengthOfRange(FloatOrInt start, FloatOrInt limit, FloatOrInt delta) {
2429   // Refer to the implementation in
2430   // tensorflow/lite/kernels/range.cc.
2431   return std::is_integral<FloatOrInt>::value
2432              ? ((std::abs(limit - start) + std::abs(delta) - 1) /
2433                 std::abs(delta))
2434              : std::ceil(std::abs((limit - start) / delta));
2435 }
2436 
2437 // Builds a constant range tensor of `result_elem_type` elements.
2438 // Template parameter `FloatOrIntAtrr` must be mlir::IntegerAttr or
2439 // mlir::FloatAttr.
2440 template <typename FloatOrIntAtrr>
BuildConstRangeTensor(Type result_elem_type,int num_elements,FloatOrIntAtrr start_attr,FloatOrIntAtrr delta_attr)2441 DenseElementsAttr BuildConstRangeTensor(Type result_elem_type, int num_elements,
2442                                         FloatOrIntAtrr start_attr,
2443                                         FloatOrIntAtrr delta_attr) {
2444   using ValueType = typename FloatOrIntAtrr::ValueType;  // APInt or APFloat
2445   ValueType start = start_attr.getValue();
2446   ValueType delta = delta_attr.getValue();
2447 
2448   SmallVector<ValueType, 16> new_values;
2449   new_values.reserve(num_elements);
2450   ValueType new_value = start;
2451   for (int i = 0; i < num_elements; ++i) {
2452     new_values.push_back(new_value);
2453     new_value = new_value + delta;
2454   }
2455   // Result is always a 1-D tensor.
2456   auto new_result_type =
2457       RankedTensorType::get({num_elements}, result_elem_type);
2458   return DenseElementsAttr::get(new_result_type, new_values);
2459 }
2460 }  // namespace
2461 
fold(ArrayRef<Attribute> operands)2462 OpFoldResult RangeOp::fold(ArrayRef<Attribute> operands) {
2463   assert(operands.size() == 3);
2464   auto start_tensor = operands[0].dyn_cast_or_null<ElementsAttr>();
2465   auto limit_tensor = operands[1].dyn_cast_or_null<ElementsAttr>();
2466   auto delta_tensor = operands[2].dyn_cast_or_null<ElementsAttr>();
2467   if (start_tensor && limit_tensor && delta_tensor) {
2468     // Operands should all be scalars
2469     assert(start_tensor.getType().getRank() == 0 &&
2470            limit_tensor.getType().getRank() == 0 &&
2471            delta_tensor.getType().getRank() == 0);
2472     Type elem_type = getType().cast<ShapedType>().getElementType();
2473     if (elem_type.isSignlessInteger()) {
2474       auto start_attr = start_tensor.getValue<IntegerAttr>({});
2475       auto limit_attr = limit_tensor.getValue<IntegerAttr>({});
2476       auto delta_attr = delta_tensor.getValue<IntegerAttr>({});
2477       const int num_elements = GetLengthOfRange(
2478           start_attr.getInt(), limit_attr.getInt(), delta_attr.getInt());
2479       return BuildConstRangeTensor(elem_type, num_elements, start_attr,
2480                                    delta_attr);
2481     } else if (elem_type.isa<FloatType>()) {
2482       auto start_attr = start_tensor.getValue<FloatAttr>({});
2483       auto limit_attr = limit_tensor.getValue<FloatAttr>({});
2484       auto delta_attr = delta_tensor.getValue<FloatAttr>({});
2485       const int num_elements = GetLengthOfRange(start_attr.getValueAsDouble(),
2486                                                 limit_attr.getValueAsDouble(),
2487                                                 delta_attr.getValueAsDouble());
2488       return BuildConstRangeTensor(elem_type, num_elements, start_attr,
2489                                    delta_attr);
2490     }
2491   }
2492 
2493   return nullptr;
2494 }
2495 
2496 //===----------------------------------------------------------------------===//
2497 // TransposeConvOp
2498 //===----------------------------------------------------------------------===//
2499 
Verify(TransposeConvOp op)2500 static LogicalResult Verify(TransposeConvOp op) {
2501   ShapedType output_type = op.output().getType().cast<ShapedType>();
2502   ShapedType output_shape_type = op.output_shape().getType().cast<ShapedType>();
2503   if (output_type.hasRank() && output_shape_type.hasStaticShape()) {
2504     if (output_type.getRank() != output_shape_type.getDimSize(0)) {
2505       return op.emitOpError(llvm::formatv(
2506           "expect output type has rank = {0}, got output type {1}",
2507           output_shape_type.getDimSize(0), output_type));
2508     }
2509   }
2510 
2511   DenseIntElementsAttr output_shape_elements;
2512   if (!matchPattern(op.output_shape(), m_Constant(&output_shape_elements))) {
2513     return success();
2514   }
2515 
2516   llvm::SmallVector<int64_t, 4> output_shape;
2517   output_shape.reserve(output_shape_elements.getNumElements());
2518   for (auto dim : output_shape_elements.getValues<int>()) {
2519     output_shape.push_back(dim);
2520   }
2521 
2522   auto expected_output_type =
2523       RankedTensorType::get(output_shape, output_type.getElementType());
2524   if (failed(mlir::verifyCompatibleShape(output_type, expected_output_type))) {
2525     return op.emitOpError(llvm::formatv("expect output type {0}, got {1}",
2526                                         expected_output_type, output_type));
2527   }
2528 
2529   return success();
2530 }
2531 
2532 //===----------------------------------------------------------------------===//
2533 // TransposeOp
2534 //===----------------------------------------------------------------------===//
2535 
2536 namespace {
2537 
2538 // Computes the permutation of a constant `input_tensor` according to `perm`.
2539 // The function recursively traverses the dimensions of the output tensor in
2540 // a row-major order and writes the value in the output tensor into
2541 // `new_values`.
ComputePermutation(ElementsAttr input_tensor,ArrayRef<int32_t> perm,ArrayRef<int64_t> output_shape,int num_dimensions,int output_axis,std::vector<uint64_t> * input_indices,std::vector<Attribute> * new_values)2542 void ComputePermutation(ElementsAttr input_tensor, ArrayRef<int32_t> perm,
2543                         ArrayRef<int64_t> output_shape, int num_dimensions,
2544                         int output_axis, std::vector<uint64_t> *input_indices,
2545                         std::vector<Attribute> *new_values) {
2546   // Refer to the implementation of `Transpose` function in
2547   // tensorflow/lite/kernels/internal/reference/reference_ops.h
2548   assert(output_axis < num_dimensions);
2549   const int input_axis = perm[output_axis];
2550   for (int i = 0; i < output_shape[output_axis]; ++i) {
2551     // Update the input indices on `input_axis`.
2552     input_indices->at(input_axis) = i;
2553     // Write the value from `input_tensor` if it is the last axis or
2554     // recurse into the next axis.
2555     const bool is_last_axis = output_axis == num_dimensions - 1;
2556     if (is_last_axis) {
2557       new_values->push_back(input_tensor.getValue(*input_indices));
2558     } else {
2559       ComputePermutation(input_tensor, perm, output_shape, num_dimensions,
2560                          output_axis + 1, input_indices, new_values);
2561     }
2562   }
2563 }
2564 
2565 }  // namespace
2566 
fold(ArrayRef<Attribute> operands)2567 OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
2568   assert(operands.size() == 2);
2569   auto input_tensor = operands[0].dyn_cast_or_null<ElementsAttr>();
2570   auto perm_tensor = operands[1].dyn_cast_or_null<ElementsAttr>();
2571   if (!input_tensor || !perm_tensor) return nullptr;
2572 
2573   // Do not try to fold elements attr of a quant type because
2574   // DenseElementsAttr does not support it.
2575   if (!getType().cast<ShapedType>().getElementType().isSignlessIntOrFloat())
2576     return nullptr;
2577 
2578   assert(perm_tensor.getType().getRank() == 1);
2579   const int num_dimensions = input_tensor.getType().getRank();
2580   assert(perm_tensor.getType().getNumElements() == num_dimensions);
2581 
2582   ArrayRef<int64_t> input_shape = input_tensor.getType().getShape();
2583   auto output_type = getType().cast<ShapedType>();
2584 
2585   SmallVector<int32_t, 4> perm;
2586   SmallVector<int64_t, 4> output_shape;
2587   for (int i = 0; i < num_dimensions; ++i) {
2588     perm.push_back(
2589         perm_tensor.getValue<IntegerAttr>({static_cast<uint64_t>(i)}).getInt());
2590     output_shape.push_back(input_shape[perm[i]]);
2591 
2592     // Check that the derived output shape matches the static shape.
2593     assert(!output_type.hasStaticShape() ||
2594            output_type.getShape()[i] == output_shape[i]);
2595   }
2596 
2597   std::vector<Attribute> new_values;
2598   new_values.reserve(input_tensor.getType().getNumElements());
2599   std::vector<uint64_t> input_indices(num_dimensions);
2600   ComputePermutation(input_tensor, perm, output_shape, num_dimensions,
2601                      /*output_axis=*/0, &input_indices, &new_values);
2602   auto result_type =
2603       RankedTensorType::get(output_shape, output_type.getElementType());
2604   return DenseElementsAttr::get(result_type, new_values);
2605 }
2606 
Verify(TransposeOp op)2607 static LogicalResult Verify(TransposeOp op) {
2608   auto input_type = op.input().getType().cast<ShapedType>();
2609   auto perm_type = op.perm().getType().cast<ShapedType>();
2610   auto output_type = op.output().getType().cast<ShapedType>();
2611   if (input_type.hasStaticShape() && perm_type.hasStaticShape()) {
2612     if (perm_type.getNumElements() != input_type.getRank()) {
2613       return op.emitOpError(
2614           "perm tensor elements size is not equal to input tensor rank");
2615     }
2616   }
2617 
2618   DenseIntElementsAttr perm;
2619   if (!matchPattern(op.perm(), m_Constant(&perm))) {
2620     return success();
2621   }
2622 
2623   int index = 0;
2624   llvm::SmallVector<int64_t, 4> axes;
2625   for (const auto &axis_int : perm.getValues<APInt>()) {
2626     const int64_t axis = axis_int.getSExtValue();
2627     if (axis < 0 || (input_type.hasRank() && axis >= input_type.getRank())) {
2628       return op.emitOpError(
2629           llvm::formatv("perm[{0}] must be in [0, rank)", index));
2630     }
2631     if (std::count(axes.begin(), axes.end(), axis) > 0) {
2632       return op.emitOpError(
2633           llvm::formatv("perm[{0}] cannot have duplicated axis", index));
2634     }
2635     axes.push_back(axis);
2636     index++;
2637   }
2638 
2639   if (input_type.hasStaticShape() && output_type.hasStaticShape()) {
2640     llvm::SmallVector<int64_t, 4> transposed_shape;
2641     for (int64_t axis : axes) {
2642       transposed_shape.push_back(input_type.getDimSize(axis));
2643     }
2644     auto expected_output_type =
2645         RankedTensorType::get(transposed_shape, input_type.getElementType());
2646     if (failed(
2647             mlir::verifyCompatibleShape(output_type, expected_output_type))) {
2648       return op.emitOpError(llvm::formatv("expect output type {0}, got {1}",
2649                                           expected_output_type, output_type));
2650     }
2651   }
2652 
2653   return success();
2654 }
2655 
2656 //===----------------------------------------------------------------------===//
2657 // WhileOp
2658 //===----------------------------------------------------------------------===//
2659 
Verify(WhileOp op)2660 LogicalResult Verify(WhileOp op) {
2661   if (op.getNumOperands() != op.getNumResults())
2662     return op.emitOpError(llvm::formatv(
2663         "number of operands does not match number of results ({0} != {1})",
2664         op.getNumOperands(), op.getNumResults()));
2665   // TODO(jpienaar): Verify operand, result & block arguments types
2666   return success();
2667 }
2668 
2669 namespace {
2670 // Canonicalize While op so that results and operands match and external values
2671 // are via implicit capture rather than via block args.
2672 struct WhileResultOperandsMatchAndImplicitCapture
2673     : public OpRewritePattern<WhileOp> {
2674   using OpRewritePattern<WhileOp>::OpRewritePattern;
2675 
matchAndRewritemlir::TFL::__anonc99a34742011::WhileResultOperandsMatchAndImplicitCapture2676   LogicalResult matchAndRewrite(WhileOp while_op,
2677                                 PatternRewriter &rewriter) const override {
2678     // Replace values simply passed through the body with extern values
2679     // (in both body and condition regions as well as while result). The
2680     // block arguments of body and while match and so the corresponding cond
2681     // argument can be easily found.
2682     bool unchanged = true;
2683     auto &body_block = while_op.body().front();
2684     auto &cond_block = while_op.cond().front();
2685     auto &yield = *body_block.getTerminator();
2686     for (auto ba : body_block.getArguments()) {
2687       int arg_no = ba.getArgNumber();
2688       // Skip removing resources that are not read-only variables.
2689       if (getElementTypeOrSelf(ba.getType()).isa<TF::ResourceType>()) {
2690         bool has_read_only_variables = true;
2691         for (auto user : ba.getUsers()) {
2692           // Ternimator ops, for example, tfl::yield op, should be ignored since
2693           // the argument can be used for yielding as the `body` function result
2694           // and that does not give any meaningful points to the decision
2695           // whether the given arugment is a read-only variable or not.
2696           if (user->hasTrait<OpTrait::IsTerminator>()) continue;
2697           if (!llvm::isa<mlir::TF::ReadVariableOp>(user)) {
2698             has_read_only_variables = false;
2699             break;
2700           }
2701         }
2702         if (!has_read_only_variables) continue;
2703       }
2704       if (ba == yield.getOperand(arg_no)) {
2705         unchanged = false;
2706         auto value = while_op.getOperand(arg_no);
2707         ba.replaceAllUsesWith(value);
2708         cond_block.getArgument(arg_no).replaceAllUsesWith(value);
2709 
2710         // This could be relaxed and casts inserted.
2711         if (while_op.getResult(arg_no).getType() == value.getType())
2712           while_op.getResult(arg_no).replaceAllUsesWith(value);
2713       }
2714     }
2715 
2716     // The While ops operands and result types need to match
2717     SmallVector<Value, 4> new_operands;
2718     SmallVector<Value, 4> new_body_yield;
2719     SmallVector<bool, 4> removed_operand(while_op.getNumOperands(), false);
2720     llvm::SmallVector<Type, 4> types;
2721     new_operands.reserve(while_op.getNumOperands());
2722     new_body_yield.reserve(while_op.getNumOperands());
2723     types.reserve(while_op.getNumOperands());
2724 
2725     // Remove block arguments not used in either cond or body. This leaves the
2726     // block arguments of body and cond matching still.
2727     int arg_index = 0;
2728     for (int while_index = 0, e = while_op.getNumOperands(); while_index < e;
2729          ++while_index) {
2730       auto value = while_op.getOperand(while_index);
2731       if (body_block.getArgument(arg_index).use_empty() &&
2732           cond_block.getArgument(arg_index).use_empty() &&
2733           // Note: since we are not erasing results, need to use while_index
2734           // to check if the corresponding result is unused.
2735           while_op.getResult(while_index).use_empty()) {
2736         unchanged = false;
2737         body_block.eraseArgument(arg_index);
2738         cond_block.eraseArgument(arg_index);
2739 
2740         // Mark operand for removal.
2741         removed_operand[while_index] = true;
2742       } else {
2743         new_operands.push_back(value);
2744         new_body_yield.push_back(yield.getOperand(while_index));
2745         auto type = while_op.getResult(while_index).getType();
2746         types.push_back(type);
2747         ++arg_index;
2748       }
2749     }
2750 
2751     // Done if no values removed from blocks and operands & results match.
2752     if (unchanged) return failure();
2753 
2754     // Replace with new While with matching operands and results.
2755     Operation *op = while_op.getOperation();
2756     Operation *new_op = rewriter.insert(
2757         Operation::create(op->getLoc(), op->getName(), types, new_operands,
2758                           op->getAttrs(), {}, /*numRegions=*/2));
2759 
2760     for (int i = 0; i < 2; ++i) new_op->getRegion(i).takeBody(op->getRegion(i));
2761     int new_index = 0;
2762     for (int op_index = 0, e = op->getNumResults(); op_index < e; ++op_index) {
2763       if (removed_operand[op_index]) continue;
2764       op->getResult(op_index).replaceAllUsesWith(new_op->getResult(new_index));
2765       ++new_index;
2766     }
2767     rewriter.eraseOp(op);
2768 
2769     Block &new_body_block = cast<WhileOp>(new_op).body().front();
2770     rewriter.setInsertionPointToEnd(&new_body_block);
2771     rewriter.replaceOpWithNewOp<YieldOp>(new_body_block.getTerminator(),
2772                                          new_body_yield);
2773 
2774     return success();
2775   }
2776 };
2777 
2778 }  // namespace
2779 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2780 void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2781                                           MLIRContext *context) {
2782   results.insert<WhileResultOperandsMatchAndImplicitCapture>(context);
2783 }
2784 
getLoopBody()2785 Region &WhileOp::getLoopBody() { return body(); }
2786 
isDefinedOutsideOfLoop(Value value)2787 bool WhileOp::isDefinedOutsideOfLoop(Value value) {
2788   // TODO(jpienaar): This is to overly conservative and disables anything other
2789   // than constant hoisting initially.
2790   return false;
2791 }
2792 
moveOutOfLoop(llvm::ArrayRef<mlir::Operation * > ops)2793 LogicalResult WhileOp::moveOutOfLoop(llvm::ArrayRef<mlir::Operation *> ops) {
2794   if (ops.empty()) return success();
2795 
2796   // Move the hoisted value to just before the while.
2797   Operation *while_op = this->getOperation();
2798   for (auto op : ops) op->moveBefore(while_op);
2799 
2800   return success();
2801 }
2802 
2803 //===----------------------------------------------------------------------===//
2804 // TableGen'd op method definitions
2805 //===----------------------------------------------------------------------===//
2806 
2807 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.cc.inc"
2808 
2809 }  // namespace TFL
2810 }  // namespace mlir
2811 
2812 #define GET_OP_CLASSES
2813 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
2814 
2815 namespace mlir {
2816 namespace TFL {
2817 
2818 #include "tensorflow/compiler/mlir/lite/runtime_verifiers.inc"
2819 
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)2820 Operation *TensorFlowLiteDialect::materializeConstant(OpBuilder &builder,
2821                                                       Attribute value,
2822                                                       Type type, Location loc) {
2823   // If this is an opaque elements attribute or the result type doesn't match
2824   // the attribute type, then generate a tfl.pseudo_const.
2825   if (value.isa<OpaqueElementsAttr>() ||
2826       (value.isa<ElementsAttr>() && value.getType() != type))
2827     return builder.create<ConstOp>(loc, type, value.cast<ElementsAttr>());
2828   if (ConstantOp::isBuildableWith(value, type))
2829     return builder.create<ConstantOp>(loc, type, value);
2830   return nullptr;
2831 }
2832 
2833 }  // namespace TFL
2834 }  // namespace mlir
2835