1 // Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 // use this file except in compliance with the License. You may obtain a copy 5 // of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 // License for the specific language governing permissions and limitations under 13 // the License. 14 // ============================================================================== 15 16 // TensorFlow kernels and Ops for constructing WALS normal equations. 17 // TODO(agarwal,rmlarsen): Add security checks to the code. 18 19 #include <algorithm> 20 #include <numeric> 21 #include <vector> 22 23 // This is only used for std::this_thread::get_id() 24 #include <thread> // NOLINT 25 26 #include "tensorflow/core/framework/op.h" 27 #include "tensorflow/core/framework/op_kernel.h" 28 #include "tensorflow/core/framework/tensor.h" 29 #include "tensorflow/core/framework/tensor_shape.h" 30 #include "tensorflow/core/framework/types.h" 31 #include "tensorflow/core/lib/core/blocking_counter.h" 32 #include "tensorflow/core/lib/core/errors.h" 33 #include "tensorflow/core/lib/core/threadpool.h" 34 #include "tensorflow/core/platform/mutex.h" 35 36 using tensorflow::DEVICE_CPU; 37 using tensorflow::DT_BOOL; 38 using tensorflow::DT_FLOAT; 39 using tensorflow::DT_INT64; 40 using tensorflow::OpKernel; 41 using tensorflow::OpKernelConstruction; 42 using tensorflow::OpKernelContext; 43 using tensorflow::Tensor; 44 using tensorflow::TensorShape; 45 using tensorflow::TensorShapeUtils; 46 using tensorflow::errors::InvalidArgument; 47 48 namespace tensorflow { 49 50 // TODO(ataei): Consider using RowMajor maps. 51 typedef Eigen::Map< 52 Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>> 53 EigenMatrixFloatMap; 54 typedef Eigen::Map< 55 const Eigen::Matrix<int64, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>> 56 ConstEigenMatrixInt64Map; 57 typedef Eigen::Map< 58 const Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>> 59 ConstEigenMatrixFloatMap; 60 61 class WALSComputePartialLhsAndRhsOp : public OpKernel { 62 public: WALSComputePartialLhsAndRhsOp(OpKernelConstruction * context)63 explicit WALSComputePartialLhsAndRhsOp(OpKernelConstruction* context) 64 : OpKernel(context) { 65 OP_REQUIRES_OK(context, 66 context->MatchSignature( 67 {DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT64, 68 DT_FLOAT, DT_FLOAT, DT_INT64, DT_BOOL}, 69 {DT_FLOAT, DT_FLOAT})); 70 } 71 Compute(OpKernelContext * context)72 void Compute(OpKernelContext* context) override { 73 const Tensor& factors = context->input(0); 74 const Tensor& factor_weights = context->input(1); 75 const Tensor& unobserved_weights = context->input(2); 76 const Tensor& input_weights = context->input(3); 77 const Tensor& input_indices = context->input(4); 78 const Tensor& input_values = context->input(5); 79 const Tensor& entry_weights = context->input(6); 80 const Tensor& input_block_size = context->input(7); 81 const Tensor& input_is_transpose = context->input(8); 82 83 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(factors.shape()), 84 InvalidArgument("Input factors should be a matrix.")); 85 OP_REQUIRES(context, TensorShapeUtils::IsVector(factor_weights.shape()), 86 InvalidArgument("Input factor_weights should be a vector.")); 87 OP_REQUIRES( 88 context, TensorShapeUtils::IsScalar(unobserved_weights.shape()), 89 InvalidArgument("Input unobserved_weights should be a scalar.")); 90 OP_REQUIRES(context, TensorShapeUtils::IsVector(input_weights.shape()), 91 InvalidArgument("Input input_weights should be a vector.")); 92 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices.shape()), 93 InvalidArgument("Input input_indices should be a matrix.")); 94 OP_REQUIRES( 95 context, input_indices.dim_size(1) == 2, 96 InvalidArgument("Input input_indices should have shape (?, 2).")); 97 OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values.shape()), 98 InvalidArgument("Input input_values should be a vector")); 99 OP_REQUIRES(context, TensorShapeUtils::IsVector(entry_weights.shape()), 100 InvalidArgument("Input entry_weights should be a vector")); 101 OP_REQUIRES(context, input_indices.dim_size(0) == input_values.dim_size(0), 102 InvalidArgument("Input input_values' length should match the " 103 "first dimension of Input input_indices ")); 104 OP_REQUIRES(context, TensorShapeUtils::IsScalar(input_block_size.shape()), 105 InvalidArgument("Input input_block_size should be a scalar.")); 106 OP_REQUIRES( 107 context, TensorShapeUtils::IsScalar(input_is_transpose.shape()), 108 InvalidArgument("Input input_is_transpose should be a scalar.")); 109 OP_REQUIRES( 110 context, 111 ((input_weights.dim_size(0) > 0 && 112 factor_weights.dim_size(0) == factors.dim_size(0) && 113 entry_weights.dim_size(0) == 0) || 114 (input_weights.dim_size(0) == 0 && factor_weights.dim_size(0) == 0 && 115 entry_weights.dim_size(0) == input_indices.dim_size(0))), 116 InvalidArgument("To specify the weights for observed entries, either " 117 "(1) entry_weights must be set or (2) input_weights " 118 "and factor_weights must be set, but not both.")); 119 // TODO(yifanchen): Deprecate the support of input_weights and 120 // factor_weights. 121 122 const int64 factor_dim = factors.dim_size(1); 123 const int64 factors_size = factors.dim_size(0); 124 const int64 num_nonzero_elements = input_indices.dim_size(0); 125 const int64 block_size = input_block_size.scalar<int64>()(); 126 const auto& factor_weights_vec = factor_weights.vec<float>(); 127 const auto& input_weights_vec = input_weights.vec<float>(); 128 const float w_0 = unobserved_weights.scalar<float>()(); 129 const auto& input_values_vec = input_values.vec<float>(); 130 const auto& entry_weights_vec = entry_weights.vec<float>(); 131 132 ConstEigenMatrixFloatMap factors_mat(factors.matrix<float>().data(), 133 factor_dim, factors_size); 134 ConstEigenMatrixInt64Map indices_mat(input_indices.matrix<int64>().data(), 135 2, num_nonzero_elements); 136 137 Tensor* output_lhs_tensor; 138 OP_REQUIRES_OK(context, 139 context->allocate_output( 140 0, TensorShape({block_size, factor_dim, factor_dim}), 141 &output_lhs_tensor)); 142 auto output_lhs_t = output_lhs_tensor->tensor<float, 3>(); 143 output_lhs_t.setZero(); 144 Tensor* output_rhs_tensor; 145 OP_REQUIRES_OK(context, context->allocate_output( 146 1, TensorShape({block_size, factor_dim}), 147 &output_rhs_tensor)); 148 EigenMatrixFloatMap rhs_mat(output_rhs_tensor->matrix<float>().data(), 149 factor_dim, block_size); 150 rhs_mat.setZero(); 151 const bool is_transpose = input_is_transpose.scalar<bool>()(); 152 153 auto get_input_index = [is_transpose, &indices_mat](int64 i) { 154 return is_transpose ? indices_mat(1, i) : indices_mat(0, i); 155 }; 156 auto get_factor_index = [is_transpose, &indices_mat](int64 i) { 157 return is_transpose ? indices_mat(0, i) : indices_mat(1, i); 158 }; 159 160 const bool use_entry_weights = entry_weights_vec.size() > 0; 161 162 // TODO(rmlarsen): In principle, we should be using the SparseTensor class 163 // and machinery for iterating over groups, but the fact that class 164 // SparseTensor makes a complete copy of the matrix makes me reluctant to 165 // use it. 166 std::vector<int64> perm(num_nonzero_elements); 167 std::iota(perm.begin(), perm.end(), 0); 168 169 typedef std::pair<int64, int64> Shard; 170 std::vector<Shard> shards; 171 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); 172 int64 shard_total = 0; 173 // Compute a permutation such that get_input_index(perm[i]) is sorted, use 174 // stable_sort to preserve spatial locality. 175 std::stable_sort(perm.begin(), perm.end(), 176 [&get_input_index](int64 i, int64 j) { 177 return get_input_index(i) < get_input_index(j); 178 }); 179 180 // Compute the start and end of runs with identical input_index. 181 // These are the shards of work that can be processed in parallel 182 // without locking. 183 int64 start = 0; 184 int64 end = 0; 185 while (end < num_nonzero_elements) { 186 start = end; 187 while (end < num_nonzero_elements && 188 get_input_index(perm[start]) == get_input_index(perm[end])) { 189 ++end; 190 } 191 shards.emplace_back(start, end); 192 shard_total += end - start; 193 } 194 CHECK_EQ(shard_total, num_nonzero_elements); 195 CHECK_LE(shards.size(), num_nonzero_elements); 196 CHECK_GT(shards.size(), 0); 197 198 // Batch the rank-one updates into a rank-k update to lower memory traffic 199 const int kMaxBatchSize = 128; 200 201 // Since we do not have an easy way of generating thread id's within the 202 // range [0,num_threads), we can instead call out to an std::unordered_map 203 // of matrices and initialize the matrix on the first call. 204 // However, this might have a performance penalty, as memory allocation can 205 // cause the OS kernel to enter a critical section and temporarily disable 206 // parallelism, and the unordered_map must be protected with a read/write 207 // mutex. 208 // 209 // TODO(jpoulson): Simplify after the thread rank can be queried 210 std::unordered_map<size_t, Eigen::MatrixXf> factor_batch_map; 211 mutex map_mutex; 212 213 BlockingCounter counter(shards.size()); 214 // Lambda encapsulating the per-shard computation. 215 auto work = [&](const Shard& shard) { 216 const std::thread::id thread_id = std::this_thread::get_id(); 217 const size_t id_hash = std::hash<std::thread::id>()(thread_id); 218 // If this thread's unique factors_mat.rows() x kMaxBatchSize 219 // batching matrix has not yet been created, then emplace it into the 220 // map using the hash of the thread id as the key. 221 // 222 // TODO(jpoulson): Switch to try_emplace once C++17 is supported 223 // TODO(b/72952120): Check whether the 3 lock-unlock pairs can be 224 // consolidated into just one. 225 map_mutex.lock(); 226 const auto key_count = factor_batch_map.count(id_hash); 227 map_mutex.unlock(); 228 if (!key_count) { 229 map_mutex.lock(); 230 factor_batch_map.emplace( 231 std::piecewise_construct, std::forward_as_tuple(id_hash), 232 std::forward_as_tuple(factors_mat.rows(), kMaxBatchSize)); 233 map_mutex.unlock(); 234 } 235 map_mutex.lock(); 236 auto& factor_batch = factor_batch_map[id_hash]; 237 map_mutex.unlock(); 238 239 CHECK_GE(shard.first, 0); 240 CHECK_LE(shard.second, perm.size()); 241 CHECK_LE(shard.first, shard.second); 242 const int64 input_index = get_input_index(perm[shard.first]); 243 const float input_weight = 244 use_entry_weights ? 1.0 : input_weights_vec(input_index); 245 // Accumulate the rhs and lhs terms in the normal equations 246 // for the non-zero elements in the row or column of the sparse matrix 247 // corresponding to input_index. 248 int num_batched = 0; 249 EigenMatrixFloatMap lhs_mat(output_lhs_tensor->flat<float>().data() + 250 input_index * factor_dim * factor_dim, 251 factor_dim, factor_dim); 252 auto lhs_symm = lhs_mat.selfadjointView<Eigen::Lower>(); 253 for (int64 p = shard.first; p < shard.second; ++p) { 254 const int64 i = perm[p]; 255 // Check that all entries in the shard have the same input index. 256 CHECK_EQ(input_index, get_input_index(i)); 257 const int64 factor_index = get_factor_index(i); 258 const float input_value = input_values_vec(i); 259 const float weight = 260 use_entry_weights ? entry_weights_vec(i) 261 : input_weight * factor_weights_vec(factor_index); 262 CHECK_GE(weight, 0); 263 factor_batch.col(num_batched) = 264 factors_mat.col(factor_index) * std::sqrt(weight); 265 ++num_batched; 266 if (num_batched == kMaxBatchSize) { 267 lhs_symm.rankUpdate(factor_batch); 268 num_batched = 0; 269 } 270 271 rhs_mat.col(input_index) += 272 input_value * (w_0 + weight) * factors_mat.col(factor_index); 273 } 274 if (num_batched != 0) { 275 auto factor_block = 276 factor_batch.block(0, 0, factors_mat.rows(), num_batched); 277 lhs_symm.rankUpdate(factor_block); 278 } 279 // Copy lower triangular to upper triangular part of normal equation 280 // matrix. 281 lhs_mat = lhs_symm; 282 counter.DecrementCount(); 283 }; 284 for (size_t i = 1; i < shards.size(); ++i) { 285 worker_threads.workers->Schedule(std::bind(work, shards[i])); 286 } 287 // Inline execute the 1st shard. 288 work(shards[0]); 289 counter.Wait(); 290 } 291 }; 292 293 REGISTER_KERNEL_BUILDER(Name("WALSComputePartialLhsAndRhs").Device(DEVICE_CPU), 294 WALSComputePartialLhsAndRhsOp); 295 296 } // namespace tensorflow 297