1 // Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License.  You may obtain a copy
5 // of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 // ==============================================================================
15 
16 // TensorFlow kernels and Ops for constructing WALS normal equations.
17 // TODO(agarwal,rmlarsen): Add security checks to the code.
18 
19 #include <algorithm>
20 #include <numeric>
21 #include <vector>
22 
23 // This is only used for std::this_thread::get_id()
24 #include <thread>  // NOLINT
25 
26 #include "tensorflow/core/framework/op.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/tensor_shape.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/lib/core/blocking_counter.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/core/threadpool.h"
34 #include "tensorflow/core/platform/mutex.h"
35 
36 using tensorflow::DEVICE_CPU;
37 using tensorflow::DT_BOOL;
38 using tensorflow::DT_FLOAT;
39 using tensorflow::DT_INT64;
40 using tensorflow::OpKernel;
41 using tensorflow::OpKernelConstruction;
42 using tensorflow::OpKernelContext;
43 using tensorflow::Tensor;
44 using tensorflow::TensorShape;
45 using tensorflow::TensorShapeUtils;
46 using tensorflow::errors::InvalidArgument;
47 
48 namespace tensorflow {
49 
50 // TODO(ataei): Consider using RowMajor maps.
51 typedef Eigen::Map<
52     Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>
53     EigenMatrixFloatMap;
54 typedef Eigen::Map<
55     const Eigen::Matrix<int64, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>
56     ConstEigenMatrixInt64Map;
57 typedef Eigen::Map<
58     const Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>
59     ConstEigenMatrixFloatMap;
60 
61 class WALSComputePartialLhsAndRhsOp : public OpKernel {
62  public:
WALSComputePartialLhsAndRhsOp(OpKernelConstruction * context)63   explicit WALSComputePartialLhsAndRhsOp(OpKernelConstruction* context)
64       : OpKernel(context) {
65     OP_REQUIRES_OK(context,
66                    context->MatchSignature(
67                        {DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT64,
68                         DT_FLOAT, DT_FLOAT, DT_INT64, DT_BOOL},
69                        {DT_FLOAT, DT_FLOAT}));
70   }
71 
Compute(OpKernelContext * context)72   void Compute(OpKernelContext* context) override {
73     const Tensor& factors = context->input(0);
74     const Tensor& factor_weights = context->input(1);
75     const Tensor& unobserved_weights = context->input(2);
76     const Tensor& input_weights = context->input(3);
77     const Tensor& input_indices = context->input(4);
78     const Tensor& input_values = context->input(5);
79     const Tensor& entry_weights = context->input(6);
80     const Tensor& input_block_size = context->input(7);
81     const Tensor& input_is_transpose = context->input(8);
82 
83     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(factors.shape()),
84                 InvalidArgument("Input factors should be a matrix."));
85     OP_REQUIRES(context, TensorShapeUtils::IsVector(factor_weights.shape()),
86                 InvalidArgument("Input factor_weights should be a vector."));
87     OP_REQUIRES(
88         context, TensorShapeUtils::IsScalar(unobserved_weights.shape()),
89         InvalidArgument("Input unobserved_weights should be a scalar."));
90     OP_REQUIRES(context, TensorShapeUtils::IsVector(input_weights.shape()),
91                 InvalidArgument("Input input_weights should be a vector."));
92     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices.shape()),
93                 InvalidArgument("Input input_indices should be a matrix."));
94     OP_REQUIRES(
95         context, input_indices.dim_size(1) == 2,
96         InvalidArgument("Input input_indices should have shape (?, 2)."));
97     OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values.shape()),
98                 InvalidArgument("Input input_values should be a vector"));
99     OP_REQUIRES(context, TensorShapeUtils::IsVector(entry_weights.shape()),
100                 InvalidArgument("Input entry_weights should be a vector"));
101     OP_REQUIRES(context, input_indices.dim_size(0) == input_values.dim_size(0),
102                 InvalidArgument("Input input_values' length should match the "
103                                 "first dimension of Input input_indices "));
104     OP_REQUIRES(context, TensorShapeUtils::IsScalar(input_block_size.shape()),
105                 InvalidArgument("Input input_block_size should be a scalar."));
106     OP_REQUIRES(
107         context, TensorShapeUtils::IsScalar(input_is_transpose.shape()),
108         InvalidArgument("Input input_is_transpose should be a scalar."));
109     OP_REQUIRES(
110         context,
111         ((input_weights.dim_size(0) > 0 &&
112           factor_weights.dim_size(0) == factors.dim_size(0) &&
113           entry_weights.dim_size(0) == 0) ||
114          (input_weights.dim_size(0) == 0 && factor_weights.dim_size(0) == 0 &&
115           entry_weights.dim_size(0) == input_indices.dim_size(0))),
116         InvalidArgument("To specify the weights for observed entries, either "
117                         "(1) entry_weights must be set or (2) input_weights "
118                         "and factor_weights must be set, but not both."));
119     // TODO(yifanchen): Deprecate the support of input_weights and
120     // factor_weights.
121 
122     const int64 factor_dim = factors.dim_size(1);
123     const int64 factors_size = factors.dim_size(0);
124     const int64 num_nonzero_elements = input_indices.dim_size(0);
125     const int64 block_size = input_block_size.scalar<int64>()();
126     const auto& factor_weights_vec = factor_weights.vec<float>();
127     const auto& input_weights_vec = input_weights.vec<float>();
128     const float w_0 = unobserved_weights.scalar<float>()();
129     const auto& input_values_vec = input_values.vec<float>();
130     const auto& entry_weights_vec = entry_weights.vec<float>();
131 
132     ConstEigenMatrixFloatMap factors_mat(factors.matrix<float>().data(),
133                                          factor_dim, factors_size);
134     ConstEigenMatrixInt64Map indices_mat(input_indices.matrix<int64>().data(),
135                                          2, num_nonzero_elements);
136 
137     Tensor* output_lhs_tensor;
138     OP_REQUIRES_OK(context,
139                    context->allocate_output(
140                        0, TensorShape({block_size, factor_dim, factor_dim}),
141                        &output_lhs_tensor));
142     auto output_lhs_t = output_lhs_tensor->tensor<float, 3>();
143     output_lhs_t.setZero();
144     Tensor* output_rhs_tensor;
145     OP_REQUIRES_OK(context, context->allocate_output(
146                                 1, TensorShape({block_size, factor_dim}),
147                                 &output_rhs_tensor));
148     EigenMatrixFloatMap rhs_mat(output_rhs_tensor->matrix<float>().data(),
149                                 factor_dim, block_size);
150     rhs_mat.setZero();
151     const bool is_transpose = input_is_transpose.scalar<bool>()();
152 
153     auto get_input_index = [is_transpose, &indices_mat](int64 i) {
154       return is_transpose ? indices_mat(1, i) : indices_mat(0, i);
155     };
156     auto get_factor_index = [is_transpose, &indices_mat](int64 i) {
157       return is_transpose ? indices_mat(0, i) : indices_mat(1, i);
158     };
159 
160     const bool use_entry_weights = entry_weights_vec.size() > 0;
161 
162     // TODO(rmlarsen): In principle, we should be using the SparseTensor class
163     // and machinery for iterating over groups, but the fact that class
164     // SparseTensor makes a complete copy of the matrix makes me reluctant to
165     // use it.
166     std::vector<int64> perm(num_nonzero_elements);
167     std::iota(perm.begin(), perm.end(), 0);
168 
169     typedef std::pair<int64, int64> Shard;
170     std::vector<Shard> shards;
171     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
172     int64 shard_total = 0;
173     // Compute a permutation such that get_input_index(perm[i]) is sorted, use
174     // stable_sort to preserve spatial locality.
175     std::stable_sort(perm.begin(), perm.end(),
176                      [&get_input_index](int64 i, int64 j) {
177                        return get_input_index(i) < get_input_index(j);
178                      });
179 
180     // Compute the start and end of runs with identical input_index.
181     // These are the shards of work that can be processed in parallel
182     // without locking.
183     int64 start = 0;
184     int64 end = 0;
185     while (end < num_nonzero_elements) {
186       start = end;
187       while (end < num_nonzero_elements &&
188              get_input_index(perm[start]) == get_input_index(perm[end])) {
189         ++end;
190       }
191       shards.emplace_back(start, end);
192       shard_total += end - start;
193     }
194     CHECK_EQ(shard_total, num_nonzero_elements);
195     CHECK_LE(shards.size(), num_nonzero_elements);
196     CHECK_GT(shards.size(), 0);
197 
198     // Batch the rank-one updates into a rank-k update to lower memory traffic
199     const int kMaxBatchSize = 128;
200 
201     // Since we do not have an easy way of generating thread id's within the
202     // range [0,num_threads), we can instead call out to an std::unordered_map
203     // of matrices and initialize the matrix on the first call.
204     // However, this might have a performance penalty, as memory allocation can
205     // cause the OS kernel to enter a critical section and temporarily disable
206     // parallelism, and the unordered_map must be protected with a read/write
207     // mutex.
208     //
209     // TODO(jpoulson): Simplify after the thread rank can be queried
210     std::unordered_map<size_t, Eigen::MatrixXf> factor_batch_map;
211     mutex map_mutex;
212 
213     BlockingCounter counter(shards.size());
214     // Lambda encapsulating the per-shard computation.
215     auto work = [&](const Shard& shard) {
216       const std::thread::id thread_id = std::this_thread::get_id();
217       const size_t id_hash = std::hash<std::thread::id>()(thread_id);
218       // If this thread's unique factors_mat.rows() x kMaxBatchSize
219       // batching matrix has not yet been created, then emplace it into the
220       // map using the hash of the thread id as the key.
221       //
222       // TODO(jpoulson): Switch to try_emplace once C++17 is supported
223       // TODO(b/72952120): Check whether the 3 lock-unlock pairs can be
224       // consolidated into just one.
225       map_mutex.lock();
226       const auto key_count = factor_batch_map.count(id_hash);
227       map_mutex.unlock();
228       if (!key_count) {
229         map_mutex.lock();
230         factor_batch_map.emplace(
231             std::piecewise_construct, std::forward_as_tuple(id_hash),
232             std::forward_as_tuple(factors_mat.rows(), kMaxBatchSize));
233         map_mutex.unlock();
234       }
235       map_mutex.lock();
236       auto& factor_batch = factor_batch_map[id_hash];
237       map_mutex.unlock();
238 
239       CHECK_GE(shard.first, 0);
240       CHECK_LE(shard.second, perm.size());
241       CHECK_LE(shard.first, shard.second);
242       const int64 input_index = get_input_index(perm[shard.first]);
243       const float input_weight =
244           use_entry_weights ? 1.0 : input_weights_vec(input_index);
245       // Accumulate the rhs and lhs terms in the normal equations
246       // for the non-zero elements in the row or column of the sparse matrix
247       // corresponding to input_index.
248       int num_batched = 0;
249       EigenMatrixFloatMap lhs_mat(output_lhs_tensor->flat<float>().data() +
250                                       input_index * factor_dim * factor_dim,
251                                   factor_dim, factor_dim);
252       auto lhs_symm = lhs_mat.selfadjointView<Eigen::Lower>();
253       for (int64 p = shard.first; p < shard.second; ++p) {
254         const int64 i = perm[p];
255         // Check that all entries in the shard have the same input index.
256         CHECK_EQ(input_index, get_input_index(i));
257         const int64 factor_index = get_factor_index(i);
258         const float input_value = input_values_vec(i);
259         const float weight =
260             use_entry_weights ? entry_weights_vec(i)
261                               : input_weight * factor_weights_vec(factor_index);
262         CHECK_GE(weight, 0);
263         factor_batch.col(num_batched) =
264             factors_mat.col(factor_index) * std::sqrt(weight);
265         ++num_batched;
266         if (num_batched == kMaxBatchSize) {
267           lhs_symm.rankUpdate(factor_batch);
268           num_batched = 0;
269         }
270 
271         rhs_mat.col(input_index) +=
272             input_value * (w_0 + weight) * factors_mat.col(factor_index);
273       }
274       if (num_batched != 0) {
275         auto factor_block =
276             factor_batch.block(0, 0, factors_mat.rows(), num_batched);
277         lhs_symm.rankUpdate(factor_block);
278       }
279       // Copy lower triangular to upper triangular part of normal equation
280       // matrix.
281       lhs_mat = lhs_symm;
282       counter.DecrementCount();
283     };
284     for (size_t i = 1; i < shards.size(); ++i) {
285       worker_threads.workers->Schedule(std::bind(work, shards[i]));
286     }
287     // Inline execute the 1st shard.
288     work(shards[0]);
289     counter.Wait();
290   }
291 };
292 
293 REGISTER_KERNEL_BUILDER(Name("WALSComputePartialLhsAndRhs").Device(DEVICE_CPU),
294                         WALSComputePartialLhsAndRhsOp);
295 
296 }  // namespace tensorflow
297