1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "absl/algorithm/container.h"
17 #include "absl/strings/str_cat.h"
18 #include "absl/strings/str_split.h"
19 #include "tensorflow/core/framework/common_shape_fns.h"
20 #include "tensorflow/core/framework/op.h"
21 #include "tensorflow/core/framework/shape_inference.h"
22 #include "tensorflow/core/lib/core/errors.h"
23
24 namespace tensorflow {
25 namespace {
26
27 // Helper shape function for operators that return an output with the same rank
28 // as their first input.
UnchangedRank(shape_inference::InferenceContext * c)29 Status UnchangedRank(shape_inference::InferenceContext* c) {
30 if (c->RankKnown(c->input(0))) {
31 c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0))));
32 } else {
33 c->set_output(0, c->input(0));
34 }
35 return Status::OK();
36 }
37
38 REGISTER_OP("XlaBroadcastHelper")
39 .Input("lhs: T")
40 .Input("rhs: T")
41 .Input("broadcast_dims: Tindices")
42 .Attr("T: numbertype")
43 .Attr("Tindices: {int32, int64}")
44 .Output("lhs_output: T")
45 .Output("rhs_output: T")
46 .SetShapeFn(shape_inference::UnknownShape)
47 .Doc(R"doc(
48 Helper operator for performing XLA-style broadcasts
49
50 Broadcasts `lhs` and `rhs` to the same rank, by adding size 1 dimensions to
51 whichever of `lhs` and `rhs` has the lower rank, using XLA's broadcasting rules
52 for binary operators.
53
54 lhs: the LHS input tensor
55 rhs: the RHS input tensor
56 broadcast_dims: an XLA-style broadcast dimension specification
57 lhs_output: the broadcasted LHS tensor
58 rhs_output: the broadcasted RHS tensor
59 )doc");
60
61 REGISTER_OP("XlaSelfAdjointEig")
62 .Input("a: T")
63 .Attr("lower: bool")
64 .Attr("max_iter: int")
65 .Attr("epsilon: float")
66 .Output("w: T")
67 .Output("v: T")
68 .SetShapeFn(shape_inference::UnknownShape)
69 .Attr("T: numbertype")
70 .Doc(R"doc(
71 Computes the eigen decomposition of a batch of self-adjoint matrices
72 (Note: Only real inputs are supported).
73
74 Computes the eigenvalues and eigenvectors of the innermost N-by-N matrices in
75 tensor such that tensor[...,:,:] * v[..., :,i] = e[..., i] * v[...,:,i], for
76 i=0...N-1.
77
78 a: the input tensor.
79
80 lower: a boolean specifies whether the calculation is done with the lower
81 triangular part or the upper triangular part.
82
83 max_iter: maximum number of sweep update, i.e., the whole lower triangular
84 part or upper triangular part based on parameter lower. Heuristically, it has
85 been argued that approximatly logN sweeps are needed in practice (Ref: Golub &
86 van Loan "Matrix Computation").
87
88 epsilon: the tolerance ratio.
89
90 w: The eigenvalues in ascending order, each repeated according to its
91 multiplicity.
92 v: The column v[..., :, i] is the normalized eigenvector corresponding to the
93 eigenvalue w[..., i].
94 )doc");
95
96 REGISTER_OP("XlaSvd")
97 .Input("a: T")
98 .Attr("max_iter: int")
99 .Attr("epsilon: float")
100 .Attr("precision_config: string")
101 .Output("s: T")
102 .Output("u: T")
103 .Output("v: T")
104 .SetShapeFn(shape_inference::UnknownShape)
105 .Attr("T: numbertype")
106 .Doc(R"doc(
107 Computes the eigen decomposition of a batch of self-adjoint matrices
108 (Note: Only real inputs are supported).
109
110 Computes the eigenvalues and eigenvectors of the innermost M-by-N matrices in
111 tensor such that tensor[...,:,:] = u[..., :, :] * Diag(s[..., :]) * Transpose(v[...,:,:]).
112
113 a: the input tensor.
114
115 max_iter: maximum number of sweep update, i.e., the whole lower triangular
116 part or upper triangular part based on parameter lower. Heuristically, it has
117 been argued that approximatly log(min (M, N)) sweeps are needed in practice
118 (Ref: Golub & van Loan "Matrix Computation").
119
120 epsilon: the tolerance ratio.
121
122 precision_config: a serialized xla::PrecisionConfig proto.
123
124 s: Singular values. The values are sorted in reverse order of magnitude, so
125 s[..., 0] is the largest value, s[..., 1] is the second largest, etc.
126 u: Left singular vectors.
127 v: Right singular vectors.
128 )doc");
129
130 REGISTER_OP("XlaConv")
131 .Input("lhs: T")
132 .Input("rhs: T")
133 .Input("window_strides: Tindices")
134 .Input("padding: Tindices")
135 .Input("lhs_dilation: Tindices")
136 .Input("rhs_dilation: Tindices")
137 .Input("feature_group_count: Tindices")
138 .Attr("T: numbertype")
139 .Attr("Tindices: {int32, int64}")
140 .Attr("dimension_numbers: string")
141 .Attr("precision_config: string")
142 .Output("output: T")
143 .SetShapeFn(UnchangedRank)
144 .Doc(R"doc(
145 Wraps the XLA ConvGeneralDilated operator, documented at
146 https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
147 .
148
149 lhs: the input tensor
150 rhs: the kernel tensor
151 window_strides: the inter-window strides
152 padding: the padding to apply at the start and end of each input dimensions
153 lhs_dilation: dilation to apply between input elements
154 rhs_dilation: dilation to apply between kernel elements
155 feature_group_count: number of feature groups for grouped convolution.
156 dimension_numbers: a serialized xla::ConvolutionDimensionNumbers proto.
157 precision_config: a serialized xla::PrecisionConfig proto.
158 )doc");
159
160 REGISTER_OP("XlaDot")
161 .Input("lhs: T")
162 .Input("rhs: T")
163 .Attr("T: numbertype")
164 .Attr("dimension_numbers: string")
165 .Attr("precision_config: string")
166 .Output("output: T")
167 .SetShapeFn(shape_inference::UnknownShape)
168 .Doc(R"doc(
169 Wraps the XLA ConvGeneralDilated operator, documented at
170 https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral
171 .
172
173 lhs: the LHS tensor
174 rhs: the RHS tensor
175 dimension_numbers: a serialized xla::DotDimensionNumbers proto.
176 precision_config: a serialized xla::PrecisionConfig proto.
177 )doc");
178
179 REGISTER_OP("XlaDynamicSlice")
180 .Input("input: T")
181 .Input("start_indices: Tindices")
182 .Input("size_indices: Tindices")
183 .Output("output: T")
184 .Attr("T: type")
185 .Attr("Tindices: {int32, int64}")
186 .SetShapeFn(shape_inference::UnknownShape)
187 .Doc(R"doc(
188 Wraps the XLA DynamicSlice operator, documented at
189 https://www.tensorflow.org/performance/xla/operation_semantics#dynamicslice
190 .
191
192 DynamicSlice extracts a sub-array from the input array at dynamic
193 start_indices. The size of the slice in each dimension is passed in
194 size_indices, which specify the end point of exclusive slice intervals in each
195 dimension -- [start, start + size). The shape of start_indices must have rank 1,
196 with dimension size equal to the rank of operand.
197
198 input: A `Tensor` of type T.
199
200 start_indices: Rank 1 tensor of N integers containing the starting indices of
201 the slice for each dimension. Value must be greater than or equal to zero.
202
203 start_indices: List of N integers containing the slice size for each
204 dimension. Each value must be strictly greater than zero, and start + size
205 must be less than or equal to the size of the dimension to avoid
206 implementation defined behavior.
207 )doc");
208
209 REGISTER_OP("XlaDynamicUpdateSlice")
210 .Input("input: T")
211 .Input("update: T")
212 .Input("indices: Tindices")
213 .Output("output: T")
214 .Attr("T: type")
215 .Attr("Tindices: {int32, int64}")
216 .SetShapeFn(shape_inference::UnchangedShape)
217 .Doc(R"doc(
218 Wraps the XLA DynamicUpdateSlice operator, documented at
219 https://www.tensorflow.org/performance/xla/operation_semantics#dynamicupdateslice
220 .
221
222 XlaDynamicUpdateSlice generates a result which is the value of the `input`
223 operand, with a slice update overwritten at `indices`. The shape of `update`
224 determines the shape of the sub-array of the result which is updated. The shape
225 of indices must be rank == 1, with dimension size equal to the rank of `input`.
226
227 Handling of out-of-bounds slice indices is implementation-defined.
228
229 input: A `Tensor` of type T.
230 indices: A vector of indices into `input`. Must have length equal to the rank of
231 `input`.
232 update: A `Tensor` of type T. Same rank as `input`.
233 output: A `Tensor` of type T.
234 )doc");
235
236 // TODO(b/37549631) setting the If Op to always be stateful is too
237 // conservative.
238 REGISTER_OP("XlaIf")
239 .Input("cond: Tcond")
240 .Input("inputs: Tin")
241 .Output("output: Tout")
242 .Attr("Tcond: type")
243 .Attr("then_branch: func")
244 .Attr("else_branch: func")
245 .Attr("Tin: list(type) >= 0")
246 .Attr("Tout: list(type) >= 0")
247 .SetIsStateful()
248 .SetShapeFn(shape_inference::UnknownShape)
249 .Doc(R"doc(
250 output = cond ? then_branch(inputs) : else_branch(inputs).
251
252 cond: A boolean scalar.
253 inputs: A list of input tensors.
254 output: A list of tensors returned by either then_branch(inputs) or
255 else_branch(inputs). The input shapes of the then_branch and
256 else_branch must match.
257 then_branch: A function takes 'inputs' and returns a list of tensors,
258 whose types are the same as what else_branch returns.
259 else_branch: A function takes 'inputs' and returns a list of tensors.
260 whose types are the same as what then_branch returns.
261 )doc");
262
263 REGISTER_OP("XlaPad")
264 .Input("input: T")
265 .Input("padding_value: T")
266 .Input("padding_low: Tindices")
267 .Input("padding_high: Tindices")
268 .Input("padding_interior: Tindices")
269 .Output("output: T")
270 .Attr("T: type")
271 .Attr("Tindices: {int32, int64}")
272 .SetShapeFn(UnchangedRank)
273 .Doc(R"doc(
274 Wraps the XLA Pad operator, documented at
275 https://www.tensorflow.org/performance/xla/operation_semantics#pad
276 .
277
278 input: A `Tensor` of type T.
279 padding_value: A scalar `Tensor` of type T.
280 padding_low: the padding to apply at the start of each input dimensions
281 padding_high: the padding to apply at the end of each input dimension.
282 padding_interior: the padding to apply between each input element.
283 output: A `Tensor` of type T.
284 )doc");
285
286 REGISTER_OP("XlaRecv")
287 .Output("tensor: dtype")
288 .Attr("dtype: type")
289 .Attr("tensor_name: string")
290 .Attr("shape: shape")
291 .SetIsStateful()
__anon3500048a0202(shape_inference::InferenceContext* c) 292 .SetShapeFn([](shape_inference::InferenceContext* c) {
293 TensorShape shape_attr;
294 TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_attr));
295 shape_inference::ShapeHandle s;
296 TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s));
297 c->set_output(0, s);
298 return Status::OK();
299 })
300 .Doc(R"doc(
301 Receives the named tensor from another XLA computation. Wraps the XLA Recv
302 operator documented at
303 https://www.tensorflow.org/performance/xla/operation_semantics#recv .
304
305 tensor: The tensor to receive.
306 dtype: The type of the tensor.
307 tensor_name: A string key that identifies the channel.
308 shape: The shape of the tensor.
309 )doc");
310
311 REGISTER_OP("XlaReduce")
312 .Input("input: T")
313 .Input("init_value: T")
314 .Attr("T: numbertype")
315 .Attr("dimensions_to_reduce: list(int)")
316 .Attr("reducer: func")
317 .Output("output: T")
__anon3500048a0302(shape_inference::InferenceContext* c) 318 .SetShapeFn([](shape_inference::InferenceContext* c) {
319 if (c->RankKnown(c->input(0))) {
320 int rank = c->Rank(c->input(0));
321 std::vector<int64> dimensions_to_reduce;
322 TF_RETURN_IF_ERROR(
323 c->GetAttr("dimensions_to_reduce", &dimensions_to_reduce));
324 std::set<int64> dims_set(dimensions_to_reduce.begin(),
325 dimensions_to_reduce.end());
326 auto dim_in_range = [rank](int64 dim) {
327 return dim >= 0 && dim < rank;
328 };
329 if (rank < dimensions_to_reduce.size() ||
330 dims_set.size() != dimensions_to_reduce.size() ||
331 !absl::c_all_of(dimensions_to_reduce, dim_in_range)) {
332 return errors::InvalidArgument(
333 "Invalid dimensions_to_reduce argument to XlaReduce");
334 }
335 c->set_output(
336 0, c->UnknownShapeOfRank(rank - dimensions_to_reduce.size()));
337 } else {
338 c->set_output(0, c->input(0));
339 }
340 return Status::OK();
341 })
342 .Doc(R"doc(
343 Wraps the XLA Reduce operator, documented at
344 https://www.tensorflow.org/performance/xla/operation_semantics#reduce .
345
346 input: the input tensor
347 init_value: a scalar representing the initial value for the reduction
348 reducer: a reducer function to apply
349 dimensions_to_reduce: dimension numbers over which to reduce
350 )doc");
351
352 REGISTER_OP("XlaReduceWindow")
353 .Input("input: T")
354 .Input("init_value: T")
355 .Input("window_dimensions: Tindices")
356 .Input("window_strides: Tindices")
357 .Input("base_dilations: Tindices")
358 .Input("window_dilations: Tindices")
359 .Input("padding: Tindices")
360 .Attr("T: numbertype")
361 .Attr("Tindices: {int32, int64}")
362 .Attr("computation: func")
363 .Output("output: T")
364 .SetShapeFn(UnchangedRank)
365 .Doc(R"doc(
366 Wraps the XLA ReduceWindow operator, documented at
367 https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow .
368
369 input: the input tensor
370 init_value: a scalar representing the initial value for the reduction
371 computation: a reducer function to apply
372 window_dimensions: the shape of the window
373 window_strides: the inter-window strides
374 padding: the padding to apply at the start and end of each input dimensions
375 )doc");
376
377 REGISTER_OP("XlaSelectAndScatter")
378 .Input("operand: T")
379 .Input("window_dimensions: Tindices")
380 .Input("window_strides: Tindices")
381 .Input("padding: Tindices")
382 .Input("source: T")
383 .Input("init_value: T")
384 .Attr("T: numbertype")
385 .Attr("Tindices: {int32, int64}")
386 .Attr("select: func")
387 .Attr("scatter: func")
388 .Output("output: T")
389 .SetShapeFn(UnchangedRank)
390 .Doc(R"doc(
391 Wraps the XLA SelectAndScatter operator, documented at
392 https://www.tensorflow.org/performance/xla/operation_semantics#selectandscatter
393 .
394
395 operand: the input tensor
396 window_dimensions: the shape of the window
397 window_strides: the inter-window strides
398 padding: the padding to apply at the start and end of each input dimensions
399 source: a tensor of values to scatter
400 init_value: a scalar representing the initial value for the output tensor
401 select: a selection function to apply
402 scatter: a scatter function to apply
403 )doc");
404
405 REGISTER_OP("XlaSend")
406 .Input("tensor: T")
407 .Attr("T: type")
408 .Attr("tensor_name: string")
409 .SetIsStateful()
410 .SetShapeFn(shape_inference::UnknownShape)
411 .Doc(R"doc(
412 Sends the named tensor to another XLA computation. Wraps the XLA Send operator
413 documented at
414 https://www.tensorflow.org/performance/xla/operation_semantics#send .
415
416 tensor: The tensor to send.
417 tensor_name: A string key that identifies the channel.
418 )doc");
419
420 REGISTER_OP("XlaSort")
421 .Input("input: T")
422 .Output("output: T")
423 .Attr("T: type")
424 .SetShapeFn(shape_inference::UnchangedShape)
425 .Doc(R"doc(
426 Wraps the XLA Sort operator, documented at
427 https://www.tensorflow.org/performance/xla/operation_semantics#sort
428 .
429
430 Sorts a tensor. Currently only sorts in ascending order are supported.
431
432 input: A `Tensor` of type T.
433 output: A `Tensor` of type T.
434 )doc");
435
436 REGISTER_OP("XlaKeyValueSort")
437 .Input("keys: K")
438 .Input("values: V")
439 .Output("sorted_keys: K")
440 .Output("sorted_values: V")
441 .Attr("K: realnumbertype")
442 .Attr("V: type")
__anon3500048a0502(shape_inference::InferenceContext* c) 443 .SetShapeFn([](shape_inference::InferenceContext* c) {
444 c->set_output(0, c->input(0));
445 c->set_output(1, c->input(1));
446 return Status::OK();
447 })
448 .Doc(R"doc(
449 Wraps the XLA Sort operator, documented at
450 https://www.tensorflow.org/performance/xla/operation_semantics#sort
451 .
452
453 Sorts a tensor. Currently only sorts in ascending order are supported.
454
455 keys: A `Tensor` of type K.
456 values: A `Tensor` of type V.
457 sorted_keys: A `Tensor` of type K.
458 sorted_values: A `Tensor` of type V.
459 )doc");
460
461 // TODO(b/37549631) setting the While Op to always be stateful is too
462 // conservative.
463 REGISTER_OP("XlaWhile")
464 .Input("input: T")
465 .Output("output: T")
466 .Attr("T: list(type) >= 0")
467 .Attr("cond: func")
468 .Attr("body: func")
469 .SetIsStateful()
470 .SetShapeFn(shape_inference::UnknownShape)
471 .Doc(R"doc(
472 output = input; While (Cond(output)) { output = Body(output) }
473
474 input: A list of input tensors whose types are T.
475 output: A list of output tensors whose types are T.
476 cond: A function takes 'input' and returns a tensor. If the tensor is
477 a scalar of non-boolean, the scalar is converted to a boolean
478 according to the following rule: if the scalar is a numerical
479 value, non-zero means True and zero means False; if the scalar is
480 a string, non-empty means True and empty means False. If the
481 tensor is not a scalar, non-emptiness means True and False
482 otherwise.
483 body: A function that takes a list of tensors and returns another
484 list of tensors. Both lists have the same types as specified by T.
485 )doc");
486
487 REGISTER_OP("XlaDequantize")
488 .Input("input: uint32")
489 .Output("output: bfloat16")
490 .Attr("min_range: float")
491 .Attr("max_range: float")
492 .Attr("mode: string")
493 .Attr("transpose_output: bool")
494 .SetIsStateful()
495 .SetShapeFn(shape_inference::UnknownShape)
496 .Doc(R"doc(
497 Takes the packed uint32 input and unpacks the input to uint8 to do
498 Dequantization on deivce.
499
500 input: Input tensors whose types is uint32, shape is [d0, ..., dn].
501 output: Output tensors whose types is bloat16. If transpose_output is true,
502 output shape is [dn * 4, dn-1, ..., d1, d0]. If transpose_output
503 is false, output shape is [d0,..., dn * 4].
504 min_range: The minimum scalar value possibly produced for the input.
505 max_range: The maximum scalar value possibly produced for the input.
506 mode: String to determine the dequantize mode in {"MIN_COMBINED", "MIN_FIRST", "SCALED"}.
507 transpose_output: Boolean to determine if output is transposed. transpose_output
508 is faster when input is large and rank of input is higher than 1.
509 )doc");
510
511 REGISTER_OP("XlaEinsum")
512 .Input("a: T")
513 .Input("b: T")
514 .Output("product: T")
515 .Attr("equation: string")
516 .Attr("T: {bfloat16, float}")
__anon3500048a0602(shape_inference::InferenceContext* context) 517 .SetShapeFn([](shape_inference::InferenceContext* context) {
518 shape_inference::ShapeHandle input_a = context->input(0);
519 shape_inference::ShapeHandle input_b = context->input(1);
520
521 int64 rank_a, rank_b;
522 if (context->RankKnown(input_a)) {
523 rank_a = context->Rank(input_a);
524 } else {
525 return errors::InvalidArgument("input 0's rank is unknown.");
526 }
527 if (context->RankKnown(input_b)) {
528 rank_b = context->Rank(input_b);
529 } else {
530 return errors::InvalidArgument("input 1's rank is unknown.");
531 }
532 string equation;
533 TF_RETURN_IF_ERROR(context->GetAttr("equation", &equation));
534
535 std::map<char, shape_inference::DimensionHandle> left_map;
536 std::map<char, shape_inference::DimensionHandle> right_map;
537 std::vector<shape_inference::DimensionHandle> dims;
538
539 std::vector<string> equation_split = absl::StrSplit(equation, "->");
540
541 if (equation_split.size() != 2) {
542 return errors::InvalidArgument("Expected one \"->\" in equation. Got: ",
543 equation);
544 }
545
546 std::vector<string> lhs_rhs_split =
547 absl::StrSplit(equation_split[0], ',');
548 if (lhs_rhs_split.size() != 2) {
549 return errors::InvalidArgument("Expected one \",\" in equation. Got: ",
550 equation);
551 }
552
553 if (rank_a != lhs_rhs_split[0].size()) {
554 return errors::InvalidArgument(absl::StrCat(
555 "Expected equation[0] with size: ", rank_a, " Got '",
556 lhs_rhs_split[0], "'", " with size: ", lhs_rhs_split[0].size()));
557 }
558
559 if (rank_b != lhs_rhs_split[1].size()) {
560 return errors::InvalidArgument(absl::StrCat(
561 "Expected equation[1] with size: ", rank_b, " Got '",
562 lhs_rhs_split[1], "'", " with size: ", lhs_rhs_split[1].size()));
563 }
564
565 for (const char& c : lhs_rhs_split[0]) {
566 left_map[c] = context->Dim(input_a, left_map.size());
567 }
568 for (const char& c : lhs_rhs_split[1]) {
569 right_map[c] = context->Dim(input_b, right_map.size());
570 }
571
572 for (const char& c : equation_split[1]) {
573 if (left_map.count(c)) {
574 dims.push_back(left_map[c]);
575 } else if (right_map.count(c)) {
576 dims.push_back(right_map[c]);
577 } else {
578 return errors::InvalidArgument("Invalid equation: ", equation);
579 }
580 }
581
582 context->set_output(0, context->MakeShape(dims));
583 return Status::OK();
584 })
585 .Doc(R"doc(
586 An op which supports basic einsum op with 2 inputs and 1 output.
587
588 This op has better TPU performnce since it doesn't have explicitly reshape and
589 transpose operations as tf.einsum does.
590 )doc");
591
592 } // namespace
593 } // namespace tensorflow
594