1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #include <algorithm>
19 #include <numeric>
20 #include <unordered_map>
21 #include <utility>
22 #include <vector>
23 
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/register_types.h"
26 
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/register_types.h"
29 #include "tensorflow/core/framework/resource_mgr.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/tensor_util.h"
32 #include "tensorflow/core/framework/types.h"
33 #include "tensorflow/core/lib/gtl/inlined_vector.h"
34 #include "tensorflow/core/util/sparse/sparse_tensor.h"
35 
36 namespace tensorflow {
37 
38 typedef Eigen::ThreadPoolDevice CPUDevice;
39 
40 using sparse::SparseTensor;
41 
42 class SparseTensorsMap : public ResourceBase {
43  public:
SparseTensorsMap(const string & name)44   explicit SparseTensorsMap(const string& name) : name_(name), counter_(0) {}
45 
DebugString() const46   string DebugString() const override { return "A SparseTensorsMap"; }
47 
48   typedef struct {
49     PersistentTensor indices;
50     PersistentTensor values;
51     gtl::InlinedVector<int64, 8> shape;
52   } PersistentSparseTensor;
53 
AddSparseTensor(OpKernelContext * ctx,const SparseTensor & sp,int64 * handle)54   Status AddSparseTensor(OpKernelContext* ctx, const SparseTensor& sp,
55                          int64* handle) {
56     PersistentTensor persistent_ix;
57     Tensor* ix;
58     TF_RETURN_IF_ERROR(ctx->allocate_persistent(
59         sp.indices().dtype(), sp.indices().shape(), &persistent_ix, &ix));
60     *ix = sp.indices();
61 
62     PersistentTensor persistent_values;
63     Tensor* values;
64     TF_RETURN_IF_ERROR(ctx->allocate_persistent(sp.indices().dtype(),
65                                                 sp.indices().shape(),
66                                                 &persistent_values, &values));
67     *values = sp.values();
68     {
69       mutex_lock l(mu_);
70       int64 unique_st_handle = counter_++;  // increment is guarded on purpose
71       sp_tensors_[unique_st_handle] = PersistentSparseTensor{
72           persistent_ix, persistent_values,
73           gtl::InlinedVector<int64, 8>(sp.shape().begin(), sp.shape().end())};
74       *handle = unique_st_handle;
75     }
76     return Status::OK();
77   }
78 
RetrieveAndClearSparseTensors(OpKernelContext * ctx,const TTypes<int64>::ConstVec & handles,std::vector<SparseTensor> * sparse_tensors)79   Status RetrieveAndClearSparseTensors(
80       OpKernelContext* ctx, const TTypes<int64>::ConstVec& handles,
81       std::vector<SparseTensor>* sparse_tensors) {
82     sparse_tensors->clear();
83     sparse_tensors->reserve(handles.size());
84     {
85       mutex_lock l(mu_);
86       for (size_t i = 0; i < handles.size(); ++i) {
87         const int64 handle = handles(i);
88         auto sp_iter = sp_tensors_.find(handle);
89         if (sp_iter == sp_tensors_.end()) {
90           return errors::InvalidArgument(
91               "Unable to find SparseTensor: ", handle, " in map: ", name_);
92         }
93         const Tensor* ix = sp_iter->second.indices.AccessTensor(ctx);
94         const Tensor* values = sp_iter->second.values.AccessTensor(ctx);
95         const auto& shape = sp_iter->second.shape;
96         SparseTensor tensor;
97         TF_RETURN_IF_ERROR(SparseTensor::Create(*ix, *values, shape, &tensor));
98         sparse_tensors->push_back(std::move(tensor));
99         sp_tensors_.erase(sp_iter);
100       }
101     }
102 
103     return Status::OK();
104   }
105 
106  protected:
~SparseTensorsMap()107   ~SparseTensorsMap() override {}
108 
109  private:
110   string name_;
111 
112   mutex mu_;
113   int64 counter_ TF_GUARDED_BY(mu_);
114   std::unordered_map<int64, PersistentSparseTensor> sp_tensors_
115       TF_GUARDED_BY(mu_);
116 };
117 
118 class SparseTensorAccessingOp : public OpKernel {
119  public:
120   typedef std::function<Status(SparseTensorsMap**)> CreatorCallback;
121 
SparseTensorAccessingOp(OpKernelConstruction * context)122   explicit SparseTensorAccessingOp(OpKernelConstruction* context)
123       : OpKernel(context), sparse_tensors_map_(nullptr) {}
124 
125  protected:
~SparseTensorAccessingOp()126   ~SparseTensorAccessingOp() override {
127     if (sparse_tensors_map_) sparse_tensors_map_->Unref();
128   }
129 
GetMap(OpKernelContext * ctx,bool is_writing,SparseTensorsMap ** sparse_tensors_map)130   Status GetMap(OpKernelContext* ctx, bool is_writing,
131                 SparseTensorsMap** sparse_tensors_map) {
132     mutex_lock l(mu_);
133 
134     if (sparse_tensors_map_) {
135       *sparse_tensors_map = sparse_tensors_map_;
136       return Status::OK();
137     }
138 
139     TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def(),
140                                    is_writing /* use_node_name_as_default */));
141 
142     CreatorCallback sparse_tensors_map_creator = [this](SparseTensorsMap** c) {
143       SparseTensorsMap* map = new SparseTensorsMap(cinfo_.name());
144       *c = map;
145       return Status::OK();
146     };
147 
148     TF_RETURN_IF_ERROR(
149         cinfo_.resource_manager()->LookupOrCreate<SparseTensorsMap>(
150             cinfo_.container(), cinfo_.name(), &sparse_tensors_map_,
151             sparse_tensors_map_creator));
152 
153     *sparse_tensors_map = sparse_tensors_map_;
154     return Status::OK();
155   }
156 
157  private:
158   ContainerInfo cinfo_;
159 
160   mutex mu_;
161   SparseTensorsMap* sparse_tensors_map_ TF_PT_GUARDED_BY(mu_);
162 };
163 
164 class AddSparseToTensorsMapOp : public SparseTensorAccessingOp {
165  public:
AddSparseToTensorsMapOp(OpKernelConstruction * context)166   explicit AddSparseToTensorsMapOp(OpKernelConstruction* context)
167       : SparseTensorAccessingOp(context) {}
168 
Compute(OpKernelContext * context)169   void Compute(OpKernelContext* context) override {
170     const Tensor* input_indices;
171     const Tensor* input_values;
172     const Tensor* input_shape;
173     SparseTensorsMap* map;
174 
175     OP_REQUIRES_OK(context, context->input("sparse_indices", &input_indices));
176     OP_REQUIRES_OK(context, context->input("sparse_values", &input_values));
177     OP_REQUIRES_OK(context, context->input("sparse_shape", &input_shape));
178     OP_REQUIRES_OK(context, GetMap(context, true /* is_writing */, &map));
179 
180     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices->shape()),
181                 errors::InvalidArgument(
182                     "Input indices should be a matrix but received shape ",
183                     input_indices->shape().DebugString()));
184 
185     OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values->shape()),
186                 errors::InvalidArgument(
187                     "Input values should be a vector but received shape ",
188                     input_values->shape().DebugString()));
189 
190     OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape->shape()),
191                 errors::InvalidArgument(
192                     "Input shape should be a vector but received shape ",
193                     input_shape->shape().DebugString()));
194 
195     TensorShape input_shape_object;
196     OP_REQUIRES_OK(context,
197                    TensorShapeUtils::MakeShape(input_shape->vec<int64>().data(),
198                                                input_shape->NumElements(),
199                                                &input_shape_object));
200     SparseTensor st;
201     OP_REQUIRES_OK(context, SparseTensor::Create(*input_indices, *input_values,
202                                                  input_shape_object, &st));
203     int64 handle;
204     OP_REQUIRES_OK(context, map->AddSparseTensor(context, st, &handle));
205 
206     Tensor sparse_handle(DT_INT64, TensorShape({}));
207     auto sparse_handle_t = sparse_handle.scalar<int64>();
208 
209     sparse_handle_t() = handle;
210 
211     context->set_output(0, sparse_handle);
212   }
213 };
214 
215 REGISTER_KERNEL_BUILDER(Name("AddSparseToTensorsMap").Device(DEVICE_CPU),
216                         AddSparseToTensorsMapOp);
217 
218 template <typename T>
219 class AddManySparseToTensorsMapOp : public SparseTensorAccessingOp {
220  public:
AddManySparseToTensorsMapOp(OpKernelConstruction * context)221   explicit AddManySparseToTensorsMapOp(OpKernelConstruction* context)
222       : SparseTensorAccessingOp(context) {}
223 
Compute(OpKernelContext * context)224   void Compute(OpKernelContext* context) override {
225     const Tensor* input_indices;
226     const Tensor* input_values;
227     const Tensor* input_shape;
228     SparseTensorsMap* map;
229 
230     OP_REQUIRES_OK(context, context->input("sparse_indices", &input_indices));
231     OP_REQUIRES_OK(context, context->input("sparse_values", &input_values));
232     OP_REQUIRES_OK(context, context->input("sparse_shape", &input_shape));
233     OP_REQUIRES_OK(context, GetMap(context, true /* is_writing */, &map));
234 
235     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices->shape()),
236                 errors::InvalidArgument(
237                     "Input indices should be a matrix but received shape ",
238                     input_indices->shape().DebugString()));
239 
240     OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values->shape()),
241                 errors::InvalidArgument(
242                     "Input values should be a vector but received shape ",
243                     input_values->shape().DebugString()));
244 
245     OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape->shape()),
246                 errors::InvalidArgument(
247                     "Input shape should be a vector but received shape ",
248                     input_shape->shape().DebugString()));
249 
250     int rank = input_shape->NumElements();
251 
252     OP_REQUIRES(
253         context, rank > 1,
254         errors::InvalidArgument(
255             "Rank of input SparseTensor should be > 1, but saw rank: ", rank));
256 
257     TensorShape tensor_input_shape(input_shape->vec<int64>());
258     gtl::InlinedVector<int64, 8> std_order(rank);
259     std::iota(std_order.begin(), std_order.end(), 0);
260     SparseTensor input_st;
261     OP_REQUIRES_OK(context, SparseTensor::Create(*input_indices, *input_values,
262                                                  tensor_input_shape, std_order,
263                                                  &input_st));
264 
265     auto input_shape_t = input_shape->vec<int64>();
266     const int64 N = input_shape_t(0);
267 
268     Tensor sparse_handles(DT_INT64, TensorShape({N}));
269     auto sparse_handles_t = sparse_handles.vec<int64>();
270 
271     OP_REQUIRES_OK(context, input_st.IndicesValid());
272 
273     // We can generate the output shape proto string now, for all
274     // minibatch entries.
275     TensorShape output_shape;
276     OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
277                                 input_shape_t.data() + 1,
278                                 input_shape->NumElements() - 1, &output_shape));
279 
280     // Get groups by minibatch dimension
281     std::unordered_set<int64> visited;
282     sparse::GroupIterable minibatch = input_st.group({0});
283     for (const auto& subset : minibatch) {
284       const int64 b = subset.group()[0];
285       visited.insert(b);
286       OP_REQUIRES(
287           context, b > -1 && b < N,
288           errors::InvalidArgument(
289               "Received unexpected column 0 value in input SparseTensor: ", b,
290               " < 0 or >= N (= ", N, ")"));
291 
292       const auto indices = subset.indices();
293       const auto values = subset.values<T>();
294       const int64 num_entries = values.size();
295 
296       Tensor output_indices = Tensor(DT_INT64, {num_entries, rank - 1});
297       Tensor output_values = Tensor(DataTypeToEnum<T>::value, {num_entries});
298 
299       auto output_indices_t = output_indices.matrix<int64>();
300       auto output_values_t = output_values.vec<T>();
301 
302       for (int i = 0; i < num_entries; ++i) {
303         for (int d = 1; d < rank; ++d) {
304           output_indices_t(i, d - 1) = indices(i, d);
305         }
306         output_values_t(i) = values(i);
307       }
308 
309       SparseTensor st_i;
310       OP_REQUIRES_OK(context,
311                      SparseTensor::Create(output_indices, output_values,
312                                           output_shape, &st_i));
313       int64 handle;
314       OP_REQUIRES_OK(context, map->AddSparseTensor(context, st_i, &handle));
315       sparse_handles_t(b) = handle;
316     }
317 
318     // Fill in any gaps; we must provide an empty ST for batch entries
319     // the grouper didn't find.
320     if (visited.size() < N) {
321       Tensor empty_indices(DT_INT64, {0, rank - 1});
322       Tensor empty_values(DataTypeToEnum<T>::value, {0});
323       SparseTensor empty_st;
324       OP_REQUIRES_OK(context, SparseTensor::Create(empty_indices, empty_values,
325                                                    output_shape, &empty_st));
326 
327       for (int64 b = 0; b < N; ++b) {
328         // We skipped this batch entry.
329         if (visited.find(b) == visited.end()) {
330           int64 handle;
331           OP_REQUIRES_OK(context,
332                          map->AddSparseTensor(context, empty_st, &handle));
333           sparse_handles_t(b) = handle;
334         }
335       }
336     }
337 
338     context->set_output(0, sparse_handles);
339   }
340 };
341 
342 #define REGISTER_KERNELS(type)                              \
343   REGISTER_KERNEL_BUILDER(Name("AddManySparseToTensorsMap") \
344                               .Device(DEVICE_CPU)           \
345                               .TypeConstraint<type>("T"),   \
346                           AddManySparseToTensorsMapOp<type>)
347 
348 TF_CALL_ALL_TYPES(REGISTER_KERNELS);
349 #undef REGISTER_KERNELS
350 
351 template <typename T>
352 class TakeManySparseFromTensorsMapOp : public SparseTensorAccessingOp {
353  public:
TakeManySparseFromTensorsMapOp(OpKernelConstruction * context)354   explicit TakeManySparseFromTensorsMapOp(OpKernelConstruction* context)
355       : SparseTensorAccessingOp(context) {}
356 
Compute(OpKernelContext * context)357   void Compute(OpKernelContext* context) override {
358     SparseTensorsMap* map = nullptr;
359     OP_REQUIRES_OK(context, GetMap(context, false /* is_writing */, &map));
360 
361     const Tensor& sparse_handles = context->input(0);
362 
363     OP_REQUIRES(context, TensorShapeUtils::IsVector(sparse_handles.shape()),
364                 errors::InvalidArgument(
365                     "sparse_handles should be a vector but received shape ",
366                     sparse_handles.shape().DebugString()));
367 
368     int64 N = sparse_handles.shape().dim_size(0);
369 
370     OP_REQUIRES(
371         context, N > 0,
372         errors::InvalidArgument("Must have at least 1 serialized SparseTensor, "
373                                 "but input matrix has 0 rows"));
374 
375     std::vector<Tensor> indices_to_concat;
376     std::vector<Tensor> values_to_concat;
377     std::vector<TensorShape> shapes_to_concat;
378 
379     const auto& sparse_handles_t = sparse_handles.vec<int64>();
380 
381     std::vector<SparseTensor> sparse_tensors;
382 
383     OP_REQUIRES_OK(context, map->RetrieveAndClearSparseTensors(
384                                 context, sparse_handles_t, &sparse_tensors));
385 
386     for (int64 i = 0; i < N; ++i) {
387       const SparseTensor& st = sparse_tensors[i];
388       const Tensor& output_indices = st.indices();
389       const Tensor& output_values = st.values();
390       const auto output_shape = st.shape();
391 
392       OP_REQUIRES(context, TensorShapeUtils::IsMatrix(output_indices.shape()),
393                   errors::InvalidArgument(
394                       "Expected sparse_handles[", i,
395                       "] to represent an index matrix but received shape ",
396                       output_indices.shape().DebugString()));
397       OP_REQUIRES(context, TensorShapeUtils::IsVector(output_values.shape()),
398                   errors::InvalidArgument(
399                       "Expected sparse_handles[", i,
400                       "] to represent a values vector but received shape ",
401                       output_values.shape().DebugString()));
402       OP_REQUIRES(
403           context, DataTypeToEnum<T>::value == output_values.dtype(),
404           errors::InvalidArgument(
405               "Requested SparseTensor of type ",
406               DataTypeString(DataTypeToEnum<T>::value), " but SparseTensor[", i,
407               "].values.dtype() == ", DataTypeString(output_values.dtype())));
408 
409       int64 num_entries = output_indices.dim_size(0);
410       OP_REQUIRES(context, num_entries == output_values.dim_size(0),
411                   errors::InvalidArgument(
412                       "Expected row counts of SparseTensor[", i,
413                       "].indices and SparseTensor[", i,
414                       "].values to match but they do not: ", num_entries,
415                       " vs. ", output_values.dim_size(0)));
416       int rank = output_indices.dim_size(1);
417       OP_REQUIRES(
418           context, rank == output_shape.size(),
419           errors::InvalidArgument("Expected column counts of SparseTensor[", i,
420                                   "].indices to match size of SparseTensor[", i,
421                                   "].shape "
422                                   "but they do not: ",
423                                   rank, " vs. ", output_shape.size()));
424 
425       // Now we expand each SparseTensors' indices and shape by
426       // prefixing a dimension
427       Tensor expanded_indices(
428           DT_INT64, TensorShape({num_entries, 1 + output_indices.dim_size(1)}));
429       Tensor expanded_shape(DT_INT64, TensorShape({1 + rank}));
430       const auto& output_indices_t = output_indices.matrix<int64>();
431       auto expanded_indices_t = expanded_indices.matrix<int64>();
432       auto expanded_shape_t = expanded_shape.vec<int64>();
433       expanded_indices_t.chip<1>(0).setZero();
434       Eigen::DSizes<Eigen::DenseIndex, 2> indices_start(0, 1);
435       Eigen::DSizes<Eigen::DenseIndex, 2> indices_sizes(num_entries, rank);
436       expanded_indices_t.slice(indices_start, indices_sizes) = output_indices_t;
437       expanded_shape_t(0) = 1;
438       // TODO: copy shape from TensorShape to &expanded_shape_t(1)
439       // std::copy_n(&output_shape_t(0), rank, &expanded_shape_t(1));
440       for (int i = 0; i < rank; ++i) {
441         expanded_shape_t(i + 1) = output_shape[i];
442       }
443       TensorShape expanded_tensor_shape(expanded_shape_t);
444 
445       indices_to_concat.push_back(std::move(expanded_indices));
446       values_to_concat.push_back(output_values);
447       shapes_to_concat.push_back(std::move(expanded_tensor_shape));
448     }
449 
450     int rank = -1;
451     for (int i = 0; i < N; ++i) {
452       if (rank < 0) rank = shapes_to_concat[i].dims();
453       OP_REQUIRES(context, rank == shapes_to_concat[i].dims(),
454                   errors::InvalidArgument(
455                       "Inconsistent rank across SparseTensors: rank prior to "
456                       "SparseTensor[",
457                       i, "] was: ", rank, " but rank of SparseTensor[", i,
458                       "] is: ", shapes_to_concat[i].dims()));
459     }
460 
461     // SparseTensor::Concat requires consistent shape for all but the
462     // primary order dimension (dimension 0 in this case).  So we get
463     // the maximum value across all the input SparseTensors for each
464     // dimension and use that.
465     TensorShape preconcat_shape(shapes_to_concat[0]);
466     for (int i = 0; i < N; ++i) {
467       for (int d = 0; d < rank; ++d) {
468         preconcat_shape.set_dim(d, std::max(preconcat_shape.dim_size(d),
469                                             shapes_to_concat[i].dim_size(d)));
470       }
471     }
472 
473     // Dimension 0 is the primary dimension.
474     gtl::InlinedVector<int64, 8> std_order(rank);
475     std::iota(std_order.begin(), std_order.end(), 0);
476 
477     std::vector<SparseTensor> tensors_to_concat;
478     tensors_to_concat.reserve(N);
479     for (int i = 0; i < N; ++i) {
480       SparseTensor tensor;
481       OP_REQUIRES_OK(context,
482                      SparseTensor::Create(std::move(indices_to_concat[i]),
483                                           std::move(values_to_concat[i]),
484                                           preconcat_shape, std_order, &tensor));
485       tensors_to_concat.push_back(std::move(tensor));
486     }
487 
488     auto output = SparseTensor::Concat<T>(tensors_to_concat);
489     Tensor final_output_shape(DT_INT64, TensorShape({output.dims()}));
490 
491     std::copy_n(output.shape().data(), output.dims(),
492                 final_output_shape.vec<int64>().data());
493 
494     context->set_output(0, output.indices());
495     context->set_output(1, output.values());
496     context->set_output(2, final_output_shape);
497   }
498 };
499 
500 #define REGISTER_KERNELS(type)                                 \
501   REGISTER_KERNEL_BUILDER(Name("TakeManySparseFromTensorsMap") \
502                               .Device(DEVICE_CPU)              \
503                               .TypeConstraint<type>("dtype"),  \
504                           TakeManySparseFromTensorsMapOp<type>)
505 
506 TF_CALL_ALL_TYPES(REGISTER_KERNELS);
507 #undef REGISTER_KERNELS
508 
509 }  // namespace tensorflow
510