1// RUN: mlir-opt %s -split-input-file -verify-diagnostics
2
3// Verify that ops with broadcastable trait verifies operand and result type
4// combinations and emits an error for invalid combinations.
5
6func @broadcast_scalar_scalar_scalar(tensor<i32>, tensor<i32>) -> tensor<i32> {
7^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
8  %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
9  return %0 : tensor<i32>
10}
11
12// -----
13
14func @broadcast_tensor_scalar_tensor(tensor<4xi32>, tensor<i32>) -> tensor<4xi32> {
15^bb0(%arg0: tensor<4xi32>, %arg1: tensor<i32>):
16  %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
17  return %0 : tensor<4xi32>
18}
19
20// -----
21
22// Check only one dimension has size 1
23func @broadcast_tensor_tensor_tensor(tensor<4x3x2xi32>, tensor<3x1xi32>) -> tensor<4x3x2xi32> {
24^bb0(%arg0: tensor<4x3x2xi32>, %arg1: tensor<3x1xi32>):
25  %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<4x3x2xi32>, tensor<3x1xi32>) -> tensor<4x3x2xi32>
26  return %0 : tensor<4x3x2xi32>
27}
28
29// -----
30
31// Check multiple dimensions have size 1
32func @broadcast_tensor_tensor_tensor(tensor<8x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x6x5xi32> {
33^bb0(%arg0: tensor<8x1x6x1xi32>, %arg1: tensor<7x1x5xi32>):
34  %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<8x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x6x5xi32>
35  return %0 : tensor<8x7x6x5xi32>
36}
37
38// -----
39
40// Check leading unknown dimension
41func @broadcast_tensor_tensor_tensor(tensor<?x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<?x7x6x5xi32> {
42^bb0(%arg0: tensor<?x1x6x1xi32>, %arg1: tensor<7x1x5xi32>):
43  %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<?x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<?x7x6x5xi32>
44  return %0 : tensor<?x7x6x5xi32>
45}
46
47// -----
48
49// Check unknown dimension in the middle
50func @broadcast_tensor_tensor_tensor(tensor<8x1x?x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x?x5xi32> {
51^bb0(%arg0: tensor<8x1x?x1xi32>, %arg1: tensor<7x1x5xi32>):
52  %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<8x1x?x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x?x5xi32>
53  return %0 : tensor<8x7x?x5xi32>
54}
55
56// -----
57
58// Check incompatible vector and tensor result type
59func @broadcast_scalar_vector_vector(tensor<4xf32>, tensor<4xf32>) -> vector<4xf32> {
60^bb0(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>):
61  // expected-error @+1 {{cannot broadcast vector with tensor}}
62  %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> vector<4xf32>
63  return %0 : vector<4xf32>
64}
65
66// -----
67
68// Check incompatible operand types with known dimension
69func @broadcast_tensor_tensor_tensor(tensor<4x3x2xi32>, tensor<3x3xi32>) -> tensor<4x3x2xi32> {
70^bb0(%arg0: tensor<4x3x2xi32>, %arg1: tensor<3x3xi32>):
71  // expected-error @+1 {{operands don't have broadcast-compatible shapes}}
72  %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<4x3x2xi32>, tensor<3x3xi32>) -> tensor<4x3x2xi32>
73  return %0 : tensor<4x3x2xi32>
74}
75
76// -----
77
78// Check incompatible result type with known dimension
79func @broadcast_tensor_tensor_tensor(tensor<4x3x2xi32>, tensor<3x1xi32>) -> tensor<4x3x3xi32> {
80^bb0(%arg0: tensor<4x3x2xi32>, %arg1: tensor<3x1xi32>):
81  // expected-error @+1 {{op result type '4x3x3' not broadcast compatible with broadcasted operands's shapes '4x3x2'}}
82  %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<4x3x2xi32>, tensor<3x1xi32>) -> tensor<4x3x3xi32>
83  return %0 : tensor<4x3x3xi32>
84}
85
86// -----
87
88// Check incompatible result type with known dimension
89func @broadcast_tensor_tensor_tensor(tensor<8x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x6x1xi32> {
90^bb0(%arg0: tensor<8x1x6x1xi32>, %arg1: tensor<7x1x5xi32>):
91  // expected-error @+1 {{op result type '8x7x6x1' not broadcast compatible with broadcasted operands's shapes '8x7x6x5'}}
92  %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<8x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x6x1xi32>
93  return %0 : tensor<8x7x6x1xi32>
94}
95
96// -----
97
98func @broadcast_tensor_tensor_tensor(tensor<2xi32>, tensor<2xi32>) -> tensor<*xi32> {
99^bb0(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>):
100  %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<*xi32>
101  return %0 : tensor<*xi32>
102}
103
104// -----
105
106func @broadcast_tensor_tensor_tensor(tensor<4x3x2xi32>, tensor<?xi32>) -> tensor<4x3x2xi32> {
107^bb0(%arg0: tensor<4x3x2xi32>, %arg1: tensor<?xi32>):
108  %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<4x3x2xi32>, tensor<?xi32>) -> tensor<4x3x2xi32>
109  return %0 : tensor<4x3x2xi32>
110}
111
112// -----
113
114// Unranked operands but ranked result
115func @broadcast_tensor_tensor_tensor(tensor<*xi32>, tensor<*xi32>) -> tensor<2xi32> {
116^bb0(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>):
117  %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<2xi32>
118  return %0 : tensor<2xi32>
119}
120
121// -----
122
123// Unranked operand and compatible ranked result
124func @broadcast_tensor_tensor_tensor(tensor<3x2xi32>, tensor<*xi32>) -> tensor<4x3x2xi32> {
125^bb0(%arg0: tensor<3x2xi32>, %arg1: tensor<*xi32>):
126  %0 = "test.broadcastable"(%arg0, %arg0, %arg1) : (tensor<3x2xi32>, tensor<3x2xi32>, tensor<*xi32>) -> tensor<4x3x2xi32>
127  return %0 : tensor<4x3x2xi32>
128}
129
130// -----
131
132func @broadcast_tensor_tensor_tensor(tensor<3x2xi32>, tensor<*xi32>) -> tensor<2xi32> {
133^bb0(%arg0: tensor<3x2xi32>, %arg1: tensor<*xi32>):
134  // expected-error @+1 {{op result type '2' not broadcast compatible with broadcasted operands's shapes '3x2'}}
135  %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<3x2xi32>, tensor<*xi32>) -> tensor<2xi32>
136  return %0 : tensor<2xi32>
137}
138
139// -----
140
141func @broadcast_tensor_tensor_tensor(tensor<?x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x6x5xi32> {
142^bb0(%arg0: tensor<?x1x6x1xi32>, %arg1: tensor<7x1x5xi32>):
143  %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<?x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x6x5xi32>
144  return %0 : tensor<8x7x6x5xi32>
145}
146
147// -----
148
149func @broadcastDifferentResultType(tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> {
150^bb0(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>):
151  %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
152  return %0 : tensor<4xi1>
153}
154