1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// This is the optimization pattern definition file for TensorFlow Lite.
17
18include "mlir/IR/OpBase.td"
19include "mlir/Dialect/StandardOps/IR/Ops.td"
20include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
21include "tensorflow/compiler/mlir/lite/utils/utils.td"
22include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
23
24// Checks if the param passed is a F32 ElementsAttr.
25def F32ElementsAttr : ElementsAttrBase<
26  CPred<"$_self.isa<ElementsAttr>() && $_self.cast<ElementsAttr>().getType().getElementType().isF32()">,
27        "32 bit float constant tensor">;
28
29// Checks if the param passed is a float ElementsAttr.
30def FloatElementsAttr : ElementsAttrBase<
31  CPred<"$_self.isa<ElementsAttr>() && $_self.cast<ElementsAttr>().getType().getElementType().isa<FloatType>()">,
32        "float constant tensor">;
33
34// Checks if the param passed is of NoneType.
35def IsNoneType : Constraint<CPred<"$0.getType().isa<NoneType>()">>;
36
37def ExtractSingleElementAsFloat : NativeCodeCall<
38    "ExtractSingleElementAsFloat($_self.cast<ElementsAttr>())">;
39
40// Checks if the value has rank at most 'n'.
41class HasRankAtMost<int n> : Constraint<
42    CPred<"$0.getType().cast<ShapedType>().getRank() <= " # n>>;
43
44//===----------------------------------------------------------------------===//
45// Ternary ops patterns.
46//===----------------------------------------------------------------------===//
47// Multi-pattern consisting of matching stand-alone convolution op followed by
48// activation op.
49multiclass FuseActFnIntoConvOpPat<dag ActFnOp, dag ActFnAttr> {
50  def FuseActivationFuncWithConv#ActFnOp#ActFnAttr : Pat<
51    (ActFnOp (TFL_Conv2DOp:$conv_out $input, $filter, $bias, $h_factor,
52                 $w_factor, TFL_AF_None, $padding, $stride_h, $stride_w)),
53    (TFL_Conv2DOp $input, $filter, $bias, $h_factor, $w_factor, ActFnAttr,
54        $padding, $stride_h, $stride_w),
55    [(HasOneUse $conv_out)]>;
56  def FuseActivationFuncWithDepthwiseConv#ActFnOp#ActFnAttr : Pat<
57    (ActFnOp (TFL_DepthwiseConv2DOp:$conv_out $input, $filter, $bias, $h_factor,
58                $w_factor, TFL_AF_None, $padding, $stride_h, $stride_w,
59                $multiplier)),
60    (TFL_DepthwiseConv2DOp $input, $filter, $bias, $h_factor, $w_factor,
61        ActFnAttr, $padding, $stride_h, $stride_w, $multiplier),
62    [(HasOneUse $conv_out)]>;
63}
64
65multiclass FuseActFnIntoPoolOpPat<dag ActFnOp, dag ActFnAttr> {
66  def FuseActivationFuncWithAvgPool#ActFnOp#ActFnAttr : Pat<
67    (ActFnOp (TFL_AveragePool2DOp:$pool_out $input, $filter_height,
68                  $filter_width, $padding, $stride_h, $stride_w, TFL_AF_None)),
69    (TFL_AveragePool2DOp $input, $filter_height, $filter_width, $padding,
70        $stride_h, $stride_w, ActFnAttr),
71    [(HasOneUse $pool_out)]>;
72  def FuseActivationFuncWithMaxPool#ActFnOp#ActFnAttr : Pat<
73    (ActFnOp (TFL_MaxPool2DOp:$pool_out $input, $padding, $stride_w, $stride_h,
74                  $filter_width, $filter_height, TFL_AF_None)),
75    (TFL_MaxPool2DOp $input, $padding, $stride_w, $stride_h,
76        $filter_width, $filter_height, ActFnAttr),
77    [(HasOneUse $pool_out)]>;
78}
79
80// TODO(hinsu): Also fuse ops corresponding to SIGN_BIT fused
81// activation functions.
82// Currently we're not fusing tanh, sigmoid, hard_swish and other activations
83// those cannot be simply translated into clamping.
84foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
85                     [TFL_Relu6Op, TFL_AF_Relu6],
86                     [TFL_Relu1Op, TFL_AF_Relu1]] in {
87  defm : FuseActFnIntoConvOpPat<actFnPair[0], actFnPair[1]>;
88  defm : FuseActFnIntoPoolOpPat<actFnPair[0], actFnPair[1]>;
89}
90
91class CanFuseConvOrDepthwiseConv<string is_depthwise> : Constraint<
92  CPred<"TFL::CanFuseConvOrDepthwiseConv($0, $1, " # is_depthwise # ")">>;
93
94// If we see a binary op (add, sub) op adding a constant value to a convolution
95// op with constant bias, we can fuse the binary op into the convolution op by
96// constant folding the bias and the binary op's constant operand. The following
97// pattern restricts to float constant values for now.
98multiclass FuseBinaryOpToPrecedingAffine<dag binaryOp> {
99  def FuseBinaryOpWithConv#binaryOp : Pat<
100    (binaryOp (TFL_Conv2DOp:$output $input, $filter,
101                (ConstantOp FloatElementsAttr:$bias), $h_factor, $w_factor,
102                TFL_AF_None, $padding, $stride_h, $stride_w),
103              (ConstantOp FloatElementsAttr:$value), $act_fn),
104    (TFL_Conv2DOp $input, $filter,
105      (binaryOp (ConstantOp $bias),
106         (ConstantOp $value), TFL_AF_None),
107      $h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w),
108    [(CanFuseConvOrDepthwiseConv<"false"> $filter, $value),
109     (HasOneUse $output)]>;
110  def FuseBinaryOpWithDepthwiseConv#binaryOp : Pat<
111    (binaryOp (TFL_DepthwiseConv2DOp:$output $input, $filter,
112                (ConstantOp FloatElementsAttr:$bias),
113                $h_factor, $w_factor, TFL_AF_None, $padding, $stride_h,
114                $stride_w, $multiplier),
115              (ConstantOp FloatElementsAttr:$value), $act_fn),
116    (TFL_DepthwiseConv2DOp $input, $filter,
117      (binaryOp (ConstantOp $bias), (ConstantOp $value), TFL_AF_None),
118      $h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w,
119      $multiplier),
120    [(CanFuseConvOrDepthwiseConv<"true"> $filter, $value),
121     (HasOneUse $output)]>;
122   def FuseBinaryOpWithTransposeConv#binaryOp : Pat<
123    (binaryOp (TFL_TransposeConvOp:$output $output_shape, $weights, $inputs,
124                (ConstantOp FloatElementsAttr:$bias), $padding,
125                $stride_h, $stride_w),
126              (ConstantOp FloatElementsAttr:$value), TFL_AF_None),
127    (TFL_TransposeConvOp $output_shape, $weights, $inputs,
128      (binaryOp (ConstantOp $bias),
129         (ConstantOp $value), TFL_AF_None),
130      $padding, $stride_h, $stride_w),
131    [(CanFuseConvOrDepthwiseConv<"false"> $weights, $value),
132     (HasOneUse $output)]>;
133  // Fuse for TransposeConv with no bias
134  def FuseBinaryOpWithTransposeConvNoneBias#binaryOp : Pat<
135    (binaryOp (TFL_TransposeConvOp:$output $output_shape, $weights, $inputs,
136                (ConstantOp $bias), $padding,
137                $stride_h, $stride_w),
138              (ConstantOp FloatElementsAttr:$value), TFL_AF_None),
139    (TFL_TransposeConvOp $output_shape, $weights, $inputs,
140      (ConstantOp $value),
141      $padding, $stride_h, $stride_w),
142    [(CanFuseConvOrDepthwiseConv<"false"> $weights, $value),
143     (IsNoneType $bias),
144     (HasOneUse $output)]>;
145}
146foreach binaryOp = [TFL_AddOp, TFL_SubOp] in
147  defm : FuseBinaryOpToPrecedingAffine<binaryOp>;
148
149def ExpandTo4DForConv: NativeCodeCall<"ExpandTo4DForConv($0)">;
150
151def ExpandTo4DForDepthwiseConv: NativeCodeCall<
152  "ExpandTo4DForDepthwiseConv($0)">;
153
154// If we see a (div or Mul) op (dividing/multiplying) a constant value
155// to a convolution op with constant filter and bias, we can fuse the div/mul
156// into the convolution op by constant folding
157// the filter/bias and the div/mul op's constant operand.
158// The following pattern restricts to float constant values for now.
159
160multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d<dag BinaryOp> {
161  def FuseMulOrDivWithDepthwiseConv#BinaryOp : Pat<
162    (BinaryOp (TFL_DepthwiseConv2DOp:$output $input,
163                (ConstantOp FloatElementsAttr:$filter),
164                (ConstantOp FloatElementsAttr:$bias),
165                $h_factor, $w_factor, TFL_AF_None, $padding, $stride_h,
166                $stride_w, $multiplier),
167              (ConstantOp FloatElementsAttr:$value), $act_fn),
168    (TFL_DepthwiseConv2DOp $input,
169      (BinaryOp
170        (ConstantOp $filter),
171        (ConstantOp (ExpandTo4DForDepthwiseConv $value)),
172        TFL_AF_None),
173      (BinaryOp
174        (ConstantOp $bias),
175        (ConstantOp $value),
176        TFL_AF_None),
177      $h_factor, $w_factor, $act_fn, $padding, $stride_h,
178      $stride_w, $multiplier),
179    [(CanFuseConvOrDepthwiseConv<"true"> $filter, $value),
180     (HasOneUse $output)]>;
181  def FuseMulOrDivWithConv#BinaryOp : Pat<
182    (BinaryOp (TFL_Conv2DOp:$conv_output $input,
183                (ConstantOp FloatElementsAttr:$filter),
184                (ConstantOp FloatElementsAttr:$bias),
185                $h_factor, $w_factor, TFL_AF_None,
186                $padding, $stride_h, $stride_w),
187              (ConstantOp FloatElementsAttr:$value), $act_fn),
188    (TFL_Conv2DOp $input,
189      (BinaryOp (ConstantOp $filter),
190        (ConstantOp (ExpandTo4DForConv $value)),
191        TFL_AF_None),
192      (BinaryOp (ConstantOp $bias),
193        (ConstantOp $value),
194        TFL_AF_None),
195      $h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w),
196    [(CanFuseConvOrDepthwiseConv<"false"> $filter, $value),
197     (HasOneUse $conv_output)]>;
198  def FuseMulOrDivWithTransposeConv#BinaryOp : Pat<
199    (BinaryOp (TFL_TransposeConvOp:$output $output_shape,
200                (ConstantOp FloatElementsAttr:$weights), $input,
201                (ConstantOp FloatElementsAttr:$bias),
202                $padding, $stride_h, $stride_w),
203              (ConstantOp $value), TFL_AF_None),
204    (TFL_TransposeConvOp $output_shape,
205      (BinaryOp (ConstantOp $weights),
206        (ConstantOp (ExpandTo4DForConv $value)),
207        TFL_AF_None),
208      $input,
209      (BinaryOp (ConstantOp $bias),
210        (ConstantOp $value),
211        TFL_AF_None),
212      $padding, $stride_h, $stride_w),
213    [(CanFuseConvOrDepthwiseConv<"false"> $weights, $value),
214     (HasOneUse $output)]>;
215  def FuseMulOrDivWithTransposeConvWithNoneBias#BinaryOp : Pat<
216    (BinaryOp (TFL_TransposeConvOp:$output $output_shape,
217                (ConstantOp FloatElementsAttr:$weights), $input,
218                (ConstantOp $bias),
219                $padding, $stride_h, $stride_w),
220              (ConstantOp $value), TFL_AF_None),
221    (TFL_TransposeConvOp $output_shape,
222      (BinaryOp (ConstantOp $weights),
223        (ConstantOp (ExpandTo4DForConv $value)),
224        TFL_AF_None),
225      $input,
226      (ConstantOp $bias),
227      $padding, $stride_h, $stride_w),
228    [(CanFuseConvOrDepthwiseConv<"false"> $weights, $value),
229     (IsNoneType $bias),
230     (HasOneUse $output)]>;
231}
232
233foreach BinaryOp = [TFL_DivOp, TFL_MulOp] in
234  defm : FuseMulOrDivWithConv2dOrDepthwiseConv2d<BinaryOp>;
235
236
237// This pattern applies when the same quantize/dequantize have been used twice
238// with the same scale. We want to remove the redundancy.
239// TODO(fengliuai): move this to the sanity check of pre-quantize pass.
240def eliminate_dq_q_pairs : Pat<
241  (TFL_QuantizeOp (TFL_DequantizeOp $in), $qt),
242  (replaceWithValue $in),
243  [(NotFromQuantOpOrSameQuantType $in, $qt)]>;
244
245
246
247
248// Checks if the operand has rank == n
249class OperandHasRank<int n> : Constraint<
250  CPred<"$0.getType().cast<ShapedType>().getRank() == " # n>>;
251
252// Matching HardSwish
253def MatchHardSwishPattern1 : Pat<
254  (TFL_MulOp
255    (TFL_MulOp
256     $x, (TFL_AddOp
257          $x,
258          (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">),
259          TFL_AF_Relu6),
260     TFL_AF_None),
261    (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">),
262     TFL_AF_None),
263  (TFL_HardSwishOp $x)>;
264
265def MatchHardSwishPattern2 : Pat<
266  (TFL_MulOp
267    $x,
268    (TFL_MulOp
269     (TFL_AddOp
270      $x,
271      (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">),
272      TFL_AF_Relu6),
273     (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">),
274     TFL_AF_None),
275     TFL_AF_None),
276  (TFL_HardSwishOp $x)>;
277
278// Matching HardSwish with extra FakeQuant. These FakeQuant ops were due to
279// incorrect placement in the quantization aware training.
280// TODO(b/149735743): We should make the placement automatically.
281def MatchHardSwishQuantized : Pat<
282  (TFL_MulOp (TFL_DequantizeOp (TFL_QuantizeOp
283    (TFL_MulOp
284     $x, (TFL_DequantizeOp (TFL_QuantizeOp (TFL_AddOp
285          $x,
286          (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">),
287          TFL_AF_Relu6), $qattr2)),
288     TFL_AF_None), $qattr1)),
289    (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">),
290     TFL_AF_None),
291  (TFL_HardSwishOp $x)>;
292
293// Constraint that the attribute value is less than 'n'
294class ConstDoubleValueLessThan<string n> : Constraint<
295  CPred<"$0.isa<DenseElementsAttr>() && "
296  "$0.cast<DenseElementsAttr>().getNumElements() == 1 && "
297  "std::abs(*$0.cast<DenseElementsAttr>().getValues<float>().begin()) < "
298  # n>>;
299
300def L2NormValidReduceIndex : Constraint<CPred<
301  "L2NormalizeReduceAxis($0, $1.cast<DenseElementsAttr>())">>;
302
303// Currently L2Normalization doesn't support activation function
304// in TFLite.
305// TODO(karimnosseir): Add constraints that the kernel code assumes.
306// constraint on axis and depth.
307multiclass L2NormalizePatterns<dag FirstOp, dag SecondOp> {
308  // This pattern constructs L2NormalizationOp from
309  // Mul->Rsqrt->Sum->Square Or
310  // Div->sqrt->Sum->Square
311  def L2NormalizePattern1#FirstOp#SecondOp : Pat<
312                  (FirstOp $x,
313                     (SecondOp
314                        (TFL_SumOp
315                           (TFL_SquareOp:$sq_op $x),
316                           (ConstantOp I32ElementsAttr:$axis),
317                           $keep_dims)),
318                     TFL_AF_None),
319           (TFL_L2NormalizationOp $x, TFL_AF_None),
320           [(L2NormValidReduceIndex $sq_op, $axis)]>;
321
322  // Below patterns for L2Normalize when there is an Add or Maximum
323  // adding or clamping to a small constant scalar.
324  def L2NormalizePattern2#FirstOp#SecondOp : Pat<
325                    (FirstOp $x,
326                     (SecondOp
327                      (TFL_AddOp
328                       (TFL_SumOp
329                        (TFL_SquareOp:$sq_op $x),
330                        (ConstantOp I32ElementsAttr:$axis),
331                        $keep_dims),
332                       (ConstantOp $epsilon), TFL_AF_None)),
333           TFL_AF_None),
334           (TFL_L2NormalizationOp $x, TFL_AF_None),
335           [(L2NormValidReduceIndex $sq_op, $axis),
336            (ConstDoubleValueLessThan<"1e-3"> $epsilon)]>;
337
338  def L2NormalizePattern3#FirstOp#SecondOp : Pat<
339                    (FirstOp $x,
340                     (SecondOp
341                      (TFL_MaximumOp
342                       (TFL_SumOp
343                        (TFL_SquareOp:$sq_op $x),
344                        (ConstantOp I32ElementsAttr:$axis),
345                        $keep_dims),
346                       (ConstantOp $epsilon))),
347           TFL_AF_None),
348           (TFL_L2NormalizationOp $x, TFL_AF_None),
349           [(L2NormValidReduceIndex $sq_op, $axis),
350            (ConstDoubleValueLessThan<"1e-3"> $epsilon)]>;
351
352}
353
354foreach L2NormalizePairs = [[TFL_MulOp, TFL_RsqrtOp], [TFL_DivOp, TFL_SqrtOp]]
355  in defm : L2NormalizePatterns<L2NormalizePairs[0], L2NormalizePairs[1]>;
356
357//===----------------------------------------------------------------------===//
358// Binary ops patterns.
359//===----------------------------------------------------------------------===//
360def AreBroadcastableTypes : Constraint<CPred<
361  "TFL::IsBroadcastableElementsAttrAndType($0.getType(), $1.getType())">>;
362
363def OperandsBroadcastToOutputType : Constraint<CPred<
364  "TFL::OperandsBroadcastToOutputType($0.getType(), $1.getType(), "
365                                     "$2.getType())">>;
366
367def IsTailOfShape : Constraint<CPred<
368  "TFL::IsTailOfShape($0.getType(), $1.getType())">>;
369
370// Pattern for skipping Tile if it is mainly for broadcasting and the
371// Op is already supporting broadcasting.
372multiclass FuseTileBroadcastIntoFollowingBinary<dag BinaryOp> {
373  def FuseTileBroadcastToBinaryOp1#BinaryOp : Pat<
374    (BinaryOp:$result (TFL_TileOp $input, (ConstantOp $tile)),
375     $operand, $act_func),
376    (BinaryOp $input, $operand, $act_func),
377  [(OperandsBroadcastToOutputType $input, $operand, $result),
378   (HasRankAtMost<4> $input),
379   (HasRankAtMost<4> $operand)]>;
380
381  def FuseTileBroadcastToBinaryOp2#BinaryOp : Pat<
382    (BinaryOp:$result $operand,
383      (TFL_TileOp $input, (ConstantOp $tile)), $act_func),
384    (BinaryOp $operand, $input, $act_func),
385  [(OperandsBroadcastToOutputType $operand, $input, $result),
386   (HasRankAtMost<4> $operand),
387   (HasRankAtMost<4> $input)]>;
388}
389
390// Multi-pattern consisting of matching stand-alone op or op followed by relu.
391multiclass FusedBinaryActivationFuncOpPat<dag BinaryOp> {
392  foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
393                       [TFL_Relu6Op, TFL_AF_Relu6],
394                       [TFL_Relu1Op, TFL_AF_Relu1]] in {
395    def FuseBinaryWithActivation#BinaryOp#actFnPair[0] : Pat<
396      (actFnPair[0] (BinaryOp:$binary_out $lhs, $rhs, TFL_AF_None)),
397      (BinaryOp $lhs, $rhs, actFnPair[1]),
398    [(HasOneUse $binary_out)]>;
399  }
400}
401
402foreach BinaryOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp] in {
403  defm : FuseTileBroadcastIntoFollowingBinary<BinaryOp>;
404
405  // Instantiated FusedBinary patterns for the from-to pairs of ops.
406  defm : FusedBinaryActivationFuncOpPat<BinaryOp>;
407
408  // Move binary op before reshape: reshape -> binary => binary -> reshape.
409  // This is valid only when the binary operand is constant and the shape is the
410  // tail of the other operand and the intermediate result isn't used by other
411  // ops.
412  // $rhs is required to be the tail shape of $lhs, so after transformation the
413  // shape of the binary op result is valid. For example, assume the shapes of
414  // $input, $lhs and $rhs are [1600], [1,40,40] and [40x1]. After the
415  // transformation, the shape of the binary op result is [40x1600], which
416  // couldn't be reshaped to [1,40,40]. `IsTailOfShape` constraint is added to
417  // make sure $rhs is the tail shape of $lhs.
418  def MoveBinaryOpBeforeReshape#BinaryOp : Pat<
419    (BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)),
420      (ConstantOp:$rhs $a), $act_fn),
421    (TFL_ReshapeOp (BinaryOp $input, $rhs, $act_fn), $shape),
422    // The broadcasting of "BinaryOp" only happens in the lower
423    // dimensions, and the higher dimensions are same, so we know the
424    // result and input of the "BinaryOp" in the source pattern have
425    // the same shape, which is defined by `shape`.
426    [(IsTailOfShape $rhs, $lhs),
427     (HasOneUse $lhs),
428     // The result of the new "BinaryOp" will have the same shape as
429     // `input`. In other words, the shape of the `Reshape` op are not
430     // changed after the transformation.
431     (IsTailOfShape $rhs, $input),
432     (HasRankAtMost<4> $input),
433     (HasRankAtMost<4> $lhs),
434     (HasRankAtMost<4> $rhs)]>;
435}
436
437foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp,
438                    TFL_MaximumOp, TFL_LessOp, TFL_LessEqualOp, TFL_GreaterOp,
439                    TFL_GreaterEqualOp] in {
440  // Move binary op before reshape: reshape -> binary => binary -> reshape.
441  // This is valid only when the binary operand is constant and the shape is the
442  // tail of the other operand and the intermediate result isn't used by other
443  // ops.
444  // $rhs is required to be the tail shape of $lhs, so after transformation the
445  // shape of the binary op result is valid. For example, assume the shapes of
446  // $input, $lhs and $rhs are [1600], [1,40,40] and [40x1]. After the
447  // transformation, the shape of the binary op result is [40x1600], which
448  // couldn't be reshaped to [1,40,40]. `IsTailOfShape` constraint is added to
449  // make sure $rhs is the tail shape of $lhs.
450  def MoveBinaryOpBeforeReshape#BinaryOp : Pat<
451    (BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)),
452      (ConstantOp:$rhs $a)),
453    (TFL_ReshapeOp (BinaryOp $input, $rhs), $shape),
454    // The broadcasting of "BinaryOp" only happens in the lower
455    // dimensions, and the higher dimensions are same, so we know the
456    // result and input of the "BinaryOp" in the source pattern have
457    // the same shape, which is defined by `shape`.
458    [(IsTailOfShape $rhs, $lhs),
459     (HasOneUse $lhs),
460     // The result of the new "BinaryOp" will have the same shape as
461     // `input`. In other words, the shape of the `Reshape` op are not
462     // changed after the transformation.
463     (IsTailOfShape $rhs, $input),
464     (HasRankAtMost<4> $input),
465     (HasRankAtMost<4> $lhs),
466     (HasRankAtMost<4> $rhs)]>;
467}
468
469// Reorder the element-wise value operations and the element move operations,
470// such that the value operation happens before move operation.
471foreach ValueOp = [TFL_CeilOp, TFL_ExpOp, TFL_FloorOp, TFL_NegOp,
472                   TFL_ReluOp, TFL_Relu1Op, TFL_Relu6Op, TFL_RoundOp,
473                   TFL_TanhOp, TFL_SqrtOp, TFL_SquareOp] in {
474  foreach MoveOp = [TFL_DepthToSpaceOp, TFL_ExpandDimsOp, TFL_SqueezeOp,
475                   TFL_ReshapeOp, TFL_TransposeOp] in {
476    def ReorderElementwiseAndMoveOperations#ValueOp#MoveOp : Pat<
477      (ValueOp:$value (MoveOp:$move $input, $move_def)),
478      (MoveOp (ValueOp $input), $move_def),
479      [(HasOneUse $move)]>;
480  }
481}
482
483// Returns shape of a ranked tensor.
484// if called without a ranked tensor it will fail.
485def GetShape: NativeCodeCall<"GetShape($0)">;
486
487// Returns True if the operand type is RankedTensorType and valid.
488def HasValidRankedTensor : Constraint<CPred<
489  "$0.getType().isa<RankedTensorType>() && "
490  "$0.getType().cast<RankedTensorType>().getNumDynamicDims() <= 1">>;
491
492def ConvertSqueezeToReshape : Pat<
493  (TFL_SqueezeOp:$squeeze_op $input, $squeeze_dims),
494  (TFL_ReshapeOp $input, (ConstantOp (GetShape $squeeze_op))),
495  [(HasValidRankedTensor $squeeze_op)]>;
496
497// Convert expand_dims to reshape if possible.
498def ConvertExpandDimsToReshape : Pat<
499  (TFL_ExpandDimsOp:$expand_dims_op $input, $dim),
500  (TFL_ReshapeOp $input, (ConstantOp (GetShape $expand_dims_op))),
501  [(AnyStaticShapeTensor $expand_dims_op)]>;
502
503class FloatValueEquals<string val> : Constraint<CPred<
504  "FloatValueEquals($0, " # val # ")">>;
505
506// ReLU patterns
507def MatchReluPattern : Pat<
508  (TFL_MaximumOp $input, (ConstantOp $Zero)),
509  (TFL_ReluOp $input),
510  [(FloatValueEquals<"0"> $Zero)]>;
511
512def MatchRelu1Pattern1 : Pat<
513  (TFL_MinimumOp (TFL_MaximumOp $input, (ConstantOp $NegOne)),
514    (ConstantOp $One)),
515  (TFL_Relu1Op $input),
516  [(FloatValueEquals<"-1"> $NegOne), (FloatValueEquals<"1"> $One)]>;
517
518def MatchRelu1Pattern2 : Pat<
519  (TFL_MaximumOp (TFL_MinimumOp $input, (ConstantOp $One)),
520    (ConstantOp $NegOne)),
521  (TFL_Relu1Op $input),
522  [(FloatValueEquals<"-1"> $NegOne), (FloatValueEquals<"1"> $One)]>;
523
524def MatchLeakyRelu : Pat<
525  (TFL_MaximumOp
526    (TFL_MulOp:$mul_out $x,
527     (ConstantOp F32ElementsAttr:$alpha), TFL_AF_None),
528    $x),
529  (TFL_LeakyReluOp $x, ExtractSingleElementAsFloat:$alpha),
530  [(ConstDoubleValueLessThan<"1"> $alpha),
531   (HasOneUse $mul_out)]>;
532
533def RemoveTrivialCast : Pat<(TFL_CastOp:$output $input),
534                            (replaceWithValue $input),
535                            [(SameElementType $input, $output)]>;
536
537// Checks if the operand0's rank is one less than operand1's rank.
538def PReluAlphaRankCheck : Constraint<
539  CPred<"$0.getType().cast<ShapedType>().getRank() == "
540  "$1.getType().cast<ShapedType>().getRank() - 1">>;
541
542// PReLU pattern from Keras:
543// f(x) = Relu(x) + (-alpha * Relu(-x))
544def MatchPRelu : Pat<
545  (TFL_AddOp
546   (TFL_ReluOp:$relu_out $x),
547   (TFL_MulOp:$mul_out
548    (TFL_ReluOp (TFL_NegOp:$input_neg_out $x)),
549    $neg_alpha,
550    TFL_AF_None),
551   TFL_AF_None),
552  (TFL_PReluOp $x, (TFL_NegOp $neg_alpha)),
553  [(PReluAlphaRankCheck $neg_alpha, $x),
554   (HasOneUse $relu_out),
555   (HasOneUse $mul_out),
556   (HasOneUse $input_neg_out)]>;
557
558// The constant folding in this pass might produce constant in the tf dialect.
559// This rule is to legalize these constant to the tfl dialect.
560def LegalizeConstOp : Pat<
561  (TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>;
562
563// Reorders adds to allow constant folding.
564// Add --> Add $input, $constantA
565//    \--> $constantB
566// To
567// Add --> $input
568//    \--> Add ($constantA, $constantB)
569foreach ActFun = [TFL_AF_Relu, TFL_AF_Relu6, TFL_AF_Relu1, TFL_AF_None] in {
570  def ReorderAddToAllowConstFold_ActFunc_#ActFun : Pat<
571    (TFL_AddOp
572     (TFL_AddOp:$first_output $input, (ConstantOp $a), TFL_AF_None),
573     (ConstantOp $b), ActFun),
574    (TFL_AddOp $input,
575     (TFL_AddOp (ConstantOp $a), (ConstantOp $b), TFL_AF_None),
576     ActFun),
577    [(HasOneUse $first_output),
578     (HasRankAtMost<4> $input),
579     (HasRankAtMost<4> $a),
580     (HasRankAtMost<4> $b)]>;
581}
582
583// We can eliminate Relu from Relu(SquaredDifference(x, y)),
584// since the result of SquaredDifference is always non-negative.
585// TFLite interpreter doesn't support Relu+int32 for now. So the test cases
586// are failing without the following pattern to optimize Relu away fixes
587// the problem.
588def OptimizeReluSquaredDifference : Pat<
589  (TFL_ReluOp (TFL_SquaredDifferenceOp $l, $r)),
590  (TFL_SquaredDifferenceOp $l, $r)>;
591
592// Optimize X^1 o X
593def OptimizePow1ToIdentity : Pat<
594  (TFL_PowOp $input,
595    (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "1.0f">)),
596  (replaceWithValue $input)>;
597
598// Optimize X^2 to X*X
599def OptimizePow2ToSquare : Pat<
600  (TFL_PowOp $input,
601    (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "2.0f">)),
602  (TFL_MulOp $input, $input, TFL_AF_None)>;
603
604def CanOptimizeIdentityGatherNdOrScatterNdOp : Constraint<CPred<
605  "TFL::CanOptimizeIdentityGatherNdOrScatterNdOp("
606  "$0, $1.cast<DenseIntElementsAttr>())">>;
607
608def OptimizeIdentityGatherNdOp : Pat<
609  (TFL_GatherNdOp $params, (ConstantOp I32ElementsAttr: $indices)),
610  (replaceWithValue $params),
611  [(CanOptimizeIdentityGatherNdOrScatterNdOp $params, $indices)]>;
612
613def OptimizeIdentityScatterNdOp : Pat<
614  (TFL_ScatterNdOp (ConstantOp I32ElementsAttr: $indices), $params, $ignored),
615  (replaceWithValue $params),
616  [(CanOptimizeIdentityGatherNdOrScatterNdOp $params, $indices)]>;
617
618def ShapeMatchesReduceWithKeepAxes : Constraint<CPred<
619  "ShapeMatchesReduceWithKeepAxes($0, $1, $2)">>;
620
621// Fold reshapes re-inserting reduced dimensions into the results of a reduction
622// with `keep_dims=false` by changing it to one using `keep_dims=true`.
623foreach ReduceOp = [TFL_ReduceMaxOp, TFL_ReduceMinOp, TFL_ReduceProdOp,
624                    TFL_SumOp] in {
625  def FoldReshapeTo#ReduceOp : Pat<
626    (TFL_ReshapeOp
627      (ReduceOp:$reduce $input, (ConstantOp I32ElementsAttr: $axes),
628                        ConstBoolAttrFalse),
629      (ConstantOp I32ElementsAttr: $shape)),
630    (ReduceOp $input, (ConstantOp $axes), ConstBoolAttrTrue),
631    [(ShapeMatchesReduceWithKeepAxes $input, $axes, $shape),
632     (HasOneUse $reduce)]>;
633}
634
635
636def IsSame : Constraint<CPred<"$0 == $1">>;
637def HasTwoUse : Constraint<CPred<
638  "std::distance($0.use_begin(), $0.use_end()) == 2">>;
639def AxesIsLastDimension : Constraint<CPred<
640  "$0.cast<DenseIntElementsAttr>().getNumElements() == 1 && "
641  "($0.cast<DenseIntElementsAttr>().getValue<APInt>({0}) == "
642  "$1.getType().cast<ShapedType>().getRank() - 1 || $0.cast<DenseIntElementsAttr>().getValue<int32_t>({0}) == -1)">>;
643
644// Convert exp(x)/sum(exp(x)) into softmax.
645def OptimizeToSoftmax : Pat<
646  (TFL_DivOp (TFL_ExpOp:$exp $input),
647             (TFL_SumOp:$sum $sum_input, (ConstantOp I32ElementsAttr: $axes),
648                             ConstBoolAttrTrue), TFL_AF_None),
649  (TFL_SoftmaxOp $input, ConstF32Attr<"1.0">),
650  [(IsSame $exp, $sum_input),
651   (AxesIsLastDimension $axes, $sum_input),
652   (HasTwoUse $exp),
653   (HasOneUse $sum)]>;
654
655// Convert softmax(x-max(x)) into softmax(x) as the softmax op already deals
656// with the max normalization.
657def FoldNormalizationIntoSoftmax : Pat<
658  (TFL_SoftmaxOp
659    (TFL_SubOp:$sub $input,
660      (TFL_ReduceMaxOp:$max $max_input, (ConstantOp I32ElementsAttr: $axes),
661                            ConstBoolAttrTrue),
662    TFL_AF_None),
663    $beta),
664  (TFL_SoftmaxOp $input, $beta),
665  [(IsSame $input, $max_input),
666   (AxesIsLastDimension $axes, $max_input),
667   (HasOneUse $sub),
668   (HasOneUse $max)]>;
669
670def HaveSameType : Constraint<CPred<"($0.getType() == $1.getType())">>;
671
672class AllElementsAreF32<string val> : Constraint<CPred<
673  "($0.isa<DenseElementsAttr>() && "
674   "$0.cast<DenseElementsAttr>().getType().cast<ShapedType>().getElementType().isF32() && "
675   "std::all_of($0.cast<DenseElementsAttr>().getValues<float>().begin(), "
676               "$0.cast<DenseElementsAttr>().getValues<float>().end(), "
677               "[](float v){ return v == " #val# ";}))">>;
678
679// Optimize X*1 to X
680def OptimizeMul1ToIdentity : Pat<
681  (TFL_MulOp $input,
682             (ConstantOp $constant),
683             TFL_AF_None),
684  (replaceWithValue $input),
685  [(HaveSameType $input, $constant),
686   (AllElementsAreF32<"1.0f"> $constant)]>;
687
688class AllElementsAreBool<string val> : Constraint<CPred<
689  "($0.isa<DenseElementsAttr>() && "
690   "$0.cast<DenseElementsAttr>().getType().cast<ShapedType>().getElementType().isInteger(1) && "
691   "std::all_of($0.cast<DenseElementsAttr>().getValues<bool>().begin(), "
692               "$0.cast<DenseElementsAttr>().getValues<bool>().end(), "
693               "[](bool v){ return v == " #val# ";}))">>;
694
695// Remove select operators when the result is known in advance.
696foreach SelectOp = [TFL_SelectOp, TFL_SelectV2Op] in {
697  // select(true_tensor, A, B) -> A
698  def Optimize#SelectOp#True : Pat<
699    (SelectOp (ConstantOp $constant),
700               $input1,
701               $input2),
702    (replaceWithValue $input1),
703    [(HaveSameType $input1, $input2),
704     (IsTailOfShape $input1, $constant),
705     (IsTailOfShape $constant, $input1),
706     (AllElementsAreBool<"true"> $constant)]>;
707  // select(false_tensor, A, B) -> B
708  def Optimize#SelectOp#False : Pat<
709    (SelectOp (ConstantOp $constant),
710               $input1,
711               $input2),
712    (replaceWithValue $input2),
713    [(HaveSameType $input1, $input2),
714     (IsTailOfShape $input1, $constant),
715     (IsTailOfShape $constant, $input1),
716     (AllElementsAreBool<"false"> $constant)]>;
717}
718
719// Remove (log-)softmax before arg-minmax as (log-)softmax is monotonic.
720foreach ArgMinMaxOp = [TFL_ArgMinOp, TFL_ArgMaxOp] in {
721  def RemoveSoftmaxOpBefore#ArgMinMaxOp : Pat<
722    (ArgMinMaxOp (TFL_SoftmaxOp:$softmax $logits, TFL_FloatNonNegative:$beta),
723                 (ConstantOp:$const_axes I32ElementsAttr:$axes)),
724    (ArgMinMaxOp $logits, $const_axes),
725    [(HasOneUse $softmax),
726     (AxesIsLastDimension $axes, $logits)]>;
727
728  def RemoveLogSoftmaxOpBefore#ArgMinMaxOp : Pat<
729    (ArgMinMaxOp (TFL_LogSoftmaxOp:$log_softmax $logits),
730                 (ConstantOp:$const_axes I32ElementsAttr:$axes)),
731    (ArgMinMaxOp $logits, $const_axes),
732    [(HasOneUse $log_softmax),
733     (AxesIsLastDimension $axes, $logits)]>;
734}
735