1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "absl/algorithm/container.h"
17 #include "absl/strings/str_cat.h"
18 #include "absl/strings/str_split.h"
19 #include "tensorflow/core/framework/common_shape_fns.h"
20 #include "tensorflow/core/framework/op.h"
21 #include "tensorflow/core/framework/shape_inference.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 
24 namespace tensorflow {
25 namespace {
26 
27 // Helper shape function for operators that return an output with the same rank
28 // as their first input.
UnchangedRank(shape_inference::InferenceContext * c)29 Status UnchangedRank(shape_inference::InferenceContext* c) {
30   if (c->RankKnown(c->input(0))) {
31     c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0))));
32   } else {
33     c->set_output(0, c->input(0));
34   }
35   return Status::OK();
36 }
37 
38 REGISTER_OP("XlaBroadcastHelper")
39     .Input("lhs: T")
40     .Input("rhs: T")
41     .Input("broadcast_dims: Tindices")
42     .Attr("T: numbertype")
43     .Attr("Tindices: {int32, int64}")
44     .Output("lhs_output: T")
45     .Output("rhs_output: T")
46     .SetShapeFn(shape_inference::UnknownShape)
47     .Doc(R"doc(
48 Helper operator for performing XLA-style broadcasts
49 
50 Broadcasts `lhs` and `rhs` to the same rank, by adding size 1 dimensions to
51 whichever of `lhs` and `rhs` has the lower rank, using XLA's broadcasting rules
52 for binary operators.
53 
54 lhs: the LHS input tensor
55 rhs: the RHS input tensor
56 broadcast_dims: an XLA-style broadcast dimension specification
57 lhs_output: the broadcasted LHS tensor
58 rhs_output: the broadcasted RHS tensor
59 )doc");
60 
61 REGISTER_OP("XlaSelfAdjointEig")
62     .Input("a: T")
63     .Attr("lower: bool")
64     .Attr("max_iter: int")
65     .Attr("epsilon: float")
66     .Output("w: T")
67     .Output("v: T")
68     .SetShapeFn(shape_inference::UnknownShape)
69     .Attr("T: numbertype")
70     .Doc(R"doc(
71 Computes the eigen decomposition of a batch of self-adjoint matrices
72 (Note: Only real inputs are supported).
73 
74 Computes the eigenvalues and eigenvectors of the innermost N-by-N matrices in
75 tensor such that tensor[...,:,:] * v[..., :,i] = e[..., i] * v[...,:,i], for
76 i=0...N-1.
77 
78 a: the input tensor.
79 
80 lower: a boolean specifies whether the calculation is done with the lower
81   triangular part or the upper triangular part.
82 
83 max_iter: maximum number of sweep update, i.e., the whole lower triangular
84   part or upper triangular part based on parameter lower. Heuristically, it has
85   been argued that approximatly logN sweeps are needed in practice (Ref: Golub &
86   van Loan "Matrix Computation").
87 
88 epsilon: the tolerance ratio.
89 
90 w: The eigenvalues in ascending order, each repeated according to its
91   multiplicity.
92 v: The column v[..., :, i] is the normalized eigenvector corresponding to the
93   eigenvalue w[..., i].
94 )doc");
95 
96 REGISTER_OP("XlaSvd")
97     .Input("a: T")
98     .Attr("max_iter: int")
99     .Attr("epsilon: float")
100     .Attr("precision_config: string")
101     .Output("s: T")
102     .Output("u: T")
103     .Output("v: T")
104     .SetShapeFn(shape_inference::UnknownShape)
105     .Attr("T: numbertype")
106     .Doc(R"doc(
107 Computes the eigen decomposition of a batch of self-adjoint matrices
108 (Note: Only real inputs are supported).
109 
110 Computes the eigenvalues and eigenvectors of the innermost M-by-N matrices in
111 tensor such that tensor[...,:,:] = u[..., :, :] * Diag(s[..., :]) * Transpose(v[...,:,:]).
112 
113 a: the input tensor.
114 
115 max_iter: maximum number of sweep update, i.e., the whole lower triangular
116   part or upper triangular part based on parameter lower. Heuristically, it has
117   been argued that approximatly log(min (M, N)) sweeps are needed in practice
118   (Ref: Golub & van Loan "Matrix Computation").
119 
120 epsilon: the tolerance ratio.
121 
122 precision_config: a serialized xla::PrecisionConfig proto.
123 
124 s: Singular values. The values are sorted in reverse order of magnitude, so
125   s[..., 0] is the largest value, s[..., 1] is the second largest, etc.
126 u: Left singular vectors.
127 v: Right singular vectors.
128 )doc");
129 
130 REGISTER_OP("XlaConv")
131     .Input("lhs: T")
132     .Input("rhs: T")
133     .Input("window_strides: Tindices")
134     .Input("padding: Tindices")
135     .Input("lhs_dilation: Tindices")
136     .Input("rhs_dilation: Tindices")
137     .Input("feature_group_count: Tindices")
138     .Attr("T: numbertype")
139     .Attr("Tindices: {int32, int64}")
140     .Attr("dimension_numbers: string")
141     .Attr("precision_config: string")
142     .Output("output: T")
143     .SetShapeFn(UnchangedRank)
144     .Doc(R"doc(
145 Wraps the XLA ConvGeneralDilated operator, documented at
146  https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
147 .
148 
149 lhs: the input tensor
150 rhs: the kernel tensor
151 window_strides: the inter-window strides
152 padding: the padding to apply at the start and end of each input dimensions
153 lhs_dilation: dilation to apply between input elements
154 rhs_dilation: dilation to apply between kernel elements
155 feature_group_count: number of feature groups for grouped convolution.
156 dimension_numbers: a serialized xla::ConvolutionDimensionNumbers proto.
157 precision_config: a serialized xla::PrecisionConfig proto.
158 )doc");
159 
160 REGISTER_OP("XlaDot")
161     .Input("lhs: T")
162     .Input("rhs: T")
163     .Attr("T: numbertype")
164     .Attr("dimension_numbers: string")
165     .Attr("precision_config: string")
166     .Output("output: T")
167     .SetShapeFn(shape_inference::UnknownShape)
168     .Doc(R"doc(
169 Wraps the XLA ConvGeneralDilated operator, documented at
170  https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral
171 .
172 
173 lhs: the LHS tensor
174 rhs: the RHS tensor
175 dimension_numbers: a serialized xla::DotDimensionNumbers proto.
176 precision_config: a serialized xla::PrecisionConfig proto.
177 )doc");
178 
179 REGISTER_OP("XlaDynamicSlice")
180     .Input("input: T")
181     .Input("start_indices: Tindices")
182     .Input("size_indices: Tindices")
183     .Output("output: T")
184     .Attr("T: type")
185     .Attr("Tindices: {int32, int64}")
186     .SetShapeFn(shape_inference::UnknownShape)
187     .Doc(R"doc(
188 Wraps the XLA DynamicSlice operator, documented at
189  https://www.tensorflow.org/performance/xla/operation_semantics#dynamicslice
190 .
191 
192 DynamicSlice extracts a sub-array from the input array at dynamic
193 start_indices. The size of the slice in each dimension is passed in
194 size_indices, which specify the end point of exclusive slice intervals in each
195 dimension -- [start, start + size). The shape of start_indices must have rank 1,
196 with dimension size equal to the rank of operand.
197 
198 input: A `Tensor` of type T.
199 
200 start_indices: Rank 1 tensor of N integers containing the starting indices of
201   the slice for each dimension. Value must be greater than or equal to zero.
202 
203 start_indices: List of N integers containing the slice size for each
204   dimension. Each value must be strictly greater than zero, and start + size
205   must be less than or equal to the size of the dimension to avoid
206   implementation defined behavior.
207 )doc");
208 
209 REGISTER_OP("XlaDynamicUpdateSlice")
210     .Input("input: T")
211     .Input("update: T")
212     .Input("indices: Tindices")
213     .Output("output: T")
214     .Attr("T: type")
215     .Attr("Tindices: {int32, int64}")
216     .SetShapeFn(shape_inference::UnchangedShape)
217     .Doc(R"doc(
218 Wraps the XLA DynamicUpdateSlice operator, documented at
219  https://www.tensorflow.org/performance/xla/operation_semantics#dynamicupdateslice
220 .
221 
222 XlaDynamicUpdateSlice generates a result which is the value of the `input`
223 operand, with a slice update overwritten at `indices`. The shape of `update`
224 determines the shape of the sub-array of the result which is updated. The shape
225 of indices must be rank == 1, with dimension size equal to the rank of `input`.
226 
227 Handling of out-of-bounds slice indices is implementation-defined.
228 
229 input: A `Tensor` of type T.
230 indices: A vector of indices into `input`. Must have length equal to the rank of
231   `input`.
232 update: A `Tensor` of type T. Same rank as `input`.
233 output: A `Tensor` of type T.
234 )doc");
235 
236 // TODO(b/37549631) setting the If Op to always be stateful is too
237 // conservative.
238 REGISTER_OP("XlaIf")
239     .Input("cond: Tcond")
240     .Input("inputs: Tin")
241     .Output("output: Tout")
242     .Attr("Tcond: type")
243     .Attr("then_branch: func")
244     .Attr("else_branch: func")
245     .Attr("Tin: list(type) >= 0")
246     .Attr("Tout: list(type) >= 0")
247     .SetIsStateful()
248     .SetShapeFn(shape_inference::UnknownShape)
249     .Doc(R"doc(
250 output = cond ? then_branch(inputs) : else_branch(inputs).
251 
252 cond: A boolean scalar.
253 inputs: A list of input tensors.
254 output: A list of tensors returned by either then_branch(inputs) or
255         else_branch(inputs). The input shapes of the then_branch and
256         else_branch must match.
257 then_branch: A function takes 'inputs' and returns a list of tensors,
258              whose types are the same as what else_branch returns.
259 else_branch: A function takes 'inputs' and returns a list of tensors.
260              whose types are the same as what then_branch returns.
261 )doc");
262 
263 REGISTER_OP("XlaPad")
264     .Input("input: T")
265     .Input("padding_value: T")
266     .Input("padding_low: Tindices")
267     .Input("padding_high: Tindices")
268     .Input("padding_interior: Tindices")
269     .Output("output: T")
270     .Attr("T: type")
271     .Attr("Tindices: {int32, int64}")
272     .SetShapeFn(UnchangedRank)
273     .Doc(R"doc(
274 Wraps the XLA Pad operator, documented at
275  https://www.tensorflow.org/performance/xla/operation_semantics#pad
276 .
277 
278 input: A `Tensor` of type T.
279 padding_value: A scalar `Tensor` of type T.
280 padding_low: the padding to apply at the start of each input dimensions
281 padding_high: the padding to apply at the end of each input dimension.
282 padding_interior: the padding to apply between each input element.
283 output: A `Tensor` of type T.
284 )doc");
285 
286 REGISTER_OP("XlaRecv")
287     .Output("tensor: dtype")
288     .Attr("dtype: type")
289     .Attr("tensor_name: string")
290     .Attr("shape: shape")
291     .SetIsStateful()
__anon3500048a0202(shape_inference::InferenceContext* c) 292     .SetShapeFn([](shape_inference::InferenceContext* c) {
293       TensorShape shape_attr;
294       TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_attr));
295       shape_inference::ShapeHandle s;
296       TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s));
297       c->set_output(0, s);
298       return Status::OK();
299     })
300     .Doc(R"doc(
301 Receives the named tensor from another XLA computation. Wraps the XLA Recv
302 operator documented at
303  https://www.tensorflow.org/performance/xla/operation_semantics#recv .
304 
305 tensor: The tensor to receive.
306 dtype: The type of the tensor.
307 tensor_name: A string key that identifies the channel.
308 shape: The shape of the tensor.
309 )doc");
310 
311 REGISTER_OP("XlaReduce")
312     .Input("input: T")
313     .Input("init_value: T")
314     .Attr("T: numbertype")
315     .Attr("dimensions_to_reduce: list(int)")
316     .Attr("reducer: func")
317     .Output("output: T")
__anon3500048a0302(shape_inference::InferenceContext* c) 318     .SetShapeFn([](shape_inference::InferenceContext* c) {
319       if (c->RankKnown(c->input(0))) {
320         int rank = c->Rank(c->input(0));
321         std::vector<int64> dimensions_to_reduce;
322         TF_RETURN_IF_ERROR(
323             c->GetAttr("dimensions_to_reduce", &dimensions_to_reduce));
324         std::set<int64> dims_set(dimensions_to_reduce.begin(),
325                                  dimensions_to_reduce.end());
326         auto dim_in_range = [rank](int64 dim) {
327           return dim >= 0 && dim < rank;
328         };
329         if (rank < dimensions_to_reduce.size() ||
330             dims_set.size() != dimensions_to_reduce.size() ||
331             !absl::c_all_of(dimensions_to_reduce, dim_in_range)) {
332           return errors::InvalidArgument(
333               "Invalid dimensions_to_reduce argument to XlaReduce");
334         }
335         c->set_output(
336             0, c->UnknownShapeOfRank(rank - dimensions_to_reduce.size()));
337       } else {
338         c->set_output(0, c->input(0));
339       }
340       return Status::OK();
341     })
342     .Doc(R"doc(
343 Wraps the XLA Reduce operator, documented at
344  https://www.tensorflow.org/performance/xla/operation_semantics#reduce .
345 
346 input: the input tensor
347 init_value: a scalar representing the initial value for the reduction
348 reducer: a reducer function to apply
349 dimensions_to_reduce: dimension numbers over which to reduce
350 )doc");
351 
352 REGISTER_OP("XlaReduceWindow")
353     .Input("input: T")
354     .Input("init_value: T")
355     .Input("window_dimensions: Tindices")
356     .Input("window_strides: Tindices")
357     .Input("base_dilations: Tindices")
358     .Input("window_dilations: Tindices")
359     .Input("padding: Tindices")
360     .Attr("T: numbertype")
361     .Attr("Tindices: {int32, int64}")
362     .Attr("computation: func")
363     .Output("output: T")
364     .SetShapeFn(UnchangedRank)
365     .Doc(R"doc(
366 Wraps the XLA ReduceWindow operator, documented at
367  https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow .
368 
369 input: the input tensor
370 init_value: a scalar representing the initial value for the reduction
371 computation: a reducer function to apply
372 window_dimensions: the shape of the window
373 window_strides: the inter-window strides
374 padding: the padding to apply at the start and end of each input dimensions
375 )doc");
376 
377 REGISTER_OP("XlaSelectAndScatter")
378     .Input("operand: T")
379     .Input("window_dimensions: Tindices")
380     .Input("window_strides: Tindices")
381     .Input("padding: Tindices")
382     .Input("source: T")
383     .Input("init_value: T")
384     .Attr("T: numbertype")
385     .Attr("Tindices: {int32, int64}")
386     .Attr("select: func")
387     .Attr("scatter: func")
388     .Output("output: T")
389     .SetShapeFn(UnchangedRank)
390     .Doc(R"doc(
391 Wraps the XLA SelectAndScatter operator, documented at
392  https://www.tensorflow.org/performance/xla/operation_semantics#selectandscatter
393 .
394 
395 operand: the input tensor
396 window_dimensions: the shape of the window
397 window_strides: the inter-window strides
398 padding: the padding to apply at the start and end of each input dimensions
399 source: a tensor of values to scatter
400 init_value: a scalar representing the initial value for the output tensor
401 select: a selection function to apply
402 scatter: a scatter function to apply
403 )doc");
404 
405 REGISTER_OP("XlaSend")
406     .Input("tensor: T")
407     .Attr("T: type")
408     .Attr("tensor_name: string")
409     .SetIsStateful()
410     .SetShapeFn(shape_inference::UnknownShape)
411     .Doc(R"doc(
412 Sends the named tensor to another XLA computation. Wraps the XLA Send operator
413 documented at
414  https://www.tensorflow.org/performance/xla/operation_semantics#send .
415 
416 tensor: The tensor to send.
417 tensor_name: A string key that identifies the channel.
418 )doc");
419 
420 REGISTER_OP("XlaSort")
421     .Input("input: T")
422     .Output("output: T")
423     .Attr("T: type")
424     .SetShapeFn(shape_inference::UnchangedShape)
425     .Doc(R"doc(
426 Wraps the XLA Sort operator, documented at
427  https://www.tensorflow.org/performance/xla/operation_semantics#sort
428 .
429 
430 Sorts a tensor. Currently only sorts in ascending order are supported.
431 
432 input: A `Tensor` of type T.
433 output: A `Tensor` of type T.
434 )doc");
435 
436 REGISTER_OP("XlaKeyValueSort")
437     .Input("keys: K")
438     .Input("values: V")
439     .Output("sorted_keys: K")
440     .Output("sorted_values: V")
441     .Attr("K: realnumbertype")
442     .Attr("V: type")
__anon3500048a0502(shape_inference::InferenceContext* c) 443     .SetShapeFn([](shape_inference::InferenceContext* c) {
444       c->set_output(0, c->input(0));
445       c->set_output(1, c->input(1));
446       return Status::OK();
447     })
448     .Doc(R"doc(
449 Wraps the XLA Sort operator, documented at
450  https://www.tensorflow.org/performance/xla/operation_semantics#sort
451 .
452 
453 Sorts a tensor. Currently only sorts in ascending order are supported.
454 
455 keys: A `Tensor` of type K.
456 values: A `Tensor` of type V.
457 sorted_keys: A `Tensor` of type K.
458 sorted_values: A `Tensor` of type V.
459 )doc");
460 
461 // TODO(b/37549631) setting the While Op to always be stateful is too
462 // conservative.
463 REGISTER_OP("XlaWhile")
464     .Input("input: T")
465     .Output("output: T")
466     .Attr("T: list(type) >= 0")
467     .Attr("cond: func")
468     .Attr("body: func")
469     .SetIsStateful()
470     .SetShapeFn(shape_inference::UnknownShape)
471     .Doc(R"doc(
472 output = input; While (Cond(output)) { output = Body(output) }
473 
474 input: A list of input tensors whose types are T.
475 output: A list of output tensors whose types are T.
476 cond: A function takes 'input' and returns a tensor.  If the tensor is
477       a scalar of non-boolean, the scalar is converted to a boolean
478       according to the following rule: if the scalar is a numerical
479       value, non-zero means True and zero means False; if the scalar is
480       a string, non-empty means True and empty means False. If the
481       tensor is not a scalar, non-emptiness means True and False
482       otherwise.
483 body: A function that takes a list of tensors and returns another
484       list of tensors. Both lists have the same types as specified by T.
485 )doc");
486 
487 REGISTER_OP("XlaDequantize")
488     .Input("input: uint32")
489     .Output("output: bfloat16")
490     .Attr("min_range: float")
491     .Attr("max_range: float")
492     .Attr("mode: string")
493     .Attr("transpose_output: bool")
494     .SetIsStateful()
495     .SetShapeFn(shape_inference::UnknownShape)
496     .Doc(R"doc(
497 Takes the packed uint32 input and unpacks the input to uint8 to do
498 Dequantization on deivce.
499 
500 input: Input tensors whose types is uint32, shape is [d0, ..., dn].
501 output: Output tensors whose types is bloat16. If transpose_output is true,
502      output shape is [dn * 4, dn-1, ..., d1, d0]. If transpose_output
503      is false, output shape is [d0,..., dn * 4].
504 min_range: The minimum scalar value possibly produced for the input.
505 max_range: The maximum scalar value possibly produced for the input.
506 mode: String to determine the dequantize mode in {"MIN_COMBINED", "MIN_FIRST", "SCALED"}.
507 transpose_output: Boolean to determine if output is transposed. transpose_output
508      is faster when input is large and rank of input is higher than 1.
509 )doc");
510 
511 REGISTER_OP("XlaEinsum")
512     .Input("a: T")
513     .Input("b: T")
514     .Output("product: T")
515     .Attr("equation: string")
516     .Attr("T: {bfloat16, float}")
__anon3500048a0602(shape_inference::InferenceContext* context) 517     .SetShapeFn([](shape_inference::InferenceContext* context) {
518       shape_inference::ShapeHandle input_a = context->input(0);
519       shape_inference::ShapeHandle input_b = context->input(1);
520 
521       int64 rank_a, rank_b;
522       if (context->RankKnown(input_a)) {
523         rank_a = context->Rank(input_a);
524       } else {
525         return errors::InvalidArgument("input 0's rank is unknown.");
526       }
527       if (context->RankKnown(input_b)) {
528         rank_b = context->Rank(input_b);
529       } else {
530         return errors::InvalidArgument("input 1's rank is unknown.");
531       }
532       string equation;
533       TF_RETURN_IF_ERROR(context->GetAttr("equation", &equation));
534 
535       std::map<char, shape_inference::DimensionHandle> left_map;
536       std::map<char, shape_inference::DimensionHandle> right_map;
537       std::vector<shape_inference::DimensionHandle> dims;
538 
539       std::vector<string> equation_split = absl::StrSplit(equation, "->");
540 
541       if (equation_split.size() != 2) {
542         return errors::InvalidArgument("Expected one \"->\" in equation. Got: ",
543                                        equation);
544       }
545 
546       std::vector<string> lhs_rhs_split =
547           absl::StrSplit(equation_split[0], ',');
548       if (lhs_rhs_split.size() != 2) {
549         return errors::InvalidArgument("Expected one \",\" in equation. Got: ",
550                                        equation);
551       }
552 
553       if (rank_a != lhs_rhs_split[0].size()) {
554         return errors::InvalidArgument(absl::StrCat(
555             "Expected equation[0] with size: ", rank_a, " Got '",
556             lhs_rhs_split[0], "'", " with size: ", lhs_rhs_split[0].size()));
557       }
558 
559       if (rank_b != lhs_rhs_split[1].size()) {
560         return errors::InvalidArgument(absl::StrCat(
561             "Expected equation[1] with size: ", rank_b, " Got '",
562             lhs_rhs_split[1], "'", " with size: ", lhs_rhs_split[1].size()));
563       }
564 
565       for (const char& c : lhs_rhs_split[0]) {
566         left_map[c] = context->Dim(input_a, left_map.size());
567       }
568       for (const char& c : lhs_rhs_split[1]) {
569         right_map[c] = context->Dim(input_b, right_map.size());
570       }
571 
572       for (const char& c : equation_split[1]) {
573         if (left_map.count(c)) {
574           dims.push_back(left_map[c]);
575         } else if (right_map.count(c)) {
576           dims.push_back(right_map[c]);
577         } else {
578           return errors::InvalidArgument("Invalid equation: ", equation);
579         }
580       }
581 
582       context->set_output(0, context->MakeShape(dims));
583       return Status::OK();
584     })
585     .Doc(R"doc(
586 An op which supports basic einsum op with 2 inputs and 1 output.
587 
588 This op has better TPU performnce since it doesn't have explicitly reshape and
589 transpose operations as tf.einsum does.
590 )doc");
591 
592 }  // namespace
593 }  // namespace tensorflow
594