1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/lib/broadcast.h"
17 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
18 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
19 #include "tensorflow/compiler/xla/client/xla_builder.h"
20 #include "tensorflow/compiler/xla/xla_data.pb.h"
21 #include "tensorflow/core/util/bcast.h"
22 #include "tensorflow/core/util/matmul_bcast.h"
23
24 namespace tensorflow {
25 namespace {
26
27 class MatrixTriangularSolveOp : public XlaOpKernel {
28 public:
MatrixTriangularSolveOp(OpKernelConstruction * ctx)29 explicit MatrixTriangularSolveOp(OpKernelConstruction* ctx)
30 : XlaOpKernel(ctx) {
31 OP_REQUIRES_OK(ctx, ctx->GetAttr("lower", &lower_));
32 OP_REQUIRES_OK(ctx, ctx->GetAttr("adjoint", &adjoint_));
33 }
34
Compile(XlaOpKernelContext * ctx)35 void Compile(XlaOpKernelContext* ctx) override {
36 const TensorShape lhs_shape = ctx->InputShape(0);
37 const TensorShape rhs_shape = ctx->InputShape(1);
38
39 // By TensorFlow conventions the inputs may not have the same
40 // shapes, in which case they will be automatically broadcast if
41 // possible before mapping. Use the standard TensorFlow helper to
42 // compute valid broadcast shapes, but rely below on XLA to
43 // automatically perform the broadcast assuming its valid shapes are
44 // a superset of TensorFlow's valid shapes.
45 MatMulBCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape));
46 if (!bcast.IsValid()) {
47 ctx->SetStatus(errors::InvalidArgument(
48 "Incompatible shapes: ", lhs_shape.DebugString(), " vs. ",
49 rhs_shape.DebugString()));
50 return;
51 }
52
53 auto lhs_size = lhs_shape.dims();
54 OP_REQUIRES(
55 ctx,
56 lhs_shape.dim_size(lhs_size - 1) == lhs_shape.dim_size(lhs_size - 2),
57 errors::InvalidArgument("The coefficient matrix must be square in "
58 "the inner-most two dimensions: ",
59 lhs_shape.DebugString()));
60
61 xla::XlaOp a = ctx->Input(0);
62 xla::XlaOp b = ctx->Input(1);
63 std::tie(a, b) = Broadcast(a, lhs_shape, b, rhs_shape, bcast);
64 auto result = xla::TriangularSolve(
65 a, b, /*left_side=*/true,
66 /*lower=*/lower_, /*unit_diagonal=*/false,
67 /*transpose_a=*/
68 adjoint_ ? xla::TriangularSolveOptions::ADJOINT
69 : xla::TriangularSolveOptions::NO_TRANSPOSE);
70 ctx->SetOutput(0, result);
71 }
72
73 private:
74 static std::pair<xla::XlaOp, xla::XlaOp> Broadcast(
75 xla::XlaOp lhs, const TensorShape& lhs_shape, xla::XlaOp rhs,
76 const TensorShape& rhs_shape, const MatMulBCast& broadcast_helper);
77 bool lower_;
78 bool adjoint_;
79 };
80
81 /* static */ std::pair<xla::XlaOp, xla::XlaOp>
Broadcast(xla::XlaOp lhs,const TensorShape & lhs_shape,xla::XlaOp rhs,const TensorShape & rhs_shape,const MatMulBCast & broadcast_helper)82 MatrixTriangularSolveOp::Broadcast(xla::XlaOp lhs, const TensorShape& lhs_shape,
83 xla::XlaOp rhs, const TensorShape& rhs_shape,
84 const MatMulBCast& broadcast_helper) {
85 // Get the batch shape.
86 int64 m = lhs_shape.dim_size(lhs_shape.dims() - 1);
87 int64 n = rhs_shape.dim_size(rhs_shape.dims() - 1);
88
89 TensorShape lhs_broadcast_shape(broadcast_helper.output_batch_shape());
90 lhs_broadcast_shape.AddDim(m);
91 lhs_broadcast_shape.AddDim(m);
92 auto lhs_output = BroadcastTo(lhs, lhs_broadcast_shape.dim_sizes());
93 if (!lhs_output.ok()) {
94 xla::XlaOp error = lhs.builder()->ReportError(lhs_output.status());
95 return {error, error};
96 }
97
98 TensorShape rhs_broadcast_shape(broadcast_helper.output_batch_shape());
99 rhs_broadcast_shape.AddDim(m);
100 rhs_broadcast_shape.AddDim(n);
101 auto rhs_output = BroadcastTo(rhs, rhs_broadcast_shape.dim_sizes());
102 if (!rhs_output.ok()) {
103 xla::XlaOp error = rhs.builder()->ReportError(rhs_output.status());
104 return {error, error};
105 }
106 return {lhs_output.ValueOrDie(), rhs_output.ValueOrDie()};
107 }
108
109 REGISTER_XLA_OP(Name("MatrixTriangularSolve"), MatrixTriangularSolveOp);
110
111 } // namespace
112 } // namespace tensorflow
113