1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #include <algorithm> 19 #include <numeric> 20 #include <unordered_map> 21 #include <utility> 22 #include <vector> 23 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/framework/register_types.h" 26 27 #include "tensorflow/core/framework/op_kernel.h" 28 #include "tensorflow/core/framework/register_types.h" 29 #include "tensorflow/core/framework/resource_mgr.h" 30 #include "tensorflow/core/framework/tensor.h" 31 #include "tensorflow/core/framework/tensor_util.h" 32 #include "tensorflow/core/framework/types.h" 33 #include "tensorflow/core/lib/gtl/inlined_vector.h" 34 #include "tensorflow/core/util/sparse/sparse_tensor.h" 35 36 namespace tensorflow { 37 38 typedef Eigen::ThreadPoolDevice CPUDevice; 39 40 using sparse::SparseTensor; 41 42 class SparseTensorsMap : public ResourceBase { 43 public: SparseTensorsMap(const string & name)44 explicit SparseTensorsMap(const string& name) : name_(name), counter_(0) {} 45 DebugString() const46 string DebugString() const override { return "A SparseTensorsMap"; } 47 48 typedef struct { 49 PersistentTensor indices; 50 PersistentTensor values; 51 gtl::InlinedVector<int64, 8> shape; 52 } PersistentSparseTensor; 53 AddSparseTensor(OpKernelContext * ctx,const SparseTensor & sp,int64 * handle)54 Status AddSparseTensor(OpKernelContext* ctx, const SparseTensor& sp, 55 int64* handle) { 56 PersistentTensor persistent_ix; 57 Tensor* ix; 58 TF_RETURN_IF_ERROR(ctx->allocate_persistent( 59 sp.indices().dtype(), sp.indices().shape(), &persistent_ix, &ix)); 60 *ix = sp.indices(); 61 62 PersistentTensor persistent_values; 63 Tensor* values; 64 TF_RETURN_IF_ERROR(ctx->allocate_persistent(sp.indices().dtype(), 65 sp.indices().shape(), 66 &persistent_values, &values)); 67 *values = sp.values(); 68 { 69 mutex_lock l(mu_); 70 int64 unique_st_handle = counter_++; // increment is guarded on purpose 71 sp_tensors_[unique_st_handle] = PersistentSparseTensor{ 72 persistent_ix, persistent_values, 73 gtl::InlinedVector<int64, 8>(sp.shape().begin(), sp.shape().end())}; 74 *handle = unique_st_handle; 75 } 76 return Status::OK(); 77 } 78 RetrieveAndClearSparseTensors(OpKernelContext * ctx,const TTypes<int64>::ConstVec & handles,std::vector<SparseTensor> * sparse_tensors)79 Status RetrieveAndClearSparseTensors( 80 OpKernelContext* ctx, const TTypes<int64>::ConstVec& handles, 81 std::vector<SparseTensor>* sparse_tensors) { 82 sparse_tensors->clear(); 83 sparse_tensors->reserve(handles.size()); 84 { 85 mutex_lock l(mu_); 86 for (size_t i = 0; i < handles.size(); ++i) { 87 const int64 handle = handles(i); 88 auto sp_iter = sp_tensors_.find(handle); 89 if (sp_iter == sp_tensors_.end()) { 90 return errors::InvalidArgument( 91 "Unable to find SparseTensor: ", handle, " in map: ", name_); 92 } 93 const Tensor* ix = sp_iter->second.indices.AccessTensor(ctx); 94 const Tensor* values = sp_iter->second.values.AccessTensor(ctx); 95 const auto& shape = sp_iter->second.shape; 96 SparseTensor tensor; 97 TF_RETURN_IF_ERROR(SparseTensor::Create(*ix, *values, shape, &tensor)); 98 sparse_tensors->push_back(std::move(tensor)); 99 sp_tensors_.erase(sp_iter); 100 } 101 } 102 103 return Status::OK(); 104 } 105 106 protected: ~SparseTensorsMap()107 ~SparseTensorsMap() override {} 108 109 private: 110 string name_; 111 112 mutex mu_; 113 int64 counter_ TF_GUARDED_BY(mu_); 114 std::unordered_map<int64, PersistentSparseTensor> sp_tensors_ 115 TF_GUARDED_BY(mu_); 116 }; 117 118 class SparseTensorAccessingOp : public OpKernel { 119 public: 120 typedef std::function<Status(SparseTensorsMap**)> CreatorCallback; 121 SparseTensorAccessingOp(OpKernelConstruction * context)122 explicit SparseTensorAccessingOp(OpKernelConstruction* context) 123 : OpKernel(context), sparse_tensors_map_(nullptr) {} 124 125 protected: ~SparseTensorAccessingOp()126 ~SparseTensorAccessingOp() override { 127 if (sparse_tensors_map_) sparse_tensors_map_->Unref(); 128 } 129 GetMap(OpKernelContext * ctx,bool is_writing,SparseTensorsMap ** sparse_tensors_map)130 Status GetMap(OpKernelContext* ctx, bool is_writing, 131 SparseTensorsMap** sparse_tensors_map) { 132 mutex_lock l(mu_); 133 134 if (sparse_tensors_map_) { 135 *sparse_tensors_map = sparse_tensors_map_; 136 return Status::OK(); 137 } 138 139 TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def(), 140 is_writing /* use_node_name_as_default */)); 141 142 CreatorCallback sparse_tensors_map_creator = [this](SparseTensorsMap** c) { 143 SparseTensorsMap* map = new SparseTensorsMap(cinfo_.name()); 144 *c = map; 145 return Status::OK(); 146 }; 147 148 TF_RETURN_IF_ERROR( 149 cinfo_.resource_manager()->LookupOrCreate<SparseTensorsMap>( 150 cinfo_.container(), cinfo_.name(), &sparse_tensors_map_, 151 sparse_tensors_map_creator)); 152 153 *sparse_tensors_map = sparse_tensors_map_; 154 return Status::OK(); 155 } 156 157 private: 158 ContainerInfo cinfo_; 159 160 mutex mu_; 161 SparseTensorsMap* sparse_tensors_map_ TF_PT_GUARDED_BY(mu_); 162 }; 163 164 class AddSparseToTensorsMapOp : public SparseTensorAccessingOp { 165 public: AddSparseToTensorsMapOp(OpKernelConstruction * context)166 explicit AddSparseToTensorsMapOp(OpKernelConstruction* context) 167 : SparseTensorAccessingOp(context) {} 168 Compute(OpKernelContext * context)169 void Compute(OpKernelContext* context) override { 170 const Tensor* input_indices; 171 const Tensor* input_values; 172 const Tensor* input_shape; 173 SparseTensorsMap* map; 174 175 OP_REQUIRES_OK(context, context->input("sparse_indices", &input_indices)); 176 OP_REQUIRES_OK(context, context->input("sparse_values", &input_values)); 177 OP_REQUIRES_OK(context, context->input("sparse_shape", &input_shape)); 178 OP_REQUIRES_OK(context, GetMap(context, true /* is_writing */, &map)); 179 180 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices->shape()), 181 errors::InvalidArgument( 182 "Input indices should be a matrix but received shape ", 183 input_indices->shape().DebugString())); 184 185 OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values->shape()), 186 errors::InvalidArgument( 187 "Input values should be a vector but received shape ", 188 input_values->shape().DebugString())); 189 190 OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape->shape()), 191 errors::InvalidArgument( 192 "Input shape should be a vector but received shape ", 193 input_shape->shape().DebugString())); 194 195 TensorShape input_shape_object; 196 OP_REQUIRES_OK(context, 197 TensorShapeUtils::MakeShape(input_shape->vec<int64>().data(), 198 input_shape->NumElements(), 199 &input_shape_object)); 200 SparseTensor st; 201 OP_REQUIRES_OK(context, SparseTensor::Create(*input_indices, *input_values, 202 input_shape_object, &st)); 203 int64 handle; 204 OP_REQUIRES_OK(context, map->AddSparseTensor(context, st, &handle)); 205 206 Tensor sparse_handle(DT_INT64, TensorShape({})); 207 auto sparse_handle_t = sparse_handle.scalar<int64>(); 208 209 sparse_handle_t() = handle; 210 211 context->set_output(0, sparse_handle); 212 } 213 }; 214 215 REGISTER_KERNEL_BUILDER(Name("AddSparseToTensorsMap").Device(DEVICE_CPU), 216 AddSparseToTensorsMapOp); 217 218 template <typename T> 219 class AddManySparseToTensorsMapOp : public SparseTensorAccessingOp { 220 public: AddManySparseToTensorsMapOp(OpKernelConstruction * context)221 explicit AddManySparseToTensorsMapOp(OpKernelConstruction* context) 222 : SparseTensorAccessingOp(context) {} 223 Compute(OpKernelContext * context)224 void Compute(OpKernelContext* context) override { 225 const Tensor* input_indices; 226 const Tensor* input_values; 227 const Tensor* input_shape; 228 SparseTensorsMap* map; 229 230 OP_REQUIRES_OK(context, context->input("sparse_indices", &input_indices)); 231 OP_REQUIRES_OK(context, context->input("sparse_values", &input_values)); 232 OP_REQUIRES_OK(context, context->input("sparse_shape", &input_shape)); 233 OP_REQUIRES_OK(context, GetMap(context, true /* is_writing */, &map)); 234 235 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices->shape()), 236 errors::InvalidArgument( 237 "Input indices should be a matrix but received shape ", 238 input_indices->shape().DebugString())); 239 240 OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values->shape()), 241 errors::InvalidArgument( 242 "Input values should be a vector but received shape ", 243 input_values->shape().DebugString())); 244 245 OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape->shape()), 246 errors::InvalidArgument( 247 "Input shape should be a vector but received shape ", 248 input_shape->shape().DebugString())); 249 250 int rank = input_shape->NumElements(); 251 252 OP_REQUIRES( 253 context, rank > 1, 254 errors::InvalidArgument( 255 "Rank of input SparseTensor should be > 1, but saw rank: ", rank)); 256 257 TensorShape tensor_input_shape(input_shape->vec<int64>()); 258 gtl::InlinedVector<int64, 8> std_order(rank); 259 std::iota(std_order.begin(), std_order.end(), 0); 260 SparseTensor input_st; 261 OP_REQUIRES_OK(context, SparseTensor::Create(*input_indices, *input_values, 262 tensor_input_shape, std_order, 263 &input_st)); 264 265 auto input_shape_t = input_shape->vec<int64>(); 266 const int64 N = input_shape_t(0); 267 268 Tensor sparse_handles(DT_INT64, TensorShape({N})); 269 auto sparse_handles_t = sparse_handles.vec<int64>(); 270 271 OP_REQUIRES_OK(context, input_st.IndicesValid()); 272 273 // We can generate the output shape proto string now, for all 274 // minibatch entries. 275 TensorShape output_shape; 276 OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( 277 input_shape_t.data() + 1, 278 input_shape->NumElements() - 1, &output_shape)); 279 280 // Get groups by minibatch dimension 281 std::unordered_set<int64> visited; 282 sparse::GroupIterable minibatch = input_st.group({0}); 283 for (const auto& subset : minibatch) { 284 const int64 b = subset.group()[0]; 285 visited.insert(b); 286 OP_REQUIRES( 287 context, b > -1 && b < N, 288 errors::InvalidArgument( 289 "Received unexpected column 0 value in input SparseTensor: ", b, 290 " < 0 or >= N (= ", N, ")")); 291 292 const auto indices = subset.indices(); 293 const auto values = subset.values<T>(); 294 const int64 num_entries = values.size(); 295 296 Tensor output_indices = Tensor(DT_INT64, {num_entries, rank - 1}); 297 Tensor output_values = Tensor(DataTypeToEnum<T>::value, {num_entries}); 298 299 auto output_indices_t = output_indices.matrix<int64>(); 300 auto output_values_t = output_values.vec<T>(); 301 302 for (int i = 0; i < num_entries; ++i) { 303 for (int d = 1; d < rank; ++d) { 304 output_indices_t(i, d - 1) = indices(i, d); 305 } 306 output_values_t(i) = values(i); 307 } 308 309 SparseTensor st_i; 310 OP_REQUIRES_OK(context, 311 SparseTensor::Create(output_indices, output_values, 312 output_shape, &st_i)); 313 int64 handle; 314 OP_REQUIRES_OK(context, map->AddSparseTensor(context, st_i, &handle)); 315 sparse_handles_t(b) = handle; 316 } 317 318 // Fill in any gaps; we must provide an empty ST for batch entries 319 // the grouper didn't find. 320 if (visited.size() < N) { 321 Tensor empty_indices(DT_INT64, {0, rank - 1}); 322 Tensor empty_values(DataTypeToEnum<T>::value, {0}); 323 SparseTensor empty_st; 324 OP_REQUIRES_OK(context, SparseTensor::Create(empty_indices, empty_values, 325 output_shape, &empty_st)); 326 327 for (int64 b = 0; b < N; ++b) { 328 // We skipped this batch entry. 329 if (visited.find(b) == visited.end()) { 330 int64 handle; 331 OP_REQUIRES_OK(context, 332 map->AddSparseTensor(context, empty_st, &handle)); 333 sparse_handles_t(b) = handle; 334 } 335 } 336 } 337 338 context->set_output(0, sparse_handles); 339 } 340 }; 341 342 #define REGISTER_KERNELS(type) \ 343 REGISTER_KERNEL_BUILDER(Name("AddManySparseToTensorsMap") \ 344 .Device(DEVICE_CPU) \ 345 .TypeConstraint<type>("T"), \ 346 AddManySparseToTensorsMapOp<type>) 347 348 TF_CALL_ALL_TYPES(REGISTER_KERNELS); 349 #undef REGISTER_KERNELS 350 351 template <typename T> 352 class TakeManySparseFromTensorsMapOp : public SparseTensorAccessingOp { 353 public: TakeManySparseFromTensorsMapOp(OpKernelConstruction * context)354 explicit TakeManySparseFromTensorsMapOp(OpKernelConstruction* context) 355 : SparseTensorAccessingOp(context) {} 356 Compute(OpKernelContext * context)357 void Compute(OpKernelContext* context) override { 358 SparseTensorsMap* map = nullptr; 359 OP_REQUIRES_OK(context, GetMap(context, false /* is_writing */, &map)); 360 361 const Tensor& sparse_handles = context->input(0); 362 363 OP_REQUIRES(context, TensorShapeUtils::IsVector(sparse_handles.shape()), 364 errors::InvalidArgument( 365 "sparse_handles should be a vector but received shape ", 366 sparse_handles.shape().DebugString())); 367 368 int64 N = sparse_handles.shape().dim_size(0); 369 370 OP_REQUIRES( 371 context, N > 0, 372 errors::InvalidArgument("Must have at least 1 serialized SparseTensor, " 373 "but input matrix has 0 rows")); 374 375 std::vector<Tensor> indices_to_concat; 376 std::vector<Tensor> values_to_concat; 377 std::vector<TensorShape> shapes_to_concat; 378 379 const auto& sparse_handles_t = sparse_handles.vec<int64>(); 380 381 std::vector<SparseTensor> sparse_tensors; 382 383 OP_REQUIRES_OK(context, map->RetrieveAndClearSparseTensors( 384 context, sparse_handles_t, &sparse_tensors)); 385 386 for (int64 i = 0; i < N; ++i) { 387 const SparseTensor& st = sparse_tensors[i]; 388 const Tensor& output_indices = st.indices(); 389 const Tensor& output_values = st.values(); 390 const auto output_shape = st.shape(); 391 392 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(output_indices.shape()), 393 errors::InvalidArgument( 394 "Expected sparse_handles[", i, 395 "] to represent an index matrix but received shape ", 396 output_indices.shape().DebugString())); 397 OP_REQUIRES(context, TensorShapeUtils::IsVector(output_values.shape()), 398 errors::InvalidArgument( 399 "Expected sparse_handles[", i, 400 "] to represent a values vector but received shape ", 401 output_values.shape().DebugString())); 402 OP_REQUIRES( 403 context, DataTypeToEnum<T>::value == output_values.dtype(), 404 errors::InvalidArgument( 405 "Requested SparseTensor of type ", 406 DataTypeString(DataTypeToEnum<T>::value), " but SparseTensor[", i, 407 "].values.dtype() == ", DataTypeString(output_values.dtype()))); 408 409 int64 num_entries = output_indices.dim_size(0); 410 OP_REQUIRES(context, num_entries == output_values.dim_size(0), 411 errors::InvalidArgument( 412 "Expected row counts of SparseTensor[", i, 413 "].indices and SparseTensor[", i, 414 "].values to match but they do not: ", num_entries, 415 " vs. ", output_values.dim_size(0))); 416 int rank = output_indices.dim_size(1); 417 OP_REQUIRES( 418 context, rank == output_shape.size(), 419 errors::InvalidArgument("Expected column counts of SparseTensor[", i, 420 "].indices to match size of SparseTensor[", i, 421 "].shape " 422 "but they do not: ", 423 rank, " vs. ", output_shape.size())); 424 425 // Now we expand each SparseTensors' indices and shape by 426 // prefixing a dimension 427 Tensor expanded_indices( 428 DT_INT64, TensorShape({num_entries, 1 + output_indices.dim_size(1)})); 429 Tensor expanded_shape(DT_INT64, TensorShape({1 + rank})); 430 const auto& output_indices_t = output_indices.matrix<int64>(); 431 auto expanded_indices_t = expanded_indices.matrix<int64>(); 432 auto expanded_shape_t = expanded_shape.vec<int64>(); 433 expanded_indices_t.chip<1>(0).setZero(); 434 Eigen::DSizes<Eigen::DenseIndex, 2> indices_start(0, 1); 435 Eigen::DSizes<Eigen::DenseIndex, 2> indices_sizes(num_entries, rank); 436 expanded_indices_t.slice(indices_start, indices_sizes) = output_indices_t; 437 expanded_shape_t(0) = 1; 438 // TODO: copy shape from TensorShape to &expanded_shape_t(1) 439 // std::copy_n(&output_shape_t(0), rank, &expanded_shape_t(1)); 440 for (int i = 0; i < rank; ++i) { 441 expanded_shape_t(i + 1) = output_shape[i]; 442 } 443 TensorShape expanded_tensor_shape(expanded_shape_t); 444 445 indices_to_concat.push_back(std::move(expanded_indices)); 446 values_to_concat.push_back(output_values); 447 shapes_to_concat.push_back(std::move(expanded_tensor_shape)); 448 } 449 450 int rank = -1; 451 for (int i = 0; i < N; ++i) { 452 if (rank < 0) rank = shapes_to_concat[i].dims(); 453 OP_REQUIRES(context, rank == shapes_to_concat[i].dims(), 454 errors::InvalidArgument( 455 "Inconsistent rank across SparseTensors: rank prior to " 456 "SparseTensor[", 457 i, "] was: ", rank, " but rank of SparseTensor[", i, 458 "] is: ", shapes_to_concat[i].dims())); 459 } 460 461 // SparseTensor::Concat requires consistent shape for all but the 462 // primary order dimension (dimension 0 in this case). So we get 463 // the maximum value across all the input SparseTensors for each 464 // dimension and use that. 465 TensorShape preconcat_shape(shapes_to_concat[0]); 466 for (int i = 0; i < N; ++i) { 467 for (int d = 0; d < rank; ++d) { 468 preconcat_shape.set_dim(d, std::max(preconcat_shape.dim_size(d), 469 shapes_to_concat[i].dim_size(d))); 470 } 471 } 472 473 // Dimension 0 is the primary dimension. 474 gtl::InlinedVector<int64, 8> std_order(rank); 475 std::iota(std_order.begin(), std_order.end(), 0); 476 477 std::vector<SparseTensor> tensors_to_concat; 478 tensors_to_concat.reserve(N); 479 for (int i = 0; i < N; ++i) { 480 SparseTensor tensor; 481 OP_REQUIRES_OK(context, 482 SparseTensor::Create(std::move(indices_to_concat[i]), 483 std::move(values_to_concat[i]), 484 preconcat_shape, std_order, &tensor)); 485 tensors_to_concat.push_back(std::move(tensor)); 486 } 487 488 auto output = SparseTensor::Concat<T>(tensors_to_concat); 489 Tensor final_output_shape(DT_INT64, TensorShape({output.dims()})); 490 491 std::copy_n(output.shape().data(), output.dims(), 492 final_output_shape.vec<int64>().data()); 493 494 context->set_output(0, output.indices()); 495 context->set_output(1, output.values()); 496 context->set_output(2, final_output_shape); 497 } 498 }; 499 500 #define REGISTER_KERNELS(type) \ 501 REGISTER_KERNEL_BUILDER(Name("TakeManySparseFromTensorsMap") \ 502 .Device(DEVICE_CPU) \ 503 .TypeConstraint<type>("dtype"), \ 504 TakeManySparseFromTensorsMapOp<type>) 505 506 TF_CALL_ALL_TYPES(REGISTER_KERNELS); 507 #undef REGISTER_KERNELS 508 509 } // namespace tensorflow 510