1// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting \
2// RUN: -buffer-deallocation -split-input-file -cse %s -o - \
3// RUN: | FILECHECK_OPTS="" FileCheck %s
4
5// CHECK-LABEL: func @attrs
6func @attrs_copy(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
7  %result = "mhlo.exponential"(%operand)
8      {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
9      : (tensor<2x2xf32>) -> tensor<2x2xf32>
10  // CHECK: "lmhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
11  return %result : tensor<2x2xf32>
12}
13
14// -----
15
16func @return_func(%arg0: tensor<4xf32>) -> tensor<4xf32> {
17  return %arg0 : tensor<4xf32>
18}
19//      CHECK: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
20// CHECK-NEXT: return %[[ARG0]]
21
22// -----
23
24// CHECK-LABEL: func @func_op_long
25func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
26  %1 = mhlo.maximum %arg0, %arg1 : tensor<4xf32>
27  %2 = mhlo.add %arg0, %1 : tensor<4xf32>
28  %3 = mhlo.minimum %arg0, %arg1 : tensor<4xf32>
29  %4 = mhlo.subtract %arg1, %3 : tensor<4xf32>
30  %5 = mhlo.multiply %2, %4 : tensor<4xf32>
31  return %5 : tensor<4xf32>
32}
33//       CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>) -> memref<4xf32>
34//  CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32>
35//  CHECK-NEXT: "lmhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
36//  CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32>
37//  CHECK-NEXT: "lmhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
38//  CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
39//  CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32>
40//  CHECK-NEXT: "lmhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
41//  CHECK-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32>
42//  CHECK-NEXT: "lmhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
43//  CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32>
44//  CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32>
45//  CHECK-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
46//  CHECK-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32>
47//  CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32>
48//  CHECK-NEXT: return %[[MUL_RESULT]] : memref<4xf32>
49
50// -----
51
52// CHECK-LABEL: func @fusion
53func @fusion(%multiplier: tensor<2x2xf32>, %summand_1: tensor<2x2xf32>,
54             %summand_2: tensor<2x2xf32>) -> tensor<2x2xf32> {
55  // CHECK: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}})
56  // CHECK-NEXT:  %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32>
57  %sum = "mhlo.add"(%summand_1, %summand_2)
58      : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
59  // CHECK-NEXT: "lmhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
60  // CHECK-NEXT:  %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32>
61  %result = "mhlo.multiply"(%sum, %multiplier)
62      : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
63  // CHECK-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
64  // CHECK-NEXT:  dealloc %[[ADD_RESULT]] : memref<2x2xf32>
65  // CHECK-NEXT:  return %[[MUL_RESULT]] : memref<2x2xf32>
66  return %result : tensor<2x2xf32>
67}
68
69// -----
70
71// CHECK-LABEL: func @copy
72func @copy(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
73  %result = "mhlo.copy"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
74  // TODO(herhut): An explicit copy should not be removed.
75  // TODO-CHECK: "lmhlo.copy"(%{{.*}}, %{{.*}})
76  return %result : tensor<2x2xf32>
77}
78
79// -----
80
81// CHECK-LABEL: func @exp
82func @exp(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
83  %result = "mhlo.exponential"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
84  // CHECK: "lmhlo.exponential"(%{{.*}}, %{{.*}})
85  return %result : tensor<2x2xf32>
86}
87
88// -----
89
90// CHECK-LABEL: func @expm1
91func @expm1(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
92  %result = "mhlo.exponential_minus_one"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
93  // CHECK: "lmhlo.exponential_minus_one"(%{{.*}}, %{{.*}})
94  return %result : tensor<2x2xf32>
95}
96
97// -----
98
99// CHECK-LABEL: func @log
100func @log(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
101  %result = "mhlo.log"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
102  // CHECK: "lmhlo.log"(%{{.*}}, %{{.*}})
103  return %result : tensor<2x2xf32>
104}
105
106// -----
107
108// CHECK-LABEL: func @select
109func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>,
110             %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
111  %result = "mhlo.select"(%pred, %lhs, %rhs)
112      : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
113  // CHECK: "lmhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}})
114  return %result : tensor<2x2xf32>
115}
116
117// -----
118
119// CHECK-LABEL: func @compare
120func @compare(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xi1> {
121  %result = "mhlo.compare"(%lhs, %rhs)
122      {comparison_direction = "EQ"}
123      : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
124  // CHECK: "lmhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"}
125  return %result : tensor<2x2xi1>
126}
127
128// -----
129
130// CHECK-LABEL: func @broadcast
131func @broadcast(%operand: tensor<5xf32>) -> tensor<10x5xf32> {
132  %result = "mhlo.broadcast_in_dim"(%operand)
133      {broadcast_dimensions = dense<1> : tensor<1xi64>}
134        : (tensor<5xf32>) -> tensor<10x5xf32>
135  // CHECK: "lmhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>}
136  return %result : tensor<10x5xf32>
137}
138
139// -----
140
141// CHECK: #[[MAP:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + d2 * s2)>
142
143// CHECK-LABEL: func @dyn_broadcast
144func @dyn_broadcast(%operand: tensor<?x?xf32>) -> tensor<?x?x?xf32> {
145  // CHECK-SAME: %[[OPERAND:.*]]: memref<?x?xf32>
146  %c1 = constant 1 : i64
147  %shape = tensor.from_elements %c1, %c1, %c1 : tensor<3xi64>
148  %result = "mhlo.dynamic_broadcast_in_dim"(%operand, %shape) {
149    broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
150  } : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
151  return %result : tensor<?x?x?xf32>
152}
153// CHECK: %[[SHAPE:.*]] = tensor.from_elements
154
155// CHECK: %[[C0:.*]] = constant 0 : index
156// CHECK: %[[C1:.*]] = constant 1 : index
157// CHECK: %[[OPER_DIM_1:.*]] = dim %[[OPERAND]], %[[C1]] : memref<?x?xf32>
158// CHECK: %[[OP_STRIDE_0:.*]] = muli %[[C1]], %[[OPER_DIM_1]] : index
159// CHECK: %[[OPER_DIM_0:.*]] = dim %[[OPERAND]], %[[C0]] : memref<?x?xf32>
160
161// CHECK: %[[EL0:.*]] = tensor.extract %[[SHAPE]]{{\[}}%[[C0]]] : tensor<3xi64>
162// CHECK: %[[SIZE_0:.*]] = index_cast %[[EL0]] : i64 to index
163// CHECK: %[[EL1:.*]] = tensor.extract %[[SHAPE]]{{\[}}%[[C1]]] : tensor<3xi64>
164
165// CHECK: %[[SIZE_1:.*]] = index_cast %[[EL1]] : i64 to index
166// CHECK: %[[EXPAND_1:.*]] = cmpi slt, %[[OPER_DIM_0]], %[[SIZE_1]] : index
167// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0]], %[[OP_STRIDE_0]] : index
168
169// CHECK: %[[C2:.*]] = constant 2 : index
170// CHECK: %[[EL2:.*]] = tensor.extract %[[SHAPE]]{{\[}}%[[C2]]] : tensor<3xi64>
171// CHECK: %[[SIZE_2:.*]] = index_cast %[[EL2]] : i64 to index
172// CHECK: %[[EXPAND_2:.*]] = cmpi slt, %[[OPER_DIM_1]], %[[SIZE_2]] : index
173// CHECK: %[[STRIDE_2:.*]] = select %[[EXPAND_2]], %[[C0]], %[[C1]] : index
174
175// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to offset: [0], sizes: {{\[}}%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: {{\[}}%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]] : memref<?x?xf32> to memref<?x?x?xf32, #map>
176
177// CHECK: %[[RESULT:.*]] = alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref<?x?x?xf32>
178
179// CHECK: "lmhlo.copy"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) : (memref<?x?x?xf32, #map>, memref<?x?x?xf32>) -> ()
180// CHECK: return %[[RESULT]] : memref<?x?x?xf32>
181
182// -----
183
184// CHECK-LABEL: func @complex
185func @complex(%real: tensor<2x2xf32>, %imag: tensor<2x2xf32>)
186    -> tensor<2x2xcomplex<f32>> {
187  %result = "mhlo.complex"(%real, %imag)
188      : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xcomplex<f32>>
189  // CHECK: "lmhlo.complex"(%{{.*}}, %{{.*}})
190  return %result : tensor<2x2xcomplex<f32>>
191}
192
193// -----
194
195// CHECK-LABEL: func @complex_dyn
196func @complex_dyn(%real: tensor<?xf32>, %imag: tensor<?xf32>)
197    -> tensor<?xcomplex<f32>> {
198  %result = "mhlo.complex"(%real, %imag)
199      : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xcomplex<f32>>
200  // CHECK: "lmhlo.complex"(%{{.*}}, %{{.*}})
201  return %result : tensor<?xcomplex<f32>>
202}
203
204// -----
205
206// CHECK-LABEL: func @real
207func @real(%operand: tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32> {
208  %result = "mhlo.real"(%operand)
209      : (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
210  // CHECK: "lmhlo.real"(%{{.*}}, %{{.*}})
211  return %result : tensor<2x2xf32>
212}
213
214// -----
215
216// CHECK-LABEL: func @real_dyn
217func @real_dyn(%operand: tensor<?xcomplex<f32>>) -> tensor<?xf32> {
218  %result = "mhlo.real"(%operand)
219      : (tensor<?xcomplex<f32>>) -> tensor<?xf32>
220  // CHECK: "lmhlo.real"(%{{.*}}, %{{.*}})
221  return %result : tensor<?xf32>
222}
223
224// -----
225
226// CHECK-LABEL: func @imag
227func @imag(%operand: tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32> {
228  %result = "mhlo.imag"(%operand)
229      : (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
230  // CHECK: "lmhlo.imag"(%{{.*}}, %{{.*}})
231  return %result : tensor<2x2xf32>
232}
233
234// -----
235
236// CHECK-LABEL: func @gather
237func @gather(%operand: tensor<13x7xf32>, %idxs: tensor<5xi32>)
238    -> tensor<5x7xf32> {
239  %result =
240    "mhlo.gather"(%operand, %idxs)
241      { dimension_numbers =
242        { collapsed_slice_dims = dense<0> : tensor<1xi64>
243        , index_vector_dim = 1 : i64
244        , offset_dims = dense<1> : tensor<1xi64>
245        , start_index_map = dense<0> : tensor<1xi64> }
246      , indices_are_sorted = false
247      , name = "gather.71"
248      , slice_sizes = dense<[1, 7]> : tensor<2xi64> }
249      : (tensor<13x7xf32>, tensor<5xi32>) -> tensor<5x7xf32>
250  // CHECK: "lmhlo.gather"(%{{.*}}, %{{.*}}, %{{.*}})
251  return %result : tensor<5x7xf32>
252}
253
254// -----
255
256// CHECK-LABEL: func @imag_dyn
257func @imag_dyn(%operand: tensor<?xcomplex<f32>>) -> tensor<?xf32> {
258  %result = "mhlo.imag"(%operand)
259      : (tensor<?xcomplex<f32>>) -> tensor<?xf32>
260  // CHECK: "lmhlo.imag"(%{{.*}}, %{{.*}})
261  return %result : tensor<?xf32>
262}
263
264// -----
265
266// CHECK-LABEL: func @iota
267// TODO(herhut): Dummy should not be required here.
268func @iota(%dummy: tensor<?xf32>) -> tensor<10xi32> {
269  %result = "mhlo.iota"()
270      {iota_dimension = 0 : i64} : () -> tensor<10xi32>
271  // CHECK: "lmhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64}
272  return %result : tensor<10xi32>
273}
274
275// -----
276
277// CHECK-LABEL: func @abs
278func @abs(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
279  %result = "mhlo.abs"(%operand)
280      : (tensor<2x2xf32>) -> tensor<2x2xf32>
281  // CHECK: "lmhlo.abs"(%{{.*}}, %{{.*}})
282  return %result : tensor<2x2xf32>
283}
284
285// -----
286
287// CHECK-LABEL: func @and
288func @and(%operand0: tensor<2x2xi32>, %operand1: tensor<2x2xi32>)
289    -> tensor<2x2xi32> {
290  %result = "mhlo.and"(%operand0, %operand1)
291      : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
292  // CHECK: "lmhlo.and"(%{{.*}}, %{{.*}}, %{{.*}})
293  return %result : tensor<2x2xi32>
294}
295
296// -----
297
298// CHECK-LABEL: func @ceil
299func @ceil(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
300  %result = "mhlo.ceil"(%operand)
301      : (tensor<2x2xf32>) -> tensor<2x2xf32>
302  // CHECK: "lmhlo.ceil"(%{{.*}}, %{{.*}})
303  return %result : tensor<2x2xf32>
304}
305
306// -----
307
308// CHECK-LABEL: func @convert
309func @convert(%operand: tensor<2x2xf32>) -> tensor<2x2xi32> {
310  %result = "mhlo.convert"(%operand)
311      : (tensor<2x2xf32>) -> tensor<2x2xi32>
312  // CHECK: "lmhlo.convert"(%{{.*}}, %{{.*}})
313  return %result : tensor<2x2xi32>
314}
315
316// -----
317
318// CHECK-LABEL: func @cos
319func @cos(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
320  %result = "mhlo.cosine"(%operand)
321      : (tensor<2x2xf32>) -> tensor<2x2xf32>
322  // CHECK: "lmhlo.cosine"(%{{.*}}, %{{.*}})
323  return %result : tensor<2x2xf32>
324}
325
326// -----
327
328// CHECK-LABEL: func @floor
329func @floor(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
330  %result = "mhlo.floor"(%operand)
331      : (tensor<2x2xf32>) -> tensor<2x2xf32>
332  // CHECK: "lmhlo.floor"(%{{.*}}, %{{.*}})
333  return %result : tensor<2x2xf32>
334}
335
336// -----
337
338// CHECK-LABEL: func @neg
339func @neg(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
340  %result = "mhlo.negate"(%operand)
341      : (tensor<2x2xf32>) -> tensor<2x2xf32>
342  // CHECK: "lmhlo.negate"(%{{.*}}, %{{.*}})
343  return %result : tensor<2x2xf32>
344}
345
346// -----
347
348// CHECK-LABEL: func @not
349func @not(%operand: tensor<2x2xi32>) -> tensor<2x2xi32> {
350  %result = "mhlo.not"(%operand)
351      : (tensor<2x2xi32>) -> tensor<2x2xi32>
352  // CHECK: "lmhlo.not"(%{{.*}}, %{{.*}})
353  return %result : tensor<2x2xi32>
354}
355
356// -----
357
358// CHECK-LABEL: func @or
359func @or(%operand0: tensor<2x2xi32>, %operand1: tensor<2x2xi32>)
360    -> tensor<2x2xi32> {
361  %result = "mhlo.or"(%operand0, %operand1)
362      : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
363  // CHECK: "lmhlo.or"(%{{.*}}, %{{.*}}, %{{.*}})
364  return %result : tensor<2x2xi32>
365}
366
367// -----
368
369// CHECK-LABEL: func @rsqrt
370func @rsqrt(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
371  %result = "mhlo.rsqrt"(%operand)
372      : (tensor<2x2xf32>) -> tensor<2x2xf32>
373  // CHECK: "lmhlo.rsqrt"(%{{.*}}, %{{.*}})
374  return %result : tensor<2x2xf32>
375}
376
377// -----
378
379// CHECK-LABEL: func @sign
380func @sign(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
381  %result = "mhlo.sign"(%operand)
382      : (tensor<2x2xf32>) -> tensor<2x2xf32>
383  // CHECK: "lmhlo.sign"(%{{.*}}, %{{.*}})
384  return %result : tensor<2x2xf32>
385}
386
387// -----
388
389// CHECK-LABEL: func @sqrt
390func @sqrt(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
391  %result = "mhlo.sqrt"(%operand)
392      : (tensor<2x2xf32>) -> tensor<2x2xf32>
393  // CHECK: "lmhlo.sqrt"(%{{.*}}, %{{.*}})
394  return %result : tensor<2x2xf32>
395}
396
397// -----
398
399// CHECK-LABEL: func @shift_left
400func @shift_left(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>)
401    -> tensor<2x2xi32> {
402  %result = "mhlo.shift_left"(%lhs, %rhs)
403      : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
404  // CHECK: "lmhlo.shift_left"(%{{.*}}, %{{.*}})
405  return %result : tensor<2x2xi32>
406}
407
408// -----
409
410// CHECK-LABEL: func @shift_right_arithmetic
411func @shift_right_arithmetic(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>)
412    -> tensor<2x2xi32> {
413  %result = "mhlo.shift_right_arithmetic"(%lhs, %rhs)
414      : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
415  // CHECK: "lmhlo.shift_right_arithmetic"(%{{.*}}, %{{.*}})
416  return %result : tensor<2x2xi32>
417}
418
419// -----
420
421// CHECK-LABEL: func @shift_right_logical
422func @shift_right_logical(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>)
423    -> tensor<2x2xi32> {
424  %result = "mhlo.shift_right_logical"(%lhs, %rhs)
425      : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
426  // CHECK: "lmhlo.shift_right_logical"(%{{.*}}, %{{.*}})
427  return %result : tensor<2x2xi32>
428}
429
430// -----
431
432// CHECK-LABEL: func @tanh
433func @tanh(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
434  %result = "mhlo.tanh"(%operand)
435      : (tensor<2x2xf32>) -> tensor<2x2xf32>
436  // CHECK: "lmhlo.tanh"(%{{.*}}, %{{.*}})
437  return %result : tensor<2x2xf32>
438}
439
440// -----
441
442// CHECK-LABEL: func @remainder
443func @remainder(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>)
444    -> tensor<2x2xf32> {
445  %result = "mhlo.remainder"(%lhs, %rhs)
446      : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
447  // CHECK: "lmhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}})
448  return %result : tensor<2x2xf32>
449}
450
451// -----
452
453// CHECK-LABEL: func @xor
454func @xor(%operand0: tensor<2x2xi32>, %operand1: tensor<2x2xi32>)
455    -> tensor<2x2xi32> {
456  %result = "mhlo.xor"(%operand0, %operand1)
457      : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
458  // CHECK: "lmhlo.xor"(%{{.*}}, %{{.*}})
459  return %result : tensor<2x2xi32>
460}
461
462// -----
463
464// Dynamic shape binary element-wise operation.
465// CHECK-LABEL: func @add_dyn
466func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) -> tensor<?x?xf32> {
467  %result = "mhlo.add"(%lhs, %rhs)
468      : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
469  // CHECK: %[[C0:.*]] = constant 0 : index
470  // CHECK: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
471  // CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
472  // CHECK: %[[C1:.*]] = constant 1 : index
473  // CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
474  // CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
475  // CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
476  // CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xi64>
477  // CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
478  // CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xi64>
479  // CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
480  // CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
481  // CHECK: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
482  return %result : tensor<?x?xf32>
483  // CHECK: return %[[RESULT]]
484}
485
486// -----
487
488// Dynamic shape unary element-wise operation.
489// CHECK-LABEL: func @tanh_dyn
490func @tanh_dyn(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
491  %result = "mhlo.tanh"(%arg0)
492      : (tensor<?x?xf32>) -> tensor<?x?xf32>
493  // CHECK: %[[C0:.*]] = constant 0 : index
494  // CHECK: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
495  // CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
496  // CHECK: %[[C1:.*]] = constant 1 : index
497  // CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
498  // CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
499  // CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
500  // CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xi64>
501  // CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
502  // CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xi64>
503  // CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
504  // CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
505  // CHECK: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
506  return %result : tensor<?x?xf32>
507  // CHECK: return %[[RESULT]]
508}
509
510// -----
511
512// CHECK-LABEL: func @dot
513func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
514// CHECK-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
515// CHECK-NEXT: %[[ALLOC:.*]] = alloc
516//      CHECK: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) {
517//        dot_dimension_numbers = {
518//          lhs_batching_dimensions = dense<> : tensor<0xi64>,
519//          lhs_contracting_dimensions = dense<1> : tensor<1xi64>,
520//          rhs_batching_dimensions = dense<> : tensor<0xi64>,
521//          rhs_contracting_dimensions = dense<0> : tensor<1xi64>}}
522//        : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
523  %dot = "mhlo.dot"(%arg0, %arg0)
524          : (tensor<1024x1024xf32>, tensor<1024x1024xf32>)
525              -> tensor<1024x1024xf32>
526// CHECK: return %[[ALLOC]]
527  return %dot : tensor<1024x1024xf32>
528}
529
530// -----
531
532// CHECK-LABEL: func @conv
533func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>)
534    -> tensor<3x5x5x4xf32> {
535  %c0 = constant 0 : index
536  // CHECK: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32>
537  // CHECK: "lmhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]])
538  // CHECK-SAME: padding = dense<[
539  // CHECK-SAME:                  [0, 1], [0, 1]]> : tensor<2x2xi64>
540  // CHECK-SAME: rhs_dilation = dense<[1, 2]>
541  // CHECK-SAME: window_strides = dense<[2, 1]>
542  %out = "mhlo.convolution"(%filter, %input) {
543    batch_group_count = 1 : i64,
544    dimension_numbers = {
545      input_batch_dimension = 0 : i64,
546      input_feature_dimension = 3 : i64,
547      input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
548      kernel_input_feature_dimension = 2 : i64,
549      kernel_output_feature_dimension = 3 : i64,
550      kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
551      output_batch_dimension = 0 : i64,
552      output_feature_dimension = 3 : i64,
553      output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
554    },
555    feature_group_count = 1 : i64,
556    padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
557    rhs_dilation = dense<[1, 2]> : tensor<2xi64>,
558    window_strides = dense<[2, 1]> : tensor<2xi64>
559  } : (tensor<2x2x3x4xf32>, tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32>
560  return %out : tensor<3x5x5x4xf32>
561}
562
563// -----
564
565// CHECK-LABEL: func @reduce
566func @reduce(%arg0: tensor<1x8xf32>, %arg1: tensor<f32>) -> tensor<1xf32> {
567  // CHECK: %[[OUT:.*]] = alloc() : memref<1xf32>
568  // CHECK:  "lmhlo.reduce"(%{{.+}}, %{{.+}}, %[[OUT]]) ( {
569  // CHECK:  ^bb0(%[[ARG1:.*]]: memref<f32>, %[[ARG2:.*]]: memref<f32>,
570  // CHECK-SAME:  %[[ARG3:.*]]: memref<f32>):
571  // CHECK:    %[[TMP:.*]] = alloc() : memref<f32>
572  // CHECK:    "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[TMP]])
573  // CHECK:    "lmhlo.copy"(%[[TMP]], %[[ARG3]])
574  // CHECK:    "lmhlo.terminator"() : () -> ()
575  // CHECK:  }) {dimensions = dense<1> : tensor<1xi64>}
576  // CHECK-SAME: : (memref<1x8xf32>, memref<f32>, memref<1xf32>) -> ()
577  %0 = "mhlo.reduce"(%arg0, %arg1) ( {
578  ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):  // no predecessors
579    %1 = mhlo.add %arg2, %arg3 : tensor<f32>
580    "mhlo.return"(%1) : (tensor<f32>) -> ()
581  }) {dimensions = dense<1> : tensor<1xi64>}
582      : (tensor<1x8xf32>, tensor<f32>) -> tensor<1xf32>
583  return %0 : tensor<1xf32>
584}
585
586// -----
587
588// CHECK-LABEL: func @transpose
589func @transpose(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
590  %result = "mhlo.transpose"(%operand) {permutation = dense<[1, 0]> : tensor<2xi64>}
591              : (tensor<2x2xf32>) -> tensor<2x2xf32>
592  // CHECK: "lmhlo.transpose"(%{{.*}}, %{{.*}}) {permutation = dense<[1, 0]> : tensor<2xi64>}
593  return %result : tensor<2x2xf32>
594}
595
596// -----
597
598// CHECK-LABEL: func @custom_call
599// CHECK-SAME:([[ARG0:%.*]]: memref<2x2xf32>, [[ARG1:%.*]]: memref<2x3xf32>)
600func @custom_call(%arg0: tensor<2x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<4x4xf16> {
601  // CHECK: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}) {backend_config = "", call_target_name = "foo", has_side_effect = false, operand_segment_sizes = dense<[2, 1]> : vector<2xi32>}
602  %result = "mhlo.custom_call"(%arg0, %arg1)
603              {backend_config = "", call_target_name = "foo", has_side_effect = false}
604              : (tensor<2x2xf32>, tensor<2x3xf32>) -> tensor<4x4xf16>
605  return %result : tensor<4x4xf16>
606}
607
608// -----
609
610// CHECK-LABEL: func @custom_call_multiout
611// CHECK-SAME:([[ARG0:%.*]]: memref<2x2xf32>, [[ARG1:%.*]]: memref<2x3xf32>)
612func @custom_call_multiout(%arg0: tensor<2x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<4x4xf16> {
613  // CHECK: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}, %{{.*}}) {backend_config = "", call_target_name = "foo", has_side_effect = false, operand_segment_sizes = dense<2> : vector<2xi32>}
614  %temp:2 = "mhlo.custom_call"(%arg0, %arg1)
615                   {backend_config = "", call_target_name = "foo", has_side_effect = false}
616                   : (tensor<2x2xf32>, tensor<2x3xf32>) -> (tensor<4x4xf16>, tensor<4x4xf16>)
617  %result = "mhlo.add"(%temp#0, %temp#1) : (tensor<4x4xf16>, tensor<4x4xf16>) -> tensor<4x4xf16>
618  return %result : tensor<4x4xf16>
619}
620
621// -----
622
623// CHECK-LABEL: func @isfinite
624func @isfinite(%arg0: tensor<2x2xf32>) -> tensor<2x2xi1> {
625  // CHECK: "lmhlo.is_finite"(%{{.*}}, %{{.*}})
626  %result = "mhlo.is_finite"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xi1>
627  return %result : tensor<2x2xi1>
628}
629
630// -----
631
632// Test that assuming ops propagate tensor types.
633// CHECK-LABEL: func @shape_assuming_tensor
634func @shape_assuming_tensor(%arg0: tensor<?xf16>) -> tensor<?xf16> {
635  %0 = mhlo.constant dense<0.000000e+00> : tensor<f16>
636  %1 = shape.const_witness true
637  // CHECK: shape.assuming %{{.*}} -> (memref<?xf16>)
638  %2 = shape.assuming %1 -> (tensor<?xf16>) {
639    %3 = shape.shape_of %arg0 : tensor<?xf16> -> tensor<?xindex>
640    %4 = tensor.cast %3 : tensor<?xindex> to tensor<1xindex>
641    %5 = "mhlo.dynamic_broadcast_in_dim"(%0, %4) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f16>, tensor<1xindex>) -> tensor<?xf16>
642    %6 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %4) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf16>, tensor<1xindex>) -> tensor<?xf16>
643    // CHECK: "lmhlo.maximum"(%{{.*}}, %{{.*}}, %{{.*}}) : (memref<?xf16>, memref<?xf16>, memref<?xf16>) -> ()
644    %7 = mhlo.maximum %5, %6 : tensor<?xf16>
645    // CHECK: shape.assuming_yield %{{.*}} : memref<?xf16>
646    shape.assuming_yield %7 : tensor<?xf16>
647  }
648  return %2 : tensor<?xf16>
649}
650
651
652