1//===- Shape.td - Shape operations definition --------------*- tablegen -*-===// 2// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6// 7//===----------------------------------------------------------------------===// 8// 9// This is the operation definition file for Shape dialect operations. 10// 11//===----------------------------------------------------------------------===// 12 13#ifndef SHAPE_OPS 14#define SHAPE_OPS 15 16include "mlir/Dialect/Shape/IR/ShapeBase.td" 17include "mlir/Interfaces/ControlFlowInterfaces.td" 18include "mlir/Interfaces/InferTypeOpInterface.td" 19include "mlir/Interfaces/SideEffectInterfaces.td" 20include "mlir/IR/OpAsmInterface.td" 21include "mlir/IR/SymbolInterfaces.td" 22 23//===----------------------------------------------------------------------===// 24// Shape op definitions 25//===----------------------------------------------------------------------===// 26 27// Base class for the operation in this dialect 28class Shape_Op<string mnemonic, list<OpTrait> traits = []> : 29 Op<ShapeDialect, mnemonic, traits>; 30 31def Shape_AddOp : Shape_Op<"add", [Commutative, NoSideEffect]> { 32 let summary = "Addition of sizes and indices"; 33 let description = [{ 34 Adds two sizes or indices. If either operand is an error it will be 35 propagated to the result. The operands can be of type `size` or `index`. If 36 at least one of the operands can hold an error, i.e. if it is of type `size`, 37 then also the result must be of type `size`. 38 }]; 39 40 let arguments = (ins Shape_SizeOrIndexType:$lhs, Shape_SizeOrIndexType:$rhs); 41 let results = (outs Shape_SizeOrIndexType:$result); 42 43 let assemblyFormat = [{ 44 $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) 45 }]; 46 47 let verifier = [{ return verifySizeOrIndexOp(*this); }]; 48} 49 50def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative]> { 51 let summary = "Returns the broadcasted output shape of two inputs"; 52 let description = [{ 53 Returns the broadcasted shape for two input shapes or extent tensors. Both 54 operands can be of type `shape.shape` or `tensor<?xindex>`. The result is of 55 type `shape.shape` and, if both operands are tensors, may be of type 56 `tensor<?xindex>`. 57 58 If the two operand shapes are of different rank the smaller one is padded 59 with 1's from the left. The resulting broadcasted shape is then defined as 60 61 result[i] = lhs[i] if lhs[i] == rhs[i] 62 = lhs[i] if rhs[i] == 1 63 = rhs[i] if lhs[i] == 1. 64 65 In case the resulting shape is undefined, i.e. if corresponding extents are 66 different from each other but none is 1, the result is an error shape. 67 Likewise error values are propagated if any of the operands holds an error 68 value. If the result type is an extent tensor (and can therefore not hold 69 the error value) the behavior may be undefined. The optional string 70 attribute can be used to describe the error case. 71 }]; 72 73 let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs, 74 Shape_ShapeOrExtentTensorType:$rhs, 75 OptionalAttr<StrAttr>:$error); 76 let results = (outs Shape_ShapeOrExtentTensorType:$result); 77 78 let assemblyFormat = [{ 79 $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) 80 }]; 81 82 let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }]; 83 let hasFolder = 1; 84 85 let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }]; 86} 87 88def Shape_ConstShapeOp : Shape_Op<"const_shape", [ConstantLike, NoSideEffect]> { 89 let summary = "Creates a constant shape or extent tensor"; 90 let description = [{ 91 Creates a constant shape or extent tensor. The individual extents are given 92 as the `shape` attribute. The number of these values equals the shape's 93 rank. 94 95 ```mlir 96 %0 = shape.const_shape [] : !shape.shape 97 %1 = shape.const_shape [1, 2, 3] : !shape.shape 98 %2 = shape.const_shape [4, 5, 6] : tensor<?xindex> 99 ``` 100 }]; 101 let arguments = (ins IndexElementsAttr:$shape); 102 let results = (outs Shape_ShapeOrExtentTensorType:$result); 103 104 // TODO: Move this to main so that all shape ops implement these. 105 let printer = [{ return ::print(p, *this); }]; 106 let parser = [{ return ::parse$cppClass(parser, result); }]; 107 let hasFolder = 1; 108 let hasCanonicalizer = 1; 109} 110 111def Shape_ConstSizeOp : Shape_Op<"const_size", [ 112 ConstantLike, 113 NoSideEffect, 114 DeclareOpInterfaceMethods<OpAsmOpInterface> 115 ]> { 116 let summary = "Creates a constant of type `shape.size`"; 117 let description = [{ 118 Creates a `shape.size` type representing the constant size given by `value`. 119 120 ```mlir 121 %x = shape.const_size 10 122 ``` 123 }]; 124 125 let arguments = (ins IndexAttr:$value); 126 let results = (outs Shape_SizeType:$result); 127 128 let builders = [OpBuilderDAG<(ins "int64_t":$value)>]; 129 130 let assemblyFormat = "$value attr-dict"; 131 let hasFolder = 1; 132} 133 134def Shape_ShapeEqOp : Shape_Op<"shape_eq", [Commutative, NoSideEffect]> { 135 let summary = "Returns whether the input shapes or extent tensors are equal"; 136 let description = [{ 137 Takes two shape or extent tensor operands and determines whether they are 138 equal. When extent tensors are compared to shapes they are regarded as their 139 equivalent non-error shapes. Error shapes can be tested for equality like 140 any other shape value, meaning that the error value is equal to itself. 141 }]; 142 143 let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs, 144 Shape_ShapeOrExtentTensorType:$rhs); 145 let results = (outs I1:$result); 146 147 let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)"; 148 let hasFolder = 1; 149} 150 151def Shape_FromExtentsOp : Shape_Op<"from_extents", [NoSideEffect]> { 152 let summary = "Creates a shape from extents"; 153 let description = [{ 154 Creates a shape from multiple SSA values representing the extents of 155 the shape. 156 157 ```mlir 158 // Rank 2 shape. 159 %s0 = shape.from_extents %a, %b 160 // Rank 0 shape. 161 %s1 = shape.from_extents 162 ``` 163 }]; 164 let arguments = (ins Variadic<Index>:$extents); 165 let results = (outs Shape_ShapeType:$shape); 166 167 let assemblyFormat = "$extents attr-dict"; 168 169 let hasFolder = 1; 170} 171 172def Shape_FromExtentTensorOp : Shape_Op<"from_extent_tensor", [NoSideEffect]> { 173 let summary = "Creates a shape from a tensor of extents"; 174 let description = [{ 175 Creates a shape from a 1D integral tensor of extents. The rank of the 176 resulting shape equals the number of elements in the tensor, and the 177 extents match the values of the elements. 178 }]; 179 180 let arguments = (ins IndexTensor:$input); 181 let results = (outs Shape_ShapeType:$result); 182 183 let assemblyFormat = "$input attr-dict `:` type($input)"; 184} 185 186def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable", [Commutative]> { 187 let summary = "Determines if 2 shapes can be successfully broadcasted"; 188 let description = [{ 189 Given two input shapes or extent tensors, return a predicate specifying if 190 they are broadcastable. This broadcastable follows the same logic as what 191 shape.broadcast documents. 192 193 Concretely, shape.is_broadcastable returning true implies that 194 shape.broadcast will not give an error, and shape.cstr_broadcastable will 195 not result in an assertion failure. Similarly, false implies an error or 196 assertion failure. 197 198 Example: 199 ```mlir 200 %true = shape.is_broadcastable [2,2], [3,1,2] 201 %false = shape.is_broadcastable [2,2], [3,2] 202 ``` 203 }]; 204 205 let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs, 206 Shape_ShapeOrExtentTensorType:$rhs); 207 let results = (outs I1:$result); 208 209 let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)"; 210} 211 212def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> { 213 let summary = "Gets the rank of a shape"; 214 let description = [{ 215 Returns the rank of the shape or extent tensor, i.e. the number of extents. 216 }]; 217 218 let arguments = (ins Shape_ShapeOrExtentTensorType:$shape); 219 let results = (outs Shape_SizeOrIndexType:$rank); 220 221 let assemblyFormat = "$shape attr-dict `:` type($shape) `->` type($rank)"; 222 223 let hasFolder = 1; 224 let hasCanonicalizer = 1; 225 let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; 226} 227 228def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> { 229 let summary = "Creates a dimension tensor from a shape"; 230 let description = [{ 231 Converts a shape to a 1D integral tensor of extents. The number of elements 232 in the tensor equals the rank of the shape, and the elements equal the 233 extents of the shape. 234 235 If the shape represents an error, this op's behavior is undefined. 236 }]; 237 238 let arguments = (ins Shape_ShapeOrExtentTensorType:$input); 239 let results = (outs IndexTensor:$result); 240 241 let assemblyFormat = "$input attr-dict `:` type($input) `->` type($result)"; 242 243 let hasFolder = 1; 244} 245 246def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> { 247 let summary = "Gets the specified extent from a shape or extent tensor"; 248 let description = [{ 249 Gets the extent indexed by `dim` from the `shape` operand. If the shape is 250 an error then it returns an error size. 251 }]; 252 let arguments = (ins Shape_ShapeOrExtentTensorType:$shape, 253 Shape_SizeOrIndexType:$dim); 254 let results = (outs Shape_SizeOrIndexType:$extent); 255 let assemblyFormat = "$shape `,` $dim attr-dict `:` type($shape) `,` type($dim) `->` " 256 "type($extent)"; 257 258 let builders = [ 259 // Builder that allows passing a constant dimension as a simple integer. 260 OpBuilderDAG<(ins "Value":$shape, "int64_t":$dim)> 261 ]; 262 263 let extraClassDeclaration = [{ 264 /// Get the `dim` value as integer if it is constant. 265 Optional<int64_t> getConstantDim(); 266 }]; 267 268 let hasFolder = 1; 269 let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; 270} 271 272def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [NoSideEffect]> { 273 let summary = "Converts a standard index to a shape size"; 274 let description = [{ 275 Converts a standard index to a `shape.size`. This operation and its 276 inverse, `size_to_index`, facilitate index conversion between the standard 277 and the shape dialect. 278 279 The behavior is undefined for negative indices. 280 }]; 281 282 let arguments = (ins Index:$arg); 283 let results = (outs Shape_SizeType:$result); 284 285 let assemblyFormat = "$arg attr-dict"; 286 287 let hasFolder = 1; 288 let hasCanonicalizer = 1; 289} 290 291def Shape_JoinOp : Shape_Op<"join", [Commutative]> { 292 let summary = "Returns the least general shape.size of its operands"; 293 let description = [{ 294 An operation that computes the least general shape of input operands. 295 This effectively asserts that corresponding static dimensions are equal. 296 The behavior is to match each element of the `shape.shape` and propagate the 297 most restrictive information, returning an invalid shape if there are 298 contradictory requirements. E.g., using pseudo code 299 300 ``` 301 shape.join([*], [*]) -> [*] 302 shape.join([*], [1, ?]) -> [1, ?] 303 shape.join([1, 2], [1, ?]) -> [1, 2] 304 shape.join([*], [1, 2]) -> [1, 2] 305 shape.join([], []) -> [] 306 shape.join([], [*]) -> [] 307 shape.join([], [?, ?]) -> [invalid] 308 shape.join([1, ?], [2, ?, ?]) -> [invalid] 309 ``` 310 311 `shape.join` also allows specifying an optional error string, that may be 312 used to return an error to the user upon mismatch of dimensions. 313 314 ```mlir 315 %c = shape.join %a, %b, error="<reason>" : !shape.shape 316 ``` 317 }]; 318 319 let arguments = (ins Shape_ShapeOrSizeType:$arg0, Shape_ShapeOrSizeType:$arg1, 320 OptionalAttr<StrAttr>:$error); 321 let results = (outs Shape_ShapeOrSizeType:$result); 322} 323 324def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> { 325 let summary = "Multiplication of sizes and indices"; 326 let description = [{ 327 Multiplies two sizes or indices. If either operand is an error it will be 328 propagated to the result. The operands can be of type `size` or `index`. If 329 at least one of the operands can hold an error, i.e. if it is of type `size`, 330 then also the result must be of type `size`. If error propagation is not 331 possible because both operands are of type `index` then the result must also 332 be of type `index`. 333 }]; 334 335 let arguments = (ins Shape_SizeOrIndexType:$lhs, Shape_SizeOrIndexType:$rhs); 336 let results = (outs Shape_SizeOrIndexType:$result); 337 338 let assemblyFormat = [{ 339 $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) 340 }]; 341 342 let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; 343 let hasFolder = 1; 344} 345 346def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> { 347 let summary = "Returns the number of elements for a given shape"; 348 let description = [{ 349 Returns the number of elements for a given shape which is the product of its 350 extents. If the argument is of type `shape` then the result will be of type 351 `size` and potential errors will be propagated. Otherwise, if the argument 352 is and extent tensor `tensor<?xindex>` then the result will be of type 353 `index`. 354 }]; 355 356 let arguments = (ins Shape_ShapeOrExtentTensorType:$shape); 357 let results = (outs Shape_SizeOrIndexType:$result); 358 359 let builders = [OpBuilderDAG<(ins "Value":$shape)>]; 360 361 let assemblyFormat = "$shape attr-dict `:` type($shape) `->` type($result)"; 362 363 let hasFolder = 1; 364 let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; 365} 366 367def Shape_ReduceOp : Shape_Op<"reduce", 368 [SingleBlockImplicitTerminator<"YieldOp">]> { 369 let summary = "Returns an expression reduced over a shape or extent tensor"; 370 let description = [{ 371 An operation that takes as input a shape or extent tensor, and a number of 372 initial values. This operation has a region/function that is applied 373 repeatedly for every extent of the input. Starting with the initial values, 374 the individual extents are then aggregated as defined by the associated 375 region. 376 377 Conceptually this op performs the following reduction: 378 379 ``` 380 res[] = init; 381 for (int i = 0, i < shape.rank(); i++) { 382 res = fn(i, shape[i], res[0], ..., res[n]); 383 } 384 ``` 385 386 Where `fn` is provided by the user and the result of the reduce op is the 387 last computed output of the reduce function. As an example, computing the 388 number of elements can be defined as follows: 389 390 ```mlir 391 func @reduce(%shape : !shape.shape, %init : !shape.size) -> !shape.size { 392 %num_elements = shape.reduce(%shape, %init) -> !shape.size { 393 ^bb0(%index: index, %dim: !shape.size, %acc: !shape.size): 394 %updated_acc = "shape.mul"(%acc, %dim) : 395 (!shape.size, !shape.size) -> !shape.size 396 shape.yield %updated_acc : !shape.size 397 } 398 return %num_elements : !shape.size 399 } 400 ``` 401 }]; 402 403 let arguments = (ins Shape_ShapeOrExtentTensorType:$shape, 404 Variadic<AnyType>:$initVals); 405 let results = (outs Variadic<AnyType>:$result); 406 let regions = (region SizedRegion<1>:$region); 407 408 let builders = [OpBuilderDAG<(ins "Value":$shape, "ValueRange":$initVals)>]; 409 410 let verifier = [{ return ::verify(*this); }]; 411 let printer = [{ return ::print(p, *this); }]; 412 let parser = [{ return ::parse$cppClass(parser, result); }]; 413} 414 415def Shape_ShapeOfOp : Shape_Op<"shape_of", [NoSideEffect]> { 416 let summary = "Returns shape of a value or shaped type operand"; 417 418 let description = [{ 419 The operation takes a value or a shaped operand as an argument and it 420 returns a shape or extent tensor. 421 }]; 422 423 let arguments = (ins AnyTypeOf<[AnyShaped, Shape_ValueShapeType]>:$arg); 424 let results = (outs Shape_ShapeOrExtentTensorType:$result); 425 426 let assemblyFormat = "$arg attr-dict `:` type($arg) `->` type($result)"; 427 428 let builders = [OpBuilderDAG<(ins "Value":$arg)>]; 429 430 let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }]; 431 let hasCanonicalizer = 1; 432 let hasFolder = 1; 433} 434 435def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [NoSideEffect]> { 436 let summary = "Casts between index types of the shape and standard dialect"; 437 let description = [{ 438 Converts a `shape.size` to a standard index. This operation and its 439 inverse, `index_to_size`, facilitate index conversion between the standard 440 and the shape dialect. The behavior is undefined for unknown and invalid 441 arguments. 442 }]; 443 444 let arguments = (ins Shape_SizeOrIndexType:$arg); 445 let results = (outs Index:$result); 446 447 let assemblyFormat = "$arg attr-dict `:` type($arg)"; 448 449 let hasFolder = 1; 450 let hasCanonicalizer = 1; 451} 452 453def Shape_WithOp : Shape_Op<"with_shape", [NoSideEffect]> { 454 let summary = "Returns ValueShape with given shape"; 455 let description = [{ 456 Returns ValueShape with the shape updated to match the shape operand. That 457 is a new ValueShape tuple is created with value equal to `operand`'s 458 value and shape equal to `shape`. If the ValueShape and given `shape` are 459 non-conformant, then the returned ValueShape will represent an error of 460 this mismatch. Similarly if either inputs are in an error state, then an 461 error is propagated. 462 463 Usage: 464 %0 = shape.with_shape %1, %2 : tensor<...>, !shape.shape 465 466 This is used, for example, where one combines shape function calculations 467 and/or call one shape function from another. E.g., 468 469 ```mlir 470 func @shape_foobah(%a: !shape.value_shape, 471 %b: !shape.value_shape, 472 %c: !shape.value_shape) -> !shape.shape { 473 %0 = call @shape_foo(%a, %b) : 474 (!shape.value_shape, !shape.value_shape) -> !shape.shape 475 %1 = shape.with_shape %b, %0 : !shape.value_shape, !shape.shape 476 %2 = call @shape_bah(%c, %1) : 477 (!shape.value_shape, !shape.value_shape) -> !shape.shape 478 return %2 : !shape.shape 479 } 480 ``` 481 482 This op need not be a refinement of the shape. In non-error cases the input 483 ValueShape's value and shape are conformant and so too for the output, but 484 the result may be less specified than `operand`'s shape as `shape` is 485 merely used to construct the new ValueShape. If join behavior is desired 486 then a join op should be used. 487 }]; 488 489 let arguments = (ins AnyTypeOf<[AnyShaped, Shape_ValueShapeType]>:$operand, 490 Shape_ShapeType:$shape); 491 let results = (outs Shape_ValueShapeType:$result); 492 493 let assemblyFormat = "operands attr-dict `:` type($operand) `,` type($shape)"; 494} 495 496def Shape_YieldOp : Shape_Op<"yield", 497 [HasParent<"ReduceOp, FunctionLibraryOp">, 498 NoSideEffect, 499 ReturnLike, 500 Terminator]> { 501 let summary = "Returns the value to parent op"; 502 503 let arguments = (ins Variadic<AnyType>:$operands); 504 505 let builders = [OpBuilderDAG<(ins), 506 [{ build($_builder, $_state, llvm::None); }]> 507 ]; 508 509 let verifier = [{ return ::verify(*this); }]; 510 let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; 511} 512 513// TODO: Add Ops: if_static, if_ranked 514 515// For testing usage. 516def Shape_DebugPrintOp : Shape_Op<"debug_print", []> { 517 let summary = "Prints the input shape or size"; 518 let description = [{ 519 Prints the input dim or shape and passes through input. 520 521 Note: This is intended for testing and debugging only. 522 }]; 523 524 let arguments = (ins Shape_ShapeOrSizeType:$input); 525 let results = (outs Shape_ShapeOrSizeType:$output); 526} 527 528def Shape_SplitAtOp : Shape_Op<"split_at", []> { 529 let summary = "Splits a shape at a given index"; 530 let description = [{ 531 Splits a shape at a given dimension `index`, returning two shapes. 532 If `index` is negative, it is treated as indexing from the back of the 533 shape. This negative-handling behavior is important when handling unranked 534 shapes, where the positive index is not necessarily knowable due to a 535 dynamic number of leading dimensions. 536 537 Examples: 538 - split_at([4,5,6], index=0) -> [], [4,5,6] 539 - split_at([4,5,6], index=1) -> [4], [5,6] 540 - split_at([4,5,6], index=2) -> [4,5], [6] 541 - split_at([4,5,6], index=3) -> [4,5,6], [] 542 - split_at([4,5,6], index=4) -> error 543 - split_at([4,5,6], index=-1) -> [4,5], [6] 544 - split_at([4,5,6], index=-2) -> [4], [5,6] 545 - split_at([4,5,6], index=-3) -> [], [4,5,6] 546 - split_at([4,5,6], index=-4) -> error 547 548 Requires: 549 - `index` is in the range [-rank(operand),rank(operand)] 550 }]; 551 552 let arguments = (ins Shape_ShapeOrExtentTensorType:$operand, I32:$index); 553 let results = (outs Shape_ShapeType:$head, Shape_ShapeType:$tail); 554 let hasFolder = 1; 555} 556 557def Shape_ConcatOp : Shape_Op<"concat", []> { 558 let summary = "Concatenates two shapes"; 559 let description = [{ 560 Creates a shape whose dimensions consist of first the dimensions from `lhs` 561 followed by the dimensions of `rhs`. 562 563 Example: 564 concat([2,3], [4,5]) -> [2,3,4,5] 565 concat([], []) -> [] 566 concat([], [4,5,6]) -> [4,5,6] 567 }]; 568 569 let arguments = (ins Shape_ShapeType:$lhs, Shape_ShapeType:$rhs); 570 let results = (outs Shape_ShapeType:$result); 571 572 let assemblyFormat = "$lhs `,` $rhs attr-dict"; 573 let hasFolder = 1; 574} 575 576//===----------------------------------------------------------------------===// 577// Shape constraint related ops. 578//===----------------------------------------------------------------------===// 579 580// TODO: Move the code below and witnesses to a different file. 581def Shape_AnyOp : Shape_Op<"any", [Commutative, 582 NoSideEffect]> { 583 let summary = "Return any combination of the input shapes"; 584 let description = [{ 585 This operation takes multiple input shapes or extent tensors and returns 586 some combination of their dimensions. This can be best seen with examples 587 below. 588 589 The result is undefined, but still side-effect free, in cases where the 590 inputs have differing ranks or differ in extents of shared dimensions. 591 592 Example: 593 ```mlir 594 %s0 = shape.any [2,?], [?,3] // [2,3] 595 %s1 = shape.any [?,?], [1,2] // [1,2] 596 ``` 597 }]; 598 599 let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$inputs); 600 let results = (outs Shape_ShapeOrExtentTensorType:$result); 601 602 let assemblyFormat = "$inputs attr-dict `:` type($inputs) `->` type($result)"; 603 604 let hasFolder = 1; 605} 606 607def Shape_AssumingAllOp : Shape_Op<"assuming_all", [Commutative, NoSideEffect]> { 608 let summary = "Return a logical AND of all witnesses"; 609 let description = [{ 610 Used to simplify constraints as any single failing precondition is enough 611 to prevent execution. 612 613 "assuming" operations represent an execution order restriction to the 614 compiler, information for dependent code to rely on (by assuming), and 615 nothing else. They should not exist after a program is fully lowered and 616 ready to execute. 617 618 Example: 619 ```mlir 620 %w0 = shape.cstr_broadcastable [2,2], [3,1,2] // Passing 621 %w1 = shape.cstr_broadcastable [2,2], [3,2] // Failure 622 %w2 = shape.cstr_eq [1,2], [1,2], [1,2] // Passing 623 %wf = shape.assuming_all %w0, %w1 // Failure 624 %wt = shape.assuming_all %w0, %w2 // Passing 625 ``` 626 }]; 627 628 let arguments = (ins Variadic<Shape_WitnessType>:$inputs); 629 let results = (outs Shape_WitnessType:$result); 630 631 let assemblyFormat = "$inputs attr-dict"; 632 633 let hasFolder = 1; 634 let hasCanonicalizer = 1; 635 636 let verifier = [{ return ::verify(*this); }]; 637} 638 639def Shape_AssumingOp : Shape_Op<"assuming", 640 [SingleBlockImplicitTerminator<"AssumingYieldOp">, 641 DeclareOpInterfaceMethods<RegionBranchOpInterface>, 642 RecursiveSideEffects]> { 643 let summary = "Execute the region"; 644 let description = [{ 645 Executes the region assuming all witnesses are true. 646 647 "assuming" operations represent an execution order restriction to the 648 compiler, information for dependent code to rely on (by assuming), and 649 nothing else. They should not exist after a program is fully lowered and 650 ready to execute. 651 }]; 652 let arguments = (ins Shape_WitnessType:$witness); 653 let regions = (region SizedRegion<1>:$doRegion); 654 let results = (outs Variadic<AnyType>:$results); 655 656 let printer = [{ return ::print(p, *this); }]; 657 let parser = [{ return ::parse$cppClass(parser, result); }]; 658 let verifier = [{ return RegionBranchOpInterface::verifyTypes(*this); }]; 659 660 let extraClassDeclaration = [{ 661 // Inline the region into the region containing the AssumingOp and delete 662 // the AssumingOp. 663 // 664 // This does no checks on the inputs to the AssumingOp. 665 static void inlineRegionIntoParent(AssumingOp &op, PatternRewriter &rewriter); 666 }]; 667 668 let hasCanonicalizer = 1; 669} 670 671def Shape_AssumingYieldOp : Shape_Op<"assuming_yield", 672 [NoSideEffect, ReturnLike, Terminator]> { 673 let summary = "Yield operation"; 674 let description = [{ 675 This yield operation represents a return operation within the assert_and_exec 676 region. The operation takes variable number of operands and produces no 677 results. The operand number and types must match the return signature of 678 the region that contains the operation. 679 }]; 680 681 let arguments = (ins Variadic<AnyType>:$operands); 682 683 let builders = [OpBuilderDAG<(ins), [{ /* nothing to do */ }]>]; 684 685 let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; 686} 687 688def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", [Commutative]> { 689 let summary = "Determines if 2 shapes can be successfully broadcasted"; 690 let description = [{ 691 Given two input shapes or extent tensors, return a witness specifying if 692 they are broadcastable. This broadcastable follows the same logic as what 693 shape.broadcast documents. 694 695 "cstr" operations represent runtime assertions. 696 697 Example: 698 ```mlir 699 %w0 = shape.cstr_broadcastable [2,2], [3,1,2] // Passing 700 %w1 = shape.cstr_broadcastable [2,2], [3,2] // Failure 701 ``` 702 }]; 703 704 let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs, 705 Shape_ShapeOrExtentTensorType:$rhs); 706 let results = (outs Shape_WitnessType:$result); 707 708 let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)"; 709 710 let hasCanonicalizer = 1; 711 let hasFolder = 1; 712} 713 714def Shape_CstrEqOp : Shape_Op<"cstr_eq", [Commutative]> { 715 let summary = "Determines if all input shapes are equal"; 716 let description = [{ 717 Given 1 or more input shapes, determine if all shapes are the exact same. 718 719 "cstr" operations represent runtime assertions. 720 721 Example: 722 ```mlir 723 %w0 = shape.cstr_eq [1,2], [1,2], [1,2] // Passing 724 %w1 = shape.cstr_eq [2,2], [1,2] // Failure 725 ``` 726 }]; 727 let arguments = (ins Variadic<Shape_ShapeType>:$inputs); 728 let results = (outs Shape_WitnessType:$result); 729 730 let assemblyFormat = "$inputs attr-dict"; 731 732 let hasCanonicalizer = 1; 733 let hasFolder = 1; 734} 735 736def Shape_ConstWitnessOp : Shape_Op<"const_witness", [ConstantLike, NoSideEffect]> { 737 let summary = "An operation that returns a statically known witness value"; 738 let description = [{ 739 This operation represents a statically known witness result. This can be 740 often used to canonicalize/fold constraint and assuming code that will always 741 pass. 742 743 ```mlir 744 %0 = shape.const_shape [1,2,3] 745 %1 = shape.const_shape [1, 2, 3] 746 %w0 = shape.cstr_eq(%0, %1) // Can be folded to "const_witness true" 747 %w1 = shape.const_witness true 748 %w2 = shape.assuming_all(%w0, %w2) // Can be folded to "const_witness true" 749 ``` 750 }]; 751 let arguments = (ins BoolAttr:$passing); 752 let results = (outs Shape_WitnessType:$result); 753 754 let assemblyFormat = "$passing attr-dict"; 755 756 let hasFolder = 1; 757} 758 759def Shape_CstrRequireOp : Shape_Op<"cstr_require", []> { 760 let summary = "Represents a runtime assertion that an i1 is `true`"; 761 let description = [{ 762 Represents a runtime assertion that an i1 is true. It returns a 763 !shape.witness to order this assertion. 764 765 For simplicity, prefer using other cstr_* ops if they are available for a 766 given constraint. 767 768 Example: 769 ```mlir 770 %bool = ... 771 %w0 = shape.cstr_require %bool, "msg" // Passing if `%bool` is true. 772 ``` 773 774 Since this op can be used to express many different possible assertions 775 (depending on whatever computation calculated `pred`), the `msg` 776 should clarify the nature of the assertion for users. 777 }]; 778 let arguments = (ins I1:$pred, StrAttr:$msg); 779 let results = (outs Shape_WitnessType:$result); 780 781 let assemblyFormat = "$pred `,` $msg attr-dict"; 782 783 let hasFolder = 1; 784} 785 786//===----------------------------------------------------------------------===// 787// Shape collection ops. 788//===----------------------------------------------------------------------===// 789 790def Shape_FunctionLibraryOp : Shape_Op<"function_library", 791 [AffineScope, IsolatedFromAbove, NoRegionArguments, SymbolTable, Symbol, 792 SingleBlockImplicitTerminator<"ShapeFunctionLibraryTerminatorOp">]> { 793 let summary = "Represents shape functions and corresponding ops"; 794 let description = [{ 795 Represents a list of shape functions and the ops whose shape transfer 796 functions they represent. 797 798 Example: 799 800 ```mlir 801 shape.function_library { 802 func @same_result_shape(%arg: !shape.value_shape) -> !shape.shape { 803 %0 = shape.shape_of %arg : !shape.value_shape -> !shape.shape 804 return %0 : !shape.shape 805 } 806 } mapping { 807 std.atan = @same_result_shape 808 } 809 ``` 810 }]; 811 812 let arguments = (ins SymbolNameAttr:$sym_name, 813 OptionalAttr<StrAttr>:$sym_visibility); 814 let arguments = (ins DictionaryAttr:$mapping); 815 let regions = (region AnyRegion:$body); 816 817 let extraClassDeclaration = [{ 818 /// Returns an associated shape function for an operation if defined. 819 FuncOp getShapeFunction(Operation *op); 820 }]; 821 822 let builders = [OpBuilderDAG<(ins "StringRef":$name)>]; 823 let skipDefaultBuilders = 1; 824 825 let printer = [{ ::print(p, *this); }]; 826 let parser = [{ return ::parse$cppClass(parser, result); }]; 827} 828 829//===----------------------------------------------------------------------===// 830// ShapeFunctionLibraryTerminatorOp 831//===----------------------------------------------------------------------===// 832 833def ShapeFunctionLibraryTerminatorOp : Shape_Op<"fn_lib_terminator", 834 [Terminator, HasParent<"FunctionLibraryOp">]> { 835 let summary = "A pseudo op that marks the end of a shape function library"; 836 let description = [{ 837 `shape_fn_lib_terminator` is a special pseudo terminator operation for the 838 shape function library. It has no semantic meaning beyond keeping the body 839 well-formed. 840 }]; 841 let assemblyFormat = "attr-dict"; 842} 843 844#endif // SHAPE_OPS 845