1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17
18 #include "absl/types/optional.h"
19 #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
20 #include "tensorflow/compiler/tf2xla/shape_util.h"
21 #include "tensorflow/compiler/tf2xla/type_util.h"
22 #include "tensorflow/compiler/tf2xla/xla_context.h"
23 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
24 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
25 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
26 #include "tensorflow/compiler/xla/client/lib/slicing.h"
27 #include "tensorflow/compiler/xla/client/xla_builder.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/core/framework/kernel_def_builder.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/lib/core/errors.h"
32
33 namespace tensorflow {
34
XlaGather(const xla::XlaOp & input,const TensorShape & input_shape,const xla::XlaOp & indices,const TensorShape & indices_shape,int64 axis,bool indices_are_nd,DataType dtype,DataType index_type,xla::XlaBuilder * builder,xla::XlaOp * gather_output)35 Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape,
36 const xla::XlaOp& indices, const TensorShape& indices_shape,
37 int64 axis, bool indices_are_nd, DataType dtype,
38 DataType index_type, xla::XlaBuilder* builder,
39 xla::XlaOp* gather_output) {
40 // There is no deep reason why we need this precondition, but this is the only
41 // combination that is used and tested today.
42 CHECK(!indices_are_nd || axis == 0);
43
44 // num_index_dims is the number of components in each index in the indices
45 // tensor.
46 //
47 // num_indices is the total number of (n dimensional or scalar) indices in the
48 // indices tensor.
49 //
50 // If the indices are N-dimensional, then the minor dimension of indices
51 // should be of size N and correspond to the N indices.
52 int64 num_index_dims;
53 int64 num_indices = 1;
54 if (indices_are_nd) {
55 CHECK_GE(indices_shape.dims(), 1);
56 num_index_dims = indices_shape.dim_size(indices_shape.dims() - 1);
57 for (int64 i = 0, e = indices_shape.dims() - 1; i < e; i++) {
58 num_indices *= indices_shape.dim_size(i);
59 }
60 } else {
61 num_index_dims = 1;
62 for (int64 i = 0, e = indices_shape.dims(); i < e; i++) {
63 num_indices *= indices_shape.dim_size(i);
64 }
65 }
66
67 // Degenerate case: empty indices.
68 if (num_indices == 0) {
69 TensorShape input_shape_pre_axis{input_shape};
70 input_shape_pre_axis.RemoveDimRange(axis, input_shape.dims());
71 TensorShape input_shape_post_axis{input_shape};
72 input_shape_post_axis.RemoveDimRange(0, axis + num_index_dims);
73
74 TensorShape indices_shape_no_index_vectors{indices_shape};
75 if (indices_are_nd) {
76 indices_shape_no_index_vectors.RemoveLastDims(1);
77 }
78
79 TensorShape out_shape;
80 out_shape.AppendShape(input_shape_pre_axis);
81 out_shape.AppendShape(indices_shape_no_index_vectors);
82 out_shape.AppendShape(input_shape_post_axis);
83
84 *gather_output =
85 xla::Broadcast(XlaHelpers::Zero(builder, dtype), out_shape.dim_sizes());
86 return Status::OK();
87 }
88
89 for (int64 i = 0; i < num_index_dims; ++i) {
90 if (input_shape.dim_size(axis + i) == 0) {
91 return errors::InvalidArgument("Gather dimension ", axis + i,
92 " is of size zero in tensor with shape ",
93 input_shape.DebugString());
94 }
95 }
96
97 // Example of a 1-D gather with axis=1, pulling two [3,1] tensors out of a
98 // tensor of shape [3,3].
99 //
100 // operand = s32[3,3] parameter(0)
101 // indices = s32[2] parameter(1)
102 // gather = s32[3,2] gather(operand, indices),
103 // offset_dims={0},
104 // collapsed_slice_dims={1},
105 // start_index_map={1},
106 // index_vector_dim=1,
107 // slice_sizes={3, 1}
108 //
109 //
110 // Example of an N-D gather pulling out slices of shape [1,1,2] out of a
111 // tensor of shape [3,3,2].
112 //
113 // operand = s32[3,3,2] parameter(0)
114 // indices = s32[2,2] parameter(1)
115 // gather = s32[2,2] gather(operand, indices),
116 // offset_dims={1},
117 // collapsed_slice_dims={0,1},
118 // start_index_map={0,1},
119 // index_vector_dim=0,
120 // slice_sizes={1,1,2}
121
122 xla::GatherDimensionNumbers dim_numbers;
123 std::vector<int64> slice_sizes;
124 slice_sizes.reserve(input_shape.dims());
125 for (int64 i = 0; i < input_shape.dims(); i++) {
126 int64 window_bound;
127 if (axis <= i && i < (axis + num_index_dims)) {
128 dim_numbers.add_collapsed_slice_dims(i);
129 window_bound = 1;
130 } else {
131 window_bound = input_shape.dim_size(i);
132 }
133
134 slice_sizes.push_back(window_bound);
135
136 if (i < axis) {
137 dim_numbers.add_offset_dims(i);
138 } else if (i >= (axis + num_index_dims)) {
139 int64 indices_rank =
140 indices_are_nd ? (indices_shape.dims() - 1) : indices_shape.dims();
141 dim_numbers.add_offset_dims(i + indices_rank - num_index_dims);
142 }
143 }
144
145 dim_numbers.set_index_vector_dim(indices_are_nd ? (indices_shape.dims() - 1)
146 : indices_shape.dims());
147 for (int64 i = axis; i < axis + num_index_dims; i++) {
148 dim_numbers.add_start_index_map(i);
149 }
150
151 *gather_output = xla::Gather(input, indices, dim_numbers, slice_sizes);
152 return Status::OK();
153 }
154
XlaGatherWithBatchDimsOpImpl(XlaOpKernelContext * context,const xla::XlaOp input,const TensorShape & input_shape,int batch_dims,xla::XlaOp * gather_output)155 Status XlaGatherWithBatchDimsOpImpl(XlaOpKernelContext* context,
156 const xla::XlaOp input,
157 const TensorShape& input_shape,
158 int batch_dims, xla::XlaOp* gather_output) {
159 auto indices = context->Input(1);
160 auto indices_shape = context->InputShape(1);
161
162 absl::optional<int64> axis;
163 if (context->num_inputs() == 3) {
164 const TensorShape axis_shape = context->InputShape(2);
165 if (!TensorShapeUtils::IsScalar(axis_shape)) {
166 return errors::InvalidArgument("axis must be scalar");
167 }
168 DataType axis_type = context->input_type(2);
169 if (axis_type != DT_INT32 && axis_type != DT_INT64) {
170 return errors::InvalidArgument("axis must be int32 or int64");
171 }
172
173 int64 axis_input;
174 TF_RETURN_IF_ERROR(context->ConstantInputAsIntScalar(2, &axis_input));
175
176 const auto params_dims = input_shape.dims();
177 if (-params_dims > axis_input || axis_input >= params_dims) {
178 return errors::InvalidArgument("Expected axis in the range [",
179 -params_dims, ", ", params_dims,
180 "), but got ", axis_input);
181 }
182 if (axis_input < 0) {
183 axis_input += params_dims;
184 }
185 axis = axis_input;
186 }
187
188 if (batch_dims != 0) {
189 if (batch_dims < 0) {
190 batch_dims = indices_shape.dims() + batch_dims;
191 }
192
193 axis = axis.value_or(batch_dims);
194
195 if (batch_dims < -indices_shape.dims() ||
196 batch_dims > indices_shape.dims()) {
197 return errors::InvalidArgument(
198 "Expected batch_dims in the range [", -indices_shape.dims(), ", ",
199 indices_shape.dims(), "], but got ", batch_dims);
200 }
201
202 if (batch_dims >= input_shape.dims()) {
203 return errors::InvalidArgument("batch_dims (", batch_dims,
204 ") must be less than rank(input) (",
205 input_shape.dims(), ").");
206 }
207
208 if (*axis < batch_dims) {
209 return errors::InvalidArgument("batch_dims (", batch_dims,
210 ") must be less than or equal to ",
211 "axis (", *axis, ").");
212 }
213 }
214
215 axis = axis.value_or(0);
216 DataType index_type = context->input_type(1);
217 if (index_type != DT_INT32 && index_type != DT_INT64) {
218 return errors::InvalidArgument("indices must be int32 or int64");
219 }
220
221 xla::XlaOp gather;
222 if (batch_dims > 0) {
223 *gather_output = xla::TorchIndexSelect(input, indices, *axis, batch_dims);
224 } else {
225 // XlaGather() manages degenerate cases, like empty-indices, which are
226 // error conditions and caught above if batch_dims is not 0.
227 TF_RETURN_IF_ERROR(
228 XlaGather(input, input_shape, indices, indices_shape, *axis,
229 /*indices_are_nd=*/false, context->expected_output_dtype(0),
230 index_type, context->builder(), gather_output));
231 }
232 return Status::OK();
233 }
234 class GatherOp : public XlaOpKernel {
235 public:
GatherOp(OpKernelConstruction * context)236 explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {
237 // Set batch_dims_ to 0 if the attribute does not exist.
238 if (context->HasAttr("batch_dims")) {
239 OP_REQUIRES_OK(context, context->GetAttr("batch_dims", &batch_dims_));
240 } else {
241 batch_dims_ = 0;
242 }
243 }
244
Compile(XlaOpKernelContext * context)245 void Compile(XlaOpKernelContext* context) override {
246 auto input = context->Input(0);
247 auto input_shape = context->InputShape(0);
248
249 xla::XlaOp gather;
250 OP_REQUIRES_OK(context,
251 XlaGatherWithBatchDimsOpImpl(context, input, input_shape,
252 batch_dims_, &gather));
253 context->SetOutput(0, gather);
254 }
255
256 private:
257 TF_DISALLOW_COPY_AND_ASSIGN(GatherOp);
258
259 // The number of batch dimensions, as passed in the batch_dims attribute.
260 // It must be less than or equal to rank(indices).
261 int32 batch_dims_ = 0;
262 };
263
264 REGISTER_XLA_OP(Name("Gather"), GatherOp);
265 REGISTER_XLA_OP(Name("GatherV2").CompileTimeConstantInput("axis"), GatherOp);
266
267 class GatherNdOp : public XlaOpKernel {
268 public:
GatherNdOp(OpKernelConstruction * context)269 explicit GatherNdOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
270
Compile(XlaOpKernelContext * context)271 void Compile(XlaOpKernelContext* context) override {
272 DataType params_type = context->input_type(0);
273 DataType indices_type = context->input_type(1);
274
275 TensorShape params_shape = context->InputShape(0);
276 TensorShape indices_shape = context->InputShape(1);
277 OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(params_shape),
278 errors::InvalidArgument("params must be at least a vector"));
279 OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(indices_shape),
280 errors::InvalidArgument("indices must be at least a vector"));
281 const int64 num_index_dims =
282 indices_shape.dim_size(indices_shape.dims() - 1);
283 OP_REQUIRES(
284 context, num_index_dims <= params_shape.dims(),
285 errors::InvalidArgument(
286 "index innermost dimension length must be <= params rank; saw: ",
287 indices_shape.dim_size(indices_shape.dims() - 1), " vs. ",
288 params_shape.dims()));
289
290 xla::XlaBuilder* builder = context->builder();
291 auto params = context->Input(0);
292 auto indices = context->Input(1);
293 xla::XlaOp gather;
294 OP_REQUIRES_OK(context, XlaGather(params, params_shape, indices,
295 indices_shape, /*axis=*/0,
296 /*indices_are_nd=*/true, params_type,
297 indices_type, builder, &gather));
298 context->SetOutput(0, gather);
299 }
300 };
301
302 REGISTER_XLA_OP(Name("GatherNd"), GatherNdOp);
303
304 } // namespace tensorflow
305