1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstddef>
17
18 #include "absl/algorithm/container.h"
19 #include "absl/strings/match.h"
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_split.h"
22 #include "tensorflow/compiler/xla/xla_data.pb.h"
23 #include "tensorflow/core/framework/common_shape_fns.h"
24 #include "tensorflow/core/framework/op.h"
25 #include "tensorflow/core/framework/shape_inference.h"
26 #include "tensorflow/core/lib/core/errors.h"
27
28 namespace tensorflow {
29 namespace {
30
31 // Helper shape function for operators that return an output with the same rank
32 // as their first input.
UnchangedRank(shape_inference::InferenceContext * c)33 Status UnchangedRank(shape_inference::InferenceContext* c) {
34 if (c->RankKnown(c->input(0))) {
35 c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0))));
36 } else {
37 c->set_output(0, c->input(0));
38 }
39 return Status::OK();
40 }
41
42 REGISTER_OP("XlaBroadcastHelper")
43 .Input("lhs: T")
44 .Input("rhs: T")
45 .Input("broadcast_dims: Tindices")
46 .Attr("T: numbertype")
47 .Attr("Tindices: {int32, int64}")
48 .Output("lhs_output: T")
49 .Output("rhs_output: T")
50 .SetShapeFn(shape_inference::UnknownShape)
51 .Doc(R"doc(
52 Helper operator for performing XLA-style broadcasts
53
54 Broadcasts `lhs` and `rhs` to the same rank, by adding size 1 dimensions to
55 whichever of `lhs` and `rhs` has the lower rank, using XLA's broadcasting rules
56 for binary operators.
57
58 lhs: the LHS input tensor
59 rhs: the RHS input tensor
60 broadcast_dims: an XLA-style broadcast dimension specification
61 lhs_output: the broadcasted LHS tensor
62 rhs_output: the broadcasted RHS tensor
63 )doc");
64
65 REGISTER_OP("XlaSelfAdjointEig")
66 .Input("a: T")
67 .Attr("lower: bool")
68 .Attr("max_iter: int")
69 .Attr("epsilon: float")
70 .Output("w: T")
71 .Output("v: T")
72 .SetShapeFn(shape_inference::UnknownShape)
73 .Attr("T: numbertype")
74 .Doc(R"doc(
75 Computes the eigen decomposition of a batch of self-adjoint matrices
76 (Note: Only real inputs are supported).
77
78 Computes the eigenvalues and eigenvectors of the innermost N-by-N matrices in
79 tensor such that tensor[...,:,:] * v[..., :,i] = e[..., i] * v[...,:,i], for
80 i=0...N-1.
81
82 a: the input tensor.
83
84 lower: a boolean specifies whether the calculation is done with the lower
85 triangular part or the upper triangular part.
86
87 max_iter: maximum number of sweep update, i.e., the whole lower triangular
88 part or upper triangular part based on parameter lower. Heuristically, it has
89 been argued that approximately logN sweeps are needed in practice (Ref: Golub &
90 van Loan "Matrix Computation").
91
92 epsilon: the tolerance ratio.
93
94 w: The eigenvalues in ascending order, each repeated according to its
95 multiplicity.
96 v: The column v[..., :, i] is the normalized eigenvector corresponding to the
97 eigenvalue w[..., i].
98 )doc");
99
100 REGISTER_OP("XlaSvd")
101 .Input("a: T")
102 .Attr("max_iter: int")
103 .Attr("epsilon: float")
104 .Attr("precision_config: string")
105 .Output("s: T")
106 .Output("u: T")
107 .Output("v: T")
108 .SetShapeFn(shape_inference::UnknownShape)
109 .Attr("T: numbertype")
110 .Doc(R"doc(
111 Computes the eigen decomposition of a batch of self-adjoint matrices
112 (Note: Only real inputs are supported).
113
114 Computes the eigenvalues and eigenvectors of the innermost M-by-N matrices in
115 tensor such that tensor[...,:,:] = u[..., :, :] * Diag(s[..., :]) * Transpose(v[...,:,:]).
116
117 a: the input tensor.
118
119 max_iter: maximum number of sweep update, i.e., the whole lower triangular
120 part or upper triangular part based on parameter lower. Heuristically, it has
121 been argued that approximately log(min (M, N)) sweeps are needed in practice
122 (Ref: Golub & van Loan "Matrix Computation").
123
124 epsilon: the tolerance ratio.
125
126 precision_config: a serialized xla::PrecisionConfig proto.
127
128 s: Singular values. The values are sorted in reverse order of magnitude, so
129 s[..., 0] is the largest value, s[..., 1] is the second largest, etc.
130 u: Left singular vectors.
131 v: Right singular vectors.
132 )doc");
133
134 REGISTER_OP("XlaConv")
135 .Input("lhs: T")
136 .Input("rhs: T")
137 .Input("window_strides: Tindices")
138 .Input("padding: Tindices")
139 .Input("lhs_dilation: Tindices")
140 .Input("rhs_dilation: Tindices")
141 .Input("feature_group_count: Tindices")
142 .Attr("T: numbertype")
143 .Attr("Tindices: {int32, int64}")
144 .Attr("dimension_numbers: string")
145 .Attr("precision_config: string")
146 .Output("output: T")
147 .SetShapeFn(UnchangedRank)
148 .Doc(R"doc(
149 Wraps the XLA ConvGeneralDilated operator, documented at
150 https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
151 .
152
153 lhs: the input tensor
154 rhs: the kernel tensor
155 window_strides: the inter-window strides
156 padding: the padding to apply at the start and end of each input dimensions
157 lhs_dilation: dilation to apply between input elements
158 rhs_dilation: dilation to apply between kernel elements
159 feature_group_count: number of feature groups for grouped convolution.
160 dimension_numbers: a serialized xla::ConvolutionDimensionNumbers proto.
161 precision_config: a serialized xla::PrecisionConfig proto.
162 )doc");
163
164 REGISTER_OP("XlaDot")
165 .Input("lhs: T")
166 .Input("rhs: T")
167 .Attr("T: numbertype")
168 .Attr("dimension_numbers: string")
169 .Attr("precision_config: string")
170 .Output("output: T")
__anon3500048a0202(shape_inference::InferenceContext* c) 171 .SetShapeFn([](shape_inference::InferenceContext* c) {
172 shape_inference::ShapeHandle lhs_shape_handle = c->input(0);
173 shape_inference::ShapeHandle rhs_shape_handle = c->input(1);
174 if (!c->FullyDefined(lhs_shape_handle) ||
175 !c->FullyDefined(rhs_shape_handle)) {
176 return shape_inference::UnknownShape(c);
177 }
178
179 string dimension_numbers_string;
180 TF_RETURN_IF_ERROR(
181 c->GetAttr("dimension_numbers", &dimension_numbers_string));
182
183 xla::DotDimensionNumbers dimension_numbers;
184 dimension_numbers.ParseFromString(dimension_numbers_string);
185
186 // Check that number of contracting dimensions match.
187 if (dimension_numbers.lhs_contracting_dimensions_size() !=
188 dimension_numbers.rhs_contracting_dimensions_size())
189 return errors::InvalidArgument(
190 "Must specify the same number of contracting dimensions for lhs "
191 "and rhs. Got: ",
192 dimension_numbers.lhs_contracting_dimensions_size(), " and ",
193 dimension_numbers.rhs_contracting_dimensions_size());
194
195 // Check that contracting dimension sizes match.
196 for (int64 i = 0; i < dimension_numbers.lhs_contracting_dimensions_size();
197 ++i) {
198 const int64 lhs_contracting_dimension =
199 dimension_numbers.lhs_contracting_dimensions(i);
200 const int64 rhs_contracting_dimension =
201 dimension_numbers.rhs_contracting_dimensions(i);
202 shape_inference::DimensionOrConstant
203 lhs_contracting_dimension_or_constant(
204 c->DimKnownRank(lhs_shape_handle, lhs_contracting_dimension));
205 shape_inference::DimensionOrConstant
206 rhs_contracting_dimension_or_constant(
207 c->DimKnownRank(rhs_shape_handle, rhs_contracting_dimension));
208 const int64 lhs_contracting_dimension_size =
209 c->Value(lhs_contracting_dimension_or_constant);
210 const int64 rhs_contracting_dimension_size =
211 c->Value(rhs_contracting_dimension_or_constant);
212 if (lhs_contracting_dimension_size != rhs_contracting_dimension_size) {
213 return errors::InvalidArgument(
214 "Contracting dimension sizes do not match. Got: ",
215 lhs_contracting_dimension_size, " and ",
216 rhs_contracting_dimension_size);
217 }
218 }
219
220 // Check that number of batch dimensions match.
221 if (dimension_numbers.lhs_batch_dimensions_size() !=
222 dimension_numbers.rhs_batch_dimensions_size())
223 return errors::InvalidArgument(
224 "Must specify the same number of batch dimensions for lhs "
225 "and rhs. Got: ",
226 dimension_numbers.lhs_batch_dimensions_size(), " and ",
227 dimension_numbers.rhs_batch_dimensions_size());
228
229 // Check that batch dimension sizes match.
230 for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size();
231 ++i) {
232 const int64 lhs_batch_dimension =
233 dimension_numbers.lhs_batch_dimensions(i);
234 const int64 rhs_batch_dimension =
235 dimension_numbers.rhs_batch_dimensions(i);
236 shape_inference::DimensionOrConstant lhs_batch_dimension_or_constant(
237 c->DimKnownRank(lhs_shape_handle, lhs_batch_dimension));
238 shape_inference::DimensionOrConstant rhs_batch_dimension_or_constant(
239 c->DimKnownRank(rhs_shape_handle, rhs_batch_dimension));
240 const int64 lhs_batch_dimension_size =
241 c->Value(lhs_batch_dimension_or_constant);
242 const int64 rhs_batch_dimension_size =
243 c->Value(rhs_batch_dimension_or_constant);
244 if (lhs_batch_dimension_size != rhs_batch_dimension_size) {
245 return errors::InvalidArgument(
246 "Batch dimension sizes do not match. Got: ",
247 lhs_batch_dimension_size, " and ", rhs_batch_dimension_size);
248 }
249 }
250
251 // The ranks of lhs and rhs are decremented by 1 respectively due to the
252 // contraction, and added for the rank of the result. When an input tensor
253 // is a scalar, its contribution to the rank of the result is 0. Generate
254 // the result dimensions in order, rhs dimensions followed by lhs
255 // dimensions except the contracted and batch dimensions.
256 std::vector<shape_inference::DimensionHandle> output_dims;
257 for (int64 lhs_dim : dimension_numbers.lhs_batch_dimensions()) {
258 output_dims.emplace_back(c->Dim(lhs_shape_handle, lhs_dim));
259 }
260 const int32 lhs_rank = c->Rank(lhs_shape_handle);
261 for (int64 i = 0; i < lhs_rank; ++i) {
262 if (absl::c_linear_search(
263 dimension_numbers.lhs_contracting_dimensions(), i) ||
264 absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(),
265 i)) {
266 continue;
267 }
268 output_dims.emplace_back(c->Dim(lhs_shape_handle, i));
269 }
270
271 const int32 rhs_rank = c->Rank(rhs_shape_handle);
272 for (int64 i = 0; i < rhs_rank; ++i) {
273 if (absl::c_linear_search(
274 dimension_numbers.rhs_contracting_dimensions(), i) ||
275 absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(),
276 i)) {
277 continue;
278 }
279 output_dims.emplace_back(c->Dim(rhs_shape_handle, i));
280 }
281
282 c->set_output(0, c->MakeShape(output_dims));
283 return Status::OK();
284 })
285 .Doc(R"doc(
286 Wraps the XLA DotGeneral operator, documented at
287 https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral
288 .
289
290 lhs: the LHS tensor
291 rhs: the RHS tensor
292 dimension_numbers: a serialized xla::DotDimensionNumbers proto.
293 precision_config: a serialized xla::PrecisionConfig proto.
294 )doc");
295
296 REGISTER_OP("XlaSetBound")
297 .Input("input: int32")
298 .Input("bound: int32")
299 .Output("output: int32")
300 .SetShapeFn(shape_inference::UnknownShape)
301 .Doc(
302 R"doc(Set a bound for the given input value as a hint to Xla compiler,
303 returns the same value.
304 )doc");
305
306 REGISTER_OP("XlaSetDynamicDimensionSize")
307 .Input("input: T")
308 .Input("dim_index: int32")
309 .Input("size: int32")
310 .Output("output: T")
311 .Attr("T: type")
312 // Use unknown shape to prevent constant folding.
313 .SetShapeFn(shape_inference::UnknownShape)
314 .Doc(
315 R"doc(Make a static dimension into a xla bounded dynamic dimension.
316 The current static dimension size will become the bound and the second
317 operand becomes the dynamic size of the dimension.)doc");
318
319 REGISTER_OP("XlaDynamicSlice")
320 .Input("input: T")
321 .Input("start_indices: Tindices")
322 .Input("size_indices: Tindices")
323 .Output("output: T")
324 .Attr("T: type")
325 .Attr("Tindices: {int32, int64}")
326 .SetShapeFn(shape_inference::UnknownShape)
327 .Doc(R"doc(
328 Wraps the XLA DynamicSlice operator, documented at
329 https://www.tensorflow.org/performance/xla/operation_semantics#dynamicslice
330 .
331
332 DynamicSlice extracts a sub-array from the input array at dynamic
333 start_indices. The size of the slice in each dimension is passed in
334 size_indices, which specify the end point of exclusive slice intervals in each
335 dimension -- [start, start + size). The shape of start_indices must have rank 1,
336 with dimension size equal to the rank of operand.
337
338 input: A `Tensor` of type T.
339
340 start_indices: Rank 1 tensor of N integers containing the starting indices of
341 the slice for each dimension. Value must be greater than or equal to zero.
342
343 start_indices: List of N integers containing the slice size for each
344 dimension. Each value must be strictly greater than zero, and start + size
345 must be less than or equal to the size of the dimension to avoid
346 implementation defined behavior.
347 )doc");
348
349 REGISTER_OP("XlaDynamicUpdateSlice")
350 .Input("input: T")
351 .Input("update: T")
352 .Input("indices: Tindices")
353 .Output("output: T")
354 .Attr("T: type")
355 .Attr("Tindices: {int32, int64}")
356 .SetShapeFn(shape_inference::UnchangedShape)
357 .Doc(R"doc(
358 Wraps the XLA DynamicUpdateSlice operator, documented at
359 https://www.tensorflow.org/performance/xla/operation_semantics#dynamicupdateslice
360 .
361
362 XlaDynamicUpdateSlice generates a result which is the value of the `input`
363 operand, with a slice update overwritten at `indices`. The shape of `update`
364 determines the shape of the sub-array of the result which is updated. The shape
365 of indices must be rank == 1, with dimension size equal to the rank of `input`.
366
367 Handling of out-of-bounds slice indices is implementation-defined.
368
369 input: A `Tensor` of type T.
370 indices: A vector of indices into `input`. Must have length equal to the rank of
371 `input`.
372 update: A `Tensor` of type T. Same rank as `input`.
373 output: A `Tensor` of type T.
374 )doc");
375
376 // TODO(b/37549631) setting the If Op to always be stateful is too
377 // conservative.
378 REGISTER_OP("XlaIf")
379 .Input("cond: Tcond")
380 .Input("inputs: Tin")
381 .Output("output: Tout")
382 .Attr("Tcond: type")
383 .Attr("then_branch: func")
384 .Attr("else_branch: func")
385 .Attr("Tin: list(type) >= 0")
386 .Attr("Tout: list(type) >= 0")
387 .SetIsStateful()
388 .SetShapeFn(shape_inference::UnknownShape)
389 .Doc(R"doc(
390 output = cond ? then_branch(inputs) : else_branch(inputs).
391
392 cond: A boolean scalar.
393 inputs: A list of input tensors.
394 output: A list of tensors returned by either then_branch(inputs) or
395 else_branch(inputs). The input shapes of the then_branch and
396 else_branch must match.
397 then_branch: A function takes 'inputs' and returns a list of tensors,
398 whose types are the same as what else_branch returns.
399 else_branch: A function takes 'inputs' and returns a list of tensors.
400 whose types are the same as what then_branch returns.
401 )doc");
402
403 REGISTER_OP("XlaPad")
404 .Input("input: T")
405 .Input("padding_value: T")
406 .Input("padding_low: Tindices")
407 .Input("padding_high: Tindices")
408 .Input("padding_interior: Tindices")
409 .Output("output: T")
410 .Attr("T: type")
411 .Attr("Tindices: {int32, int64}")
__anon3500048a0302(shape_inference::InferenceContext* c) 412 .SetShapeFn([](shape_inference::InferenceContext* c) {
413 shape_inference::ShapeHandle input_shape_handle = c->input(0);
414 if (!c->FullyDefined(input_shape_handle)) {
415 return UnchangedRank(c);
416 }
417 const int32 op_rank = c->Rank(input_shape_handle);
418
419 shape_inference::ShapeHandle padding_shape_handle = c->input(1);
420 if (!c->RankKnown(padding_shape_handle) ||
421 c->Rank(padding_shape_handle) != 0) {
422 return errors::InvalidArgument(
423 "padding_value input must be scalar, found rank ",
424 c->Rank(padding_shape_handle));
425 }
426 const Tensor* padding_low_tensor = c->input_tensor(2);
427 const Tensor* padding_high_tensor = c->input_tensor(3);
428 const Tensor* padding_interior_tensor = c->input_tensor(4);
429 if (padding_low_tensor == nullptr || padding_high_tensor == nullptr ||
430 padding_interior_tensor == nullptr) {
431 return UnchangedRank(c);
432 }
433
434 if (padding_low_tensor->shape().dims() != 1 ||
435 padding_low_tensor->shape().dim_size(0) != op_rank) {
436 return errors::InvalidArgument(
437 "padding_low must be a 1D tensor of size ", op_rank);
438 }
439 if (padding_high_tensor->shape().dims() != 1 ||
440 padding_high_tensor->shape().dim_size(0) != op_rank) {
441 return errors::InvalidArgument(
442 "padding_high must be a 1D tensor of size ", op_rank);
443 }
444 if (padding_interior_tensor->shape().dims() != 1 ||
445 padding_interior_tensor->shape().dim_size(0) != op_rank) {
446 return errors::InvalidArgument(
447 "padding_interior must be a 1D tensor of size ", op_rank);
448 }
449 std::vector<shape_inference::DimensionHandle> output_dims;
450 output_dims.reserve(op_rank);
451 for (int64 i = 0; i < op_rank; ++i) {
452 int64 low, high, interior;
453 TF_RETURN_IF_ERROR(c->GetScalarFromTensor(padding_low_tensor, i, &low));
454 TF_RETURN_IF_ERROR(
455 c->GetScalarFromTensor(padding_high_tensor, i, &high));
456 TF_RETURN_IF_ERROR(
457 c->GetScalarFromTensor(padding_interior_tensor, i, &interior));
458 if (interior < 0) {
459 return errors::InvalidArgument(
460 "padding_interior must contain only non-negative values, found ",
461 interior);
462 }
463
464 shape_inference::DimensionHandle orig_size_handle =
465 c->Dim(input_shape_handle, i);
466 if (c->ValueKnown(orig_size_handle)) {
467 auto orig_dim = c->Value(orig_size_handle);
468 int64 new_dim = orig_dim + low + high;
469 if (orig_dim > 0) {
470 new_dim += interior * (orig_dim - 1);
471 }
472 if (new_dim < 0) {
473 return errors::InvalidArgument(
474 "resulting padded dimension has negative size ", new_dim);
475 }
476 output_dims.emplace_back(c->MakeDim(new_dim));
477 } else {
478 output_dims.emplace_back(c->UnknownDim());
479 }
480 }
481
482 c->set_output(0, c->MakeShape(output_dims));
483 return Status::OK();
484 })
485 .Doc(R"doc(
486 Wraps the XLA Pad operator, documented at
487 https://www.tensorflow.org/performance/xla/operation_semantics#pad
488 .
489
490 input: A `Tensor` of type T.
491 padding_value: A scalar `Tensor` of type T.
492 padding_low: the padding to apply at the start of each input dimensions. Must
493 be a compile-time constant 1D tensor of length equal to rank of input.
494 padding_high: the padding to apply at the end of each input dimension. Must
495 be a compile-time constant 1D tensor of length equal to rank of input.
496 padding_interior: the padding to apply between each input element. Must
497 be a compile-time constant 1D tensor of length equal to rank of input,
498 containing only non-negative values.
499 output: A `Tensor` of type T.
500 )doc");
501
502 REGISTER_OP("XlaRecv")
503 .Output("tensor: dtype")
504 .Attr("dtype: type")
505 .Attr("tensor_name: string")
506 .Attr("shape: shape")
507 .SetIsStateful()
__anon3500048a0402(shape_inference::InferenceContext* c) 508 .SetShapeFn([](shape_inference::InferenceContext* c) {
509 TensorShape shape_attr;
510 TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_attr));
511 shape_inference::ShapeHandle s;
512 TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s));
513 c->set_output(0, s);
514 return Status::OK();
515 })
516 .Doc(R"doc(
517 Receives the named tensor from another XLA computation. Wraps the XLA Recv
518 operator documented at
519 https://www.tensorflow.org/performance/xla/operation_semantics#recv .
520
521 tensor: The tensor to receive.
522 dtype: The type of the tensor.
523 tensor_name: A string key that identifies the channel.
524 shape: The shape of the tensor.
525 )doc");
526
527 REGISTER_OP("XlaReduce")
528 .Input("input: T")
529 .Input("init_value: T")
530 .Attr("T: numbertype")
531 .Attr("dimensions_to_reduce: list(int)")
532 .Attr("reducer: func")
533 .Output("output: T")
__anon3500048a0502(shape_inference::InferenceContext* c) 534 .SetShapeFn([](shape_inference::InferenceContext* c) {
535 if (c->RankKnown(c->input(0))) {
536 int rank = c->Rank(c->input(0));
537 std::vector<int64> dimensions_to_reduce;
538 TF_RETURN_IF_ERROR(
539 c->GetAttr("dimensions_to_reduce", &dimensions_to_reduce));
540 std::set<int64> dims_set(dimensions_to_reduce.begin(),
541 dimensions_to_reduce.end());
542 auto dim_in_range = [rank](int64 dim) {
543 return dim >= 0 && dim < rank;
544 };
545 const int dimensions_to_reduce_size = dimensions_to_reduce.size();
546 if (rank < dimensions_to_reduce_size ||
547 dims_set.size() != dimensions_to_reduce.size() ||
548 !absl::c_all_of(dimensions_to_reduce, dim_in_range)) {
549 return errors::InvalidArgument(
550 "Invalid dimensions_to_reduce argument to XlaReduce");
551 }
552 c->set_output(
553 0, c->UnknownShapeOfRank(rank - dimensions_to_reduce.size()));
554 } else {
555 c->set_output(0, c->input(0));
556 }
557 return Status::OK();
558 })
559 .Doc(R"doc(
560 Wraps the XLA Reduce operator, documented at
561 https://www.tensorflow.org/performance/xla/operation_semantics#reduce .
562
563 input: the input tensor
564 init_value: a scalar representing the initial value for the reduction
565 reducer: a reducer function to apply
566 dimensions_to_reduce: dimension numbers over which to reduce
567 )doc");
568
569 REGISTER_OP("XlaVariadicReduce")
570 .Input("input: N * T")
571 .Input("init_value: N * T")
572 .Attr("N: int >= 1")
573 .Attr("T: numbertype")
574 .Attr("dimensions_to_reduce: list(int)")
575 .Attr("reducer: func")
576 .Output("output: N * T")
__anon3500048a0702(shape_inference::InferenceContext* c) 577 .SetShapeFn([](shape_inference::InferenceContext* c) {
578 int n;
579 TF_RETURN_IF_ERROR(c->GetAttr("N", &n));
580 for (int i = 0; i < n; i++) {
581 for (int j = 0; j < n; j++) {
582 c->MergeInput(i, c->input(j));
583 }
584 }
585 if (c->RankKnown(c->input(0))) {
586 int rank = c->Rank(c->input(0));
587 std::vector<int64> dimensions_to_reduce;
588 TF_RETURN_IF_ERROR(
589 c->GetAttr("dimensions_to_reduce", &dimensions_to_reduce));
590 std::set<int64> dims_set(dimensions_to_reduce.begin(),
591 dimensions_to_reduce.end());
592 auto dim_in_range = [rank](int64 dim) {
593 return dim >= 0 && dim < rank;
594 };
595 const int dimensions_to_reduce_size = dimensions_to_reduce.size();
596 if (rank < dimensions_to_reduce_size ||
597 dims_set.size() != dimensions_to_reduce.size() ||
598 !absl::c_all_of(dimensions_to_reduce, dim_in_range)) {
599 return errors::InvalidArgument(
600 "Invalid dimensions_to_reduce argument to XlaVariadicReduce");
601 }
602 for (int i = 0; i < n; i++) {
603 c->set_output(
604 i, c->UnknownShapeOfRank(rank - dimensions_to_reduce.size()));
605 }
606 } else {
607 for (int i = 0; i < n; i++) {
608 c->set_output(i, c->input(i));
609 }
610 }
611 return Status::OK();
612 })
613 .Doc(R"doc(
614 Wraps the variadic XLA Reduce operator.
615
616 Semantics are documented at
617 https://www.tensorflow.org/performance/xla/operation_semantics#variadic_reduce.
618
619 input: the input tensor(s)
620 init_value: scalar initial value(s) for the reduction
621 reducer: a reducer function to apply
622 dimensions_to_reduce: dimension numbers over which to reduce
623 )doc");
624
625 REGISTER_OP("XlaReduceWindow")
626 .Input("input: T")
627 .Input("init_value: T")
628 .Input("window_dimensions: Tindices")
629 .Input("window_strides: Tindices")
630 .Input("base_dilations: Tindices")
631 .Input("window_dilations: Tindices")
632 .Input("padding: Tindices")
633 .Attr("T: numbertype")
634 .Attr("Tindices: {int32, int64}")
635 .Attr("computation: func")
636 .Output("output: T")
637 .SetShapeFn(UnchangedRank)
638 .Doc(R"doc(
639 Wraps the XLA ReduceWindow operator, documented at
640 https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow .
641
642 input: the input tensor
643 init_value: a scalar representing the initial value for the reduction
644 computation: a reducer function to apply
645 window_dimensions: the shape of the window
646 window_strides: the inter-window strides
647 padding: the padding to apply at the start and end of each input dimensions
648 )doc");
649
650 REGISTER_OP("XlaSelectAndScatter")
651 .Input("operand: T")
652 .Input("window_dimensions: Tindices")
653 .Input("window_strides: Tindices")
654 .Input("padding: Tindices")
655 .Input("source: T")
656 .Input("init_value: T")
657 .Attr("T: numbertype")
658 .Attr("Tindices: {int32, int64}")
659 .Attr("select: func")
660 .Attr("scatter: func")
661 .Output("output: T")
662 .SetShapeFn(UnchangedRank)
663 .Doc(R"doc(
664 Wraps the XLA SelectAndScatter operator, documented at
665 https://www.tensorflow.org/performance/xla/operation_semantics#selectandscatter
666 .
667
668 operand: the input tensor
669 window_dimensions: the shape of the window
670 window_strides: the inter-window strides
671 padding: the padding to apply at the start and end of each input dimensions
672 source: a tensor of values to scatter
673 init_value: a scalar representing the initial value for the output tensor
674 select: a selection function to apply
675 scatter: a scatter function to apply
676 )doc");
677
678 REGISTER_OP("XlaSend")
679 .Input("tensor: T")
680 .Attr("T: type")
681 .Attr("tensor_name: string")
682 .SetIsStateful()
683 .SetShapeFn(shape_inference::UnknownShape)
684 .Doc(R"doc(
685 Sends the named tensor to another XLA computation. Wraps the XLA Send operator
686 documented at
687 https://www.tensorflow.org/performance/xla/operation_semantics#send .
688
689 tensor: The tensor to send.
690 tensor_name: A string key that identifies the channel.
691 )doc");
692
693 REGISTER_OP("XlaSort")
694 .Input("input: T")
695 .Output("output: T")
696 .Attr("T: type")
697 .SetShapeFn(shape_inference::UnchangedShape)
698 .Doc(R"doc(
699 Wraps the XLA Sort operator, documented at
700 https://www.tensorflow.org/performance/xla/operation_semantics#sort
701 .
702
703 Sorts a tensor. Currently only sorts in ascending order are supported.
704
705 input: A `Tensor` of type T.
706 output: A `Tensor` of type T.
707 )doc");
708
709 REGISTER_OP("XlaKeyValueSort")
710 .Input("keys: K")
711 .Input("values: V")
712 .Output("sorted_keys: K")
713 .Output("sorted_values: V")
714 .Attr("K: realnumbertype")
715 .Attr("V: type")
__anon3500048a0902(shape_inference::InferenceContext* c) 716 .SetShapeFn([](shape_inference::InferenceContext* c) {
717 c->set_output(0, c->input(0));
718 c->set_output(1, c->input(1));
719 return Status::OK();
720 })
721 .Doc(R"doc(
722 Wraps the XLA Sort operator, documented at
723 https://www.tensorflow.org/performance/xla/operation_semantics#sort
724 .
725
726 Sorts a tensor. Currently only sorts in ascending order are supported.
727
728 keys: A `Tensor` of type K.
729 values: A `Tensor` of type V.
730 sorted_keys: A `Tensor` of type K.
731 sorted_values: A `Tensor` of type V.
732 )doc");
733
734 REGISTER_OP("XlaVariadicSort")
735 .Input("inputs: T")
736 .Input("dimension: int32")
737 .Output("outputs: T")
738 .Attr("T: list(type) >= 1")
739 .Attr("comparator: func")
740 .Attr("is_stable: bool")
__anon3500048a0a02(shape_inference::InferenceContext* c) 741 .SetShapeFn([](shape_inference::InferenceContext* c) {
742 std::vector<shape_inference::ShapeHandle> input_shapes;
743 TF_RETURN_IF_ERROR(c->input("inputs", &input_shapes));
744 TF_RETURN_IF_ERROR(c->set_output("outputs", input_shapes));
745 return Status::OK();
746 })
747 .Doc(R"doc(
748 Wraps the XLA Sort operator, documented at
749 https://www.tensorflow.org/performance/xla/operation_semantics#sort
750 .
751
752 Sorts one or more tensors, with support for custom comparator, dimension, and
753 is_stable attributes.
754
755 inputs: A list of `Tensor` of identical shape but possibly different types.
756 dimension: The dimension along which to sort. Must be a compile-time constant.
757 is_stable: Whether to use stable sort.
758 comparator: A comparator function to apply to 2*N scalars and returning a
759 boolean. N is the number of sort inputs. If you want to sort in ascending
760 order then the comparator should perform a less-than comparison.
761 outputs: A list of `Tensor` of same shape and types as the `input`.
762 )doc");
763
764 // TODO(b/37549631) setting the While Op to always be stateful is too
765 // conservative.
766 REGISTER_OP("XlaWhile")
767 .Input("input: T")
768 .Output("output: T")
769 .Attr("T: list(type) >= 0")
770 .Attr("cond: func")
771 .Attr("body: func")
772 .SetIsStateful()
773 .SetShapeFn(shape_inference::UnknownShape)
774 .Doc(R"doc(
775 output = input; While (Cond(output)) { output = Body(output) }
776
777 input: A list of input tensors whose types are T.
778 output: A list of output tensors whose types are T.
779 cond: A function takes 'input' and returns a tensor. If the tensor is
780 a scalar of non-boolean, the scalar is converted to a boolean
781 according to the following rule: if the scalar is a numerical
782 value, non-zero means True and zero means False; if the scalar is
783 a string, non-empty means True and empty means False. If the
784 tensor is not a scalar, non-emptiness means True and False
785 otherwise.
786 body: A function that takes a list of tensors and returns another
787 list of tensors. Both lists have the same types as specified by T.
788 )doc");
789
790 REGISTER_OP("XlaDequantize")
791 .Input("input: uint32")
792 .Output("output: bfloat16")
793 .Attr("min_range: float")
794 .Attr("max_range: float")
795 .Attr("mode: string")
796 .Attr("transpose_output: bool")
797 .SetIsStateful()
798 .SetShapeFn(shape_inference::UnknownShape)
799 .Doc(R"doc(
800 Takes the packed uint32 input and unpacks the input to uint8 to do
801 Dequantization on device.
802
803 input: Input tensors whose types is uint32, shape is [d0, ..., dn].
804 output: Output tensors whose types is bloat16. If transpose_output is true,
805 output shape is [dn * 4, dn-1, ..., d1, d0]. If transpose_output
806 is false, output shape is [d0,..., dn * 4].
807 min_range: The minimum scalar value possibly produced for the input.
808 max_range: The maximum scalar value possibly produced for the input.
809 mode: String to determine the dequantize mode in {"MIN_COMBINED", "MIN_FIRST", "SCALED"}.
810 transpose_output: Boolean to determine if output is transposed. transpose_output
811 is faster when input is large and rank of input is higher than 1.
812 )doc");
813
814 REGISTER_OP("XlaEinsum")
815 .Input("a: T")
816 .Input("b: T")
817 .Output("product: T")
818 .Attr("equation: string")
819 .Attr("T: {complex64, bfloat16, float}")
__anon3500048a0b02(shape_inference::InferenceContext* context) 820 .SetShapeFn([](shape_inference::InferenceContext* context) {
821 string equation;
822 TF_RETURN_IF_ERROR(context->GetAttr("equation", &equation));
823 // XlaEinsum supports only two-input einsum equations.
824 if (!absl::StrContains(equation, ",")) {
825 return errors::InvalidArgument("Expected one \",\" in equation. Got: ",
826 equation);
827 }
828 // Use EinsumShape for the rest of the inference now that we know we must
829 // have a two-input einsum.
830 return shape_inference::EinsumShape(context);
831 })
832 .Doc(R"doc(
833 An op which supports basic einsum op with 2 inputs and 1 output.
834
835 This op has better TPU performance since it doesn't have explicitly reshape and
836 transpose operations as tf.einsum does.
837 )doc");
838
839 REGISTER_OP("XlaSpmdFullToShardShape")
840 .Input("input: T")
841 .Output("output: T")
842 .Attr("T: type")
843 .Attr("manual_sharding: string")
__anon3500048a0c02(shape_inference::InferenceContext* c) 844 .SetShapeFn([](shape_inference::InferenceContext* c) {
845 auto input_handle = c->input(0);
846 if (!c->RankKnown(input_handle)) {
847 return shape_inference::UnknownShape(c);
848 }
849 string sharding_attr;
850 TF_RETURN_IF_ERROR(c->GetAttr("manual_sharding", &sharding_attr));
851 xla::OpSharding sharding;
852 sharding.ParseFromString(sharding_attr);
853 if (sharding.type() != xla::OpSharding::OTHER) {
854 return shape_inference::UnchangedShape(c);
855 }
856 std::vector<shape_inference::DimensionHandle> dims;
857 for (int64 i = 0; i < c->Rank(input_handle); ++i) {
858 auto dim = c->Value(c->Dim(input_handle, i));
859 int64 partitions_i = sharding.tile_assignment_dimensions(i);
860 if (dim != shape_inference::InferenceContext::kUnknownDim &&
861 partitions_i != 1) {
862 dim = (dim + partitions_i - 1) / partitions_i;
863 }
864 dims.push_back(c->MakeDim(dim));
865 }
866 c->set_output(0, c->MakeShape(dims));
867 return Status::OK();
868 })
869 .Doc(R"doc(
870 An op used by XLA SPMD partitioner to switch from automatic partitioning to
871 manual partitioning. It annotates the input (full-shape, to be automatically
872 partitioned) with the same sharding used by manual partitioning, and outputs a
873 shard-shaped tensor to be consumed by later manually-partitioned ops. If the
874 shape is not evenly partitionable, the padding region will be masked with 0s.
875 )doc");
876
877 REGISTER_OP("XlaSpmdShardToFullShape")
878 .Input("input: T")
879 .Output("output: T")
880 .Attr("T: type")
881 .Attr("manual_sharding: string")
882 .Attr("full_shape: shape")
__anon3500048a0d02(shape_inference::InferenceContext* c) 883 .SetShapeFn([](shape_inference::InferenceContext* c) {
884 TensorShape shape_attr;
885 TF_RETURN_IF_ERROR(c->GetAttr("full_shape", &shape_attr));
886 shape_inference::ShapeHandle s;
887 TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s));
888 c->set_output(0, s);
889 return Status::OK();
890 })
891 .Doc(R"doc(
892 An op used by XLA SPMD partitioner to switch from manual partitioning to
893 automatic partitioning. It converts the shard-shaped, manually partitioned input
894 into full-shaped tensor to be partitioned automatically with the same sharding
895 used by manual partitioning.
896 )doc");
897
898 REGISTER_OP("XlaSharding")
899 .Input("input: T")
900 .Output("output: T")
901 .Attr("T: type")
902 .Attr("sharding: string = ''")
903 .SetShapeFn(shape_inference::UnchangedShape)
904 .Doc(R"doc(
905 An op which shards the input based on the given sharding attribute.
906 )doc");
907
908 REGISTER_OP("XlaReplicaId")
909 .Output("id: int32")
__anon3500048a0e02(shape_inference::InferenceContext* context) 910 .SetShapeFn([](shape_inference::InferenceContext* context) {
911 context->set_output(0, context->MakeShape({}));
912 return Status::OK();
913 })
914 .Doc("Replica ID.");
915
916 REGISTER_OP("XlaGather")
917 .Input("operand: T")
918 .Input("start_indices: Tindices")
919 .Input("slice_sizes: Tindices")
920 .Attr("dimension_numbers: string")
921 .Attr("indices_are_sorted: bool")
922 .Attr("T: {numbertype, bool}")
923 .Attr("Tindices: {int32, int64}")
924 .Output("output: T")
925 .SetShapeFn(shape_inference::UnknownShape)
926 .Doc(R"doc(
927 Wraps the XLA Gather operator documented at
928 https://www.tensorflow.org/xla/operation_semantics#gather
929 operand: The array we're gathering from.
930 start_indices: Array containing the starting indices of the slices we gather.
931 dimension_numbers: A serialized xla::GatherDimensionNumbers proto.
932 slice_sizes: slice_sizes[i] is the bounds for the slice on dimension i.
933 indices_are_sorted: Boolean indicating if the indices are sorted.
934 )doc");
935
936 REGISTER_OP("XlaScatter")
937 .Input("operand: T")
938 .Input("scatter_indices: Tindices")
939 .Input("updates: T")
940 .Attr("update_computation: func")
941 .Attr("dimension_numbers: string")
942 .Attr("indices_are_sorted: bool")
943 .Attr("T: {numbertype, bool}")
944 .Attr("Tindices: {int32, int64}")
945 .Output("output: T")
946 .SetShapeFn(shape_inference::UnchangedShape)
947 .Doc(R"doc(
948 Wraps the XLA Scatter operator documented at
949 https://www.tensorflow.org/xla/operation_semantics#scatter.
950
951 operand: Array to be scattered into.
952 scatter_indices: Array containing the starting indices of the slices that must
953 be scattered to.
954 updates: Array containing the values that must be used for scattering.
955 update_computation: Computation to be used for combining the existing values in
956 the input array and the updates during scatter.
957 dimension_numbers: A serialized xla::ScatterDimensionNumbers proto.
958 indices_are_sorted: Boolean indicating if the indices are sorted.
959 )doc");
960
961 } // namespace
962 } // namespace tensorflow
963