1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 
18 #include "absl/types/optional.h"
19 #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
20 #include "tensorflow/compiler/tf2xla/shape_util.h"
21 #include "tensorflow/compiler/tf2xla/type_util.h"
22 #include "tensorflow/compiler/tf2xla/xla_context.h"
23 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
24 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
25 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
26 #include "tensorflow/compiler/xla/client/lib/slicing.h"
27 #include "tensorflow/compiler/xla/client/xla_builder.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/core/framework/kernel_def_builder.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 
33 namespace tensorflow {
34 
XlaGather(const xla::XlaOp & input,const TensorShape & input_shape,const xla::XlaOp & indices,const TensorShape & indices_shape,int64 axis,bool indices_are_nd,DataType dtype,DataType index_type,xla::XlaBuilder * builder,xla::XlaOp * gather_output)35 Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape,
36                  const xla::XlaOp& indices, const TensorShape& indices_shape,
37                  int64 axis, bool indices_are_nd, DataType dtype,
38                  DataType index_type, xla::XlaBuilder* builder,
39                  xla::XlaOp* gather_output) {
40   // There is no deep reason why we need this precondition, but this is the only
41   // combination that is used and tested today.
42   CHECK(!indices_are_nd || axis == 0);
43 
44   // num_index_dims is the number of components in each index in the indices
45   // tensor.
46   //
47   // num_indices is the total number of (n dimensional or scalar) indices in the
48   // indices tensor.
49   //
50   // If the indices are N-dimensional, then the minor dimension of indices
51   // should be of size N and correspond to the N indices.
52   int64 num_index_dims;
53   int64 num_indices = 1;
54   if (indices_are_nd) {
55     CHECK_GE(indices_shape.dims(), 1);
56     num_index_dims = indices_shape.dim_size(indices_shape.dims() - 1);
57     for (int64 i = 0, e = indices_shape.dims() - 1; i < e; i++) {
58       num_indices *= indices_shape.dim_size(i);
59     }
60   } else {
61     num_index_dims = 1;
62     for (int64 i = 0, e = indices_shape.dims(); i < e; i++) {
63       num_indices *= indices_shape.dim_size(i);
64     }
65   }
66 
67   // Degenerate case: empty indices.
68   if (num_indices == 0) {
69     TensorShape input_shape_pre_axis{input_shape};
70     input_shape_pre_axis.RemoveDimRange(axis, input_shape.dims());
71     TensorShape input_shape_post_axis{input_shape};
72     input_shape_post_axis.RemoveDimRange(0, axis + num_index_dims);
73 
74     TensorShape indices_shape_no_index_vectors{indices_shape};
75     if (indices_are_nd) {
76       indices_shape_no_index_vectors.RemoveLastDims(1);
77     }
78 
79     TensorShape out_shape;
80     out_shape.AppendShape(input_shape_pre_axis);
81     out_shape.AppendShape(indices_shape_no_index_vectors);
82     out_shape.AppendShape(input_shape_post_axis);
83 
84     *gather_output =
85         xla::Broadcast(XlaHelpers::Zero(builder, dtype), out_shape.dim_sizes());
86     return Status::OK();
87   }
88 
89   for (int64 i = 0; i < num_index_dims; ++i) {
90     if (input_shape.dim_size(axis + i) == 0) {
91       return errors::InvalidArgument("Gather dimension ", axis + i,
92                                      " is of size zero in tensor with shape ",
93                                      input_shape.DebugString());
94     }
95   }
96 
97   // Example of a 1-D gather with axis=1, pulling two [3,1] tensors out of a
98   // tensor of shape [3,3].
99   //
100   //  operand = s32[3,3] parameter(0)
101   //  indices = s32[2] parameter(1)
102   //  gather = s32[3,2] gather(operand, indices),
103   //       offset_dims={0},
104   //       collapsed_slice_dims={1},
105   //       start_index_map={1},
106   //       index_vector_dim=1,
107   //       slice_sizes={3, 1}
108   //
109   //
110   // Example of an N-D gather pulling out slices of shape [1,1,2] out of a
111   // tensor of shape [3,3,2].
112   //
113   //  operand = s32[3,3,2] parameter(0)
114   //  indices = s32[2,2] parameter(1)
115   //  gather = s32[2,2] gather(operand, indices),
116   //       offset_dims={1},
117   //       collapsed_slice_dims={0,1},
118   //       start_index_map={0,1},
119   //       index_vector_dim=0,
120   //       slice_sizes={1,1,2}
121 
122   xla::GatherDimensionNumbers dim_numbers;
123   std::vector<int64> slice_sizes;
124   slice_sizes.reserve(input_shape.dims());
125   for (int64 i = 0; i < input_shape.dims(); i++) {
126     int64 window_bound;
127     if (axis <= i && i < (axis + num_index_dims)) {
128       dim_numbers.add_collapsed_slice_dims(i);
129       window_bound = 1;
130     } else {
131       window_bound = input_shape.dim_size(i);
132     }
133 
134     slice_sizes.push_back(window_bound);
135 
136     if (i < axis) {
137       dim_numbers.add_offset_dims(i);
138     } else if (i >= (axis + num_index_dims)) {
139       int64 indices_rank =
140           indices_are_nd ? (indices_shape.dims() - 1) : indices_shape.dims();
141       dim_numbers.add_offset_dims(i + indices_rank - num_index_dims);
142     }
143   }
144 
145   dim_numbers.set_index_vector_dim(indices_are_nd ? (indices_shape.dims() - 1)
146                                                   : indices_shape.dims());
147   for (int64 i = axis; i < axis + num_index_dims; i++) {
148     dim_numbers.add_start_index_map(i);
149   }
150 
151   *gather_output = xla::Gather(input, indices, dim_numbers, slice_sizes);
152   return Status::OK();
153 }
154 
XlaGatherWithBatchDimsOpImpl(XlaOpKernelContext * context,const xla::XlaOp input,const TensorShape & input_shape,int batch_dims,xla::XlaOp * gather_output)155 Status XlaGatherWithBatchDimsOpImpl(XlaOpKernelContext* context,
156                                     const xla::XlaOp input,
157                                     const TensorShape& input_shape,
158                                     int batch_dims, xla::XlaOp* gather_output) {
159   auto indices = context->Input(1);
160   auto indices_shape = context->InputShape(1);
161 
162   absl::optional<int64> axis;
163   if (context->num_inputs() == 3) {
164     const TensorShape axis_shape = context->InputShape(2);
165     if (!TensorShapeUtils::IsScalar(axis_shape)) {
166       return errors::InvalidArgument("axis must be scalar");
167     }
168     DataType axis_type = context->input_type(2);
169     if (axis_type != DT_INT32 && axis_type != DT_INT64) {
170       return errors::InvalidArgument("axis must be int32 or int64");
171     }
172 
173     int64 axis_input;
174     TF_RETURN_IF_ERROR(context->ConstantInputAsIntScalar(2, &axis_input));
175 
176     const auto params_dims = input_shape.dims();
177     if (-params_dims > axis_input || axis_input >= params_dims) {
178       return errors::InvalidArgument("Expected axis in the range [",
179                                      -params_dims, ", ", params_dims,
180                                      "), but got ", axis_input);
181     }
182     if (axis_input < 0) {
183       axis_input += params_dims;
184     }
185     axis = axis_input;
186   }
187 
188   if (batch_dims != 0) {
189     if (batch_dims < 0) {
190       batch_dims = indices_shape.dims() + batch_dims;
191     }
192 
193     axis = axis.value_or(batch_dims);
194 
195     if (batch_dims < -indices_shape.dims() ||
196         batch_dims > indices_shape.dims()) {
197       return errors::InvalidArgument(
198           "Expected batch_dims in the range [", -indices_shape.dims(), ", ",
199           indices_shape.dims(), "], but got ", batch_dims);
200     }
201 
202     if (batch_dims >= input_shape.dims()) {
203       return errors::InvalidArgument("batch_dims (", batch_dims,
204                                      ") must be less than rank(input) (",
205                                      input_shape.dims(), ").");
206     }
207 
208     if (*axis < batch_dims) {
209       return errors::InvalidArgument("batch_dims (", batch_dims,
210                                      ") must be less than or equal to ",
211                                      "axis (", *axis, ").");
212     }
213   }
214 
215   axis = axis.value_or(0);
216   DataType index_type = context->input_type(1);
217   if (index_type != DT_INT32 && index_type != DT_INT64) {
218     return errors::InvalidArgument("indices must be int32 or int64");
219   }
220 
221   xla::XlaOp gather;
222   if (batch_dims > 0) {
223     *gather_output = xla::TorchIndexSelect(input, indices, *axis, batch_dims);
224   } else {
225     // XlaGather() manages degenerate cases, like empty-indices, which are
226     // error conditions and caught above if batch_dims is not 0.
227     TF_RETURN_IF_ERROR(
228         XlaGather(input, input_shape, indices, indices_shape, *axis,
229                   /*indices_are_nd=*/false, context->expected_output_dtype(0),
230                   index_type, context->builder(), gather_output));
231   }
232   return Status::OK();
233 }
234 class GatherOp : public XlaOpKernel {
235  public:
GatherOp(OpKernelConstruction * context)236   explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {
237     // Set batch_dims_ to 0 if the attribute does not exist.
238     if (context->HasAttr("batch_dims")) {
239       OP_REQUIRES_OK(context, context->GetAttr("batch_dims", &batch_dims_));
240     } else {
241       batch_dims_ = 0;
242     }
243   }
244 
Compile(XlaOpKernelContext * context)245   void Compile(XlaOpKernelContext* context) override {
246     auto input = context->Input(0);
247     auto input_shape = context->InputShape(0);
248 
249     xla::XlaOp gather;
250     OP_REQUIRES_OK(context,
251                    XlaGatherWithBatchDimsOpImpl(context, input, input_shape,
252                                                 batch_dims_, &gather));
253     context->SetOutput(0, gather);
254   }
255 
256  private:
257   TF_DISALLOW_COPY_AND_ASSIGN(GatherOp);
258 
259   // The number of batch dimensions, as passed in the batch_dims attribute.
260   // It must be less than or equal to rank(indices).
261   int32 batch_dims_ = 0;
262 };
263 
264 REGISTER_XLA_OP(Name("Gather"), GatherOp);
265 REGISTER_XLA_OP(Name("GatherV2").CompileTimeConstantInput("axis"), GatherOp);
266 
267 class GatherNdOp : public XlaOpKernel {
268  public:
GatherNdOp(OpKernelConstruction * context)269   explicit GatherNdOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
270 
Compile(XlaOpKernelContext * context)271   void Compile(XlaOpKernelContext* context) override {
272     DataType params_type = context->input_type(0);
273     DataType indices_type = context->input_type(1);
274 
275     TensorShape params_shape = context->InputShape(0);
276     TensorShape indices_shape = context->InputShape(1);
277     OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(params_shape),
278                 errors::InvalidArgument("params must be at least a vector"));
279     OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(indices_shape),
280                 errors::InvalidArgument("indices must be at least a vector"));
281     const int64 num_index_dims =
282         indices_shape.dim_size(indices_shape.dims() - 1);
283     OP_REQUIRES(
284         context, num_index_dims <= params_shape.dims(),
285         errors::InvalidArgument(
286             "index innermost dimension length must be <= params rank; saw: ",
287             indices_shape.dim_size(indices_shape.dims() - 1), " vs. ",
288             params_shape.dims()));
289 
290     xla::XlaBuilder* builder = context->builder();
291     auto params = context->Input(0);
292     auto indices = context->Input(1);
293     xla::XlaOp gather;
294     OP_REQUIRES_OK(context, XlaGather(params, params_shape, indices,
295                                       indices_shape, /*axis=*/0,
296                                       /*indices_are_nd=*/true, params_type,
297                                       indices_type, builder, &gather));
298     context->SetOutput(0, gather);
299   }
300 };
301 
302 REGISTER_XLA_OP(Name("GatherNd"), GatherNdOp);
303 
304 }  // namespace tensorflow
305