1// RUN: mlir-hlo-opt %s -pass-pipeline='func(mhlo-test-optimize)' | FileCheck %s 2 3// CHECK-LABEL: @gather_is_slice_no_rank 4func @gather_is_slice_no_rank(%arg0: tensor<2x1x2xi32>, %arg1: tensor<i64>) -> tensor<1x2xi32> { 5 // CHECK: [[CST:%.+]] = mhlo.constant dense<0> : tensor<i64> 6 // CHECK: [[SLICE:%.+]] = "mhlo.dynamic-slice"(%arg0, %arg1, [[CST]], [[CST]]) {slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>} 7 // CHECK: [[RESHAPE:%.+]] = "mhlo.reshape"([[SLICE]]) 8 %res = "mhlo.gather"(%arg0, %arg1) { 9 dimension_numbers = { 10 collapsed_slice_dims = dense<0> : tensor<1xi64>, 11 index_vector_dim = 0 : i64, 12 offset_dims = dense<[0, 1]> : tensor<2xi64>, 13 start_index_map = dense<0> : tensor<1xi64> 14 }, 15 slice_sizes = dense<[1, 1, 2]> : tensor<3xi64> 16 } : (tensor<2x1x2xi32>, tensor<i64>) -> tensor<1x2xi32> 17 18 // CHECK: return [[RESHAPE]] 19 return %res : tensor<1x2xi32> 20} 21 22// CHECK-LABEL: @gather_is_slice 23func @gather_is_slice(%arg0: tensor<2x1x2xi32>, %arg1: tensor<1xi64>) -> tensor<1x2xi32> { 24 // CHECK: [[CST:%.+]] = mhlo.constant dense<0> : tensor<i64> 25 // CHECK: [[RESHAPE:%.+]] = "mhlo.reshape"(%arg1) 26 // CHECK: [[SLICE:%.+]] = "mhlo.dynamic-slice"(%arg0, [[RESHAPE]], [[CST]], [[CST]]) {slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>} 27 // CHECK: [[RES:%.+]] = "mhlo.reshape"([[SLICE]]) 28 29 %res = "mhlo.gather"(%arg0, %arg1) { 30 dimension_numbers = { 31 collapsed_slice_dims = dense<0> : tensor<1xi64>, 32 index_vector_dim = 0 : i64, 33 offset_dims = dense<[0, 1]> : tensor<2xi64>, 34 start_index_map = dense<0> : tensor<1xi64> 35 }, 36 slice_sizes = dense<[1, 1, 2]> : tensor<3xi64> 37 } : (tensor<2x1x2xi32>, tensor<1xi64>) -> tensor<1x2xi32> 38 39 // CHECK: return [[RES]] 40 return %res : tensor<1x2xi32> 41} 42 43// CHECK-LABEL: @gather_is_slice_multiple_start_indices 44func @gather_is_slice_multiple_start_indices(%arg0: tensor<2x1x2xi32>, %arg1: tensor<2xi64>) -> tensor<1x2xi32> { 45 // CHECK-DAG: [[CST:%.+]] = mhlo.constant dense<0> 46 // CHECK-DAG: [[SLICE1:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} 47 // CHECK-DAG: [[RESHAPE1:%.+]] = "mhlo.reshape"([[SLICE1]]) 48 // CHECK-DAG: [[SLICE2:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} 49 // CHECK-DAG: [[RESHAPE2:%.+]] = "mhlo.reshape"([[SLICE2]]) 50 // CHECK-DAG: [[DSLICE:%.+]] = "mhlo.dynamic-slice"(%arg0, [[RESHAPE1]], [[RESHAPE2]], [[CST]]) {slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>} 51 // CHECK-DAG: [[RES:%.+]] = "mhlo.reshape"([[DSLICE]]) 52 %res = "mhlo.gather"(%arg0, %arg1) { 53 dimension_numbers = { 54 collapsed_slice_dims = dense<0> : tensor<1xi64>, 55 index_vector_dim = 0 : i64, 56 offset_dims = dense<[0, 1]> : tensor<2xi64>, 57 start_index_map = dense<0> : tensor<1xi64> 58 }, 59 slice_sizes = dense<[1, 1, 2]> : tensor<3xi64> 60 } : (tensor<2x1x2xi32>, tensor<2xi64>) -> tensor<1x2xi32> 61 62 // CHECK: return [[RES]] 63 return %res : tensor<1x2xi32> 64} 65