1// RUN: mlir-hlo-opt %s -pass-pipeline='func(mhlo-test-optimize)' | FileCheck %s
2
3// CHECK-LABEL: @gather_is_slice_no_rank
4func @gather_is_slice_no_rank(%arg0: tensor<2x1x2xi32>, %arg1: tensor<i64>) -> tensor<1x2xi32> {
5  // CHECK: [[CST:%.+]] = mhlo.constant dense<0> : tensor<i64>
6  // CHECK: [[SLICE:%.+]] = "mhlo.dynamic-slice"(%arg0, %arg1, [[CST]], [[CST]]) {slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>}
7  // CHECK: [[RESHAPE:%.+]] = "mhlo.reshape"([[SLICE]])
8   %res = "mhlo.gather"(%arg0, %arg1) {
9    dimension_numbers = {
10      collapsed_slice_dims = dense<0> : tensor<1xi64>,
11      index_vector_dim = 0 : i64,
12      offset_dims = dense<[0, 1]> : tensor<2xi64>,
13      start_index_map = dense<0> : tensor<1xi64>
14    },
15    slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>
16  } : (tensor<2x1x2xi32>, tensor<i64>) -> tensor<1x2xi32>
17
18  // CHECK: return [[RESHAPE]]
19  return %res : tensor<1x2xi32>
20}
21
22// CHECK-LABEL: @gather_is_slice
23func @gather_is_slice(%arg0: tensor<2x1x2xi32>, %arg1: tensor<1xi64>) -> tensor<1x2xi32> {
24   // CHECK: [[CST:%.+]] = mhlo.constant dense<0> : tensor<i64>
25   // CHECK: [[RESHAPE:%.+]] = "mhlo.reshape"(%arg1)
26   // CHECK: [[SLICE:%.+]] = "mhlo.dynamic-slice"(%arg0, [[RESHAPE]], [[CST]], [[CST]]) {slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>}
27   // CHECK: [[RES:%.+]] = "mhlo.reshape"([[SLICE]])
28
29   %res = "mhlo.gather"(%arg0, %arg1) {
30    dimension_numbers = {
31      collapsed_slice_dims = dense<0> : tensor<1xi64>,
32      index_vector_dim = 0 : i64,
33      offset_dims = dense<[0, 1]> : tensor<2xi64>,
34      start_index_map = dense<0> : tensor<1xi64>
35    },
36    slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>
37  } : (tensor<2x1x2xi32>, tensor<1xi64>) -> tensor<1x2xi32>
38
39  // CHECK: return [[RES]]
40  return %res : tensor<1x2xi32>
41}
42
43// CHECK-LABEL: @gather_is_slice_multiple_start_indices
44func @gather_is_slice_multiple_start_indices(%arg0: tensor<2x1x2xi32>, %arg1: tensor<2xi64>) -> tensor<1x2xi32> {
45  // CHECK-DAG: [[CST:%.+]] = mhlo.constant dense<0>
46  // CHECK-DAG: [[SLICE1:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
47  // CHECK-DAG: [[RESHAPE1:%.+]] = "mhlo.reshape"([[SLICE1]])
48  // CHECK-DAG: [[SLICE2:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
49  // CHECK-DAG: [[RESHAPE2:%.+]] = "mhlo.reshape"([[SLICE2]])
50  // CHECK-DAG: [[DSLICE:%.+]] = "mhlo.dynamic-slice"(%arg0, [[RESHAPE1]], [[RESHAPE2]], [[CST]]) {slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>}
51  // CHECK-DAG: [[RES:%.+]] = "mhlo.reshape"([[DSLICE]])
52   %res = "mhlo.gather"(%arg0, %arg1) {
53    dimension_numbers = {
54      collapsed_slice_dims = dense<0> : tensor<1xi64>,
55      index_vector_dim = 0 : i64,
56      offset_dims = dense<[0, 1]> : tensor<2xi64>,
57      start_index_map = dense<0> : tensor<1xi64>
58    },
59    slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>
60  } : (tensor<2x1x2xi32>, tensor<2xi64>) -> tensor<1x2xi32>
61
62  // CHECK: return [[RES]]
63  return %res : tensor<1x2xi32>
64}
65