1//===- Shape.td - Shape operations definition --------------*- tablegen -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This is the operation definition file for Shape dialect operations.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef SHAPE_OPS
14#define SHAPE_OPS
15
16include "mlir/Dialect/Shape/IR/ShapeBase.td"
17include "mlir/Interfaces/ControlFlowInterfaces.td"
18include "mlir/Interfaces/InferTypeOpInterface.td"
19include "mlir/Interfaces/SideEffectInterfaces.td"
20include "mlir/IR/OpAsmInterface.td"
21include "mlir/IR/SymbolInterfaces.td"
22
23//===----------------------------------------------------------------------===//
24// Shape op definitions
25//===----------------------------------------------------------------------===//
26
27// Base class for the operation in this dialect
28class Shape_Op<string mnemonic, list<OpTrait> traits = []> :
29    Op<ShapeDialect, mnemonic, traits>;
30
31def Shape_AddOp : Shape_Op<"add", [Commutative, NoSideEffect]> {
32  let summary = "Addition of sizes and indices";
33  let description = [{
34    Adds two sizes or indices. If either operand is an error it will be
35    propagated to the result. The operands can be of type `size` or `index`. If
36    at least one of the operands can hold an error, i.e. if it is of type `size`,
37    then also the result must be of type `size`.
38  }];
39
40  let arguments = (ins Shape_SizeOrIndexType:$lhs, Shape_SizeOrIndexType:$rhs);
41  let results = (outs Shape_SizeOrIndexType:$result);
42
43  let assemblyFormat = [{
44    $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
45  }];
46
47  let verifier = [{ return verifySizeOrIndexOp(*this); }];
48}
49
50def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative]> {
51  let summary = "Returns the broadcasted output shape of two inputs";
52  let description = [{
53    Returns the broadcasted shape for two input shapes or extent tensors. Both
54    operands can be of type `shape.shape` or `tensor<?xindex>`. The result is of
55    type `shape.shape` and, if both operands are tensors, may be of type
56    `tensor<?xindex>`.
57
58    If the two operand shapes are of different rank the smaller one is padded
59    with 1's from the left. The resulting broadcasted shape is then defined as
60
61        result[i] = lhs[i] if lhs[i] == rhs[i]
62                  = lhs[i] if rhs[i] == 1
63                  = rhs[i] if lhs[i] == 1.
64
65    In case the resulting shape is undefined, i.e. if corresponding extents are
66    different from each other but none is 1, the result is an error shape.
67    Likewise error values are propagated if any of the operands holds an error
68    value. If the result type is an extent tensor (and can therefore not hold
69    the error value) the behavior may be undefined. The optional string
70    attribute can be used to describe the error case.
71  }];
72
73  let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs,
74                       Shape_ShapeOrExtentTensorType:$rhs,
75                       OptionalAttr<StrAttr>:$error);
76  let results = (outs Shape_ShapeOrExtentTensorType:$result);
77
78  let assemblyFormat = [{
79    $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
80  }];
81
82  let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }];
83  let hasFolder = 1;
84
85  let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }];
86}
87
88def Shape_ConstShapeOp : Shape_Op<"const_shape", [ConstantLike, NoSideEffect]> {
89  let summary = "Creates a constant shape or extent tensor";
90  let description = [{
91    Creates a constant shape or extent tensor. The individual extents are given
92    as the `shape` attribute. The number of these values equals the shape's
93    rank.
94
95    ```mlir
96    %0 = shape.const_shape [] : !shape.shape
97    %1 = shape.const_shape [1, 2, 3] : !shape.shape
98    %2 = shape.const_shape [4, 5, 6] : tensor<?xindex>
99    ```
100  }];
101  let arguments = (ins IndexElementsAttr:$shape);
102  let results = (outs Shape_ShapeOrExtentTensorType:$result);
103
104  // TODO: Move this to main so that all shape ops implement these.
105  let printer = [{ return ::print(p, *this); }];
106  let parser = [{ return ::parse$cppClass(parser, result); }];
107  let hasFolder = 1;
108  let hasCanonicalizer = 1;
109}
110
111def Shape_ConstSizeOp : Shape_Op<"const_size", [
112    ConstantLike,
113    NoSideEffect,
114    DeclareOpInterfaceMethods<OpAsmOpInterface>
115  ]> {
116  let summary = "Creates a constant of type `shape.size`";
117  let description = [{
118    Creates a `shape.size` type representing the constant size given by `value`.
119
120    ```mlir
121    %x = shape.const_size 10
122    ```
123  }];
124
125  let arguments = (ins IndexAttr:$value);
126  let results = (outs Shape_SizeType:$result);
127
128  let builders = [OpBuilderDAG<(ins "int64_t":$value)>];
129
130  let assemblyFormat = "$value attr-dict";
131  let hasFolder = 1;
132}
133
134def Shape_ShapeEqOp : Shape_Op<"shape_eq", [Commutative, NoSideEffect]> {
135  let summary = "Returns whether the input shapes or extent tensors are equal";
136  let description = [{
137    Takes two shape or extent tensor operands and determines whether they are
138    equal. When extent tensors are compared to shapes they are regarded as their
139    equivalent non-error shapes. Error shapes can be tested for equality like
140    any other shape value, meaning that the error value is equal to itself.
141  }];
142
143  let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs,
144                       Shape_ShapeOrExtentTensorType:$rhs);
145  let results = (outs I1:$result);
146
147  let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)";
148  let hasFolder = 1;
149}
150
151def Shape_FromExtentsOp : Shape_Op<"from_extents", [NoSideEffect]> {
152  let summary = "Creates a shape from extents";
153  let description = [{
154    Creates a shape from multiple SSA values representing the extents of
155    the shape.
156
157    ```mlir
158    // Rank 2 shape.
159    %s0 = shape.from_extents %a, %b
160    // Rank 0 shape.
161    %s1 = shape.from_extents
162    ```
163  }];
164  let arguments = (ins Variadic<Index>:$extents);
165  let results = (outs Shape_ShapeType:$shape);
166
167  let assemblyFormat = "$extents attr-dict";
168
169  let hasFolder = 1;
170}
171
172def Shape_FromExtentTensorOp : Shape_Op<"from_extent_tensor", [NoSideEffect]> {
173  let summary = "Creates a shape from a tensor of extents";
174  let description = [{
175    Creates a shape from a 1D integral tensor of extents. The rank of the
176    resulting shape equals the number of elements in the tensor, and the
177    extents match the values of the elements.
178  }];
179
180  let arguments = (ins IndexTensor:$input);
181  let results = (outs Shape_ShapeType:$result);
182
183  let assemblyFormat = "$input attr-dict `:` type($input)";
184}
185
186def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable", [Commutative]> {
187  let summary = "Determines if 2 shapes can be successfully broadcasted";
188  let description = [{
189    Given two input shapes or extent tensors, return a predicate specifying if
190    they are broadcastable. This broadcastable follows the same logic as what
191    shape.broadcast documents.
192
193    Concretely, shape.is_broadcastable returning true implies that
194    shape.broadcast will not give an error, and shape.cstr_broadcastable will
195    not result in an assertion failure. Similarly, false implies an error or
196    assertion failure.
197
198    Example:
199    ```mlir
200    %true = shape.is_broadcastable [2,2], [3,1,2]
201    %false = shape.is_broadcastable [2,2], [3,2]
202    ```
203  }];
204
205  let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs,
206                       Shape_ShapeOrExtentTensorType:$rhs);
207  let results = (outs I1:$result);
208
209  let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)";
210}
211
212def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> {
213  let summary = "Gets the rank of a shape";
214  let description = [{
215    Returns the rank of the shape or extent tensor, i.e. the number of extents.
216  }];
217
218  let arguments = (ins Shape_ShapeOrExtentTensorType:$shape);
219  let results = (outs Shape_SizeOrIndexType:$rank);
220
221  let assemblyFormat = "$shape attr-dict `:` type($shape) `->` type($rank)";
222
223  let hasFolder = 1;
224  let hasCanonicalizer = 1;
225  let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
226}
227
228def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> {
229  let summary = "Creates a dimension tensor from a shape";
230  let description = [{
231    Converts a shape to a 1D integral tensor of extents. The number of elements
232    in the tensor equals the rank of the shape, and the elements equal the
233    extents of the shape.
234
235    If the shape represents an error, this op's behavior is undefined.
236  }];
237
238  let arguments = (ins Shape_ShapeOrExtentTensorType:$input);
239  let results = (outs IndexTensor:$result);
240
241  let assemblyFormat = "$input attr-dict `:` type($input) `->` type($result)";
242
243  let hasFolder = 1;
244}
245
246def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> {
247  let summary = "Gets the specified extent from a shape or extent tensor";
248  let description = [{
249    Gets the extent indexed by `dim` from the `shape` operand. If the shape is
250    an error then it returns an error size.
251  }];
252  let arguments = (ins Shape_ShapeOrExtentTensorType:$shape,
253                       Shape_SizeOrIndexType:$dim);
254  let results = (outs Shape_SizeOrIndexType:$extent);
255  let assemblyFormat = "$shape `,` $dim attr-dict `:` type($shape) `,` type($dim) `->` "
256                       "type($extent)";
257
258  let builders = [
259    // Builder that allows passing a constant dimension as a simple integer.
260    OpBuilderDAG<(ins "Value":$shape, "int64_t":$dim)>
261  ];
262
263  let extraClassDeclaration = [{
264    /// Get the `dim` value as integer if it is constant.
265    Optional<int64_t> getConstantDim();
266  }];
267
268  let hasFolder = 1;
269  let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
270}
271
272def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [NoSideEffect]> {
273  let summary = "Converts a standard index to a shape size";
274  let description = [{
275    Converts a standard index to a `shape.size`. This operation and its
276    inverse, `size_to_index`, facilitate index conversion between the standard
277    and the shape dialect.
278
279    The behavior is undefined for negative indices.
280  }];
281
282  let arguments = (ins Index:$arg);
283  let results = (outs Shape_SizeType:$result);
284
285  let assemblyFormat = "$arg attr-dict";
286
287  let hasFolder = 1;
288  let hasCanonicalizer = 1;
289}
290
291def Shape_JoinOp : Shape_Op<"join", [Commutative]> {
292  let summary = "Returns the least general shape.size of its operands";
293  let description = [{
294    An operation that computes the least general shape of input operands.
295    This effectively asserts that corresponding static dimensions are equal.
296    The behavior is to match each element of the `shape.shape` and propagate the
297    most restrictive information, returning an invalid shape if there are
298    contradictory requirements. E.g., using pseudo code
299
300    ```
301    shape.join([*], [*]) -> [*]
302    shape.join([*], [1, ?]) -> [1, ?]
303    shape.join([1, 2], [1, ?]) -> [1, 2]
304    shape.join([*], [1, 2]) -> [1, 2]
305    shape.join([], []) -> []
306    shape.join([], [*]) -> []
307    shape.join([], [?, ?]) -> [invalid]
308    shape.join([1, ?], [2, ?, ?]) -> [invalid]
309    ```
310
311    `shape.join` also allows specifying an optional error string, that may be
312    used to return an error to the user upon mismatch of dimensions.
313
314    ```mlir
315    %c = shape.join %a, %b, error="<reason>" : !shape.shape
316    ```
317  }];
318
319  let arguments = (ins Shape_ShapeOrSizeType:$arg0, Shape_ShapeOrSizeType:$arg1,
320                   OptionalAttr<StrAttr>:$error);
321  let results = (outs Shape_ShapeOrSizeType:$result);
322}
323
324def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> {
325  let summary = "Multiplication of sizes and indices";
326  let description = [{
327    Multiplies two sizes or indices. If either operand is an error it will be
328    propagated to the result. The operands can be of type `size` or `index`. If
329    at least one of the operands can hold an error, i.e. if it is of type `size`,
330    then also the result must be of type `size`. If error propagation is not
331    possible because both operands are of type `index` then the result must also
332    be of type `index`.
333  }];
334
335  let arguments = (ins Shape_SizeOrIndexType:$lhs, Shape_SizeOrIndexType:$rhs);
336  let results = (outs Shape_SizeOrIndexType:$result);
337
338  let assemblyFormat = [{
339    $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
340  }];
341
342  let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
343  let hasFolder = 1;
344}
345
346def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> {
347  let summary = "Returns the number of elements for a given shape";
348  let description = [{
349    Returns the number of elements for a given shape which is the product of its
350    extents. If the argument is of type `shape` then the result will be of type
351    `size` and potential errors will be propagated. Otherwise, if the argument
352    is and extent tensor `tensor<?xindex>` then the result will be of type
353    `index`.
354  }];
355
356  let arguments = (ins Shape_ShapeOrExtentTensorType:$shape);
357  let results = (outs Shape_SizeOrIndexType:$result);
358
359  let builders = [OpBuilderDAG<(ins "Value":$shape)>];
360
361  let assemblyFormat = "$shape attr-dict `:` type($shape) `->` type($result)";
362
363  let hasFolder = 1;
364  let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
365}
366
367def Shape_ReduceOp : Shape_Op<"reduce",
368    [SingleBlockImplicitTerminator<"YieldOp">]> {
369  let summary = "Returns an expression reduced over a shape or extent tensor";
370  let description = [{
371    An operation that takes as input a shape or extent tensor, and a number of
372    initial values. This operation has a region/function that is applied
373    repeatedly for every extent of the input. Starting with the initial values,
374    the individual extents are then aggregated as defined by the associated
375    region.
376
377    Conceptually this op performs the following reduction:
378
379    ```
380    res[] = init;
381    for (int i = 0, i < shape.rank(); i++) {
382      res = fn(i, shape[i], res[0], ..., res[n]);
383    }
384    ```
385
386    Where `fn` is provided by the user and the result of the reduce op is the
387    last computed output of the reduce function. As an example, computing the
388    number of elements can be defined as follows:
389
390    ```mlir
391    func @reduce(%shape : !shape.shape, %init : !shape.size) -> !shape.size {
392      %num_elements = shape.reduce(%shape, %init) -> !shape.size  {
393        ^bb0(%index: index, %dim: !shape.size, %acc: !shape.size):
394          %updated_acc = "shape.mul"(%acc, %dim) :
395            (!shape.size, !shape.size) -> !shape.size
396          shape.yield %updated_acc : !shape.size
397      }
398      return %num_elements : !shape.size
399    }
400    ```
401  }];
402
403  let arguments = (ins Shape_ShapeOrExtentTensorType:$shape,
404                       Variadic<AnyType>:$initVals);
405  let results = (outs Variadic<AnyType>:$result);
406  let regions = (region SizedRegion<1>:$region);
407
408  let builders = [OpBuilderDAG<(ins "Value":$shape, "ValueRange":$initVals)>];
409
410  let verifier = [{ return ::verify(*this); }];
411  let printer = [{ return ::print(p, *this); }];
412  let parser = [{ return ::parse$cppClass(parser, result); }];
413}
414
415def Shape_ShapeOfOp : Shape_Op<"shape_of", [NoSideEffect]> {
416  let summary = "Returns shape of a value or shaped type operand";
417
418  let description = [{
419    The operation takes a value or a shaped operand as an argument and it
420    returns a shape or extent tensor.
421  }];
422
423  let arguments = (ins AnyTypeOf<[AnyShaped, Shape_ValueShapeType]>:$arg);
424  let results = (outs Shape_ShapeOrExtentTensorType:$result);
425
426  let assemblyFormat = "$arg attr-dict `:` type($arg) `->` type($result)";
427
428  let builders = [OpBuilderDAG<(ins "Value":$arg)>];
429
430  let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }];
431  let hasCanonicalizer = 1;
432  let hasFolder = 1;
433}
434
435def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [NoSideEffect]> {
436  let summary = "Casts between index types of the shape and standard dialect";
437  let description = [{
438    Converts a `shape.size` to a standard index. This operation and its
439    inverse, `index_to_size`, facilitate index conversion between the standard
440    and the shape dialect. The behavior is undefined for unknown and invalid
441    arguments.
442  }];
443
444  let arguments = (ins Shape_SizeOrIndexType:$arg);
445  let results = (outs Index:$result);
446
447  let assemblyFormat = "$arg attr-dict `:` type($arg)";
448
449  let hasFolder = 1;
450  let hasCanonicalizer = 1;
451}
452
453def Shape_WithOp : Shape_Op<"with_shape", [NoSideEffect]> {
454  let summary = "Returns ValueShape with given shape";
455  let description = [{
456    Returns ValueShape with the shape updated to match the shape operand. That
457    is a new ValueShape tuple is created with value equal to `operand`'s
458    value and shape equal to `shape`. If the ValueShape and given `shape` are
459    non-conformant, then the returned ValueShape will represent an error of
460    this mismatch. Similarly if either inputs are in an error state, then an
461    error is propagated.
462
463    Usage:
464      %0 = shape.with_shape %1, %2 : tensor<...>, !shape.shape
465
466    This is used, for example, where one combines shape function calculations
467    and/or call one shape function from another. E.g.,
468
469    ```mlir
470    func @shape_foobah(%a: !shape.value_shape,
471                       %b: !shape.value_shape,
472                       %c: !shape.value_shape) -> !shape.shape {
473      %0 = call @shape_foo(%a, %b) :
474        (!shape.value_shape, !shape.value_shape) -> !shape.shape
475      %1 = shape.with_shape %b, %0 : !shape.value_shape, !shape.shape
476      %2 = call @shape_bah(%c, %1) :
477        (!shape.value_shape, !shape.value_shape) -> !shape.shape
478      return %2 : !shape.shape
479    }
480    ```
481
482    This op need not be a refinement of the shape. In non-error cases the input
483    ValueShape's value and shape are conformant and so too for the output, but
484    the result may be less specified than `operand`'s shape as `shape` is
485    merely used to construct the new ValueShape. If join behavior is desired
486    then a join op should be used.
487  }];
488
489  let arguments = (ins AnyTypeOf<[AnyShaped, Shape_ValueShapeType]>:$operand,
490                       Shape_ShapeType:$shape);
491  let results = (outs Shape_ValueShapeType:$result);
492
493  let assemblyFormat = "operands attr-dict `:` type($operand) `,` type($shape)";
494}
495
496def Shape_YieldOp : Shape_Op<"yield",
497    [HasParent<"ReduceOp, FunctionLibraryOp">,
498     NoSideEffect,
499     ReturnLike,
500     Terminator]> {
501  let summary = "Returns the value to parent op";
502
503  let arguments = (ins Variadic<AnyType>:$operands);
504
505  let builders = [OpBuilderDAG<(ins),
506    [{ build($_builder, $_state, llvm::None); }]>
507  ];
508
509  let verifier = [{ return ::verify(*this); }];
510  let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
511}
512
513// TODO: Add Ops: if_static, if_ranked
514
515// For testing usage.
516def Shape_DebugPrintOp : Shape_Op<"debug_print", []> {
517  let summary = "Prints the input shape or size";
518  let description = [{
519    Prints the input dim or shape and passes through input.
520
521    Note: This is intended for testing and debugging only.
522  }];
523
524  let arguments = (ins Shape_ShapeOrSizeType:$input);
525  let results =  (outs Shape_ShapeOrSizeType:$output);
526}
527
528def Shape_SplitAtOp : Shape_Op<"split_at", []> {
529  let summary = "Splits a shape at a given index";
530  let description = [{
531    Splits a shape at a given dimension `index`, returning two shapes.
532    If `index` is negative, it is treated as indexing from the back of the
533    shape. This negative-handling behavior is important when handling unranked
534    shapes, where the positive index is not necessarily knowable due to a
535    dynamic number of leading dimensions.
536
537    Examples:
538    - split_at([4,5,6], index=0) -> [], [4,5,6]
539    - split_at([4,5,6], index=1) -> [4], [5,6]
540    - split_at([4,5,6], index=2) -> [4,5], [6]
541    - split_at([4,5,6], index=3) -> [4,5,6], []
542    - split_at([4,5,6], index=4) -> error
543    - split_at([4,5,6], index=-1) -> [4,5], [6]
544    - split_at([4,5,6], index=-2) -> [4], [5,6]
545    - split_at([4,5,6], index=-3) -> [], [4,5,6]
546    - split_at([4,5,6], index=-4) -> error
547
548    Requires:
549    - `index` is in the range [-rank(operand),rank(operand)]
550  }];
551
552  let arguments = (ins Shape_ShapeOrExtentTensorType:$operand, I32:$index);
553  let results = (outs Shape_ShapeType:$head, Shape_ShapeType:$tail);
554  let hasFolder = 1;
555}
556
557def Shape_ConcatOp : Shape_Op<"concat", []> {
558  let summary = "Concatenates two shapes";
559  let description = [{
560    Creates a shape whose dimensions consist of first the dimensions from `lhs`
561    followed by the dimensions of `rhs`.
562
563    Example:
564    concat([2,3], [4,5]) -> [2,3,4,5]
565    concat([], []) -> []
566    concat([], [4,5,6]) -> [4,5,6]
567  }];
568
569  let arguments = (ins Shape_ShapeType:$lhs, Shape_ShapeType:$rhs);
570  let results = (outs Shape_ShapeType:$result);
571
572  let assemblyFormat = "$lhs `,` $rhs attr-dict";
573  let hasFolder = 1;
574}
575
576//===----------------------------------------------------------------------===//
577// Shape constraint related ops.
578//===----------------------------------------------------------------------===//
579
580// TODO: Move the code below and witnesses to a different file.
581def Shape_AnyOp : Shape_Op<"any", [Commutative,
582                                   NoSideEffect]> {
583  let summary = "Return any combination of the input shapes";
584  let description = [{
585    This operation takes multiple input shapes or extent tensors and returns
586    some combination of their dimensions. This can be best seen with examples
587    below.
588
589    The result is undefined, but still side-effect free, in cases where the
590    inputs have differing ranks or differ in extents of shared dimensions.
591
592    Example:
593    ```mlir
594    %s0 = shape.any [2,?], [?,3] // [2,3]
595    %s1 = shape.any [?,?], [1,2] // [1,2]
596    ```
597  }];
598
599  let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$inputs);
600  let results = (outs Shape_ShapeOrExtentTensorType:$result);
601
602  let assemblyFormat = "$inputs attr-dict `:` type($inputs) `->` type($result)";
603
604  let hasFolder = 1;
605}
606
607def Shape_AssumingAllOp : Shape_Op<"assuming_all", [Commutative, NoSideEffect]> {
608  let summary = "Return a logical AND of all witnesses";
609  let description = [{
610    Used to simplify constraints as any single failing precondition is enough
611    to prevent execution.
612
613    "assuming" operations represent an execution order restriction to the
614    compiler, information for dependent code to rely on (by assuming), and
615    nothing else. They should not exist after a program is fully lowered and
616    ready to execute.
617
618    Example:
619    ```mlir
620    %w0 = shape.cstr_broadcastable [2,2], [3,1,2] // Passing
621    %w1 = shape.cstr_broadcastable [2,2], [3,2] // Failure
622    %w2 = shape.cstr_eq [1,2], [1,2], [1,2] // Passing
623    %wf = shape.assuming_all %w0, %w1 // Failure
624    %wt = shape.assuming_all %w0, %w2 // Passing
625    ```
626  }];
627
628  let arguments = (ins Variadic<Shape_WitnessType>:$inputs);
629  let results = (outs Shape_WitnessType:$result);
630
631  let assemblyFormat = "$inputs attr-dict";
632
633  let hasFolder = 1;
634  let hasCanonicalizer = 1;
635
636  let verifier = [{ return ::verify(*this); }];
637}
638
639def Shape_AssumingOp : Shape_Op<"assuming",
640                           [SingleBlockImplicitTerminator<"AssumingYieldOp">,
641                            DeclareOpInterfaceMethods<RegionBranchOpInterface>,
642                            RecursiveSideEffects]> {
643  let summary = "Execute the region";
644  let description = [{
645    Executes the region assuming all witnesses are true.
646
647    "assuming" operations represent an execution order restriction to the
648    compiler, information for dependent code to rely on (by assuming), and
649    nothing else. They should not exist after a program is fully lowered and
650    ready to execute.
651  }];
652  let arguments = (ins Shape_WitnessType:$witness);
653  let regions = (region SizedRegion<1>:$doRegion);
654  let results = (outs Variadic<AnyType>:$results);
655
656  let printer = [{ return ::print(p, *this); }];
657  let parser = [{ return ::parse$cppClass(parser, result); }];
658  let verifier = [{ return RegionBranchOpInterface::verifyTypes(*this); }];
659
660  let extraClassDeclaration = [{
661    // Inline the region into the region containing the AssumingOp and delete
662    // the AssumingOp.
663    //
664    // This does no checks on the inputs to the AssumingOp.
665    static void inlineRegionIntoParent(AssumingOp &op, PatternRewriter &rewriter);
666  }];
667
668  let hasCanonicalizer = 1;
669}
670
671def Shape_AssumingYieldOp : Shape_Op<"assuming_yield",
672                                     [NoSideEffect, ReturnLike, Terminator]> {
673  let summary = "Yield operation";
674  let description = [{
675    This yield operation represents a return operation within the assert_and_exec
676    region. The operation takes variable number of operands and produces no
677    results. The operand number and types must match the return signature of
678    the region that contains the operation.
679  }];
680
681  let arguments = (ins Variadic<AnyType>:$operands);
682
683  let builders = [OpBuilderDAG<(ins), [{ /* nothing to do */ }]>];
684
685  let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
686}
687
688def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", [Commutative]> {
689  let summary = "Determines if 2 shapes can be successfully broadcasted";
690  let description = [{
691    Given two input shapes or extent tensors, return a witness specifying if
692    they are broadcastable. This broadcastable follows the same logic as what
693    shape.broadcast documents.
694
695    "cstr" operations represent runtime assertions.
696
697    Example:
698    ```mlir
699    %w0 = shape.cstr_broadcastable [2,2], [3,1,2] // Passing
700    %w1 = shape.cstr_broadcastable [2,2], [3,2] // Failure
701    ```
702  }];
703
704  let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs,
705                       Shape_ShapeOrExtentTensorType:$rhs);
706  let results = (outs Shape_WitnessType:$result);
707
708  let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)";
709
710  let hasCanonicalizer = 1;
711  let hasFolder = 1;
712}
713
714def Shape_CstrEqOp : Shape_Op<"cstr_eq", [Commutative]> {
715  let summary = "Determines if all input shapes are equal";
716  let description = [{
717    Given 1 or more input shapes, determine if all shapes are the exact same.
718
719    "cstr" operations represent runtime assertions.
720
721    Example:
722    ```mlir
723    %w0 = shape.cstr_eq [1,2], [1,2], [1,2] // Passing
724    %w1 = shape.cstr_eq [2,2], [1,2] // Failure
725    ```
726  }];
727  let arguments = (ins Variadic<Shape_ShapeType>:$inputs);
728  let results = (outs Shape_WitnessType:$result);
729
730  let assemblyFormat = "$inputs attr-dict";
731
732  let hasCanonicalizer = 1;
733  let hasFolder = 1;
734}
735
736def Shape_ConstWitnessOp : Shape_Op<"const_witness", [ConstantLike, NoSideEffect]> {
737  let summary = "An operation that returns a statically known witness value";
738  let description = [{
739  This operation represents a statically known witness result. This can be
740  often used to canonicalize/fold constraint and assuming code that will always
741  pass.
742
743  ```mlir
744  %0 = shape.const_shape [1,2,3]
745  %1 = shape.const_shape [1, 2, 3]
746  %w0 = shape.cstr_eq(%0, %1) // Can be folded to "const_witness true"
747  %w1 = shape.const_witness true
748  %w2 = shape.assuming_all(%w0, %w2) // Can be folded to "const_witness true"
749  ```
750  }];
751  let arguments = (ins BoolAttr:$passing);
752  let results = (outs Shape_WitnessType:$result);
753
754  let assemblyFormat = "$passing attr-dict";
755
756  let hasFolder = 1;
757}
758
759def Shape_CstrRequireOp : Shape_Op<"cstr_require", []> {
760  let summary = "Represents a runtime assertion that an i1 is `true`";
761  let description = [{
762    Represents a runtime assertion that an i1 is true. It returns a
763    !shape.witness to order this assertion.
764
765    For simplicity, prefer using other cstr_* ops if they are available for a
766    given constraint.
767
768    Example:
769    ```mlir
770    %bool = ...
771    %w0 = shape.cstr_require %bool, "msg" // Passing if `%bool` is true.
772    ```
773
774    Since this op can be used to express many different possible assertions
775    (depending on whatever computation calculated `pred`), the `msg`
776    should clarify the nature of the assertion for users.
777  }];
778  let arguments = (ins I1:$pred, StrAttr:$msg);
779  let results = (outs Shape_WitnessType:$result);
780
781  let assemblyFormat = "$pred `,` $msg attr-dict";
782
783  let hasFolder = 1;
784}
785
786//===----------------------------------------------------------------------===//
787// Shape collection ops.
788//===----------------------------------------------------------------------===//
789
790def Shape_FunctionLibraryOp : Shape_Op<"function_library",
791    [AffineScope, IsolatedFromAbove, NoRegionArguments, SymbolTable, Symbol,
792     SingleBlockImplicitTerminator<"ShapeFunctionLibraryTerminatorOp">]> {
793  let summary = "Represents shape functions and corresponding ops";
794  let description = [{
795    Represents a list of shape functions and the ops whose shape transfer
796    functions they represent.
797
798    Example:
799
800    ```mlir
801    shape.function_library {
802      func @same_result_shape(%arg: !shape.value_shape) -> !shape.shape {
803        %0 = shape.shape_of %arg : !shape.value_shape -> !shape.shape
804        return %0 : !shape.shape
805      }
806    } mapping {
807      std.atan = @same_result_shape
808    }
809    ```
810  }];
811
812  let arguments = (ins SymbolNameAttr:$sym_name,
813                       OptionalAttr<StrAttr>:$sym_visibility);
814  let arguments = (ins DictionaryAttr:$mapping);
815  let regions = (region AnyRegion:$body);
816
817  let extraClassDeclaration = [{
818    /// Returns an associated shape function for an operation if defined.
819    FuncOp getShapeFunction(Operation *op);
820  }];
821
822  let builders = [OpBuilderDAG<(ins "StringRef":$name)>];
823  let skipDefaultBuilders = 1;
824
825  let printer = [{ ::print(p, *this); }];
826  let parser = [{ return ::parse$cppClass(parser, result); }];
827}
828
829//===----------------------------------------------------------------------===//
830// ShapeFunctionLibraryTerminatorOp
831//===----------------------------------------------------------------------===//
832
833def ShapeFunctionLibraryTerminatorOp : Shape_Op<"fn_lib_terminator",
834    [Terminator, HasParent<"FunctionLibraryOp">]> {
835  let summary = "A pseudo op that marks the end of a shape function library";
836  let description = [{
837    `shape_fn_lib_terminator` is a special pseudo terminator operation for the
838    shape function library. It has no semantic meaning beyond keeping the body
839    well-formed.
840  }];
841  let assemblyFormat = "attr-dict";
842}
843
844#endif // SHAPE_OPS
845