1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // XLA-specific base classes for Unary and Binary Ops.
17 
18 #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
19 
20 #include "tensorflow/compiler/tf2xla/lib/broadcast.h"
21 #include "tensorflow/compiler/tf2xla/type_util.h"
22 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
23 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
24 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
25 #include "tensorflow/compiler/xla/client/client_library.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/util/bcast.h"
30 
31 namespace tensorflow {
32 
Compile(XlaOpKernelContext * ctx)33 void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) {
34   const TensorShape lhs_shape = ctx->InputShape(0);
35   const TensorShape rhs_shape = ctx->InputShape(1);
36 
37   // By TensorFlow conventions the inputs may not have the same
38   // shapes, in which case they will be automatically broadcast if
39   // possible before mapping. Use the standard TensorFlow helper to
40   // compute valid broadcast shapes, but rely below on XLA to
41   // automatically perform the broadcast assuming its valid shapes are
42   // a superset of TensorFlow's valid shapes.
43   BCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape),
44               /*fewer_dims_optimization=*/false);
45   if (!bcast.IsValid()) {
46     ctx->SetStatus(errors::InvalidArgument("Incompatible shapes: ",
47                                            lhs_shape.DebugString(), " vs. ",
48                                            rhs_shape.DebugString()));
49     return;
50   }
51   TensorShape bcast_shape = BCast::ToShape(bcast.output_shape());
52 
53   // Fetch the expressions containing the input tensors.
54   auto lhs_handle = ctx->Input(0);
55   auto rhs_handle = ctx->Input(1);
56 
57   // If the ranks of the inputs don't match, TensorFlow automatically
58   // reshapes the smaller by padding with dimensions of size 1 as a
59   // prefix. In other words to pad a 5-vector to a 3-dimensional
60   // tensor it is reshaped to have shape [1,1,5]. XLA's automatic
61   // broadcast code is able to broadcast from lower to higher rank,
62   // but doesn't assume you want to pad as a prefix of the dimensions,
63   // and instead needs to be told which dimensions of the higher rank
64   // tensor to match to the lower rank tensor. In this example it
65   // would be dimensions [2]. If we were matching a matrix against a
66   // 4-D tensor the dimensions to match would be [2,3],
67   // etc. extend_dimension encodes the general case.
68   std::vector<int64> extend_dimension;
69   int max_rank = std::max(lhs_shape.dims(), rhs_shape.dims());
70   int min_rank = std::min(lhs_shape.dims(), rhs_shape.dims());
71   if (min_rank != max_rank) {
72     for (int i = 0; i < min_rank; ++i) {
73       // Match the lower rank tensor along the larger-numbered
74       // dimensions of the higher rank tensor.
75       extend_dimension.push_back(max_rank - min_rank + i);
76     }
77   }
78 
79   // Call virtual method to emit the computation.
80   xla::XlaOp output =
81       Computation(ctx, lhs_handle, lhs_shape.dim_sizes(), rhs_handle,
82                   rhs_shape.dim_sizes(), bcast, extend_dimension);
83 
84   // The TensorFlow helper computed the post-broadcast shape in
85   // output_shape: we rely on subclassed Computations to implement the
86   // same broadcast semantics.
87   ctx->SetOutput(0, output);
88 }
89 
Broadcast(xla::XlaOp lhs,xla::XlaOp rhs,const BCast & broadcast_helper)90 /* static */ std::pair<xla::XlaOp, xla::XlaOp> XlaBinaryOp::Broadcast(
91     xla::XlaOp lhs, xla::XlaOp rhs, const BCast& broadcast_helper) {
92   auto lhs_output = BroadcastTo(lhs, broadcast_helper.output_shape());
93   if (!lhs_output.ok()) {
94     xla::XlaOp error = lhs.builder()->ReportError(lhs_output.status());
95     return {error, error};
96   }
97   auto rhs_output = BroadcastTo(rhs, broadcast_helper.output_shape());
98   if (!rhs_output.ok()) {
99     xla::XlaOp error = rhs.builder()->ReportError(rhs_output.status());
100     return {error, error};
101   }
102   return {lhs_output.ValueOrDie(), rhs_output.ValueOrDie()};
103 }
104 
105 }  // namespace tensorflow
106