1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #include <algorithm> 19 #include <numeric> 20 #include <unordered_map> 21 #include <utility> 22 #include <vector> 23 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/framework/register_types.h" 26 27 #include "tensorflow/core/framework/op_kernel.h" 28 #include "tensorflow/core/framework/register_types.h" 29 #include "tensorflow/core/framework/resource_mgr.h" 30 #include "tensorflow/core/framework/tensor.h" 31 #include "tensorflow/core/framework/tensor_util.h" 32 #include "tensorflow/core/framework/types.h" 33 #include "tensorflow/core/lib/gtl/inlined_vector.h" 34 #include "tensorflow/core/util/sparse/sparse_tensor.h" 35 36 namespace tensorflow { 37 38 typedef Eigen::ThreadPoolDevice CPUDevice; 39 40 using sparse::SparseTensor; 41 42 class SparseTensorsMap : public ResourceBase { 43 public: SparseTensorsMap(const string & name)44 explicit SparseTensorsMap(const string& name) : name_(name), counter_(0) {} 45 DebugString() const46 string DebugString() const override { return "A SparseTensorsMap"; } 47 48 typedef struct { 49 PersistentTensor indices; 50 PersistentTensor values; 51 gtl::InlinedVector<int64, 8> shape; 52 } PersistentSparseTensor; 53 AddSparseTensor(OpKernelContext * ctx,const SparseTensor & sp,int64 * handle)54 Status AddSparseTensor(OpKernelContext* ctx, const SparseTensor& sp, 55 int64* handle) { 56 PersistentTensor persistent_ix; 57 Tensor* ix; 58 TF_RETURN_IF_ERROR(ctx->allocate_persistent( 59 sp.indices().dtype(), sp.indices().shape(), &persistent_ix, &ix)); 60 *ix = sp.indices(); 61 62 PersistentTensor persistent_values; 63 Tensor* values; 64 TF_RETURN_IF_ERROR(ctx->allocate_persistent(sp.indices().dtype(), 65 sp.indices().shape(), 66 &persistent_values, &values)); 67 *values = sp.values(); 68 { 69 mutex_lock l(mu_); 70 int64 unique_st_handle = counter_++; // increment is guarded on purpose 71 sp_tensors_[unique_st_handle] = PersistentSparseTensor{ 72 persistent_ix, persistent_values, 73 gtl::InlinedVector<int64, 8>(sp.shape().begin(), sp.shape().end())}; 74 *handle = unique_st_handle; 75 } 76 return Status::OK(); 77 } 78 RetrieveAndClearSparseTensors(OpKernelContext * ctx,const TTypes<int64>::ConstVec & handles,std::vector<SparseTensor> * sparse_tensors)79 Status RetrieveAndClearSparseTensors( 80 OpKernelContext* ctx, const TTypes<int64>::ConstVec& handles, 81 std::vector<SparseTensor>* sparse_tensors) { 82 sparse_tensors->clear(); 83 sparse_tensors->reserve(handles.size()); 84 { 85 mutex_lock l(mu_); 86 for (size_t i = 0; i < handles.size(); ++i) { 87 const int64 handle = handles(i); 88 auto sp_iter = sp_tensors_.find(handle); 89 if (sp_iter == sp_tensors_.end()) { 90 return errors::InvalidArgument( 91 "Unable to find SparseTensor: ", handle, " in map: ", name_); 92 } 93 const Tensor* ix = sp_iter->second.indices.AccessTensor(ctx); 94 const Tensor* values = sp_iter->second.values.AccessTensor(ctx); 95 const auto& shape = sp_iter->second.shape; 96 SparseTensor tensor; 97 TF_RETURN_IF_ERROR(SparseTensor::Create(*ix, *values, shape, &tensor)); 98 sparse_tensors->push_back(std::move(tensor)); 99 sp_tensors_.erase(sp_iter); 100 } 101 } 102 103 return Status::OK(); 104 } 105 106 protected: ~SparseTensorsMap()107 ~SparseTensorsMap() override {} 108 109 private: 110 string name_; 111 112 mutex mu_; 113 int64 counter_ GUARDED_BY(mu_); 114 std::unordered_map<int64, PersistentSparseTensor> sp_tensors_ GUARDED_BY(mu_); 115 }; 116 117 class SparseTensorAccessingOp : public OpKernel { 118 public: 119 typedef std::function<Status(SparseTensorsMap**)> CreatorCallback; 120 SparseTensorAccessingOp(OpKernelConstruction * context)121 explicit SparseTensorAccessingOp(OpKernelConstruction* context) 122 : OpKernel(context), sparse_tensors_map_(nullptr) {} 123 124 protected: ~SparseTensorAccessingOp()125 ~SparseTensorAccessingOp() override { 126 if (sparse_tensors_map_) sparse_tensors_map_->Unref(); 127 } 128 GetMap(OpKernelContext * ctx,bool is_writing,SparseTensorsMap ** sparse_tensors_map)129 Status GetMap(OpKernelContext* ctx, bool is_writing, 130 SparseTensorsMap** sparse_tensors_map) { 131 mutex_lock l(mu_); 132 133 if (sparse_tensors_map_) { 134 *sparse_tensors_map = sparse_tensors_map_; 135 return Status::OK(); 136 } 137 138 TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def(), 139 is_writing /* use_node_name_as_default */)); 140 141 CreatorCallback sparse_tensors_map_creator = [this](SparseTensorsMap** c) { 142 SparseTensorsMap* map = new SparseTensorsMap(cinfo_.name()); 143 *c = map; 144 return Status::OK(); 145 }; 146 147 TF_RETURN_IF_ERROR( 148 cinfo_.resource_manager()->LookupOrCreate<SparseTensorsMap>( 149 cinfo_.container(), cinfo_.name(), &sparse_tensors_map_, 150 sparse_tensors_map_creator)); 151 152 *sparse_tensors_map = sparse_tensors_map_; 153 return Status::OK(); 154 } 155 156 private: 157 ContainerInfo cinfo_; 158 159 mutex mu_; 160 SparseTensorsMap* sparse_tensors_map_ PT_GUARDED_BY(mu_); 161 }; 162 163 class AddSparseToTensorsMapOp : public SparseTensorAccessingOp { 164 public: AddSparseToTensorsMapOp(OpKernelConstruction * context)165 explicit AddSparseToTensorsMapOp(OpKernelConstruction* context) 166 : SparseTensorAccessingOp(context) {} 167 Compute(OpKernelContext * context)168 void Compute(OpKernelContext* context) override { 169 const Tensor* input_indices; 170 const Tensor* input_values; 171 const Tensor* input_shape; 172 SparseTensorsMap* map; 173 174 OP_REQUIRES_OK(context, context->input("sparse_indices", &input_indices)); 175 OP_REQUIRES_OK(context, context->input("sparse_values", &input_values)); 176 OP_REQUIRES_OK(context, context->input("sparse_shape", &input_shape)); 177 OP_REQUIRES_OK(context, GetMap(context, true /* is_writing */, &map)); 178 179 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices->shape()), 180 errors::InvalidArgument( 181 "Input indices should be a matrix but received shape ", 182 input_indices->shape().DebugString())); 183 184 OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values->shape()), 185 errors::InvalidArgument( 186 "Input values should be a vector but received shape ", 187 input_values->shape().DebugString())); 188 189 OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape->shape()), 190 errors::InvalidArgument( 191 "Input shape should be a vector but received shape ", 192 input_shape->shape().DebugString())); 193 194 TensorShape input_shape_object; 195 OP_REQUIRES_OK(context, 196 TensorShapeUtils::MakeShape(input_shape->vec<int64>().data(), 197 input_shape->NumElements(), 198 &input_shape_object)); 199 SparseTensor st; 200 OP_REQUIRES_OK(context, SparseTensor::Create(*input_indices, *input_values, 201 input_shape_object, &st)); 202 int64 handle; 203 OP_REQUIRES_OK(context, map->AddSparseTensor(context, st, &handle)); 204 205 Tensor sparse_handle(DT_INT64, TensorShape({})); 206 auto sparse_handle_t = sparse_handle.scalar<int64>(); 207 208 sparse_handle_t() = handle; 209 210 context->set_output(0, sparse_handle); 211 } 212 }; 213 214 REGISTER_KERNEL_BUILDER(Name("AddSparseToTensorsMap").Device(DEVICE_CPU), 215 AddSparseToTensorsMapOp); 216 217 template <typename T> 218 class AddManySparseToTensorsMapOp : public SparseTensorAccessingOp { 219 public: AddManySparseToTensorsMapOp(OpKernelConstruction * context)220 explicit AddManySparseToTensorsMapOp(OpKernelConstruction* context) 221 : SparseTensorAccessingOp(context) {} 222 Compute(OpKernelContext * context)223 void Compute(OpKernelContext* context) override { 224 const Tensor* input_indices; 225 const Tensor* input_values; 226 const Tensor* input_shape; 227 SparseTensorsMap* map; 228 229 OP_REQUIRES_OK(context, context->input("sparse_indices", &input_indices)); 230 OP_REQUIRES_OK(context, context->input("sparse_values", &input_values)); 231 OP_REQUIRES_OK(context, context->input("sparse_shape", &input_shape)); 232 OP_REQUIRES_OK(context, GetMap(context, true /* is_writing */, &map)); 233 234 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices->shape()), 235 errors::InvalidArgument( 236 "Input indices should be a matrix but received shape ", 237 input_indices->shape().DebugString())); 238 239 OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values->shape()), 240 errors::InvalidArgument( 241 "Input values should be a vector but received shape ", 242 input_values->shape().DebugString())); 243 244 OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape->shape()), 245 errors::InvalidArgument( 246 "Input shape should be a vector but received shape ", 247 input_shape->shape().DebugString())); 248 249 int rank = input_shape->NumElements(); 250 251 OP_REQUIRES( 252 context, rank > 1, 253 errors::InvalidArgument( 254 "Rank of input SparseTensor should be > 1, but saw rank: ", rank)); 255 256 TensorShape tensor_input_shape(input_shape->vec<int64>()); 257 gtl::InlinedVector<int64, 8> std_order(rank); 258 std::iota(std_order.begin(), std_order.end(), 0); 259 SparseTensor input_st; 260 OP_REQUIRES_OK(context, SparseTensor::Create(*input_indices, *input_values, 261 tensor_input_shape, std_order, 262 &input_st)); 263 264 auto input_shape_t = input_shape->vec<int64>(); 265 const int64 N = input_shape_t(0); 266 267 Tensor sparse_handles(DT_INT64, TensorShape({N})); 268 auto sparse_handles_t = sparse_handles.vec<int64>(); 269 270 OP_REQUIRES_OK(context, input_st.IndicesValid()); 271 272 // We can generate the output shape proto string now, for all 273 // minibatch entries. 274 TensorShape output_shape; 275 OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( 276 input_shape_t.data() + 1, 277 input_shape->NumElements() - 1, &output_shape)); 278 279 // Get groups by minibatch dimension 280 std::unordered_set<int64> visited; 281 sparse::GroupIterable minibatch = input_st.group({0}); 282 for (const auto& subset : minibatch) { 283 const int64 b = subset.group()[0]; 284 visited.insert(b); 285 OP_REQUIRES( 286 context, b > -1 && b < N, 287 errors::InvalidArgument( 288 "Received unexpected column 0 value in input SparseTensor: ", b, 289 " < 0 or >= N (= ", N, ")")); 290 291 const auto indices = subset.indices(); 292 const auto values = subset.values<T>(); 293 const int64 num_entries = values.size(); 294 295 Tensor output_indices = Tensor(DT_INT64, {num_entries, rank - 1}); 296 Tensor output_values = Tensor(DataTypeToEnum<T>::value, {num_entries}); 297 298 auto output_indices_t = output_indices.matrix<int64>(); 299 auto output_values_t = output_values.vec<T>(); 300 301 for (int i = 0; i < num_entries; ++i) { 302 for (int d = 1; d < rank; ++d) { 303 output_indices_t(i, d - 1) = indices(i, d); 304 } 305 output_values_t(i) = values(i); 306 } 307 308 SparseTensor st_i; 309 OP_REQUIRES_OK(context, 310 SparseTensor::Create(output_indices, output_values, 311 output_shape, &st_i)); 312 int64 handle; 313 OP_REQUIRES_OK(context, map->AddSparseTensor(context, st_i, &handle)); 314 sparse_handles_t(b) = handle; 315 } 316 317 // Fill in any gaps; we must provide an empty ST for batch entries 318 // the grouper didn't find. 319 if (visited.size() < N) { 320 Tensor empty_indices(DT_INT64, {0, rank - 1}); 321 Tensor empty_values(DataTypeToEnum<T>::value, {0}); 322 SparseTensor empty_st; 323 OP_REQUIRES_OK(context, SparseTensor::Create(empty_indices, empty_values, 324 output_shape, &empty_st)); 325 326 for (int64 b = 0; b < N; ++b) { 327 // We skipped this batch entry. 328 if (visited.find(b) == visited.end()) { 329 int64 handle; 330 OP_REQUIRES_OK(context, 331 map->AddSparseTensor(context, empty_st, &handle)); 332 sparse_handles_t(b) = handle; 333 } 334 } 335 } 336 337 context->set_output(0, sparse_handles); 338 } 339 }; 340 341 #define REGISTER_KERNELS(type) \ 342 REGISTER_KERNEL_BUILDER(Name("AddManySparseToTensorsMap") \ 343 .Device(DEVICE_CPU) \ 344 .TypeConstraint<type>("T"), \ 345 AddManySparseToTensorsMapOp<type>) 346 347 TF_CALL_ALL_TYPES(REGISTER_KERNELS); 348 #undef REGISTER_KERNELS 349 350 template <typename T> 351 class TakeManySparseFromTensorsMapOp : public SparseTensorAccessingOp { 352 public: TakeManySparseFromTensorsMapOp(OpKernelConstruction * context)353 explicit TakeManySparseFromTensorsMapOp(OpKernelConstruction* context) 354 : SparseTensorAccessingOp(context) {} 355 Compute(OpKernelContext * context)356 void Compute(OpKernelContext* context) override { 357 SparseTensorsMap* map = nullptr; 358 OP_REQUIRES_OK(context, GetMap(context, false /* is_writing */, &map)); 359 360 const Tensor& sparse_handles = context->input(0); 361 362 OP_REQUIRES(context, TensorShapeUtils::IsVector(sparse_handles.shape()), 363 errors::InvalidArgument( 364 "sparse_handles should be a vector but received shape ", 365 sparse_handles.shape().DebugString())); 366 367 int64 N = sparse_handles.shape().dim_size(0); 368 369 OP_REQUIRES( 370 context, N > 0, 371 errors::InvalidArgument("Must have at least 1 serialized SparseTensor, " 372 "but input matrix has 0 rows")); 373 374 std::vector<Tensor> indices_to_concat; 375 std::vector<Tensor> values_to_concat; 376 std::vector<TensorShape> shapes_to_concat; 377 378 const auto& sparse_handles_t = sparse_handles.vec<int64>(); 379 380 std::vector<SparseTensor> sparse_tensors; 381 382 OP_REQUIRES_OK(context, map->RetrieveAndClearSparseTensors( 383 context, sparse_handles_t, &sparse_tensors)); 384 385 for (int64 i = 0; i < N; ++i) { 386 const SparseTensor& st = sparse_tensors[i]; 387 const Tensor& output_indices = st.indices(); 388 const Tensor& output_values = st.values(); 389 const auto output_shape = st.shape(); 390 391 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(output_indices.shape()), 392 errors::InvalidArgument( 393 "Expected sparse_handles[", i, 394 "] to represent an index matrix but received shape ", 395 output_indices.shape().DebugString())); 396 OP_REQUIRES(context, TensorShapeUtils::IsVector(output_values.shape()), 397 errors::InvalidArgument( 398 "Expected sparse_handles[", i, 399 "] to represent a values vector but received shape ", 400 output_values.shape().DebugString())); 401 OP_REQUIRES( 402 context, DataTypeToEnum<T>::value == output_values.dtype(), 403 errors::InvalidArgument( 404 "Requested SparseTensor of type ", 405 DataTypeString(DataTypeToEnum<T>::value), " but SparseTensor[", i, 406 "].values.dtype() == ", DataTypeString(output_values.dtype()))); 407 408 int64 num_entries = output_indices.dim_size(0); 409 OP_REQUIRES(context, num_entries == output_values.dim_size(0), 410 errors::InvalidArgument( 411 "Expected row counts of SparseTensor[", i, 412 "].indices and SparseTensor[", i, 413 "].values to match but they do not: ", num_entries, 414 " vs. ", output_values.dim_size(0))); 415 int rank = output_indices.dim_size(1); 416 OP_REQUIRES( 417 context, rank == output_shape.size(), 418 errors::InvalidArgument("Expected column counts of SparseTensor[", i, 419 "].indices to match size of SparseTensor[", i, 420 "].shape " 421 "but they do not: ", 422 rank, " vs. ", output_shape.size())); 423 424 // Now we expand each SparseTensors' indices and shape by 425 // prefixing a dimension 426 Tensor expanded_indices( 427 DT_INT64, TensorShape({num_entries, 1 + output_indices.dim_size(1)})); 428 Tensor expanded_shape(DT_INT64, TensorShape({1 + rank})); 429 const auto& output_indices_t = output_indices.matrix<int64>(); 430 auto expanded_indices_t = expanded_indices.matrix<int64>(); 431 auto expanded_shape_t = expanded_shape.vec<int64>(); 432 expanded_indices_t.chip<1>(0).setZero(); 433 Eigen::DSizes<Eigen::DenseIndex, 2> indices_start(0, 1); 434 Eigen::DSizes<Eigen::DenseIndex, 2> indices_sizes(num_entries, rank); 435 expanded_indices_t.slice(indices_start, indices_sizes) = output_indices_t; 436 expanded_shape_t(0) = 1; 437 // TODO: copy shape from TensorShape to &expanded_shape_t(1) 438 // std::copy_n(&output_shape_t(0), rank, &expanded_shape_t(1)); 439 for (int i = 0; i < rank; ++i) { 440 expanded_shape_t(i + 1) = output_shape[i]; 441 } 442 TensorShape expanded_tensor_shape(expanded_shape_t); 443 444 indices_to_concat.push_back(std::move(expanded_indices)); 445 values_to_concat.push_back(output_values); 446 shapes_to_concat.push_back(std::move(expanded_tensor_shape)); 447 } 448 449 int rank = -1; 450 for (int i = 0; i < N; ++i) { 451 if (rank < 0) rank = shapes_to_concat[i].dims(); 452 OP_REQUIRES(context, rank == shapes_to_concat[i].dims(), 453 errors::InvalidArgument( 454 "Inconsistent rank across SparseTensors: rank prior to " 455 "SparseTensor[", 456 i, "] was: ", rank, " but rank of SparseTensor[", i, 457 "] is: ", shapes_to_concat[i].dims())); 458 } 459 460 // SparseTensor::Concat requires consistent shape for all but the 461 // primary order dimension (dimension 0 in this case). So we get 462 // the maximum value across all the input SparseTensors for each 463 // dimension and use that. 464 TensorShape preconcat_shape(shapes_to_concat[0]); 465 for (int i = 0; i < N; ++i) { 466 for (int d = 0; d < rank; ++d) { 467 preconcat_shape.set_dim(d, std::max(preconcat_shape.dim_size(d), 468 shapes_to_concat[i].dim_size(d))); 469 } 470 } 471 472 // Dimension 0 is the primary dimension. 473 gtl::InlinedVector<int64, 8> std_order(rank); 474 std::iota(std_order.begin(), std_order.end(), 0); 475 476 std::vector<SparseTensor> tensors_to_concat; 477 tensors_to_concat.reserve(N); 478 for (int i = 0; i < N; ++i) { 479 SparseTensor tensor; 480 OP_REQUIRES_OK(context, 481 SparseTensor::Create(std::move(indices_to_concat[i]), 482 std::move(values_to_concat[i]), 483 preconcat_shape, std_order, &tensor)); 484 tensors_to_concat.push_back(std::move(tensor)); 485 } 486 487 auto output = SparseTensor::Concat<T>(tensors_to_concat); 488 Tensor final_output_shape(DT_INT64, TensorShape({output.dims()})); 489 490 std::copy_n(output.shape().data(), output.dims(), 491 final_output_shape.vec<int64>().data()); 492 493 context->set_output(0, output.indices()); 494 context->set_output(1, output.values()); 495 context->set_output(2, final_output_shape); 496 } 497 }; 498 499 #define REGISTER_KERNELS(type) \ 500 REGISTER_KERNEL_BUILDER(Name("TakeManySparseFromTensorsMap") \ 501 .Device(DEVICE_CPU) \ 502 .TypeConstraint<type>("dtype"), \ 503 TakeManySparseFromTensorsMapOp<type>) 504 505 TF_CALL_ALL_TYPES(REGISTER_KERNELS); 506 #undef REGISTER_KERNELS 507 508 } // namespace tensorflow 509