1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "third_party/eigen3/Eigen/Core"
17 #include "third_party/eigen3/Eigen/LU"
18 #include "tensorflow/core/framework/kernel_def_builder.h"
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/framework/tensor_shape.h"
21 #include "tensorflow/core/lib/math/math_util.h"
22 #include "tensorflow/core/platform/types.h"
23 #include "tensorflow/core/util/work_sharder.h"
24 
25 namespace tensorflow {
26 
27 typedef Eigen::ThreadPoolDevice CPUDevice;
28 
29 template <typename Scalar, typename Tidx>
30 class LuOp : public OpKernel {
31  public:
LuOp(OpKernelConstruction * context)32   explicit LuOp(OpKernelConstruction* context) : OpKernel(context) {}
33 
34  protected:
35   using TensorShapes = gtl::InlinedVector<TensorShape, 4>;
36   using TensorOutputs = gtl::InlinedVector<Tensor*, 4>;
37 
38   using Matrix =
39       Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
40   using ConstMatrixMap = Eigen::Map<const Matrix>;
41   using MatrixMap = Eigen::Map<Matrix>;
42 
43   using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
44 
45   using Indices =
46       Eigen::Matrix<Tidx, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
47   using IndicesMap = Eigen::Map<Indices>;
48   using ConstIndicesMap = Eigen::Map<const Indices>;
49 
50  public:
51   // Returns the cost per matrix operation. This is used to determine the
52   // number of threads to use for parallelizing factorization in batch mode.
53   // Cost per unit is assumed to be roughly 1ns, based on comments
54   // in core/util/work_sharder.cc.
55   // LU decomposition for a square matrix takes roughly (2/3) * (num_rows)^3.
56   // TODO(anudhyan): Refine this estimate after taking constant factors into
57   // account.
GetCostPerUnit(const TensorShape & input_matrix_shape) const58   int64 GetCostPerUnit(const TensorShape& input_matrix_shape) const {
59     double num_rows = static_cast<double>(input_matrix_shape.dim_size(0));
60     double cost = (2 / 3.0) * MathUtil::IPow(num_rows, 3);
61     return cost >= static_cast<double>(kint64max) ? kint64max
62                                                   : static_cast<int64>(cost);
63   }
64 
Compute(OpKernelContext * context)65   void Compute(OpKernelContext* context) override {
66     OP_REQUIRES(context, context->num_inputs() == 1,
67                 errors::InvalidArgument("Expecting exactly one input, got ",
68                                         context->num_inputs()));
69 
70     const Tensor& input = context->input(0);
71     int input_rank = input.dims();
72     OP_REQUIRES(context, input_rank >= 2,
73                 errors::InvalidArgument(
74                     "Input tensor must have rank >= 2, got ", input_rank));
75 
76     // If the tensor rank is greater than 2, we consider the inner-most
77     // dimensions as matrices, and loop over all the other outer ("batch")
78     // dimensions to compute the results.
79     TensorShape input_matrix_shape;
80     TensorShape batch_shape;
81     for (int dim = 0; dim < input_rank - 2; ++dim) {
82       batch_shape.AddDim(input.dim_size(dim));
83     }
84     const int64 num_rows = input.dim_size(input_rank - 2);
85     const int64 num_cols = input.dim_size(input_rank - 1);
86 
87     input_matrix_shape.AppendShape({num_rows, num_cols});
88     OP_REQUIRES(context, TensorShapeUtils::IsSquareMatrix(input_matrix_shape),
89                 errors::InvalidArgument("Input matrix must be square."));
90 
91     // packed_triangular_factors is a matrix with the same shape as the input;
92     // permutation is a vector.
93     TensorShape permutation_shape = batch_shape;
94     permutation_shape.AddDim(num_rows);
95 
96     TensorShapes output_matrix_shapes({input.shape(), permutation_shape});
97 
98     TensorOutputs outputs;
99     Tensor* output_packed_triangular_factors = nullptr;
100     OP_REQUIRES_OK(
101         context, context->forward_input_or_allocate_output(
102                      {0}, 0, input.shape(), &output_packed_triangular_factors));
103     outputs.emplace_back(output_packed_triangular_factors);
104 
105     Tensor* output_permutation = nullptr;
106     OP_REQUIRES_OK(context, context->allocate_output(1, permutation_shape,
107                                                      &output_permutation));
108     outputs.emplace_back(output_permutation);
109 
110     if (num_rows == 0) {
111       return;
112     }
113 
114     // Process the individual matrix problems in parallel using a threadpool.
115     auto shard = [this, &input, &num_rows, &num_cols, &outputs,
116                   &output_matrix_shapes, context](int64 begin, int64 end) {
117       for (int64 i = begin; i < end; ++i) {
118         ComputeTensorSlice(context, i, input, num_rows, num_cols, outputs,
119                            output_matrix_shapes);
120       }
121     };
122     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
123     Shard(worker_threads.num_threads, worker_threads.workers,
124           batch_shape.num_elements(), GetCostPerUnit(input_matrix_shape),
125           shard);
126   }
127 
ComputeTensorSlice(OpKernelContext * context,int64 matrix_index,const Tensor & input,int64 num_rows,int64 num_cols,const TensorOutputs & outputs,const TensorShapes & output_matrix_shapes)128   void ComputeTensorSlice(OpKernelContext* context, int64 matrix_index,
129                           const Tensor& input, int64 num_rows, int64 num_cols,
130                           const TensorOutputs& outputs,
131                           const TensorShapes& output_matrix_shapes) {
132     // TODO(kalakris): Handle alignment if possible. Eigen::Map is
133     // unaligned by default.
134     ConstMatrixMap input_matrix(
135         input.flat<Scalar>().data() + matrix_index * num_rows * num_cols,
136         num_rows, num_cols);
137 
138     // packed_triangular_factors has shape [num_rows, num_cols]
139     MatrixMap packed_triangular_factors(
140         outputs[0]->flat<Scalar>().data() + matrix_index * num_rows * num_cols,
141         num_rows, num_rows);
142 
143     // permutation has shape [num_rows, 1]
144     IndicesMap permutation_indices(
145         outputs[1]->flat<Tidx>().data() + matrix_index * num_rows, num_rows, 1);
146 
147     Eigen::PartialPivLU<Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>>
148         lu_decomposition(input_matrix);
149 
150     // Output the packed triangular factors in a dense form.
151     // The lower triangular factor L corresponds to the strictly lower
152     // triangular part of packed_triangular_factors with an implicit unit
153     // diagonal. The upper triangular factor U is the upper triangular part of
154     // packed_triangular_factors. The triangular factors satisfy the equation
155     //     P * input_matrix = L * U
156     // where P is the permutation matrix corresponding to the indices in
157     // permutation_indices.
158     packed_triangular_factors = lu_decomposition.matrixLU();
159     // Output the permutation matrix used for pivoting.
160     Eigen::PermutationMatrix<-1, -1, Tidx> permutation =
161         lu_decomposition.permutationP().transpose();
162     permutation_indices = permutation.indices();
163 
164     // PartialPivLU cannot give strong guarantees on invertibility,
165     // but we can at least guard against exact zero pivots. This can occur as
166     // a result of basic user mistakes such providing integer valued
167     // matrices that are exactly singular, or due to underflow if this
168     // code is run with denormals being flushed to zero.
169     const RealScalar min_abs_pivot =
170         packed_triangular_factors.diagonal().cwiseAbs().minCoeff();
171     OP_REQUIRES(context, min_abs_pivot > RealScalar(0),
172                 errors::InvalidArgument("Input is not invertible."));
173   }
174 };
175 
176 #define REGISTER_LU(type, idx_type)                                         \
177   REGISTER_KERNEL_BUILDER(Name("Lu")                                        \
178                               .Device(DEVICE_CPU)                           \
179                               .TypeConstraint<type>("T")                    \
180                               .TypeConstraint<idx_type>("output_idx_type"), \
181                           LuOp<type, idx_type>);
182 
183 REGISTER_LU(float, int32);
184 REGISTER_LU(double, int32);
185 REGISTER_LU(complex64, int32);
186 REGISTER_LU(complex128, int32);
187 
188 REGISTER_LU(float, int64);
189 REGISTER_LU(double, int64);
190 REGISTER_LU(complex64, int64);
191 REGISTER_LU(complex128, int64);
192 
193 }  // namespace tensorflow
194