1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "third_party/eigen3/Eigen/Core" 17 #include "third_party/eigen3/Eigen/LU" 18 #include "tensorflow/core/framework/kernel_def_builder.h" 19 #include "tensorflow/core/framework/op_kernel.h" 20 #include "tensorflow/core/framework/tensor_shape.h" 21 #include "tensorflow/core/lib/math/math_util.h" 22 #include "tensorflow/core/platform/types.h" 23 #include "tensorflow/core/util/work_sharder.h" 24 25 namespace tensorflow { 26 27 typedef Eigen::ThreadPoolDevice CPUDevice; 28 29 template <typename Scalar, typename Tidx> 30 class LuOp : public OpKernel { 31 public: LuOp(OpKernelConstruction * context)32 explicit LuOp(OpKernelConstruction* context) : OpKernel(context) {} 33 34 protected: 35 using TensorShapes = gtl::InlinedVector<TensorShape, 4>; 36 using TensorOutputs = gtl::InlinedVector<Tensor*, 4>; 37 38 using Matrix = 39 Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>; 40 using ConstMatrixMap = Eigen::Map<const Matrix>; 41 using MatrixMap = Eigen::Map<Matrix>; 42 43 using RealScalar = typename Eigen::NumTraits<Scalar>::Real; 44 45 using Indices = 46 Eigen::Matrix<Tidx, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>; 47 using IndicesMap = Eigen::Map<Indices>; 48 using ConstIndicesMap = Eigen::Map<const Indices>; 49 50 public: 51 // Returns the cost per matrix operation. This is used to determine the 52 // number of threads to use for parallelizing factorization in batch mode. 53 // Cost per unit is assumed to be roughly 1ns, based on comments 54 // in core/util/work_sharder.cc. 55 // LU decomposition for a square matrix takes roughly (2/3) * (num_rows)^3. 56 // TODO(anudhyan): Refine this estimate after taking constant factors into 57 // account. GetCostPerUnit(const TensorShape & input_matrix_shape) const58 int64 GetCostPerUnit(const TensorShape& input_matrix_shape) const { 59 double num_rows = static_cast<double>(input_matrix_shape.dim_size(0)); 60 double cost = (2 / 3.0) * MathUtil::IPow(num_rows, 3); 61 return cost >= static_cast<double>(kint64max) ? kint64max 62 : static_cast<int64>(cost); 63 } 64 Compute(OpKernelContext * context)65 void Compute(OpKernelContext* context) override { 66 OP_REQUIRES(context, context->num_inputs() == 1, 67 errors::InvalidArgument("Expecting exactly one input, got ", 68 context->num_inputs())); 69 70 const Tensor& input = context->input(0); 71 int input_rank = input.dims(); 72 OP_REQUIRES(context, input_rank >= 2, 73 errors::InvalidArgument( 74 "Input tensor must have rank >= 2, got ", input_rank)); 75 76 // If the tensor rank is greater than 2, we consider the inner-most 77 // dimensions as matrices, and loop over all the other outer ("batch") 78 // dimensions to compute the results. 79 TensorShape input_matrix_shape; 80 TensorShape batch_shape; 81 for (int dim = 0; dim < input_rank - 2; ++dim) { 82 batch_shape.AddDim(input.dim_size(dim)); 83 } 84 const int64 num_rows = input.dim_size(input_rank - 2); 85 const int64 num_cols = input.dim_size(input_rank - 1); 86 87 input_matrix_shape.AppendShape({num_rows, num_cols}); 88 OP_REQUIRES(context, TensorShapeUtils::IsSquareMatrix(input_matrix_shape), 89 errors::InvalidArgument("Input matrix must be square.")); 90 91 // packed_triangular_factors is a matrix with the same shape as the input; 92 // permutation is a vector. 93 TensorShape permutation_shape = batch_shape; 94 permutation_shape.AddDim(num_rows); 95 96 TensorShapes output_matrix_shapes({input.shape(), permutation_shape}); 97 98 TensorOutputs outputs; 99 Tensor* output_packed_triangular_factors = nullptr; 100 OP_REQUIRES_OK( 101 context, context->forward_input_or_allocate_output( 102 {0}, 0, input.shape(), &output_packed_triangular_factors)); 103 outputs.emplace_back(output_packed_triangular_factors); 104 105 Tensor* output_permutation = nullptr; 106 OP_REQUIRES_OK(context, context->allocate_output(1, permutation_shape, 107 &output_permutation)); 108 outputs.emplace_back(output_permutation); 109 110 if (num_rows == 0) { 111 return; 112 } 113 114 // Process the individual matrix problems in parallel using a threadpool. 115 auto shard = [this, &input, &num_rows, &num_cols, &outputs, 116 &output_matrix_shapes, context](int64 begin, int64 end) { 117 for (int64 i = begin; i < end; ++i) { 118 ComputeTensorSlice(context, i, input, num_rows, num_cols, outputs, 119 output_matrix_shapes); 120 } 121 }; 122 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); 123 Shard(worker_threads.num_threads, worker_threads.workers, 124 batch_shape.num_elements(), GetCostPerUnit(input_matrix_shape), 125 shard); 126 } 127 ComputeTensorSlice(OpKernelContext * context,int64 matrix_index,const Tensor & input,int64 num_rows,int64 num_cols,const TensorOutputs & outputs,const TensorShapes & output_matrix_shapes)128 void ComputeTensorSlice(OpKernelContext* context, int64 matrix_index, 129 const Tensor& input, int64 num_rows, int64 num_cols, 130 const TensorOutputs& outputs, 131 const TensorShapes& output_matrix_shapes) { 132 // TODO(kalakris): Handle alignment if possible. Eigen::Map is 133 // unaligned by default. 134 ConstMatrixMap input_matrix( 135 input.flat<Scalar>().data() + matrix_index * num_rows * num_cols, 136 num_rows, num_cols); 137 138 // packed_triangular_factors has shape [num_rows, num_cols] 139 MatrixMap packed_triangular_factors( 140 outputs[0]->flat<Scalar>().data() + matrix_index * num_rows * num_cols, 141 num_rows, num_rows); 142 143 // permutation has shape [num_rows, 1] 144 IndicesMap permutation_indices( 145 outputs[1]->flat<Tidx>().data() + matrix_index * num_rows, num_rows, 1); 146 147 Eigen::PartialPivLU<Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>> 148 lu_decomposition(input_matrix); 149 150 // Output the packed triangular factors in a dense form. 151 // The lower triangular factor L corresponds to the strictly lower 152 // triangular part of packed_triangular_factors with an implicit unit 153 // diagonal. The upper triangular factor U is the upper triangular part of 154 // packed_triangular_factors. The triangular factors satisfy the equation 155 // P * input_matrix = L * U 156 // where P is the permutation matrix corresponding to the indices in 157 // permutation_indices. 158 packed_triangular_factors = lu_decomposition.matrixLU(); 159 // Output the permutation matrix used for pivoting. 160 Eigen::PermutationMatrix<-1, -1, Tidx> permutation = 161 lu_decomposition.permutationP().transpose(); 162 permutation_indices = permutation.indices(); 163 164 // PartialPivLU cannot give strong guarantees on invertibility, 165 // but we can at least guard against exact zero pivots. This can occur as 166 // a result of basic user mistakes such providing integer valued 167 // matrices that are exactly singular, or due to underflow if this 168 // code is run with denormals being flushed to zero. 169 const RealScalar min_abs_pivot = 170 packed_triangular_factors.diagonal().cwiseAbs().minCoeff(); 171 OP_REQUIRES(context, min_abs_pivot > RealScalar(0), 172 errors::InvalidArgument("Input is not invertible.")); 173 } 174 }; 175 176 #define REGISTER_LU(type, idx_type) \ 177 REGISTER_KERNEL_BUILDER(Name("Lu") \ 178 .Device(DEVICE_CPU) \ 179 .TypeConstraint<type>("T") \ 180 .TypeConstraint<idx_type>("output_idx_type"), \ 181 LuOp<type, idx_type>); 182 183 REGISTER_LU(float, int32); 184 REGISTER_LU(double, int32); 185 REGISTER_LU(complex64, int32); 186 REGISTER_LU(complex128, int32); 187 188 REGISTER_LU(float, int64); 189 REGISTER_LU(double, int64); 190 REGISTER_LU(complex64, int64); 191 REGISTER_LU(complex128, int64); 192 193 } // namespace tensorflow 194