1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstddef>
17 
18 #include "absl/algorithm/container.h"
19 #include "absl/strings/match.h"
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_split.h"
22 #include "tensorflow/compiler/xla/xla_data.pb.h"
23 #include "tensorflow/core/framework/common_shape_fns.h"
24 #include "tensorflow/core/framework/op.h"
25 #include "tensorflow/core/framework/shape_inference.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 
28 namespace tensorflow {
29 namespace {
30 
31 // Helper shape function for operators that return an output with the same rank
32 // as their first input.
UnchangedRank(shape_inference::InferenceContext * c)33 Status UnchangedRank(shape_inference::InferenceContext* c) {
34   if (c->RankKnown(c->input(0))) {
35     c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0))));
36   } else {
37     c->set_output(0, c->input(0));
38   }
39   return Status::OK();
40 }
41 
42 REGISTER_OP("XlaBroadcastHelper")
43     .Input("lhs: T")
44     .Input("rhs: T")
45     .Input("broadcast_dims: Tindices")
46     .Attr("T: numbertype")
47     .Attr("Tindices: {int32, int64}")
48     .Output("lhs_output: T")
49     .Output("rhs_output: T")
50     .SetShapeFn(shape_inference::UnknownShape)
51     .Doc(R"doc(
52 Helper operator for performing XLA-style broadcasts
53 
54 Broadcasts `lhs` and `rhs` to the same rank, by adding size 1 dimensions to
55 whichever of `lhs` and `rhs` has the lower rank, using XLA's broadcasting rules
56 for binary operators.
57 
58 lhs: the LHS input tensor
59 rhs: the RHS input tensor
60 broadcast_dims: an XLA-style broadcast dimension specification
61 lhs_output: the broadcasted LHS tensor
62 rhs_output: the broadcasted RHS tensor
63 )doc");
64 
65 REGISTER_OP("XlaSelfAdjointEig")
66     .Input("a: T")
67     .Attr("lower: bool")
68     .Attr("max_iter: int")
69     .Attr("epsilon: float")
70     .Output("w: T")
71     .Output("v: T")
72     .SetShapeFn(shape_inference::UnknownShape)
73     .Attr("T: numbertype")
74     .Doc(R"doc(
75 Computes the eigen decomposition of a batch of self-adjoint matrices
76 (Note: Only real inputs are supported).
77 
78 Computes the eigenvalues and eigenvectors of the innermost N-by-N matrices in
79 tensor such that tensor[...,:,:] * v[..., :,i] = e[..., i] * v[...,:,i], for
80 i=0...N-1.
81 
82 a: the input tensor.
83 
84 lower: a boolean specifies whether the calculation is done with the lower
85   triangular part or the upper triangular part.
86 
87 max_iter: maximum number of sweep update, i.e., the whole lower triangular
88   part or upper triangular part based on parameter lower. Heuristically, it has
89   been argued that approximately logN sweeps are needed in practice (Ref: Golub &
90   van Loan "Matrix Computation").
91 
92 epsilon: the tolerance ratio.
93 
94 w: The eigenvalues in ascending order, each repeated according to its
95   multiplicity.
96 v: The column v[..., :, i] is the normalized eigenvector corresponding to the
97   eigenvalue w[..., i].
98 )doc");
99 
100 REGISTER_OP("XlaSvd")
101     .Input("a: T")
102     .Attr("max_iter: int")
103     .Attr("epsilon: float")
104     .Attr("precision_config: string")
105     .Output("s: T")
106     .Output("u: T")
107     .Output("v: T")
108     .SetShapeFn(shape_inference::UnknownShape)
109     .Attr("T: numbertype")
110     .Doc(R"doc(
111 Computes the eigen decomposition of a batch of self-adjoint matrices
112 (Note: Only real inputs are supported).
113 
114 Computes the eigenvalues and eigenvectors of the innermost M-by-N matrices in
115 tensor such that tensor[...,:,:] = u[..., :, :] * Diag(s[..., :]) * Transpose(v[...,:,:]).
116 
117 a: the input tensor.
118 
119 max_iter: maximum number of sweep update, i.e., the whole lower triangular
120   part or upper triangular part based on parameter lower. Heuristically, it has
121   been argued that approximately log(min (M, N)) sweeps are needed in practice
122   (Ref: Golub & van Loan "Matrix Computation").
123 
124 epsilon: the tolerance ratio.
125 
126 precision_config: a serialized xla::PrecisionConfig proto.
127 
128 s: Singular values. The values are sorted in reverse order of magnitude, so
129   s[..., 0] is the largest value, s[..., 1] is the second largest, etc.
130 u: Left singular vectors.
131 v: Right singular vectors.
132 )doc");
133 
134 REGISTER_OP("XlaConv")
135     .Input("lhs: T")
136     .Input("rhs: T")
137     .Input("window_strides: Tindices")
138     .Input("padding: Tindices")
139     .Input("lhs_dilation: Tindices")
140     .Input("rhs_dilation: Tindices")
141     .Input("feature_group_count: Tindices")
142     .Attr("T: numbertype")
143     .Attr("Tindices: {int32, int64}")
144     .Attr("dimension_numbers: string")
145     .Attr("precision_config: string")
146     .Output("output: T")
147     .SetShapeFn(UnchangedRank)
148     .Doc(R"doc(
149 Wraps the XLA ConvGeneralDilated operator, documented at
150  https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
151 .
152 
153 lhs: the input tensor
154 rhs: the kernel tensor
155 window_strides: the inter-window strides
156 padding: the padding to apply at the start and end of each input dimensions
157 lhs_dilation: dilation to apply between input elements
158 rhs_dilation: dilation to apply between kernel elements
159 feature_group_count: number of feature groups for grouped convolution.
160 dimension_numbers: a serialized xla::ConvolutionDimensionNumbers proto.
161 precision_config: a serialized xla::PrecisionConfig proto.
162 )doc");
163 
164 REGISTER_OP("XlaDot")
165     .Input("lhs: T")
166     .Input("rhs: T")
167     .Attr("T: numbertype")
168     .Attr("dimension_numbers: string")
169     .Attr("precision_config: string")
170     .Output("output: T")
__anon3500048a0202(shape_inference::InferenceContext* c) 171     .SetShapeFn([](shape_inference::InferenceContext* c) {
172       shape_inference::ShapeHandle lhs_shape_handle = c->input(0);
173       shape_inference::ShapeHandle rhs_shape_handle = c->input(1);
174       if (!c->FullyDefined(lhs_shape_handle) ||
175           !c->FullyDefined(rhs_shape_handle)) {
176         return shape_inference::UnknownShape(c);
177       }
178 
179       string dimension_numbers_string;
180       TF_RETURN_IF_ERROR(
181           c->GetAttr("dimension_numbers", &dimension_numbers_string));
182 
183       xla::DotDimensionNumbers dimension_numbers;
184       dimension_numbers.ParseFromString(dimension_numbers_string);
185 
186       // Check that number of contracting dimensions match.
187       if (dimension_numbers.lhs_contracting_dimensions_size() !=
188           dimension_numbers.rhs_contracting_dimensions_size())
189         return errors::InvalidArgument(
190             "Must specify the same number of contracting dimensions for lhs "
191             "and rhs. Got: ",
192             dimension_numbers.lhs_contracting_dimensions_size(), " and ",
193             dimension_numbers.rhs_contracting_dimensions_size());
194 
195       // Check that contracting dimension sizes match.
196       for (int64 i = 0; i < dimension_numbers.lhs_contracting_dimensions_size();
197            ++i) {
198         const int64 lhs_contracting_dimension =
199             dimension_numbers.lhs_contracting_dimensions(i);
200         const int64 rhs_contracting_dimension =
201             dimension_numbers.rhs_contracting_dimensions(i);
202         shape_inference::DimensionOrConstant
203             lhs_contracting_dimension_or_constant(
204                 c->DimKnownRank(lhs_shape_handle, lhs_contracting_dimension));
205         shape_inference::DimensionOrConstant
206             rhs_contracting_dimension_or_constant(
207                 c->DimKnownRank(rhs_shape_handle, rhs_contracting_dimension));
208         const int64 lhs_contracting_dimension_size =
209             c->Value(lhs_contracting_dimension_or_constant);
210         const int64 rhs_contracting_dimension_size =
211             c->Value(rhs_contracting_dimension_or_constant);
212         if (lhs_contracting_dimension_size != rhs_contracting_dimension_size) {
213           return errors::InvalidArgument(
214               "Contracting dimension sizes do not match. Got: ",
215               lhs_contracting_dimension_size, " and ",
216               rhs_contracting_dimension_size);
217         }
218       }
219 
220       // Check that number of batch dimensions match.
221       if (dimension_numbers.lhs_batch_dimensions_size() !=
222           dimension_numbers.rhs_batch_dimensions_size())
223         return errors::InvalidArgument(
224             "Must specify the same number of batch dimensions for lhs "
225             "and rhs. Got: ",
226             dimension_numbers.lhs_batch_dimensions_size(), " and ",
227             dimension_numbers.rhs_batch_dimensions_size());
228 
229       // Check that batch dimension sizes match.
230       for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size();
231            ++i) {
232         const int64 lhs_batch_dimension =
233             dimension_numbers.lhs_batch_dimensions(i);
234         const int64 rhs_batch_dimension =
235             dimension_numbers.rhs_batch_dimensions(i);
236         shape_inference::DimensionOrConstant lhs_batch_dimension_or_constant(
237             c->DimKnownRank(lhs_shape_handle, lhs_batch_dimension));
238         shape_inference::DimensionOrConstant rhs_batch_dimension_or_constant(
239             c->DimKnownRank(rhs_shape_handle, rhs_batch_dimension));
240         const int64 lhs_batch_dimension_size =
241             c->Value(lhs_batch_dimension_or_constant);
242         const int64 rhs_batch_dimension_size =
243             c->Value(rhs_batch_dimension_or_constant);
244         if (lhs_batch_dimension_size != rhs_batch_dimension_size) {
245           return errors::InvalidArgument(
246               "Batch dimension sizes do not match. Got: ",
247               lhs_batch_dimension_size, " and ", rhs_batch_dimension_size);
248         }
249       }
250 
251       // The ranks of lhs and rhs are decremented by 1 respectively due to the
252       // contraction, and added for the rank of the result. When an input tensor
253       // is a scalar, its contribution to the rank of the result is 0. Generate
254       // the result dimensions in order, rhs dimensions followed by lhs
255       // dimensions except the contracted and batch dimensions.
256       std::vector<shape_inference::DimensionHandle> output_dims;
257       for (int64 lhs_dim : dimension_numbers.lhs_batch_dimensions()) {
258         output_dims.emplace_back(c->Dim(lhs_shape_handle, lhs_dim));
259       }
260       const int32 lhs_rank = c->Rank(lhs_shape_handle);
261       for (int64 i = 0; i < lhs_rank; ++i) {
262         if (absl::c_linear_search(
263                 dimension_numbers.lhs_contracting_dimensions(), i) ||
264             absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(),
265                                   i)) {
266           continue;
267         }
268         output_dims.emplace_back(c->Dim(lhs_shape_handle, i));
269       }
270 
271       const int32 rhs_rank = c->Rank(rhs_shape_handle);
272       for (int64 i = 0; i < rhs_rank; ++i) {
273         if (absl::c_linear_search(
274                 dimension_numbers.rhs_contracting_dimensions(), i) ||
275             absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(),
276                                   i)) {
277           continue;
278         }
279         output_dims.emplace_back(c->Dim(rhs_shape_handle, i));
280       }
281 
282       c->set_output(0, c->MakeShape(output_dims));
283       return Status::OK();
284     })
285     .Doc(R"doc(
286 Wraps the XLA DotGeneral operator, documented at
287  https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral
288 .
289 
290 lhs: the LHS tensor
291 rhs: the RHS tensor
292 dimension_numbers: a serialized xla::DotDimensionNumbers proto.
293 precision_config: a serialized xla::PrecisionConfig proto.
294 )doc");
295 
296 REGISTER_OP("XlaSetBound")
297     .Input("input: int32")
298     .Input("bound: int32")
299     .Output("output: int32")
300     .SetShapeFn(shape_inference::UnknownShape)
301     .Doc(
302         R"doc(Set a bound for the given input value as a hint to Xla compiler,
303         returns the same value.
304 )doc");
305 
306 REGISTER_OP("XlaSetDynamicDimensionSize")
307     .Input("input: T")
308     .Input("dim_index: int32")
309     .Input("size: int32")
310     .Output("output: T")
311     .Attr("T: type")
312     // Use unknown shape to prevent constant folding.
313     .SetShapeFn(shape_inference::UnknownShape)
314     .Doc(
315         R"doc(Make a static dimension into a xla bounded dynamic dimension.
316         The current static dimension size will become the bound and the second
317         operand becomes the dynamic size of the dimension.)doc");
318 
319 REGISTER_OP("XlaDynamicSlice")
320     .Input("input: T")
321     .Input("start_indices: Tindices")
322     .Input("size_indices: Tindices")
323     .Output("output: T")
324     .Attr("T: type")
325     .Attr("Tindices: {int32, int64}")
326     .SetShapeFn(shape_inference::UnknownShape)
327     .Doc(R"doc(
328 Wraps the XLA DynamicSlice operator, documented at
329  https://www.tensorflow.org/performance/xla/operation_semantics#dynamicslice
330 .
331 
332 DynamicSlice extracts a sub-array from the input array at dynamic
333 start_indices. The size of the slice in each dimension is passed in
334 size_indices, which specify the end point of exclusive slice intervals in each
335 dimension -- [start, start + size). The shape of start_indices must have rank 1,
336 with dimension size equal to the rank of operand.
337 
338 input: A `Tensor` of type T.
339 
340 start_indices: Rank 1 tensor of N integers containing the starting indices of
341   the slice for each dimension. Value must be greater than or equal to zero.
342 
343 start_indices: List of N integers containing the slice size for each
344   dimension. Each value must be strictly greater than zero, and start + size
345   must be less than or equal to the size of the dimension to avoid
346   implementation defined behavior.
347 )doc");
348 
349 REGISTER_OP("XlaDynamicUpdateSlice")
350     .Input("input: T")
351     .Input("update: T")
352     .Input("indices: Tindices")
353     .Output("output: T")
354     .Attr("T: type")
355     .Attr("Tindices: {int32, int64}")
356     .SetShapeFn(shape_inference::UnchangedShape)
357     .Doc(R"doc(
358 Wraps the XLA DynamicUpdateSlice operator, documented at
359  https://www.tensorflow.org/performance/xla/operation_semantics#dynamicupdateslice
360 .
361 
362 XlaDynamicUpdateSlice generates a result which is the value of the `input`
363 operand, with a slice update overwritten at `indices`. The shape of `update`
364 determines the shape of the sub-array of the result which is updated. The shape
365 of indices must be rank == 1, with dimension size equal to the rank of `input`.
366 
367 Handling of out-of-bounds slice indices is implementation-defined.
368 
369 input: A `Tensor` of type T.
370 indices: A vector of indices into `input`. Must have length equal to the rank of
371   `input`.
372 update: A `Tensor` of type T. Same rank as `input`.
373 output: A `Tensor` of type T.
374 )doc");
375 
376 // TODO(b/37549631) setting the If Op to always be stateful is too
377 // conservative.
378 REGISTER_OP("XlaIf")
379     .Input("cond: Tcond")
380     .Input("inputs: Tin")
381     .Output("output: Tout")
382     .Attr("Tcond: type")
383     .Attr("then_branch: func")
384     .Attr("else_branch: func")
385     .Attr("Tin: list(type) >= 0")
386     .Attr("Tout: list(type) >= 0")
387     .SetIsStateful()
388     .SetShapeFn(shape_inference::UnknownShape)
389     .Doc(R"doc(
390 output = cond ? then_branch(inputs) : else_branch(inputs).
391 
392 cond: A boolean scalar.
393 inputs: A list of input tensors.
394 output: A list of tensors returned by either then_branch(inputs) or
395         else_branch(inputs). The input shapes of the then_branch and
396         else_branch must match.
397 then_branch: A function takes 'inputs' and returns a list of tensors,
398              whose types are the same as what else_branch returns.
399 else_branch: A function takes 'inputs' and returns a list of tensors.
400              whose types are the same as what then_branch returns.
401 )doc");
402 
403 REGISTER_OP("XlaPad")
404     .Input("input: T")
405     .Input("padding_value: T")
406     .Input("padding_low: Tindices")
407     .Input("padding_high: Tindices")
408     .Input("padding_interior: Tindices")
409     .Output("output: T")
410     .Attr("T: type")
411     .Attr("Tindices: {int32, int64}")
__anon3500048a0302(shape_inference::InferenceContext* c) 412     .SetShapeFn([](shape_inference::InferenceContext* c) {
413       shape_inference::ShapeHandle input_shape_handle = c->input(0);
414       if (!c->FullyDefined(input_shape_handle)) {
415         return UnchangedRank(c);
416       }
417       const int32 op_rank = c->Rank(input_shape_handle);
418 
419       shape_inference::ShapeHandle padding_shape_handle = c->input(1);
420       if (!c->RankKnown(padding_shape_handle) ||
421           c->Rank(padding_shape_handle) != 0) {
422         return errors::InvalidArgument(
423             "padding_value input must be scalar, found rank ",
424             c->Rank(padding_shape_handle));
425       }
426       const Tensor* padding_low_tensor = c->input_tensor(2);
427       const Tensor* padding_high_tensor = c->input_tensor(3);
428       const Tensor* padding_interior_tensor = c->input_tensor(4);
429       if (padding_low_tensor == nullptr || padding_high_tensor == nullptr ||
430           padding_interior_tensor == nullptr) {
431         return UnchangedRank(c);
432       }
433 
434       if (padding_low_tensor->shape().dims() != 1 ||
435           padding_low_tensor->shape().dim_size(0) != op_rank) {
436         return errors::InvalidArgument(
437             "padding_low must be a 1D tensor of size ", op_rank);
438       }
439       if (padding_high_tensor->shape().dims() != 1 ||
440           padding_high_tensor->shape().dim_size(0) != op_rank) {
441         return errors::InvalidArgument(
442             "padding_high must be a 1D tensor of size ", op_rank);
443       }
444       if (padding_interior_tensor->shape().dims() != 1 ||
445           padding_interior_tensor->shape().dim_size(0) != op_rank) {
446         return errors::InvalidArgument(
447             "padding_interior must be a 1D tensor of size ", op_rank);
448       }
449       std::vector<shape_inference::DimensionHandle> output_dims;
450       output_dims.reserve(op_rank);
451       for (int64 i = 0; i < op_rank; ++i) {
452         int64 low, high, interior;
453         TF_RETURN_IF_ERROR(c->GetScalarFromTensor(padding_low_tensor, i, &low));
454         TF_RETURN_IF_ERROR(
455             c->GetScalarFromTensor(padding_high_tensor, i, &high));
456         TF_RETURN_IF_ERROR(
457             c->GetScalarFromTensor(padding_interior_tensor, i, &interior));
458         if (interior < 0) {
459           return errors::InvalidArgument(
460               "padding_interior must contain only non-negative values, found ",
461               interior);
462         }
463 
464         shape_inference::DimensionHandle orig_size_handle =
465             c->Dim(input_shape_handle, i);
466         if (c->ValueKnown(orig_size_handle)) {
467           auto orig_dim = c->Value(orig_size_handle);
468           int64 new_dim = orig_dim + low + high;
469           if (orig_dim > 0) {
470             new_dim += interior * (orig_dim - 1);
471           }
472           if (new_dim < 0) {
473             return errors::InvalidArgument(
474                 "resulting padded dimension has negative size ", new_dim);
475           }
476           output_dims.emplace_back(c->MakeDim(new_dim));
477         } else {
478           output_dims.emplace_back(c->UnknownDim());
479         }
480       }
481 
482       c->set_output(0, c->MakeShape(output_dims));
483       return Status::OK();
484     })
485     .Doc(R"doc(
486 Wraps the XLA Pad operator, documented at
487  https://www.tensorflow.org/performance/xla/operation_semantics#pad
488 .
489 
490 input: A `Tensor` of type T.
491 padding_value: A scalar `Tensor` of type T.
492 padding_low: the padding to apply at the start of each input dimensions. Must
493   be a compile-time constant 1D tensor of length equal to rank of input.
494 padding_high: the padding to apply at the end of each input dimension. Must
495   be a compile-time constant 1D tensor of length equal to rank of input.
496 padding_interior: the padding to apply between each input element. Must
497   be a compile-time constant 1D tensor of length equal to rank of input,
498   containing only non-negative values.
499 output: A `Tensor` of type T.
500 )doc");
501 
502 REGISTER_OP("XlaRecv")
503     .Output("tensor: dtype")
504     .Attr("dtype: type")
505     .Attr("tensor_name: string")
506     .Attr("shape: shape")
507     .SetIsStateful()
__anon3500048a0402(shape_inference::InferenceContext* c) 508     .SetShapeFn([](shape_inference::InferenceContext* c) {
509       TensorShape shape_attr;
510       TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_attr));
511       shape_inference::ShapeHandle s;
512       TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s));
513       c->set_output(0, s);
514       return Status::OK();
515     })
516     .Doc(R"doc(
517 Receives the named tensor from another XLA computation. Wraps the XLA Recv
518 operator documented at
519  https://www.tensorflow.org/performance/xla/operation_semantics#recv .
520 
521 tensor: The tensor to receive.
522 dtype: The type of the tensor.
523 tensor_name: A string key that identifies the channel.
524 shape: The shape of the tensor.
525 )doc");
526 
527 REGISTER_OP("XlaReduce")
528     .Input("input: T")
529     .Input("init_value: T")
530     .Attr("T: numbertype")
531     .Attr("dimensions_to_reduce: list(int)")
532     .Attr("reducer: func")
533     .Output("output: T")
__anon3500048a0502(shape_inference::InferenceContext* c) 534     .SetShapeFn([](shape_inference::InferenceContext* c) {
535       if (c->RankKnown(c->input(0))) {
536         int rank = c->Rank(c->input(0));
537         std::vector<int64> dimensions_to_reduce;
538         TF_RETURN_IF_ERROR(
539             c->GetAttr("dimensions_to_reduce", &dimensions_to_reduce));
540         std::set<int64> dims_set(dimensions_to_reduce.begin(),
541                                  dimensions_to_reduce.end());
542         auto dim_in_range = [rank](int64 dim) {
543           return dim >= 0 && dim < rank;
544         };
545         const int dimensions_to_reduce_size = dimensions_to_reduce.size();
546         if (rank < dimensions_to_reduce_size ||
547             dims_set.size() != dimensions_to_reduce.size() ||
548             !absl::c_all_of(dimensions_to_reduce, dim_in_range)) {
549           return errors::InvalidArgument(
550               "Invalid dimensions_to_reduce argument to XlaReduce");
551         }
552         c->set_output(
553             0, c->UnknownShapeOfRank(rank - dimensions_to_reduce.size()));
554       } else {
555         c->set_output(0, c->input(0));
556       }
557       return Status::OK();
558     })
559     .Doc(R"doc(
560 Wraps the XLA Reduce operator, documented at
561  https://www.tensorflow.org/performance/xla/operation_semantics#reduce .
562 
563 input: the input tensor
564 init_value: a scalar representing the initial value for the reduction
565 reducer: a reducer function to apply
566 dimensions_to_reduce: dimension numbers over which to reduce
567 )doc");
568 
569 REGISTER_OP("XlaVariadicReduce")
570     .Input("input: N * T")
571     .Input("init_value: N * T")
572     .Attr("N: int >= 1")
573     .Attr("T: numbertype")
574     .Attr("dimensions_to_reduce: list(int)")
575     .Attr("reducer: func")
576     .Output("output: N * T")
__anon3500048a0702(shape_inference::InferenceContext* c) 577     .SetShapeFn([](shape_inference::InferenceContext* c) {
578       int n;
579       TF_RETURN_IF_ERROR(c->GetAttr("N", &n));
580       for (int i = 0; i < n; i++) {
581         for (int j = 0; j < n; j++) {
582           c->MergeInput(i, c->input(j));
583         }
584       }
585       if (c->RankKnown(c->input(0))) {
586         int rank = c->Rank(c->input(0));
587         std::vector<int64> dimensions_to_reduce;
588         TF_RETURN_IF_ERROR(
589             c->GetAttr("dimensions_to_reduce", &dimensions_to_reduce));
590         std::set<int64> dims_set(dimensions_to_reduce.begin(),
591                                  dimensions_to_reduce.end());
592         auto dim_in_range = [rank](int64 dim) {
593           return dim >= 0 && dim < rank;
594         };
595         const int dimensions_to_reduce_size = dimensions_to_reduce.size();
596         if (rank < dimensions_to_reduce_size ||
597             dims_set.size() != dimensions_to_reduce.size() ||
598             !absl::c_all_of(dimensions_to_reduce, dim_in_range)) {
599           return errors::InvalidArgument(
600               "Invalid dimensions_to_reduce argument to XlaVariadicReduce");
601         }
602         for (int i = 0; i < n; i++) {
603           c->set_output(
604               i, c->UnknownShapeOfRank(rank - dimensions_to_reduce.size()));
605         }
606       } else {
607         for (int i = 0; i < n; i++) {
608           c->set_output(i, c->input(i));
609         }
610       }
611       return Status::OK();
612     })
613     .Doc(R"doc(
614 Wraps the variadic XLA Reduce operator.
615 
616 Semantics are documented at
617  https://www.tensorflow.org/performance/xla/operation_semantics#variadic_reduce.
618 
619 input: the input tensor(s)
620 init_value: scalar initial value(s) for the reduction
621 reducer: a reducer function to apply
622 dimensions_to_reduce: dimension numbers over which to reduce
623 )doc");
624 
625 REGISTER_OP("XlaReduceWindow")
626     .Input("input: T")
627     .Input("init_value: T")
628     .Input("window_dimensions: Tindices")
629     .Input("window_strides: Tindices")
630     .Input("base_dilations: Tindices")
631     .Input("window_dilations: Tindices")
632     .Input("padding: Tindices")
633     .Attr("T: numbertype")
634     .Attr("Tindices: {int32, int64}")
635     .Attr("computation: func")
636     .Output("output: T")
637     .SetShapeFn(UnchangedRank)
638     .Doc(R"doc(
639 Wraps the XLA ReduceWindow operator, documented at
640  https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow .
641 
642 input: the input tensor
643 init_value: a scalar representing the initial value for the reduction
644 computation: a reducer function to apply
645 window_dimensions: the shape of the window
646 window_strides: the inter-window strides
647 padding: the padding to apply at the start and end of each input dimensions
648 )doc");
649 
650 REGISTER_OP("XlaSelectAndScatter")
651     .Input("operand: T")
652     .Input("window_dimensions: Tindices")
653     .Input("window_strides: Tindices")
654     .Input("padding: Tindices")
655     .Input("source: T")
656     .Input("init_value: T")
657     .Attr("T: numbertype")
658     .Attr("Tindices: {int32, int64}")
659     .Attr("select: func")
660     .Attr("scatter: func")
661     .Output("output: T")
662     .SetShapeFn(UnchangedRank)
663     .Doc(R"doc(
664 Wraps the XLA SelectAndScatter operator, documented at
665  https://www.tensorflow.org/performance/xla/operation_semantics#selectandscatter
666 .
667 
668 operand: the input tensor
669 window_dimensions: the shape of the window
670 window_strides: the inter-window strides
671 padding: the padding to apply at the start and end of each input dimensions
672 source: a tensor of values to scatter
673 init_value: a scalar representing the initial value for the output tensor
674 select: a selection function to apply
675 scatter: a scatter function to apply
676 )doc");
677 
678 REGISTER_OP("XlaSend")
679     .Input("tensor: T")
680     .Attr("T: type")
681     .Attr("tensor_name: string")
682     .SetIsStateful()
683     .SetShapeFn(shape_inference::UnknownShape)
684     .Doc(R"doc(
685 Sends the named tensor to another XLA computation. Wraps the XLA Send operator
686 documented at
687  https://www.tensorflow.org/performance/xla/operation_semantics#send .
688 
689 tensor: The tensor to send.
690 tensor_name: A string key that identifies the channel.
691 )doc");
692 
693 REGISTER_OP("XlaSort")
694     .Input("input: T")
695     .Output("output: T")
696     .Attr("T: type")
697     .SetShapeFn(shape_inference::UnchangedShape)
698     .Doc(R"doc(
699 Wraps the XLA Sort operator, documented at
700  https://www.tensorflow.org/performance/xla/operation_semantics#sort
701 .
702 
703 Sorts a tensor. Currently only sorts in ascending order are supported.
704 
705 input: A `Tensor` of type T.
706 output: A `Tensor` of type T.
707 )doc");
708 
709 REGISTER_OP("XlaKeyValueSort")
710     .Input("keys: K")
711     .Input("values: V")
712     .Output("sorted_keys: K")
713     .Output("sorted_values: V")
714     .Attr("K: realnumbertype")
715     .Attr("V: type")
__anon3500048a0902(shape_inference::InferenceContext* c) 716     .SetShapeFn([](shape_inference::InferenceContext* c) {
717       c->set_output(0, c->input(0));
718       c->set_output(1, c->input(1));
719       return Status::OK();
720     })
721     .Doc(R"doc(
722 Wraps the XLA Sort operator, documented at
723  https://www.tensorflow.org/performance/xla/operation_semantics#sort
724 .
725 
726 Sorts a tensor. Currently only sorts in ascending order are supported.
727 
728 keys: A `Tensor` of type K.
729 values: A `Tensor` of type V.
730 sorted_keys: A `Tensor` of type K.
731 sorted_values: A `Tensor` of type V.
732 )doc");
733 
734 REGISTER_OP("XlaVariadicSort")
735     .Input("inputs: T")
736     .Input("dimension: int32")
737     .Output("outputs: T")
738     .Attr("T: list(type) >= 1")
739     .Attr("comparator: func")
740     .Attr("is_stable: bool")
__anon3500048a0a02(shape_inference::InferenceContext* c) 741     .SetShapeFn([](shape_inference::InferenceContext* c) {
742       std::vector<shape_inference::ShapeHandle> input_shapes;
743       TF_RETURN_IF_ERROR(c->input("inputs", &input_shapes));
744       TF_RETURN_IF_ERROR(c->set_output("outputs", input_shapes));
745       return Status::OK();
746     })
747     .Doc(R"doc(
748 Wraps the XLA Sort operator, documented at
749  https://www.tensorflow.org/performance/xla/operation_semantics#sort
750 .
751 
752 Sorts one or more tensors, with support for custom comparator, dimension, and
753 is_stable attributes.
754 
755 inputs: A list of `Tensor` of identical shape but possibly different types.
756 dimension: The dimension along which to sort. Must be a compile-time constant.
757 is_stable: Whether to use stable sort.
758 comparator: A comparator function to apply to 2*N scalars and returning a
759   boolean. N is the number of sort inputs. If you want to sort in ascending
760   order then the comparator should perform a less-than comparison.
761 outputs: A list of `Tensor` of same shape and types as the `input`.
762 )doc");
763 
764 // TODO(b/37549631) setting the While Op to always be stateful is too
765 // conservative.
766 REGISTER_OP("XlaWhile")
767     .Input("input: T")
768     .Output("output: T")
769     .Attr("T: list(type) >= 0")
770     .Attr("cond: func")
771     .Attr("body: func")
772     .SetIsStateful()
773     .SetShapeFn(shape_inference::UnknownShape)
774     .Doc(R"doc(
775 output = input; While (Cond(output)) { output = Body(output) }
776 
777 input: A list of input tensors whose types are T.
778 output: A list of output tensors whose types are T.
779 cond: A function takes 'input' and returns a tensor.  If the tensor is
780       a scalar of non-boolean, the scalar is converted to a boolean
781       according to the following rule: if the scalar is a numerical
782       value, non-zero means True and zero means False; if the scalar is
783       a string, non-empty means True and empty means False. If the
784       tensor is not a scalar, non-emptiness means True and False
785       otherwise.
786 body: A function that takes a list of tensors and returns another
787       list of tensors. Both lists have the same types as specified by T.
788 )doc");
789 
790 REGISTER_OP("XlaDequantize")
791     .Input("input: uint32")
792     .Output("output: bfloat16")
793     .Attr("min_range: float")
794     .Attr("max_range: float")
795     .Attr("mode: string")
796     .Attr("transpose_output: bool")
797     .SetIsStateful()
798     .SetShapeFn(shape_inference::UnknownShape)
799     .Doc(R"doc(
800 Takes the packed uint32 input and unpacks the input to uint8 to do
801 Dequantization on device.
802 
803 input: Input tensors whose types is uint32, shape is [d0, ..., dn].
804 output: Output tensors whose types is bloat16. If transpose_output is true,
805      output shape is [dn * 4, dn-1, ..., d1, d0]. If transpose_output
806      is false, output shape is [d0,..., dn * 4].
807 min_range: The minimum scalar value possibly produced for the input.
808 max_range: The maximum scalar value possibly produced for the input.
809 mode: String to determine the dequantize mode in {"MIN_COMBINED", "MIN_FIRST", "SCALED"}.
810 transpose_output: Boolean to determine if output is transposed. transpose_output
811      is faster when input is large and rank of input is higher than 1.
812 )doc");
813 
814 REGISTER_OP("XlaEinsum")
815     .Input("a: T")
816     .Input("b: T")
817     .Output("product: T")
818     .Attr("equation: string")
819     .Attr("T: {complex64, bfloat16, float}")
__anon3500048a0b02(shape_inference::InferenceContext* context) 820     .SetShapeFn([](shape_inference::InferenceContext* context) {
821       string equation;
822       TF_RETURN_IF_ERROR(context->GetAttr("equation", &equation));
823       // XlaEinsum supports only two-input einsum equations.
824       if (!absl::StrContains(equation, ",")) {
825         return errors::InvalidArgument("Expected one \",\" in equation. Got: ",
826                                        equation);
827       }
828       // Use EinsumShape for the rest of the inference now that we know we must
829       // have a two-input einsum.
830       return shape_inference::EinsumShape(context);
831     })
832     .Doc(R"doc(
833 An op which supports basic einsum op with 2 inputs and 1 output.
834 
835 This op has better TPU performance since it doesn't have explicitly reshape and
836 transpose operations as tf.einsum does.
837 )doc");
838 
839 REGISTER_OP("XlaSpmdFullToShardShape")
840     .Input("input: T")
841     .Output("output: T")
842     .Attr("T: type")
843     .Attr("manual_sharding: string")
__anon3500048a0c02(shape_inference::InferenceContext* c) 844     .SetShapeFn([](shape_inference::InferenceContext* c) {
845       auto input_handle = c->input(0);
846       if (!c->RankKnown(input_handle)) {
847         return shape_inference::UnknownShape(c);
848       }
849       string sharding_attr;
850       TF_RETURN_IF_ERROR(c->GetAttr("manual_sharding", &sharding_attr));
851       xla::OpSharding sharding;
852       sharding.ParseFromString(sharding_attr);
853       if (sharding.type() != xla::OpSharding::OTHER) {
854         return shape_inference::UnchangedShape(c);
855       }
856       std::vector<shape_inference::DimensionHandle> dims;
857       for (int64 i = 0; i < c->Rank(input_handle); ++i) {
858         auto dim = c->Value(c->Dim(input_handle, i));
859         int64 partitions_i = sharding.tile_assignment_dimensions(i);
860         if (dim != shape_inference::InferenceContext::kUnknownDim &&
861             partitions_i != 1) {
862           dim = (dim + partitions_i - 1) / partitions_i;
863         }
864         dims.push_back(c->MakeDim(dim));
865       }
866       c->set_output(0, c->MakeShape(dims));
867       return Status::OK();
868     })
869     .Doc(R"doc(
870 An op used by XLA SPMD partitioner to switch from automatic partitioning to
871 manual partitioning. It annotates the input (full-shape, to be automatically
872 partitioned) with the same sharding used by manual partitioning, and outputs a
873 shard-shaped tensor to be consumed by later manually-partitioned ops. If the
874 shape is not evenly partitionable, the padding region will be masked with 0s.
875 )doc");
876 
877 REGISTER_OP("XlaSpmdShardToFullShape")
878     .Input("input: T")
879     .Output("output: T")
880     .Attr("T: type")
881     .Attr("manual_sharding: string")
882     .Attr("full_shape: shape")
__anon3500048a0d02(shape_inference::InferenceContext* c) 883     .SetShapeFn([](shape_inference::InferenceContext* c) {
884       TensorShape shape_attr;
885       TF_RETURN_IF_ERROR(c->GetAttr("full_shape", &shape_attr));
886       shape_inference::ShapeHandle s;
887       TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s));
888       c->set_output(0, s);
889       return Status::OK();
890     })
891     .Doc(R"doc(
892 An op used by XLA SPMD partitioner to switch from manual partitioning to
893 automatic partitioning. It converts the shard-shaped, manually partitioned input
894 into full-shaped tensor to be partitioned automatically with the same sharding
895 used by manual partitioning.
896 )doc");
897 
898 REGISTER_OP("XlaSharding")
899     .Input("input: T")
900     .Output("output: T")
901     .Attr("T: type")
902     .Attr("sharding: string = ''")
903     .SetShapeFn(shape_inference::UnchangedShape)
904     .Doc(R"doc(
905 An op which shards the input based on the given sharding attribute.
906 )doc");
907 
908 REGISTER_OP("XlaReplicaId")
909     .Output("id: int32")
__anon3500048a0e02(shape_inference::InferenceContext* context) 910     .SetShapeFn([](shape_inference::InferenceContext* context) {
911       context->set_output(0, context->MakeShape({}));
912       return Status::OK();
913     })
914     .Doc("Replica ID.");
915 
916 REGISTER_OP("XlaGather")
917     .Input("operand: T")
918     .Input("start_indices: Tindices")
919     .Input("slice_sizes: Tindices")
920     .Attr("dimension_numbers: string")
921     .Attr("indices_are_sorted: bool")
922     .Attr("T: {numbertype, bool}")
923     .Attr("Tindices: {int32, int64}")
924     .Output("output: T")
925     .SetShapeFn(shape_inference::UnknownShape)
926     .Doc(R"doc(
927 Wraps the XLA Gather operator documented at
928   https://www.tensorflow.org/xla/operation_semantics#gather
929 operand: The array we're gathering from.
930 start_indices: Array containing the starting indices of the slices we gather.
931 dimension_numbers: A serialized xla::GatherDimensionNumbers proto.
932 slice_sizes: slice_sizes[i] is the bounds for the slice on dimension i.
933 indices_are_sorted: Boolean indicating if the indices are sorted.
934 )doc");
935 
936 REGISTER_OP("XlaScatter")
937     .Input("operand: T")
938     .Input("scatter_indices: Tindices")
939     .Input("updates: T")
940     .Attr("update_computation: func")
941     .Attr("dimension_numbers: string")
942     .Attr("indices_are_sorted: bool")
943     .Attr("T: {numbertype, bool}")
944     .Attr("Tindices: {int32, int64}")
945     .Output("output: T")
946     .SetShapeFn(shape_inference::UnchangedShape)
947     .Doc(R"doc(
948 Wraps the XLA Scatter operator documented at
949   https://www.tensorflow.org/xla/operation_semantics#scatter.
950 
951 operand: Array to be scattered into.
952 scatter_indices: Array containing the starting indices of the slices that must
953   be scattered to.
954 updates: Array containing the values that must be used for scattering.
955 update_computation: Computation to be used for combining the existing values in
956   the input array and the updates during scatter.
957 dimension_numbers: A serialized xla::ScatterDimensionNumbers proto.
958 indices_are_sorted: Boolean indicating if the indices are sorted.
959 )doc");
960 
961 }  // namespace
962 }  // namespace tensorflow
963