1// RUN: mlir-opt --split-input-file --convert-shape-to-std --verify-diagnostics %s | FileCheck %s
2
3// Lower binary ops.
4// CHECK-LABEL: @binary_ops
5// CHECK-SAME: (%[[LHS:.*]]: index, %[[RHS:.*]]: index)
6func @binary_ops(%lhs : index, %rhs : index) {
7  // CHECK: addi %[[LHS]], %[[RHS]] : index
8  %sum = shape.add %lhs, %rhs : index, index -> index
9  // CHECK: muli %[[LHS]], %[[RHS]] : index
10  %product = shape.mul %lhs, %rhs : index, index -> index
11  return
12}
13
14// -----
15
16// Don't lower binary ops when they operate on `shape.size`.
17// CHECK-LABEL: @binary_ops_on_size
18// CHECK-SAME: (%[[LHS:.*]]: !shape.size, %[[RHS:.*]]: !shape.size)
19func @binary_ops_on_size(%lhs : !shape.size, %rhs : !shape.size) {
20  // CHECK: shape.add %[[LHS]], %[[RHS]] : !shape.size, !shape.size -> !shape.size
21  // CHECK: shape.mul %[[LHS]], %[[RHS]] : !shape.size, !shape.size -> !shape.size
22  %sum = shape.add %lhs, %rhs : !shape.size, !shape.size -> !shape.size
23  %prod = shape.mul %lhs, %rhs : !shape.size, !shape.size -> !shape.size
24  return
25}
26
27// -----
28
29// Convert `rank` to `dim` of the first dimension.
30// CHECK-LABEL: @rank
31// CHECK-SAME: (%[[SHAPE:.*]]: tensor<?xindex>) -> index
32func @rank(%shape : tensor<?xindex>) -> index {
33  // CHECK: %[[C0:.*]] = constant 0 : index
34  // CHECK: %[[RESULT:.*]] = dim %[[SHAPE]], %[[C0]]
35  // CHECK: return %[[RESULT]] : index
36  %rank = shape.rank %shape : tensor<?xindex> -> index
37  return %rank : index
38}
39
40// -----
41
42// Don't lower `get_extent` if it is of type `shape.size`.
43// CHECK-LABEL: @get_extent
44func @get_extent(%shape : tensor<?xindex>, %idx : !shape.size) -> !shape.size {
45  // CHECK: shape.get_extent
46  %result = shape.get_extent %shape, %idx
47      : tensor<?xindex>, !shape.size -> !shape.size
48  return %result : !shape.size
49}
50
51// -----
52
53// Don't lower `rank` if type is not error-free.
54// CHECK-LABEL: @rank
55func @rank(%shape : !shape.shape) {
56  // CHECK: shape.rank
57  %rank = shape.rank %shape : !shape.shape -> !shape.size
58  return
59}
60
61// -----
62
63// Express `get_extent` as `std.dim` when it relies directly on the outcome of a
64// `shape_of` operation.
65// CHECK-LABEL: @get_extent_shape_of
66// CHECK-SAME:  (%[[ARG:.*]]: tensor<2x3xf32>, %[[IDX:.*]]: index) -> index
67func @get_extent_shape_of(%arg : tensor<2x3xf32>, %idx : index) -> index {
68  // CHECK: %[[RESULT:.*]] = dim %[[ARG]], %[[IDX]] : tensor<2x3xf32>
69  // CHECK: return %[[RESULT]] : index
70  %shape = shape.shape_of %arg : tensor<2x3xf32> -> tensor<?xindex>
71  %result = shape.get_extent %shape, %idx : tensor<?xindex>, index -> index
72  return %result : index
73}
74
75// -----
76
77// Express `get_extent` as `std.extract_element`.
78// CHECK-LABEL: @get_extent_from_extent_tensor
79// CHECK-SAME: (%[[EXTENTS:.*]]: tensor<?xindex>, %[[IDX:.*]]: index) -> index
80func @get_extent_from_extent_tensor(%extents : tensor<?xindex>, %idx : index)
81    -> index {
82  // CHECK: %[[RESULT:.*]] = extract_element %[[EXTENTS]][%[[IDX]]] : tensor<?xindex>
83  // CHECK: return %[[RESULT]] : index
84  %result = shape.get_extent %extents, %idx : tensor<?xindex>, index -> index
85  return %result : index
86}
87
88// -----
89
90// Lower `const_shape` to `tensor_from_elements`.
91// CHECK-LABEL: @const_shape
92// CHECK-SAME: () -> tensor<?xindex>
93func @const_shape() -> tensor<?xindex> {
94  // CHECK: %[[C1:.*]] = constant 1 : index
95  // CHECK: %[[C2:.*]] = constant 2 : index
96  // CHECK: %[[C3:.*]] = constant 3 : index
97  // CHECK: %[[TENSOR3:.*]] = tensor_from_elements %[[C1]], %[[C2]], %[[C3]]
98  // CHECK: %[[RESULT:.*]] = tensor_cast %[[TENSOR3]] : tensor<3xindex> to tensor<?xindex>
99  // CHECK: return %[[RESULT]] : tensor<?xindex>
100  %shape = shape.const_shape [1, 2, 3] : tensor<?xindex>
101  return %shape : tensor<?xindex>
102}
103
104// -----
105
106// Lower `const_shape` in the case of rank 0.
107// CHECK-LABEL: func @const_shape_zero_elements
108// CHECK-SAME: () -> tensor<?xindex>
109func @const_shape_zero_elements() -> tensor<?xindex> {
110  // CHECK: %[[TENSOR:.*]] = tensor_from_elements : tensor<0xindex>
111  // CHECK: %[[RESULT:.*]] = tensor_cast %[[TENSOR]] : tensor<0xindex> to tensor<?xindex>
112  // CHECK: return %[[RESULT]] : tensor<?xindex>
113  %shape = shape.const_shape [] : tensor<?xindex>
114  return %shape : tensor<?xindex>
115}
116
117// -----
118
119// Lower `any` to its first operand.
120// CHECK-LABEL: @any_of_three
121// CHECK-SAME:  (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<?xindex>) -> tensor<?xindex>
122func @any_of_three(%a : tensor<?xindex>,
123                   %b : tensor<?xindex>,
124                   %c : tensor<?xindex>) -> tensor<?xindex> {
125  // CHECK: return %[[A]] : tensor<?xindex>
126  %result = "shape.any"(%a, %b, %c) : (tensor<?xindex>, tensor<?xindex>, tensor<?xindex>) -> tensor<?xindex>
127  return %result : tensor<?xindex>
128}
129
130// -----
131
132// Lower `any` to its first operand.
133// CHECK-LABEL: @any_of_one
134// CHECK-SAME:  (%[[A:.*]]: tensor<?xindex>) -> tensor<?xindex>
135func @any_of_one(%a : tensor<?xindex>) -> tensor<?xindex> {
136  // CHECK: return %[[A]] : tensor<?xindex>
137  %result = "shape.any"(%a) : (tensor<?xindex>) -> tensor<?xindex>
138  return %result : tensor<?xindex>
139}
140
141// -----
142
143// Lower 'const_size` to `std.constant`
144// CHECK-LABEL: @const_size
145func @const_size() -> index {
146  // CHECK: %[[RES:.*]] = constant 42 : index
147  %size = shape.const_size 42
148  %result = shape.size_to_index %size : !shape.size
149  // CHECK: return %[[RES]]
150  return %result : index
151}
152
153// -----
154
155// Lower `to_extent_tensor` to `std.tensor_cast`
156// Fold to_extent_tensor when already on tensor.
157// CHECK-LABEL: @to_extent_tensor
158// CHECK-SAME: (%[[ARG:.*]]: tensor<?xindex>
159func @to_extent_tensor(%arg: tensor<?xindex>) -> tensor<3xindex> {
160  // CHECK-NOT: to_extent_tensor
161  // CHECK: %[[RES:.*]] = tensor_cast %[[ARG]] : tensor<?xindex> to tensor<3xindex
162  %casted = shape.to_extent_tensor %arg : tensor<?xindex> -> tensor<3xindex>
163  // CHECK: return %[[RES]]
164  return %casted : tensor<3xindex>
165}
166
167// CHECK-LABEL: @shape_reduce
168// CHECK-SAME:  (%[[SHAPE:.*]]: tensor<?xindex>) -> index
169func @shape_reduce(%shape : tensor<?xindex>) -> index {
170  %init = constant 1 : index
171  %num_elements = shape.reduce(%shape, %init) : tensor<?xindex> -> index {
172    ^bb0(%index : index, %extent : index, %acc: index):
173      %new_acc = muli %acc, %extent : index
174      shape.yield %new_acc : index
175  }
176  return %num_elements : index
177}
178// CHECK-NEXT: %[[INIT:.*]] = constant 1 : index
179// CHECK-NEXT: %[[C0:.*]] = constant 0 : index
180// CHECK-NEXT: %[[C1:.*]] = constant 1 : index
181// CHECK-NEXT: %[[RANK:.*]] = dim %[[SHAPE]], %[[C0]] : tensor<?xindex>
182// CHECK-NEXT: %[[RESULT:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] iter_args(%[[ACC:.*]] = %[[INIT]]) -> (index)
183// CHECK-NEXT:   %[[EXTENT:.*]] = extract_element %[[SHAPE]][%[[I]]]
184// CHECK-NEXT:   %[[NEW_ACC:.*]] = muli %[[ACC]], %[[EXTENT]] : index
185// CHECK-NEXT:   scf.yield %[[NEW_ACC]] : index
186// CHECK-NEXT: }
187// CHECK-NEXT: return %[[RESULT]] : index
188
189// -----
190
191// Don't lower `shape_of` for result type of `shape.shape`.
192// CHECK-LABEL: @shape_of
193// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
194func @shape_of(%arg : tensor<*xf32>) {
195  // CHECK: shape.shape
196  %shape = shape.shape_of %arg : tensor<*xf32> -> !shape.shape
197  return
198}
199
200// -----
201
202// Lower `shape_of` for unranked tensors.
203// CHECK-LABEL: @shape_of_unranked
204// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
205func @shape_of_unranked(%arg : tensor<*xf32>) {
206  // CHECK: %[[RANK:.*]] = rank %[[ARG]] : tensor<*xf32>
207  // CHECK: %[[SHAPE:.*]] = dynamic_tensor_from_elements %[[RANK]] {
208  // CHECK: ^bb0(%[[I:.*]]: index):
209  // CHECK:   %[[EXTENT:.*]] = dim %[[ARG]], %[[I]] : tensor<*xf32>
210  // CHECK:   yield %[[EXTENT]] : index
211  // CHECK: } : tensor<?xindex>
212  %shape = shape.shape_of %arg : tensor<*xf32> -> tensor<?xindex>
213  return
214}
215
216// -----
217
218// Don't lower `shape_of` with `shape.shape` type.
219// CHECK-LABEL: @shape_of
220// CHECK-SAME: (%[[ARG:.*]]: tensor<1x2x3xf32>)
221func @shape_of_stat(%arg : tensor<1x2x3xf32>) {
222  // CHECK: shape.shape_of %[[ARG]] : tensor<1x2x3xf32> -> !shape.shape
223  %shape = shape.shape_of %arg : tensor<1x2x3xf32> -> !shape.shape
224  return
225}
226
227// -----
228
229// Lower `shape_of` for statically shaped tensor.
230// CHECK-LABEL: @shape_of_stat
231// CHECK-SAME: (%[[ARG:.*]]: tensor<1x2x3xf32>)
232func @shape_of_stat(%arg : tensor<1x2x3xf32>) {
233  // CHECK-DAG: %[[C1:.*]] = constant 1 : index
234  // CHECK-DAG: %[[C2:.*]] = constant 2 : index
235  // CHECK-DAG: %[[C3:.*]] = constant 3 : index
236  // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements %[[C1]], %[[C2]], %[[C3]] : tensor<3xindex>
237  %shape = shape.shape_of %arg : tensor<1x2x3xf32> -> tensor<?xindex>
238  return
239}
240
241// -----
242
243// Lower `shape_of` for 0-D tensor.
244// CHECK-LABEL: @shape_of_zero_d
245// CHECK-SAME: (%[[ARG:.*]]: tensor<f32>)
246func @shape_of_zero_d(%arg : tensor<f32>) {
247  // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements : tensor<0xindex>
248  %shape = shape.shape_of %arg : tensor<f32> -> tensor<?xindex>
249  return
250}
251
252// -----
253
254// Lower `shape_of` for dynamically shaped tensor.
255// CHECK-LABEL: @shape_of_dyn
256// CHECK-SAME: (%[[ARG:.*]]: tensor<1x5x?xf32>)
257func @shape_of_dyn(%arg : tensor<1x5x?xf32>) {
258  // CHECK-DAG: %[[C1:.*]] = constant 1 : index
259  // CHECK-DAG: %[[C5:.*]] = constant 5 : index
260  // CHECK-DAG: %[[C2:.*]] = constant 2 : index
261  // CHECK-DAG: %[[DYN_DIM:.*]] = dim %[[ARG]], %[[C2]] : tensor<1x5x?xf32>
262  // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements %[[C1]], %[[C5]], %[[DYN_DIM]] : tensor<3xindex>
263  %shape = shape.shape_of %arg : tensor<1x5x?xf32> -> tensor<?xindex>
264  return
265}
266
267// -----
268
269// CHECK-LABEL:  @shape_eq
270// CHECK-SAME:   (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>) -> i1
271func @shape_eq(%a : tensor<?xindex>, %b : tensor<?xindex>) -> i1 {
272  // CHECK: %[[C0:.*]] = constant 0 : index
273  // CHECK: %[[RANK_A:.*]] = dim %[[A]], %[[C0]] : tensor<?xindex>
274  // CHECK: %[[RANK_B:.*]] = dim %[[B]], %[[C0]] : tensor<?xindex>
275  // CHECK: %[[RANK_EQ:.*]] = cmpi "eq", %[[RANK_A]], %[[RANK_B]]
276  // CHECK: %[[SHAPE_EQ:.*]] = scf.if %[[RANK_EQ]] -> (i1) {
277  // CHECK:   %[[C1:.*]] = constant 1 : index
278  // CHECK:   %[[INIT:.*]] = constant true
279  // CHECK:   %[[SHAPE_EQ_INNER:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK_A]] step %[[C1]] iter_args(%[[CONJ:.*]] = %[[INIT]]) -> (i1) {
280  // CHECK:     %[[EXTENT_A:.*]] = extract_element %[[A]][%[[I]]] : tensor<?xindex>
281  // CHECK:     %[[EXTENT_B:.*]] = extract_element %[[B]][%[[I]]] : tensor<?xindex>
282  // CHECK:     %[[EXTENT_EQ:.*]] = cmpi "eq", %[[EXTENT_A]], %[[EXTENT_B]]
283  // CHECK:     %[[CONJ_NEXT:.*]] = and %[[CONJ]], %[[EXTENT_EQ]]
284  // CHECK:     scf.yield %[[CONJ_NEXT]] : i1
285  // CHECK:   }
286  // CHECK:   scf.yield %[[SHAPE_EQ_INNER]] : i1
287  // CHECK: } else {
288  // CHECK:   %[[SHAPE_EQ_INNER:.*]] = constant false
289  // CHECK:   scf.yield %[[SHAPE_EQ_INNER]] : i1
290  // CHECK: }
291  // CHECK: return %[[SHAPE_EQ]] : i1
292  %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
293  return %result : i1
294}
295
296// -----
297
298// Don't lower `shape.broadcast` if a `shape.shape` type is involved.
299// CHECK-LABEL: @broadcast
300func @broadcast(%a : tensor<?xindex>, %b : !shape.shape) -> !shape.shape {
301  // CHECK: shape.broadcast
302  %c = shape.broadcast %a, %b : tensor<?xindex>, !shape.shape -> !shape.shape
303  return %c : !shape.shape
304}
305
306// -----
307
308// CHECK-LABEL:   func @broadcast_unknown_extents(
309// CHECK-SAME:                                    %[[LHS:.*]]: tensor<?xindex>,
310// CHECK-SAME:                                    %[[RHS:.*]]: tensor<?xindex>) {
311func @broadcast_unknown_extents(%a : tensor<?xindex>, %b : tensor<?xindex>) {
312  // CHECK:           %[[C0:.*]] = constant 0 : index
313  // CHECK:           %[[C1:.*]] = constant 1 : index
314  // CHECK:           %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
315  // CHECK:           %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
316  // CHECK:           %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
317  // CHECK:           %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
318  // CHECK:           %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
319  // CHECK:           %[[ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<?xindex> to tensor<?xindex>
320  // CHECK:           %[[ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
321  // CHECK:           %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor<?xindex>
322  // CHECK:           %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor<?xindex>
323  // CHECK:           %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
324  // CHECK:           %[[RESULT:.*]] = dynamic_tensor_from_elements %[[GREATER_RANK]] {
325  // CHECK:           ^bb0(%[[OUTPUT_DIMENSION:.*]]: index):
326  // CHECK:             %[[IS_UNCHALLENGED_DIMENSION:.*]] = cmpi "ult", %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index
327  // CHECK:             %[[GREATER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[OUTPUT_DIMENSION]]] : tensor<?xindex>
328  // CHECK:             %[[OUTPUT_EXTENT:.*]] = scf.if %[[IS_UNCHALLENGED_DIMENSION]] -> (index) {
329  // CHECK:               scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index
330  // CHECK:             } else {
331  // CHECK:               %[[LESSER_RANK_OPERAND_DIMENSION:.*]] = subi %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index
332  // CHECK:               %[[LESSER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[LESSER_RANK_OPERAND]][%[[LESSER_RANK_OPERAND_DIMENSION]]] : tensor<?xindex>
333  // CHECK:               %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index
334  // CHECK:               %[[BROADCASTED_EXTENT:.*]] = select %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE]], %[[LESSER_RANK_OPERAND_EXTENT]], %[[GREATER_RANK_OPERAND_EXTENT]] : index
335  // CHECK:               scf.yield %[[BROADCASTED_EXTENT]] : index
336  // CHECK:             }
337  // CHECK:             yield %[[OUTPUT_EXTENT:.*]] : index
338  // CHECK:           } : tensor<?xindex>
339  // CHECK:           return
340  // CHECK:         }
341  %0 = shape.broadcast %a, %b
342      : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
343  return
344}
345
346// -----
347
348// CHECK-LABEL:   func @broadcast_known_different_extents(
349// CHECK-SAME:                                            %[[LHS:.*]]: tensor<2xindex>,
350// CHECK-SAME:                                            %[[RHS:.*]]: tensor<3xindex>) {
351func @broadcast_known_different_extents(%a : tensor<2xindex>, %b : tensor<3xindex>) {
352  // CHECK:           %[[C0:.*]] = constant 0 : index
353  // CHECK:           %[[C1:.*]] = constant 1 : index
354  // CHECK:           %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<2xindex>
355  // CHECK:           %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<3xindex>
356  // CHECK:           %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
357  // CHECK:           %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
358  // CHECK:           %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
359  // CHECK:           %[[ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<2xindex> to tensor<?xindex>
360  // CHECK:           %[[ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<3xindex> to tensor<?xindex>
361  // CHECK:           %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor<?xindex>
362  // CHECK:           %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor<?xindex>
363  // CHECK:           %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
364  // CHECK:           %[[RESULT:.*]] = dynamic_tensor_from_elements %[[GREATER_RANK]] {
365  // CHECK:           ^bb0(%[[OUTPUT_DIMENSION:.*]]: index):
366  // CHECK:             %[[IS_UNCHALLENGED_DIMENSION:.*]] = cmpi "ult", %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index
367  // CHECK:             %[[GREATER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[OUTPUT_DIMENSION]]] : tensor<?xindex>
368  // CHECK:             %[[OUTPUT_EXTENT:.*]] = scf.if %[[IS_UNCHALLENGED_DIMENSION]] -> (index) {
369  // CHECK:               scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index
370  // CHECK:             } else {
371  // CHECK:               %[[LESSER_RANK_OPERAND_DIMENSION:.*]] = subi %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index
372  // CHECK:               %[[LESSER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[LESSER_RANK_OPERAND]][%[[LESSER_RANK_OPERAND_DIMENSION]]] : tensor<?xindex>
373  // CHECK:               %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index
374  // CHECK:               %[[BROADCASTED_EXTENT:.*]] = select %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE]], %[[LESSER_RANK_OPERAND_EXTENT]], %[[GREATER_RANK_OPERAND_EXTENT]] : index
375  // CHECK:               scf.yield %[[BROADCASTED_EXTENT]] : index
376  // CHECK:             }
377  // CHECK:             yield %[[OUTPUT_EXTENT:.*]] : index
378  // CHECK:           } : tensor<?xindex>
379  // CHECK:           return
380  // CHECK:         }
381  %0 = shape.broadcast %a, %b
382      : tensor<2xindex>, tensor<3xindex> -> tensor<?xindex>
383  return
384}
385
386// -----
387
388func @try_is_broadcastable(%a : tensor<3xindex>, %b : tensor<?xindex>) -> i1 {
389  %0 = shape.is_broadcastable %a, %b : tensor<3xindex>, tensor<?xindex>
390  return %0 : i1
391}
392
393// CHECK-LABEL:   func @try_is_broadcastable(
394// CHECK-SAME:        %[[LHS:.*]]: tensor<3xindex>,
395// CHECK-SAME:        %[[RHS:.*]]: tensor<?xindex>) -> i1 {
396// CHECK:           %[[C0:.*]] = constant 0 : index
397// CHECK:           %[[C1:.*]] = constant 1 : index
398// CHECK:           %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<3xindex>
399// CHECK:           %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
400// CHECK:           %[[LHS_SMALLER:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
401// CHECK:           %[[SMALLER_RANK:.*]] = select %[[LHS_SMALLER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
402// CHECK:           %[[LARGER_RANK:.*]] = select %[[LHS_SMALLER]], %[[RHS_RANK]], %[[LHS_RANK]] : index
403// CHECK:           %[[RANK_ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<3xindex> to tensor<?xindex>
404// CHECK:           %[[RANK_ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
405// CHECK:           %[[SMALLER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_LHS]], %[[RANK_ERASED_RHS]] : tensor<?xindex>
406// CHECK:           %[[LARGER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_RHS]], %[[RANK_ERASED_LHS]] : tensor<?xindex>
407// CHECK:           %[[RANK_DIFF:.*]] = subi %[[LARGER_RANK]], %[[SMALLER_RANK]] : index
408// CHECK:           %[[TRUE:.*]] = constant true
409// CHECK:           %[[ALL_RESULT:.*]] = scf.for %[[I:.*]] = %[[RANK_DIFF]] to %[[LARGER_RANK]] step %[[C1]] iter_args(%[[ALL_SO_FAR:.*]] = %[[TRUE]]) -> (i1) {
410// CHECK:             %[[LARGER_EXTENT:.*]] = extract_element %[[LARGER_SHAPE]]{{\[}}%[[I]]] : tensor<?xindex>
411// CHECK:             %[[LARGER_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[LARGER_EXTENT]], %[[C1]] : index
412// CHECK:             %[[SMALLER_EXTENT_INDEX:.*]] = subi %[[I]], %[[RANK_DIFF]] : index
413// CHECK:             %[[SMALLER_EXTENT:.*]] = extract_element %[[SMALLER_SHAPE]]{{\[}}%[[SMALLER_EXTENT_INDEX]]] : tensor<?xindex>
414// CHECK:             %[[SMALLER_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[SMALLER_EXTENT]], %[[C1]] : index
415// CHECK:             %[[EXTENTS_ARE_EQUAL:.*]] = cmpi "eq", %[[LARGER_EXTENT]], %[[SMALLER_EXTENT]] : index
416// CHECK:             %[[EITHER_EXTENT_IS_ONE:.*]] = or %[[LARGER_EXTENT_IS_ONE]], %[[SMALLER_EXTENT_IS_ONE]] : i1
417// CHECK:             %[[OR_EXTENTS_ARE_EQUAL:.*]] = or %[[EITHER_EXTENT_IS_ONE]], %[[EXTENTS_ARE_EQUAL]] : i1
418// CHECK:             %[[NEW_ALL_SO_FAR:.*]] = and %[[ALL_SO_FAR]], %[[OR_EXTENTS_ARE_EQUAL]] : i1
419// CHECK:             scf.yield %[[NEW_ALL_SO_FAR]] : i1
420// CHECK:           }
421// CHECK:           return %[[ALL_RESULT]] : i1
422// CHECK:         }
423
424// -----
425
426func @broadcast(%a : tensor<?xindex>, %b : tensor<?xindex>) -> !shape.witness {
427  %0 = shape.cstr_broadcastable %a, %b : tensor<?xindex>, tensor<?xindex>
428  return %0 : !shape.witness
429}
430
431// CHECK-LABEL:   func @broadcast(
432// CHECK-SAME:                    %[[LHS:.*]]: tensor<?xindex>,
433// CHECK-SAME:                    %[[RHS:.*]]: tensor<?xindex>) -> !shape.witness {
434// CHECK:           %[[C0:.*]] = constant 0 : index
435// CHECK:           %[[C1:.*]] = constant 1 : index
436// CHECK:           %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
437// CHECK:           %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
438// CHECK:           %[[LHS_SMALLER:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
439// CHECK:           %[[SMALLER_RANK:.*]] = select %[[LHS_SMALLER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
440// CHECK:           %[[LARGER_RANK:.*]] = select %[[LHS_SMALLER]], %[[RHS_RANK]], %[[LHS_RANK]] : index
441// CHECK:           %[[RANK_ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<?xindex> to tensor<?xindex>
442// CHECK:           %[[RANK_ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
443// CHECK:           %[[SMALLER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_LHS]], %[[RANK_ERASED_RHS]] : tensor<?xindex>
444// CHECK:           %[[LARGER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_RHS]], %[[RANK_ERASED_LHS]] : tensor<?xindex>
445// CHECK:           %[[RANK_DIFF:.*]] = subi %[[LARGER_RANK]], %[[SMALLER_RANK]] : index
446// CHECK:           %[[TRUE:.*]] = constant true
447// CHECK:           %[[ALL_RESULT:.*]] = scf.for %[[VAL_16:.*]] = %[[RANK_DIFF]] to %[[LARGER_RANK]] step %[[C1]] iter_args(%[[ALL_SO_FAR:.*]] = %[[TRUE]]) -> (i1) {
448// CHECK:             %[[LARGER_EXTENT:.*]] = extract_element %[[LARGER_SHAPE]]{{\[}}%[[VAL_16]]] : tensor<?xindex>
449// CHECK:             %[[LARGER_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[LARGER_EXTENT]], %[[C1]] : index
450// CHECK:             %[[LHS_EXTENT_INDEX:.*]] = subi %[[VAL_16]], %[[RANK_DIFF]] : index
451// CHECK:             %[[SMALLER_EXTENT:.*]] = extract_element %[[SMALLER_SHAPE]]{{\[}}%[[LHS_EXTENT_INDEX]]] : tensor<?xindex>
452// CHECK:             %[[SMALLER_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[SMALLER_EXTENT]], %[[C1]] : index
453// CHECK:             %[[EXTENTS_ARE_EQUAL:.*]] = cmpi "eq", %[[LARGER_EXTENT]], %[[SMALLER_EXTENT]] : index
454// CHECK:             %[[EITHER_EXTENT_IS_ONE:.*]] = or %[[LARGER_EXTENT_IS_ONE]], %[[SMALLER_EXTENT_IS_ONE]] : i1
455// CHECK:             %[[OR_EXTENTS_ARE_EQUAL:.*]] = or %[[EITHER_EXTENT_IS_ONE]], %[[EXTENTS_ARE_EQUAL]] : i1
456// CHECK:             %[[NEW_ALL_SO_FAR:.*]] = and %[[ALL_SO_FAR]], %[[OR_EXTENTS_ARE_EQUAL]] : i1
457// CHECK:             scf.yield %[[NEW_ALL_SO_FAR]] : i1
458// CHECK:           }
459// CHECK:           %[[RESULT:.*]] = shape.cstr_require %[[ALL_RESULT]], "required broadcastable shapes"
460// CHECK:           return %[[RESULT]] : !shape.witness
461// CHECK:         }
462