1// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting \ 2// RUN: -buffer-deallocation -split-input-file -cse %s -o - \ 3// RUN: | FILECHECK_OPTS="" FileCheck %s 4 5// CHECK-LABEL: func @attrs 6func @attrs_copy(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> { 7 %result = "mhlo.exponential"(%operand) 8 {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>} 9 : (tensor<2x2xf32>) -> tensor<2x2xf32> 10 // CHECK: "lmhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>} 11 return %result : tensor<2x2xf32> 12} 13 14// ----- 15 16func @return_func(%arg0: tensor<4xf32>) -> tensor<4xf32> { 17 return %arg0 : tensor<4xf32> 18} 19// CHECK: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]] 20// CHECK-NEXT: return %[[ARG0]] 21 22// ----- 23 24// CHECK-LABEL: func @func_op_long 25func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { 26 %1 = mhlo.maximum %arg0, %arg1 : tensor<4xf32> 27 %2 = mhlo.add %arg0, %1 : tensor<4xf32> 28 %3 = mhlo.minimum %arg0, %arg1 : tensor<4xf32> 29 %4 = mhlo.subtract %arg1, %3 : tensor<4xf32> 30 %5 = mhlo.multiply %2, %4 : tensor<4xf32> 31 return %5 : tensor<4xf32> 32} 33// CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>) -> memref<4xf32> 34// CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32> 35// CHECK-NEXT: "lmhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]]) 36// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32> 37// CHECK-NEXT: "lmhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]]) 38// CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32> 39// CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32> 40// CHECK-NEXT: "lmhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]]) 41// CHECK-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32> 42// CHECK-NEXT: "lmhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]]) 43// CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32> 44// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32> 45// CHECK-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]]) 46// CHECK-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32> 47// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32> 48// CHECK-NEXT: return %[[MUL_RESULT]] : memref<4xf32> 49 50// ----- 51 52// CHECK-LABEL: func @fusion 53func @fusion(%multiplier: tensor<2x2xf32>, %summand_1: tensor<2x2xf32>, 54 %summand_2: tensor<2x2xf32>) -> tensor<2x2xf32> { 55 // CHECK: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}}) 56 // CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32> 57 %sum = "mhlo.add"(%summand_1, %summand_2) 58 : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> 59 // CHECK-NEXT: "lmhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]]) 60 // CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32> 61 %result = "mhlo.multiply"(%sum, %multiplier) 62 : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> 63 // CHECK-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]]) 64 // CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32> 65 // CHECK-NEXT: return %[[MUL_RESULT]] : memref<2x2xf32> 66 return %result : tensor<2x2xf32> 67} 68 69// ----- 70 71// CHECK-LABEL: func @copy 72func @copy(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> { 73 %result = "mhlo.copy"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> 74 // TODO(herhut): An explicit copy should not be removed. 75 // TODO-CHECK: "lmhlo.copy"(%{{.*}}, %{{.*}}) 76 return %result : tensor<2x2xf32> 77} 78 79// ----- 80 81// CHECK-LABEL: func @exp 82func @exp(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> { 83 %result = "mhlo.exponential"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> 84 // CHECK: "lmhlo.exponential"(%{{.*}}, %{{.*}}) 85 return %result : tensor<2x2xf32> 86} 87 88// ----- 89 90// CHECK-LABEL: func @expm1 91func @expm1(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> { 92 %result = "mhlo.exponential_minus_one"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> 93 // CHECK: "lmhlo.exponential_minus_one"(%{{.*}}, %{{.*}}) 94 return %result : tensor<2x2xf32> 95} 96 97// ----- 98 99// CHECK-LABEL: func @log 100func @log(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> { 101 %result = "mhlo.log"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> 102 // CHECK: "lmhlo.log"(%{{.*}}, %{{.*}}) 103 return %result : tensor<2x2xf32> 104} 105 106// ----- 107 108// CHECK-LABEL: func @select 109func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>, 110 %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { 111 %result = "mhlo.select"(%pred, %lhs, %rhs) 112 : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> 113 // CHECK: "lmhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) 114 return %result : tensor<2x2xf32> 115} 116 117// ----- 118 119// CHECK-LABEL: func @compare 120func @compare(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xi1> { 121 %result = "mhlo.compare"(%lhs, %rhs) 122 {comparison_direction = "EQ"} 123 : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1> 124 // CHECK: "lmhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"} 125 return %result : tensor<2x2xi1> 126} 127 128// ----- 129 130// CHECK-LABEL: func @broadcast 131func @broadcast(%operand: tensor<5xf32>) -> tensor<10x5xf32> { 132 %result = "mhlo.broadcast_in_dim"(%operand) 133 {broadcast_dimensions = dense<1> : tensor<1xi64>} 134 : (tensor<5xf32>) -> tensor<10x5xf32> 135 // CHECK: "lmhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} 136 return %result : tensor<10x5xf32> 137} 138 139// ----- 140 141// CHECK: #[[MAP:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + d2 * s2)> 142 143// CHECK-LABEL: func @dyn_broadcast 144func @dyn_broadcast(%operand: tensor<?x?xf32>) -> tensor<?x?x?xf32> { 145 // CHECK-SAME: %[[OPERAND:.*]]: memref<?x?xf32> 146 %c1 = constant 1 : i64 147 %shape = tensor.from_elements %c1, %c1, %c1 : tensor<3xi64> 148 %result = "mhlo.dynamic_broadcast_in_dim"(%operand, %shape) { 149 broadcast_dimensions = dense<[1, 2]> : tensor<2xi64> 150 } : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32> 151 return %result : tensor<?x?x?xf32> 152} 153// CHECK: %[[SHAPE:.*]] = tensor.from_elements 154 155// CHECK: %[[C0:.*]] = constant 0 : index 156// CHECK: %[[C1:.*]] = constant 1 : index 157// CHECK: %[[OPER_DIM_1:.*]] = dim %[[OPERAND]], %[[C1]] : memref<?x?xf32> 158// CHECK: %[[OP_STRIDE_0:.*]] = muli %[[C1]], %[[OPER_DIM_1]] : index 159// CHECK: %[[OPER_DIM_0:.*]] = dim %[[OPERAND]], %[[C0]] : memref<?x?xf32> 160 161// CHECK: %[[EL0:.*]] = tensor.extract %[[SHAPE]]{{\[}}%[[C0]]] : tensor<3xi64> 162// CHECK: %[[SIZE_0:.*]] = index_cast %[[EL0]] : i64 to index 163// CHECK: %[[EL1:.*]] = tensor.extract %[[SHAPE]]{{\[}}%[[C1]]] : tensor<3xi64> 164 165// CHECK: %[[SIZE_1:.*]] = index_cast %[[EL1]] : i64 to index 166// CHECK: %[[EXPAND_1:.*]] = cmpi slt, %[[OPER_DIM_0]], %[[SIZE_1]] : index 167// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0]], %[[OP_STRIDE_0]] : index 168 169// CHECK: %[[C2:.*]] = constant 2 : index 170// CHECK: %[[EL2:.*]] = tensor.extract %[[SHAPE]]{{\[}}%[[C2]]] : tensor<3xi64> 171// CHECK: %[[SIZE_2:.*]] = index_cast %[[EL2]] : i64 to index 172// CHECK: %[[EXPAND_2:.*]] = cmpi slt, %[[OPER_DIM_1]], %[[SIZE_2]] : index 173// CHECK: %[[STRIDE_2:.*]] = select %[[EXPAND_2]], %[[C0]], %[[C1]] : index 174 175// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to offset: [0], sizes: {{\[}}%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: {{\[}}%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]] : memref<?x?xf32> to memref<?x?x?xf32, #map> 176 177// CHECK: %[[RESULT:.*]] = alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref<?x?x?xf32> 178 179// CHECK: "lmhlo.copy"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) : (memref<?x?x?xf32, #map>, memref<?x?x?xf32>) -> () 180// CHECK: return %[[RESULT]] : memref<?x?x?xf32> 181 182// ----- 183 184// CHECK-LABEL: func @complex 185func @complex(%real: tensor<2x2xf32>, %imag: tensor<2x2xf32>) 186 -> tensor<2x2xcomplex<f32>> { 187 %result = "mhlo.complex"(%real, %imag) 188 : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xcomplex<f32>> 189 // CHECK: "lmhlo.complex"(%{{.*}}, %{{.*}}) 190 return %result : tensor<2x2xcomplex<f32>> 191} 192 193// ----- 194 195// CHECK-LABEL: func @complex_dyn 196func @complex_dyn(%real: tensor<?xf32>, %imag: tensor<?xf32>) 197 -> tensor<?xcomplex<f32>> { 198 %result = "mhlo.complex"(%real, %imag) 199 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xcomplex<f32>> 200 // CHECK: "lmhlo.complex"(%{{.*}}, %{{.*}}) 201 return %result : tensor<?xcomplex<f32>> 202} 203 204// ----- 205 206// CHECK-LABEL: func @real 207func @real(%operand: tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32> { 208 %result = "mhlo.real"(%operand) 209 : (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32> 210 // CHECK: "lmhlo.real"(%{{.*}}, %{{.*}}) 211 return %result : tensor<2x2xf32> 212} 213 214// ----- 215 216// CHECK-LABEL: func @real_dyn 217func @real_dyn(%operand: tensor<?xcomplex<f32>>) -> tensor<?xf32> { 218 %result = "mhlo.real"(%operand) 219 : (tensor<?xcomplex<f32>>) -> tensor<?xf32> 220 // CHECK: "lmhlo.real"(%{{.*}}, %{{.*}}) 221 return %result : tensor<?xf32> 222} 223 224// ----- 225 226// CHECK-LABEL: func @imag 227func @imag(%operand: tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32> { 228 %result = "mhlo.imag"(%operand) 229 : (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32> 230 // CHECK: "lmhlo.imag"(%{{.*}}, %{{.*}}) 231 return %result : tensor<2x2xf32> 232} 233 234// ----- 235 236// CHECK-LABEL: func @gather 237func @gather(%operand: tensor<13x7xf32>, %idxs: tensor<5xi32>) 238 -> tensor<5x7xf32> { 239 %result = 240 "mhlo.gather"(%operand, %idxs) 241 { dimension_numbers = 242 { collapsed_slice_dims = dense<0> : tensor<1xi64> 243 , index_vector_dim = 1 : i64 244 , offset_dims = dense<1> : tensor<1xi64> 245 , start_index_map = dense<0> : tensor<1xi64> } 246 , indices_are_sorted = false 247 , name = "gather.71" 248 , slice_sizes = dense<[1, 7]> : tensor<2xi64> } 249 : (tensor<13x7xf32>, tensor<5xi32>) -> tensor<5x7xf32> 250 // CHECK: "lmhlo.gather"(%{{.*}}, %{{.*}}, %{{.*}}) 251 return %result : tensor<5x7xf32> 252} 253 254// ----- 255 256// CHECK-LABEL: func @imag_dyn 257func @imag_dyn(%operand: tensor<?xcomplex<f32>>) -> tensor<?xf32> { 258 %result = "mhlo.imag"(%operand) 259 : (tensor<?xcomplex<f32>>) -> tensor<?xf32> 260 // CHECK: "lmhlo.imag"(%{{.*}}, %{{.*}}) 261 return %result : tensor<?xf32> 262} 263 264// ----- 265 266// CHECK-LABEL: func @iota 267// TODO(herhut): Dummy should not be required here. 268func @iota(%dummy: tensor<?xf32>) -> tensor<10xi32> { 269 %result = "mhlo.iota"() 270 {iota_dimension = 0 : i64} : () -> tensor<10xi32> 271 // CHECK: "lmhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64} 272 return %result : tensor<10xi32> 273} 274 275// ----- 276 277// CHECK-LABEL: func @abs 278func @abs(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> { 279 %result = "mhlo.abs"(%operand) 280 : (tensor<2x2xf32>) -> tensor<2x2xf32> 281 // CHECK: "lmhlo.abs"(%{{.*}}, %{{.*}}) 282 return %result : tensor<2x2xf32> 283} 284 285// ----- 286 287// CHECK-LABEL: func @and 288func @and(%operand0: tensor<2x2xi32>, %operand1: tensor<2x2xi32>) 289 -> tensor<2x2xi32> { 290 %result = "mhlo.and"(%operand0, %operand1) 291 : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> 292 // CHECK: "lmhlo.and"(%{{.*}}, %{{.*}}, %{{.*}}) 293 return %result : tensor<2x2xi32> 294} 295 296// ----- 297 298// CHECK-LABEL: func @ceil 299func @ceil(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> { 300 %result = "mhlo.ceil"(%operand) 301 : (tensor<2x2xf32>) -> tensor<2x2xf32> 302 // CHECK: "lmhlo.ceil"(%{{.*}}, %{{.*}}) 303 return %result : tensor<2x2xf32> 304} 305 306// ----- 307 308// CHECK-LABEL: func @convert 309func @convert(%operand: tensor<2x2xf32>) -> tensor<2x2xi32> { 310 %result = "mhlo.convert"(%operand) 311 : (tensor<2x2xf32>) -> tensor<2x2xi32> 312 // CHECK: "lmhlo.convert"(%{{.*}}, %{{.*}}) 313 return %result : tensor<2x2xi32> 314} 315 316// ----- 317 318// CHECK-LABEL: func @cos 319func @cos(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> { 320 %result = "mhlo.cosine"(%operand) 321 : (tensor<2x2xf32>) -> tensor<2x2xf32> 322 // CHECK: "lmhlo.cosine"(%{{.*}}, %{{.*}}) 323 return %result : tensor<2x2xf32> 324} 325 326// ----- 327 328// CHECK-LABEL: func @floor 329func @floor(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> { 330 %result = "mhlo.floor"(%operand) 331 : (tensor<2x2xf32>) -> tensor<2x2xf32> 332 // CHECK: "lmhlo.floor"(%{{.*}}, %{{.*}}) 333 return %result : tensor<2x2xf32> 334} 335 336// ----- 337 338// CHECK-LABEL: func @neg 339func @neg(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> { 340 %result = "mhlo.negate"(%operand) 341 : (tensor<2x2xf32>) -> tensor<2x2xf32> 342 // CHECK: "lmhlo.negate"(%{{.*}}, %{{.*}}) 343 return %result : tensor<2x2xf32> 344} 345 346// ----- 347 348// CHECK-LABEL: func @not 349func @not(%operand: tensor<2x2xi32>) -> tensor<2x2xi32> { 350 %result = "mhlo.not"(%operand) 351 : (tensor<2x2xi32>) -> tensor<2x2xi32> 352 // CHECK: "lmhlo.not"(%{{.*}}, %{{.*}}) 353 return %result : tensor<2x2xi32> 354} 355 356// ----- 357 358// CHECK-LABEL: func @or 359func @or(%operand0: tensor<2x2xi32>, %operand1: tensor<2x2xi32>) 360 -> tensor<2x2xi32> { 361 %result = "mhlo.or"(%operand0, %operand1) 362 : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> 363 // CHECK: "lmhlo.or"(%{{.*}}, %{{.*}}, %{{.*}}) 364 return %result : tensor<2x2xi32> 365} 366 367// ----- 368 369// CHECK-LABEL: func @rsqrt 370func @rsqrt(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> { 371 %result = "mhlo.rsqrt"(%operand) 372 : (tensor<2x2xf32>) -> tensor<2x2xf32> 373 // CHECK: "lmhlo.rsqrt"(%{{.*}}, %{{.*}}) 374 return %result : tensor<2x2xf32> 375} 376 377// ----- 378 379// CHECK-LABEL: func @sign 380func @sign(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> { 381 %result = "mhlo.sign"(%operand) 382 : (tensor<2x2xf32>) -> tensor<2x2xf32> 383 // CHECK: "lmhlo.sign"(%{{.*}}, %{{.*}}) 384 return %result : tensor<2x2xf32> 385} 386 387// ----- 388 389// CHECK-LABEL: func @sqrt 390func @sqrt(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> { 391 %result = "mhlo.sqrt"(%operand) 392 : (tensor<2x2xf32>) -> tensor<2x2xf32> 393 // CHECK: "lmhlo.sqrt"(%{{.*}}, %{{.*}}) 394 return %result : tensor<2x2xf32> 395} 396 397// ----- 398 399// CHECK-LABEL: func @shift_left 400func @shift_left(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) 401 -> tensor<2x2xi32> { 402 %result = "mhlo.shift_left"(%lhs, %rhs) 403 : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> 404 // CHECK: "lmhlo.shift_left"(%{{.*}}, %{{.*}}) 405 return %result : tensor<2x2xi32> 406} 407 408// ----- 409 410// CHECK-LABEL: func @shift_right_arithmetic 411func @shift_right_arithmetic(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) 412 -> tensor<2x2xi32> { 413 %result = "mhlo.shift_right_arithmetic"(%lhs, %rhs) 414 : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> 415 // CHECK: "lmhlo.shift_right_arithmetic"(%{{.*}}, %{{.*}}) 416 return %result : tensor<2x2xi32> 417} 418 419// ----- 420 421// CHECK-LABEL: func @shift_right_logical 422func @shift_right_logical(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) 423 -> tensor<2x2xi32> { 424 %result = "mhlo.shift_right_logical"(%lhs, %rhs) 425 : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> 426 // CHECK: "lmhlo.shift_right_logical"(%{{.*}}, %{{.*}}) 427 return %result : tensor<2x2xi32> 428} 429 430// ----- 431 432// CHECK-LABEL: func @tanh 433func @tanh(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> { 434 %result = "mhlo.tanh"(%operand) 435 : (tensor<2x2xf32>) -> tensor<2x2xf32> 436 // CHECK: "lmhlo.tanh"(%{{.*}}, %{{.*}}) 437 return %result : tensor<2x2xf32> 438} 439 440// ----- 441 442// CHECK-LABEL: func @remainder 443func @remainder(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) 444 -> tensor<2x2xf32> { 445 %result = "mhlo.remainder"(%lhs, %rhs) 446 : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> 447 // CHECK: "lmhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}}) 448 return %result : tensor<2x2xf32> 449} 450 451// ----- 452 453// CHECK-LABEL: func @xor 454func @xor(%operand0: tensor<2x2xi32>, %operand1: tensor<2x2xi32>) 455 -> tensor<2x2xi32> { 456 %result = "mhlo.xor"(%operand0, %operand1) 457 : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> 458 // CHECK: "lmhlo.xor"(%{{.*}}, %{{.*}}) 459 return %result : tensor<2x2xi32> 460} 461 462// ----- 463 464// Dynamic shape binary element-wise operation. 465// CHECK-LABEL: func @add_dyn 466func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) -> tensor<?x?xf32> { 467 %result = "mhlo.add"(%lhs, %rhs) 468 : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> 469 // CHECK: %[[C0:.*]] = constant 0 : index 470 // CHECK: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32> 471 // CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64 472 // CHECK: %[[C1:.*]] = constant 1 : index 473 // CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32> 474 // CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64 475 // CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[IC0]], %[[IC1]] : tensor<2xi64> 476 // CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xi64> 477 // CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index 478 // CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xi64> 479 // CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index 480 // CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]]) 481 // CHECK: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () 482 return %result : tensor<?x?xf32> 483 // CHECK: return %[[RESULT]] 484} 485 486// ----- 487 488// Dynamic shape unary element-wise operation. 489// CHECK-LABEL: func @tanh_dyn 490func @tanh_dyn(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> { 491 %result = "mhlo.tanh"(%arg0) 492 : (tensor<?x?xf32>) -> tensor<?x?xf32> 493 // CHECK: %[[C0:.*]] = constant 0 : index 494 // CHECK: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32> 495 // CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64 496 // CHECK: %[[C1:.*]] = constant 1 : index 497 // CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32> 498 // CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64 499 // CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[IC0]], %[[IC1]] : tensor<2xi64> 500 // CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xi64> 501 // CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index 502 // CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xi64> 503 // CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index 504 // CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]]) 505 // CHECK: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> () 506 return %result : tensor<?x?xf32> 507 // CHECK: return %[[RESULT]] 508} 509 510// ----- 511 512// CHECK-LABEL: func @dot 513func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> { 514// CHECK-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]] 515// CHECK-NEXT: %[[ALLOC:.*]] = alloc 516// CHECK: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) { 517// dot_dimension_numbers = { 518// lhs_batching_dimensions = dense<> : tensor<0xi64>, 519// lhs_contracting_dimensions = dense<1> : tensor<1xi64>, 520// rhs_batching_dimensions = dense<> : tensor<0xi64>, 521// rhs_contracting_dimensions = dense<0> : tensor<1xi64>}} 522// : ([[TYPE]], [[TYPE]], [[TYPE]]) -> () 523 %dot = "mhlo.dot"(%arg0, %arg0) 524 : (tensor<1024x1024xf32>, tensor<1024x1024xf32>) 525 -> tensor<1024x1024xf32> 526// CHECK: return %[[ALLOC]] 527 return %dot : tensor<1024x1024xf32> 528} 529 530// ----- 531 532// CHECK-LABEL: func @conv 533func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) 534 -> tensor<3x5x5x4xf32> { 535 %c0 = constant 0 : index 536 // CHECK: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32> 537 // CHECK: "lmhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]]) 538 // CHECK-SAME: padding = dense<[ 539 // CHECK-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64> 540 // CHECK-SAME: rhs_dilation = dense<[1, 2]> 541 // CHECK-SAME: window_strides = dense<[2, 1]> 542 %out = "mhlo.convolution"(%filter, %input) { 543 batch_group_count = 1 : i64, 544 dimension_numbers = { 545 input_batch_dimension = 0 : i64, 546 input_feature_dimension = 3 : i64, 547 input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, 548 kernel_input_feature_dimension = 2 : i64, 549 kernel_output_feature_dimension = 3 : i64, 550 kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, 551 output_batch_dimension = 0 : i64, 552 output_feature_dimension = 3 : i64, 553 output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64> 554 }, 555 feature_group_count = 1 : i64, 556 padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, 557 rhs_dilation = dense<[1, 2]> : tensor<2xi64>, 558 window_strides = dense<[2, 1]> : tensor<2xi64> 559 } : (tensor<2x2x3x4xf32>, tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32> 560 return %out : tensor<3x5x5x4xf32> 561} 562 563// ----- 564 565// CHECK-LABEL: func @reduce 566func @reduce(%arg0: tensor<1x8xf32>, %arg1: tensor<f32>) -> tensor<1xf32> { 567 // CHECK: %[[OUT:.*]] = alloc() : memref<1xf32> 568 // CHECK: "lmhlo.reduce"(%{{.+}}, %{{.+}}, %[[OUT]]) ( { 569 // CHECK: ^bb0(%[[ARG1:.*]]: memref<f32>, %[[ARG2:.*]]: memref<f32>, 570 // CHECK-SAME: %[[ARG3:.*]]: memref<f32>): 571 // CHECK: %[[TMP:.*]] = alloc() : memref<f32> 572 // CHECK: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[TMP]]) 573 // CHECK: "lmhlo.copy"(%[[TMP]], %[[ARG3]]) 574 // CHECK: "lmhlo.terminator"() : () -> () 575 // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} 576 // CHECK-SAME: : (memref<1x8xf32>, memref<f32>, memref<1xf32>) -> () 577 %0 = "mhlo.reduce"(%arg0, %arg1) ( { 578 ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>): // no predecessors 579 %1 = mhlo.add %arg2, %arg3 : tensor<f32> 580 "mhlo.return"(%1) : (tensor<f32>) -> () 581 }) {dimensions = dense<1> : tensor<1xi64>} 582 : (tensor<1x8xf32>, tensor<f32>) -> tensor<1xf32> 583 return %0 : tensor<1xf32> 584} 585 586// ----- 587 588// CHECK-LABEL: func @transpose 589func @transpose(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> { 590 %result = "mhlo.transpose"(%operand) {permutation = dense<[1, 0]> : tensor<2xi64>} 591 : (tensor<2x2xf32>) -> tensor<2x2xf32> 592 // CHECK: "lmhlo.transpose"(%{{.*}}, %{{.*}}) {permutation = dense<[1, 0]> : tensor<2xi64>} 593 return %result : tensor<2x2xf32> 594} 595 596// ----- 597 598// CHECK-LABEL: func @custom_call 599// CHECK-SAME:([[ARG0:%.*]]: memref<2x2xf32>, [[ARG1:%.*]]: memref<2x3xf32>) 600func @custom_call(%arg0: tensor<2x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<4x4xf16> { 601 // CHECK: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}) {backend_config = "", call_target_name = "foo", has_side_effect = false, operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} 602 %result = "mhlo.custom_call"(%arg0, %arg1) 603 {backend_config = "", call_target_name = "foo", has_side_effect = false} 604 : (tensor<2x2xf32>, tensor<2x3xf32>) -> tensor<4x4xf16> 605 return %result : tensor<4x4xf16> 606} 607 608// ----- 609 610// CHECK-LABEL: func @custom_call_multiout 611// CHECK-SAME:([[ARG0:%.*]]: memref<2x2xf32>, [[ARG1:%.*]]: memref<2x3xf32>) 612func @custom_call_multiout(%arg0: tensor<2x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<4x4xf16> { 613 // CHECK: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}, %{{.*}}) {backend_config = "", call_target_name = "foo", has_side_effect = false, operand_segment_sizes = dense<2> : vector<2xi32>} 614 %temp:2 = "mhlo.custom_call"(%arg0, %arg1) 615 {backend_config = "", call_target_name = "foo", has_side_effect = false} 616 : (tensor<2x2xf32>, tensor<2x3xf32>) -> (tensor<4x4xf16>, tensor<4x4xf16>) 617 %result = "mhlo.add"(%temp#0, %temp#1) : (tensor<4x4xf16>, tensor<4x4xf16>) -> tensor<4x4xf16> 618 return %result : tensor<4x4xf16> 619} 620 621// ----- 622 623// CHECK-LABEL: func @isfinite 624func @isfinite(%arg0: tensor<2x2xf32>) -> tensor<2x2xi1> { 625 // CHECK: "lmhlo.is_finite"(%{{.*}}, %{{.*}}) 626 %result = "mhlo.is_finite"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xi1> 627 return %result : tensor<2x2xi1> 628} 629 630// ----- 631 632// Test that assuming ops propagate tensor types. 633// CHECK-LABEL: func @shape_assuming_tensor 634func @shape_assuming_tensor(%arg0: tensor<?xf16>) -> tensor<?xf16> { 635 %0 = mhlo.constant dense<0.000000e+00> : tensor<f16> 636 %1 = shape.const_witness true 637 // CHECK: shape.assuming %{{.*}} -> (memref<?xf16>) 638 %2 = shape.assuming %1 -> (tensor<?xf16>) { 639 %3 = shape.shape_of %arg0 : tensor<?xf16> -> tensor<?xindex> 640 %4 = tensor.cast %3 : tensor<?xindex> to tensor<1xindex> 641 %5 = "mhlo.dynamic_broadcast_in_dim"(%0, %4) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f16>, tensor<1xindex>) -> tensor<?xf16> 642 %6 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %4) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf16>, tensor<1xindex>) -> tensor<?xf16> 643 // CHECK: "lmhlo.maximum"(%{{.*}}, %{{.*}}, %{{.*}}) : (memref<?xf16>, memref<?xf16>, memref<?xf16>) -> () 644 %7 = mhlo.maximum %5, %6 : tensor<?xf16> 645 // CHECK: shape.assuming_yield %{{.*}} : memref<?xf16> 646 shape.assuming_yield %7 : tensor<?xf16> 647 } 648 return %2 : tensor<?xf16> 649} 650 651 652