1include "mlir/Dialect/Shape/IR/ShapeOps.td"
2include "mlir/Dialect/StandardOps/IR/Ops.td"
3
4def AllInputShapesEq : Constraint<CPred< [{
5  llvm::all_of($0, [&](mlir::Value val) {
6    return $0[0] == val;
7  })
8}]>>;
9
10def HasSingleElement : Constraint<CPred< [{
11  $0.size() == 1
12}]>>;
13
14// Canonicalization patterns.
15
16def AssumingAllOneOp : Pat<(Shape_AssumingAllOp $args),
17                           (replaceWithValue $args),
18                           [(HasSingleElement $args)]>;
19
20def CstrBroadcastableEqOps : Pat<(Shape_CstrBroadcastableOp:$op $x, $x),
21  (Shape_ConstWitnessOp ConstBoolAttrTrue)>;
22
23def CstrEqEqOps : Pat<(Shape_CstrEqOp:$op $shapes),
24  (Shape_ConstWitnessOp ConstBoolAttrTrue),
25  [(AllInputShapesEq $shapes)]>;
26
27def IndexToSizeToIndexCanonicalization : Pat<
28  (Shape_SizeToIndexOp (Shape_IndexToSizeOp $arg)),
29  (replaceWithValue $arg)>;
30
31def SizeToIndexToSizeCanonicalization : Pat<
32  (Shape_IndexToSizeOp (Shape_SizeToIndexOp $arg)),
33  (replaceWithValue $arg)>;
34
35def TensorCastConstShape : Pat <
36  (TensorCastOp (Shape_ConstShapeOp:$c $ty)), (replaceWithValue $c)>;
37