1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #include <algorithm>
19 #include <numeric>
20 #include <unordered_map>
21 #include <utility>
22 #include <vector>
23 
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/register_types.h"
26 
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/register_types.h"
29 #include "tensorflow/core/framework/resource_mgr.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/tensor_util.h"
32 #include "tensorflow/core/framework/types.h"
33 #include "tensorflow/core/lib/gtl/inlined_vector.h"
34 #include "tensorflow/core/util/sparse/sparse_tensor.h"
35 
36 namespace tensorflow {
37 
38 typedef Eigen::ThreadPoolDevice CPUDevice;
39 
40 using sparse::SparseTensor;
41 
42 class SparseTensorsMap : public ResourceBase {
43  public:
SparseTensorsMap(const string & name)44   explicit SparseTensorsMap(const string& name) : name_(name), counter_(0) {}
45 
DebugString() const46   string DebugString() const override { return "A SparseTensorsMap"; }
47 
48   typedef struct {
49     PersistentTensor indices;
50     PersistentTensor values;
51     gtl::InlinedVector<int64, 8> shape;
52   } PersistentSparseTensor;
53 
AddSparseTensor(OpKernelContext * ctx,const SparseTensor & sp,int64 * handle)54   Status AddSparseTensor(OpKernelContext* ctx, const SparseTensor& sp,
55                          int64* handle) {
56     PersistentTensor persistent_ix;
57     Tensor* ix;
58     TF_RETURN_IF_ERROR(ctx->allocate_persistent(
59         sp.indices().dtype(), sp.indices().shape(), &persistent_ix, &ix));
60     *ix = sp.indices();
61 
62     PersistentTensor persistent_values;
63     Tensor* values;
64     TF_RETURN_IF_ERROR(ctx->allocate_persistent(sp.indices().dtype(),
65                                                 sp.indices().shape(),
66                                                 &persistent_values, &values));
67     *values = sp.values();
68     {
69       mutex_lock l(mu_);
70       int64 unique_st_handle = counter_++;  // increment is guarded on purpose
71       sp_tensors_[unique_st_handle] = PersistentSparseTensor{
72           persistent_ix, persistent_values,
73           gtl::InlinedVector<int64, 8>(sp.shape().begin(), sp.shape().end())};
74       *handle = unique_st_handle;
75     }
76     return Status::OK();
77   }
78 
RetrieveAndClearSparseTensors(OpKernelContext * ctx,const TTypes<int64>::ConstVec & handles,std::vector<SparseTensor> * sparse_tensors)79   Status RetrieveAndClearSparseTensors(
80       OpKernelContext* ctx, const TTypes<int64>::ConstVec& handles,
81       std::vector<SparseTensor>* sparse_tensors) {
82     sparse_tensors->clear();
83     sparse_tensors->reserve(handles.size());
84     {
85       mutex_lock l(mu_);
86       for (size_t i = 0; i < handles.size(); ++i) {
87         const int64 handle = handles(i);
88         auto sp_iter = sp_tensors_.find(handle);
89         if (sp_iter == sp_tensors_.end()) {
90           return errors::InvalidArgument(
91               "Unable to find SparseTensor: ", handle, " in map: ", name_);
92         }
93         const Tensor* ix = sp_iter->second.indices.AccessTensor(ctx);
94         const Tensor* values = sp_iter->second.values.AccessTensor(ctx);
95         const auto& shape = sp_iter->second.shape;
96         SparseTensor tensor;
97         TF_RETURN_IF_ERROR(SparseTensor::Create(*ix, *values, shape, &tensor));
98         sparse_tensors->push_back(std::move(tensor));
99         sp_tensors_.erase(sp_iter);
100       }
101     }
102 
103     return Status::OK();
104   }
105 
106  protected:
~SparseTensorsMap()107   ~SparseTensorsMap() override {}
108 
109  private:
110   string name_;
111 
112   mutex mu_;
113   int64 counter_ GUARDED_BY(mu_);
114   std::unordered_map<int64, PersistentSparseTensor> sp_tensors_ GUARDED_BY(mu_);
115 };
116 
117 class SparseTensorAccessingOp : public OpKernel {
118  public:
119   typedef std::function<Status(SparseTensorsMap**)> CreatorCallback;
120 
SparseTensorAccessingOp(OpKernelConstruction * context)121   explicit SparseTensorAccessingOp(OpKernelConstruction* context)
122       : OpKernel(context), sparse_tensors_map_(nullptr) {}
123 
124  protected:
~SparseTensorAccessingOp()125   ~SparseTensorAccessingOp() override {
126     if (sparse_tensors_map_) sparse_tensors_map_->Unref();
127   }
128 
GetMap(OpKernelContext * ctx,bool is_writing,SparseTensorsMap ** sparse_tensors_map)129   Status GetMap(OpKernelContext* ctx, bool is_writing,
130                 SparseTensorsMap** sparse_tensors_map) {
131     mutex_lock l(mu_);
132 
133     if (sparse_tensors_map_) {
134       *sparse_tensors_map = sparse_tensors_map_;
135       return Status::OK();
136     }
137 
138     TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def(),
139                                    is_writing /* use_node_name_as_default */));
140 
141     CreatorCallback sparse_tensors_map_creator = [this](SparseTensorsMap** c) {
142       SparseTensorsMap* map = new SparseTensorsMap(cinfo_.name());
143       *c = map;
144       return Status::OK();
145     };
146 
147     TF_RETURN_IF_ERROR(
148         cinfo_.resource_manager()->LookupOrCreate<SparseTensorsMap>(
149             cinfo_.container(), cinfo_.name(), &sparse_tensors_map_,
150             sparse_tensors_map_creator));
151 
152     *sparse_tensors_map = sparse_tensors_map_;
153     return Status::OK();
154   }
155 
156  private:
157   ContainerInfo cinfo_;
158 
159   mutex mu_;
160   SparseTensorsMap* sparse_tensors_map_ PT_GUARDED_BY(mu_);
161 };
162 
163 class AddSparseToTensorsMapOp : public SparseTensorAccessingOp {
164  public:
AddSparseToTensorsMapOp(OpKernelConstruction * context)165   explicit AddSparseToTensorsMapOp(OpKernelConstruction* context)
166       : SparseTensorAccessingOp(context) {}
167 
Compute(OpKernelContext * context)168   void Compute(OpKernelContext* context) override {
169     const Tensor* input_indices;
170     const Tensor* input_values;
171     const Tensor* input_shape;
172     SparseTensorsMap* map;
173 
174     OP_REQUIRES_OK(context, context->input("sparse_indices", &input_indices));
175     OP_REQUIRES_OK(context, context->input("sparse_values", &input_values));
176     OP_REQUIRES_OK(context, context->input("sparse_shape", &input_shape));
177     OP_REQUIRES_OK(context, GetMap(context, true /* is_writing */, &map));
178 
179     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices->shape()),
180                 errors::InvalidArgument(
181                     "Input indices should be a matrix but received shape ",
182                     input_indices->shape().DebugString()));
183 
184     OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values->shape()),
185                 errors::InvalidArgument(
186                     "Input values should be a vector but received shape ",
187                     input_values->shape().DebugString()));
188 
189     OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape->shape()),
190                 errors::InvalidArgument(
191                     "Input shape should be a vector but received shape ",
192                     input_shape->shape().DebugString()));
193 
194     TensorShape input_shape_object;
195     OP_REQUIRES_OK(context,
196                    TensorShapeUtils::MakeShape(input_shape->vec<int64>().data(),
197                                                input_shape->NumElements(),
198                                                &input_shape_object));
199     SparseTensor st;
200     OP_REQUIRES_OK(context, SparseTensor::Create(*input_indices, *input_values,
201                                                  input_shape_object, &st));
202     int64 handle;
203     OP_REQUIRES_OK(context, map->AddSparseTensor(context, st, &handle));
204 
205     Tensor sparse_handle(DT_INT64, TensorShape({}));
206     auto sparse_handle_t = sparse_handle.scalar<int64>();
207 
208     sparse_handle_t() = handle;
209 
210     context->set_output(0, sparse_handle);
211   }
212 };
213 
214 REGISTER_KERNEL_BUILDER(Name("AddSparseToTensorsMap").Device(DEVICE_CPU),
215                         AddSparseToTensorsMapOp);
216 
217 template <typename T>
218 class AddManySparseToTensorsMapOp : public SparseTensorAccessingOp {
219  public:
AddManySparseToTensorsMapOp(OpKernelConstruction * context)220   explicit AddManySparseToTensorsMapOp(OpKernelConstruction* context)
221       : SparseTensorAccessingOp(context) {}
222 
Compute(OpKernelContext * context)223   void Compute(OpKernelContext* context) override {
224     const Tensor* input_indices;
225     const Tensor* input_values;
226     const Tensor* input_shape;
227     SparseTensorsMap* map;
228 
229     OP_REQUIRES_OK(context, context->input("sparse_indices", &input_indices));
230     OP_REQUIRES_OK(context, context->input("sparse_values", &input_values));
231     OP_REQUIRES_OK(context, context->input("sparse_shape", &input_shape));
232     OP_REQUIRES_OK(context, GetMap(context, true /* is_writing */, &map));
233 
234     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices->shape()),
235                 errors::InvalidArgument(
236                     "Input indices should be a matrix but received shape ",
237                     input_indices->shape().DebugString()));
238 
239     OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values->shape()),
240                 errors::InvalidArgument(
241                     "Input values should be a vector but received shape ",
242                     input_values->shape().DebugString()));
243 
244     OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape->shape()),
245                 errors::InvalidArgument(
246                     "Input shape should be a vector but received shape ",
247                     input_shape->shape().DebugString()));
248 
249     int rank = input_shape->NumElements();
250 
251     OP_REQUIRES(
252         context, rank > 1,
253         errors::InvalidArgument(
254             "Rank of input SparseTensor should be > 1, but saw rank: ", rank));
255 
256     TensorShape tensor_input_shape(input_shape->vec<int64>());
257     gtl::InlinedVector<int64, 8> std_order(rank);
258     std::iota(std_order.begin(), std_order.end(), 0);
259     SparseTensor input_st;
260     OP_REQUIRES_OK(context, SparseTensor::Create(*input_indices, *input_values,
261                                                  tensor_input_shape, std_order,
262                                                  &input_st));
263 
264     auto input_shape_t = input_shape->vec<int64>();
265     const int64 N = input_shape_t(0);
266 
267     Tensor sparse_handles(DT_INT64, TensorShape({N}));
268     auto sparse_handles_t = sparse_handles.vec<int64>();
269 
270     OP_REQUIRES_OK(context, input_st.IndicesValid());
271 
272     // We can generate the output shape proto string now, for all
273     // minibatch entries.
274     TensorShape output_shape;
275     OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
276                                 input_shape_t.data() + 1,
277                                 input_shape->NumElements() - 1, &output_shape));
278 
279     // Get groups by minibatch dimension
280     std::unordered_set<int64> visited;
281     sparse::GroupIterable minibatch = input_st.group({0});
282     for (const auto& subset : minibatch) {
283       const int64 b = subset.group()[0];
284       visited.insert(b);
285       OP_REQUIRES(
286           context, b > -1 && b < N,
287           errors::InvalidArgument(
288               "Received unexpected column 0 value in input SparseTensor: ", b,
289               " < 0 or >= N (= ", N, ")"));
290 
291       const auto indices = subset.indices();
292       const auto values = subset.values<T>();
293       const int64 num_entries = values.size();
294 
295       Tensor output_indices = Tensor(DT_INT64, {num_entries, rank - 1});
296       Tensor output_values = Tensor(DataTypeToEnum<T>::value, {num_entries});
297 
298       auto output_indices_t = output_indices.matrix<int64>();
299       auto output_values_t = output_values.vec<T>();
300 
301       for (int i = 0; i < num_entries; ++i) {
302         for (int d = 1; d < rank; ++d) {
303           output_indices_t(i, d - 1) = indices(i, d);
304         }
305         output_values_t(i) = values(i);
306       }
307 
308       SparseTensor st_i;
309       OP_REQUIRES_OK(context,
310                      SparseTensor::Create(output_indices, output_values,
311                                           output_shape, &st_i));
312       int64 handle;
313       OP_REQUIRES_OK(context, map->AddSparseTensor(context, st_i, &handle));
314       sparse_handles_t(b) = handle;
315     }
316 
317     // Fill in any gaps; we must provide an empty ST for batch entries
318     // the grouper didn't find.
319     if (visited.size() < N) {
320       Tensor empty_indices(DT_INT64, {0, rank - 1});
321       Tensor empty_values(DataTypeToEnum<T>::value, {0});
322       SparseTensor empty_st;
323       OP_REQUIRES_OK(context, SparseTensor::Create(empty_indices, empty_values,
324                                                    output_shape, &empty_st));
325 
326       for (int64 b = 0; b < N; ++b) {
327         // We skipped this batch entry.
328         if (visited.find(b) == visited.end()) {
329           int64 handle;
330           OP_REQUIRES_OK(context,
331                          map->AddSparseTensor(context, empty_st, &handle));
332           sparse_handles_t(b) = handle;
333         }
334       }
335     }
336 
337     context->set_output(0, sparse_handles);
338   }
339 };
340 
341 #define REGISTER_KERNELS(type)                              \
342   REGISTER_KERNEL_BUILDER(Name("AddManySparseToTensorsMap") \
343                               .Device(DEVICE_CPU)           \
344                               .TypeConstraint<type>("T"),   \
345                           AddManySparseToTensorsMapOp<type>)
346 
347 TF_CALL_ALL_TYPES(REGISTER_KERNELS);
348 #undef REGISTER_KERNELS
349 
350 template <typename T>
351 class TakeManySparseFromTensorsMapOp : public SparseTensorAccessingOp {
352  public:
TakeManySparseFromTensorsMapOp(OpKernelConstruction * context)353   explicit TakeManySparseFromTensorsMapOp(OpKernelConstruction* context)
354       : SparseTensorAccessingOp(context) {}
355 
Compute(OpKernelContext * context)356   void Compute(OpKernelContext* context) override {
357     SparseTensorsMap* map = nullptr;
358     OP_REQUIRES_OK(context, GetMap(context, false /* is_writing */, &map));
359 
360     const Tensor& sparse_handles = context->input(0);
361 
362     OP_REQUIRES(context, TensorShapeUtils::IsVector(sparse_handles.shape()),
363                 errors::InvalidArgument(
364                     "sparse_handles should be a vector but received shape ",
365                     sparse_handles.shape().DebugString()));
366 
367     int64 N = sparse_handles.shape().dim_size(0);
368 
369     OP_REQUIRES(
370         context, N > 0,
371         errors::InvalidArgument("Must have at least 1 serialized SparseTensor, "
372                                 "but input matrix has 0 rows"));
373 
374     std::vector<Tensor> indices_to_concat;
375     std::vector<Tensor> values_to_concat;
376     std::vector<TensorShape> shapes_to_concat;
377 
378     const auto& sparse_handles_t = sparse_handles.vec<int64>();
379 
380     std::vector<SparseTensor> sparse_tensors;
381 
382     OP_REQUIRES_OK(context, map->RetrieveAndClearSparseTensors(
383                                 context, sparse_handles_t, &sparse_tensors));
384 
385     for (int64 i = 0; i < N; ++i) {
386       const SparseTensor& st = sparse_tensors[i];
387       const Tensor& output_indices = st.indices();
388       const Tensor& output_values = st.values();
389       const auto output_shape = st.shape();
390 
391       OP_REQUIRES(context, TensorShapeUtils::IsMatrix(output_indices.shape()),
392                   errors::InvalidArgument(
393                       "Expected sparse_handles[", i,
394                       "] to represent an index matrix but received shape ",
395                       output_indices.shape().DebugString()));
396       OP_REQUIRES(context, TensorShapeUtils::IsVector(output_values.shape()),
397                   errors::InvalidArgument(
398                       "Expected sparse_handles[", i,
399                       "] to represent a values vector but received shape ",
400                       output_values.shape().DebugString()));
401       OP_REQUIRES(
402           context, DataTypeToEnum<T>::value == output_values.dtype(),
403           errors::InvalidArgument(
404               "Requested SparseTensor of type ",
405               DataTypeString(DataTypeToEnum<T>::value), " but SparseTensor[", i,
406               "].values.dtype() == ", DataTypeString(output_values.dtype())));
407 
408       int64 num_entries = output_indices.dim_size(0);
409       OP_REQUIRES(context, num_entries == output_values.dim_size(0),
410                   errors::InvalidArgument(
411                       "Expected row counts of SparseTensor[", i,
412                       "].indices and SparseTensor[", i,
413                       "].values to match but they do not: ", num_entries,
414                       " vs. ", output_values.dim_size(0)));
415       int rank = output_indices.dim_size(1);
416       OP_REQUIRES(
417           context, rank == output_shape.size(),
418           errors::InvalidArgument("Expected column counts of SparseTensor[", i,
419                                   "].indices to match size of SparseTensor[", i,
420                                   "].shape "
421                                   "but they do not: ",
422                                   rank, " vs. ", output_shape.size()));
423 
424       // Now we expand each SparseTensors' indices and shape by
425       // prefixing a dimension
426       Tensor expanded_indices(
427           DT_INT64, TensorShape({num_entries, 1 + output_indices.dim_size(1)}));
428       Tensor expanded_shape(DT_INT64, TensorShape({1 + rank}));
429       const auto& output_indices_t = output_indices.matrix<int64>();
430       auto expanded_indices_t = expanded_indices.matrix<int64>();
431       auto expanded_shape_t = expanded_shape.vec<int64>();
432       expanded_indices_t.chip<1>(0).setZero();
433       Eigen::DSizes<Eigen::DenseIndex, 2> indices_start(0, 1);
434       Eigen::DSizes<Eigen::DenseIndex, 2> indices_sizes(num_entries, rank);
435       expanded_indices_t.slice(indices_start, indices_sizes) = output_indices_t;
436       expanded_shape_t(0) = 1;
437       // TODO: copy shape from TensorShape to &expanded_shape_t(1)
438       // std::copy_n(&output_shape_t(0), rank, &expanded_shape_t(1));
439       for (int i = 0; i < rank; ++i) {
440         expanded_shape_t(i + 1) = output_shape[i];
441       }
442       TensorShape expanded_tensor_shape(expanded_shape_t);
443 
444       indices_to_concat.push_back(std::move(expanded_indices));
445       values_to_concat.push_back(output_values);
446       shapes_to_concat.push_back(std::move(expanded_tensor_shape));
447     }
448 
449     int rank = -1;
450     for (int i = 0; i < N; ++i) {
451       if (rank < 0) rank = shapes_to_concat[i].dims();
452       OP_REQUIRES(context, rank == shapes_to_concat[i].dims(),
453                   errors::InvalidArgument(
454                       "Inconsistent rank across SparseTensors: rank prior to "
455                       "SparseTensor[",
456                       i, "] was: ", rank, " but rank of SparseTensor[", i,
457                       "] is: ", shapes_to_concat[i].dims()));
458     }
459 
460     // SparseTensor::Concat requires consistent shape for all but the
461     // primary order dimension (dimension 0 in this case).  So we get
462     // the maximum value across all the input SparseTensors for each
463     // dimension and use that.
464     TensorShape preconcat_shape(shapes_to_concat[0]);
465     for (int i = 0; i < N; ++i) {
466       for (int d = 0; d < rank; ++d) {
467         preconcat_shape.set_dim(d, std::max(preconcat_shape.dim_size(d),
468                                             shapes_to_concat[i].dim_size(d)));
469       }
470     }
471 
472     // Dimension 0 is the primary dimension.
473     gtl::InlinedVector<int64, 8> std_order(rank);
474     std::iota(std_order.begin(), std_order.end(), 0);
475 
476     std::vector<SparseTensor> tensors_to_concat;
477     tensors_to_concat.reserve(N);
478     for (int i = 0; i < N; ++i) {
479       SparseTensor tensor;
480       OP_REQUIRES_OK(context,
481                      SparseTensor::Create(std::move(indices_to_concat[i]),
482                                           std::move(values_to_concat[i]),
483                                           preconcat_shape, std_order, &tensor));
484       tensors_to_concat.push_back(std::move(tensor));
485     }
486 
487     auto output = SparseTensor::Concat<T>(tensors_to_concat);
488     Tensor final_output_shape(DT_INT64, TensorShape({output.dims()}));
489 
490     std::copy_n(output.shape().data(), output.dims(),
491                 final_output_shape.vec<int64>().data());
492 
493     context->set_output(0, output.indices());
494     context->set_output(1, output.values());
495     context->set_output(2, final_output_shape);
496   }
497 };
498 
499 #define REGISTER_KERNELS(type)                                 \
500   REGISTER_KERNEL_BUILDER(Name("TakeManySparseFromTensorsMap") \
501                               .Device(DEVICE_CPU)              \
502                               .TypeConstraint<type>("dtype"),  \
503                           TakeManySparseFromTensorsMapOp<type>)
504 
505 TF_CALL_ALL_TYPES(REGISTER_KERNELS);
506 #undef REGISTER_KERNELS
507 
508 }  // namespace tensorflow
509