1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstddef>
17 #include <functional>
18 #include <map>
19 #include <mutex>
20 #include <numeric>
21 #include <unordered_map>
22 #include <vector>
23
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/resource_mgr.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/lib/gtl/optional.h"
29 #include "tensorflow/core/lib/strings/strcat.h"
30 #include "tensorflow/core/platform/env.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/thread_annotations.h"
33
34 namespace tensorflow {
35 namespace {
36
37 // Partial Ordering Comparator for Tensor keys containing scalar int64's
38 struct KeyTensorLess {
operator ()tensorflow::__anon07039e130111::KeyTensorLess39 bool operator()(const Tensor& lhs, const Tensor& rhs) const {
40 return std::less<int64>{}(lhs.scalar<int64>()(), rhs.scalar<int64>()());
41 }
42 };
43
44 // Key Equality operator for Tensor keys containing scalar int64's
45 struct KeyTensorEqual {
operator ()tensorflow::__anon07039e130111::KeyTensorEqual46 bool operator()(const Tensor& lhs, const Tensor& rhs) const {
47 return std::equal_to<int64>{}(lhs.scalar<int64>()(), rhs.scalar<int64>()());
48 }
49 };
50
51 // Hash for Tensor keys containing scalar int64's
52 struct KeyTensorHash {
operator ()tensorflow::__anon07039e130111::KeyTensorHash53 std::size_t operator()(const Tensor& key) const {
54 return std::hash<int64>{}(key.scalar<int64>()());
55 }
56 };
57
58 // Primary template.
59 template <bool Ordered, typename Data>
60 struct MapTraits;
61
62 // Partial specialization for ordered.
63 template <typename Data>
64 struct MapTraits<true, Data> {
65 using KeyType = Tensor;
66 using DataType = Data;
67 using MapType = std::map<KeyType, Data, KeyTensorLess>;
68 };
69
70 // Partial specialization for unordered.
71 template <typename Data>
72 struct MapTraits<false, Data> {
73 using KeyType = Tensor;
74 using DataType = Data;
75 using MapType =
76 std::unordered_map<KeyType, Data, KeyTensorHash, KeyTensorEqual>;
77 };
78
79 // Wrapper around map/unordered_map.
80 template <bool Ordered>
81 class StagingMap : public ResourceBase {
82 public:
83 // Public typedefs
84 using Tuple = std::vector<Tensor>;
85 using OptionalTensor = gtl::optional<Tensor>;
86 using OptionalTuple = std::vector<OptionalTensor>;
87
88 using MapType = typename MapTraits<Ordered, OptionalTuple>::MapType;
89 using KeyType = typename MapTraits<Ordered, OptionalTuple>::KeyType;
90
91 using IncompleteType = typename MapTraits<false, OptionalTuple>::MapType;
92
93 private:
94 // Private variables
95 DataTypeVector dtypes_ GUARDED_BY(mu_);
96 std::size_t capacity_ GUARDED_BY(mu_);
97 std::size_t memory_limit_ GUARDED_BY(mu_);
98 std::size_t current_bytes_ GUARDED_BY(mu_);
99 tensorflow::mutex mu_;
100 tensorflow::condition_variable not_empty_;
101 tensorflow::condition_variable full_;
102 IncompleteType incomplete_ GUARDED_BY(mu_);
103 MapType map_ GUARDED_BY(mu_);
104
105 private:
106 // private methods
107
108 // If map is configured for bounded capacity, notify
109 // waiting inserters that space is now available
notify_inserters_if_bounded()110 void notify_inserters_if_bounded() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
111 if (has_capacity() || has_memory_limit()) {
112 // Notify all inserters. The removal of an element
113 // may make memory available for many inserters
114 // to insert new elements
115 full_.notify_all();
116 }
117 }
118
119 // Notify all removers waiting to extract values
120 // that data is now available
notify_removers()121 void notify_removers() {
122 // Notify all removers. This is because they are
123 // waiting for specific keys to appear in the map
124 // so we don't know which one to wake up.
125 not_empty_.notify_all();
126 }
127
has_capacity() const128 bool has_capacity() const EXCLUSIVE_LOCKS_REQUIRED(mu_) {
129 return capacity_ > 0;
130 }
131
has_memory_limit() const132 bool has_memory_limit() const EXCLUSIVE_LOCKS_REQUIRED(mu_) {
133 return memory_limit_ > 0;
134 }
135
would_exceed_memory_limit(std::size_t bytes) const136 bool would_exceed_memory_limit(std::size_t bytes) const
137 EXCLUSIVE_LOCKS_REQUIRED(mu_) {
138 return has_memory_limit() && bytes + current_bytes_ > memory_limit_;
139 }
140
is_capacity_full() const141 bool is_capacity_full() const EXCLUSIVE_LOCKS_REQUIRED(mu_) {
142 return has_capacity() && map_.size() >= capacity_;
143 }
144
145 // Get number of bytes in the tuple
get_tuple_bytes(const Tuple & tuple)146 std::size_t get_tuple_bytes(const Tuple& tuple) {
147 return std::accumulate(tuple.begin(), tuple.end(),
148 static_cast<std::size_t>(0),
149 [](const std::size_t& lhs, const Tensor& rhs) {
150 return lhs + rhs.TotalBytes();
151 });
152 }
153
154 // Get number of bytes in the incomplete tuple
get_tuple_bytes(const OptionalTuple & tuple)155 std::size_t get_tuple_bytes(const OptionalTuple& tuple) {
156 return std::accumulate(
157 tuple.begin(), tuple.end(), static_cast<std::size_t>(0),
158 [](const std::size_t& lhs, const OptionalTensor& rhs) {
159 return (lhs + rhs.has_value()) ? rhs.value().TotalBytes() : 0;
160 });
161 }
162
163 // Check that the index is within bounds
check_index(const Tensor & key,std::size_t index)164 Status check_index(const Tensor& key, std::size_t index)
165 EXCLUSIVE_LOCKS_REQUIRED(mu_) {
166 if (index >= dtypes_.size()) {
167 return Status(errors::InvalidArgument(
168 "Index '", index, "' for key '", key.scalar<int64>()(),
169 "' was out of bounds '", dtypes_.size(), "'."));
170 }
171
172 return Status::OK();
173 }
174
copy_or_move_tensors(OptionalTuple * map_tuple,const Tensor & key,const Tensor & indices,Tuple * output,bool copy=false)175 Status copy_or_move_tensors(OptionalTuple* map_tuple, const Tensor& key,
176 const Tensor& indices, Tuple* output,
177 bool copy = false) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
178 auto findices = indices.flat<int>();
179
180 // Return values at specified indices
181 for (std::size_t i = 0; i < findices.dimension(0); ++i) {
182 std::size_t index = findices(i);
183
184 TF_RETURN_IF_ERROR(check_index(key, index));
185
186 // Insist on a value present at the specified index
187 if (!(*map_tuple)[index].has_value()) {
188 return Status(errors::InvalidArgument(
189 "Tensor at index '", index, "' for key '", key.scalar<int64>()(),
190 "' has already been removed."));
191 }
192
193 // Copy the contained tensor and
194 // remove from the OptionalTuple
195 output->push_back((*map_tuple)[index].value());
196
197 // Clear out the entry if we're not copying (moving)
198 if (!copy) {
199 (*map_tuple)[index].reset();
200 }
201 }
202
203 return Status::OK();
204 }
205
206 // Check that the optional value at the specified index
207 // is uninitialized
check_index_uninitialized(const Tensor & key,std::size_t index,const OptionalTuple & tuple)208 Status check_index_uninitialized(const Tensor& key, std::size_t index,
209 const OptionalTuple& tuple)
210 EXCLUSIVE_LOCKS_REQUIRED(mu_) {
211 if (tuple[index].has_value()) {
212 return Status(errors::InvalidArgument(
213 "The tensor for index '", index, "' for key '", key.scalar<int64>()(),
214 "' was already initialized '", dtypes_.size(), "'."));
215 }
216
217 return Status::OK();
218 }
219
220 // Check that the indices are strictly ordered
check_index_ordering(const Tensor & indices)221 Status check_index_ordering(const Tensor& indices) {
222 auto findices = indices.flat<int>();
223
224 for (std::size_t i = 0; i < findices.dimension(0) - 1; ++i) {
225 if (findices(i) < findices(i + 1)) {
226 continue;
227 }
228
229 return Status(
230 errors::InvalidArgument("Indices are not strictly ordered"));
231 }
232
233 return Status::OK();
234 }
235
236 // Check bytes are within memory limits memory limits
check_memory_limit(std::size_t bytes)237 Status check_memory_limit(std::size_t bytes) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
238 if (has_memory_limit() && bytes > memory_limit_) {
239 return Status(errors::ResourceExhausted(
240 "Attempted to insert tensors with combined size of '", bytes,
241 "' bytes into Staging Area with a memory limit of '", memory_limit_,
242 "'."));
243 }
244
245 return Status::OK();
246 }
247
248 // Insert incomplete data into the Barrier
put_incomplete(const KeyType & key,const Tensor & indices,OptionalTuple * tuple,tensorflow::mutex_lock * lock)249 Status put_incomplete(const KeyType& key, const Tensor& indices,
250 OptionalTuple* tuple, tensorflow::mutex_lock* lock)
251 EXCLUSIVE_LOCKS_REQUIRED(mu_) {
252 auto findices = indices.flat<int>();
253
254 // Search for the key in our incomplete set
255 auto it = incomplete_.find(key);
256
257 // Check that the tuple fits within the memory limit
258 std::size_t tuple_bytes = get_tuple_bytes(*tuple);
259 TF_RETURN_IF_ERROR(check_memory_limit(tuple_bytes));
260
261 // Wait until we don't exceed the memory limit
262 while (would_exceed_memory_limit(tuple_bytes)) {
263 full_.wait(*lock);
264 }
265
266 // This key isn't present in the incomplete set
267 // Create OptionalTuple and insert
268 if (it == incomplete_.end()) {
269 OptionalTuple empty(dtypes_.size());
270
271 // Initialize empty tuple with given dta
272 for (std::size_t i = 0; i < findices.dimension(0); ++i) {
273 std::size_t index = findices(i);
274 TF_RETURN_IF_ERROR(check_index(key, index));
275
276 // Assign tuple at this index
277 empty[index] = std::move((*tuple)[i]);
278 }
279
280 // Insert into incomplete map
281 incomplete_.insert({key, std::move(empty)});
282
283 // Increment size
284 current_bytes_ += tuple_bytes;
285 }
286 // Found an entry in the incomplete index
287 // Update with given data and insert complete entries
288 // into the main map
289 else {
290 // Reference existing incomplete tuple
291 OptionalTuple& present = it->second;
292
293 // Assign given data
294 for (std::size_t i = 0; i < findices.dimension(0); ++i) {
295 std::size_t index = findices(i);
296 TF_RETURN_IF_ERROR(check_index(key, index));
297 TF_RETURN_IF_ERROR(check_index_uninitialized(key, index, present));
298
299 // Assign tuple at this index
300 present[index] = std::move((*tuple)[i]);
301 }
302
303 // Increment size
304 current_bytes_ += tuple_bytes;
305
306 // Do we have values at all tuple elements?
307 bool complete =
308 std::all_of(present.begin(), present.end(),
309 [](const OptionalTensor& v) { return v.has_value(); });
310
311 // If so, put the tuple in the actual map
312 if (complete) {
313 OptionalTuple insert_tuple = std::move(it->second);
314
315 // Remove from incomplete
316 incomplete_.erase(it);
317
318 TF_RETURN_IF_ERROR(put_complete(key, &insert_tuple));
319 }
320 }
321
322 return Status::OK();
323 }
324
325 // Does the insertion into the actual staging area
put_complete(const KeyType & key,OptionalTuple * tuple)326 Status put_complete(const KeyType& key, OptionalTuple* tuple)
327 EXCLUSIVE_LOCKS_REQUIRED(mu_) {
328 // Insert key and tuples into the map
329 map_.insert({key, std::move(*tuple)});
330
331 notify_removers();
332
333 return Status::OK();
334 }
335
336 public:
337 // public methods
StagingMap(const DataTypeVector & dtypes,std::size_t capacity,std::size_t memory_limit)338 explicit StagingMap(const DataTypeVector& dtypes, std::size_t capacity,
339 std::size_t memory_limit)
340 : dtypes_(dtypes),
341 capacity_(capacity),
342 memory_limit_(memory_limit),
343 current_bytes_(0) {}
344
put(KeyType * key,const Tensor * indices,OptionalTuple * tuple)345 Status put(KeyType* key, const Tensor* indices, OptionalTuple* tuple) {
346 tensorflow::mutex_lock lock(mu_);
347
348 // Sanity check the indices
349 TF_RETURN_IF_ERROR(check_index_ordering(*indices));
350
351 // Handle incomplete inserts
352 if (indices->NumElements() != dtypes_.size()) {
353 return put_incomplete(*key, *indices, tuple, &lock);
354 }
355
356 std::size_t tuple_bytes = get_tuple_bytes(*tuple);
357 // Check that tuple_bytes fits within the memory limit
358 TF_RETURN_IF_ERROR(check_memory_limit(tuple_bytes));
359
360 // Wait until there's space for insertion.
361 while (would_exceed_memory_limit(tuple_bytes) || is_capacity_full()) {
362 full_.wait(lock);
363 }
364
365 // Do the put operation
366 TF_RETURN_IF_ERROR(put_complete(*key, tuple));
367
368 // Update the current size
369 current_bytes_ += tuple_bytes;
370
371 return Status::OK();
372 }
373
get(const KeyType * key,const Tensor * indices,Tuple * tuple)374 Status get(const KeyType* key, const Tensor* indices, Tuple* tuple) {
375 tensorflow::mutex_lock lock(mu_);
376
377 // Sanity check the indices
378 TF_RETURN_IF_ERROR(check_index_ordering(*indices));
379
380 typename MapType::iterator it;
381
382 // Wait until the element with the requested key is present
383 while ((it = map_.find(*key)) == map_.end()) {
384 not_empty_.wait(lock);
385 }
386
387 TF_RETURN_IF_ERROR(
388 copy_or_move_tensors(&it->second, *key, *indices, tuple, true));
389
390 // Update bytes in the Staging Area
391 current_bytes_ -= get_tuple_bytes(*tuple);
392
393 return Status::OK();
394 }
395
pop(const KeyType * key,const Tensor * indices,Tuple * tuple)396 Status pop(const KeyType* key, const Tensor* indices, Tuple* tuple) {
397 tensorflow::mutex_lock lock(mu_);
398
399 // Sanity check the indices
400 TF_RETURN_IF_ERROR(check_index_ordering(*indices));
401
402 typename MapType::iterator it;
403
404 // Wait until the element with the requested key is present
405 while ((it = map_.find(*key)) == map_.end()) {
406 not_empty_.wait(lock);
407 }
408
409 TF_RETURN_IF_ERROR(
410 copy_or_move_tensors(&it->second, *key, *indices, tuple));
411
412 // Remove entry if all the values have been consumed
413 if (!std::any_of(
414 it->second.begin(), it->second.end(),
415 [](const OptionalTensor& tensor) { return tensor.has_value(); })) {
416 map_.erase(it);
417 }
418
419 // Update bytes in the Staging Area
420 current_bytes_ -= get_tuple_bytes(*tuple);
421
422 notify_inserters_if_bounded();
423
424 return Status::OK();
425 }
426
popitem(KeyType * key,const Tensor * indices,Tuple * tuple)427 Status popitem(KeyType* key, const Tensor* indices, Tuple* tuple) {
428 tensorflow::mutex_lock lock(mu_);
429
430 // Sanity check the indices
431 TF_RETURN_IF_ERROR(check_index_ordering(*indices));
432
433 // Wait until map is not empty
434 while (this->map_.empty()) {
435 not_empty_.wait(lock);
436 }
437
438 // Move from the first element and erase it
439
440 auto it = map_.begin();
441
442 TF_RETURN_IF_ERROR(
443 copy_or_move_tensors(&it->second, *key, *indices, tuple));
444
445 *key = it->first;
446
447 // Remove entry if all the values have been consumed
448 if (!std::any_of(
449 it->second.begin(), it->second.end(),
450 [](const OptionalTensor& tensor) { return tensor.has_value(); })) {
451 map_.erase(it);
452 }
453
454 // Update bytes in the Staging Area
455 current_bytes_ -= get_tuple_bytes(*tuple);
456
457 notify_inserters_if_bounded();
458
459 return Status::OK();
460 }
461
clear()462 Status clear() {
463 tensorflow::mutex_lock lock(mu_);
464 map_.clear();
465 incomplete_.clear();
466 current_bytes_ = 0;
467
468 notify_inserters_if_bounded();
469
470 return Status::OK();
471 }
472
incomplete_size()473 std::size_t incomplete_size() {
474 tensorflow::mutex_lock lock(mu_);
475 return incomplete_.size();
476 }
477
size()478 std::size_t size() {
479 tensorflow::mutex_lock lock(mu_);
480 return map_.size();
481 }
482
DebugString() const483 string DebugString() const override { return "StagingMap"; }
484 };
485
486 template <bool Ordered>
GetStagingMap(OpKernelContext * ctx,const NodeDef & ndef,StagingMap<Ordered> ** map)487 Status GetStagingMap(OpKernelContext* ctx, const NodeDef& ndef,
488 StagingMap<Ordered>** map) {
489 auto rm = ctx->resource_manager();
490 ContainerInfo cinfo;
491
492 // Lambda for creating the Staging Area
493 auto create_fn = [&ndef](StagingMap<Ordered>** ret) -> Status {
494 DataTypeVector dtypes;
495 int64 capacity;
496 int64 memory_limit;
497 TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "dtypes", &dtypes));
498 TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "capacity", &capacity));
499 TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "memory_limit", &memory_limit));
500 *ret = new StagingMap<Ordered>(dtypes, capacity, memory_limit);
501 return Status::OK();
502 };
503
504 TF_RETURN_IF_ERROR(cinfo.Init(rm, ndef, true /* use name() */));
505 TF_RETURN_IF_ERROR(rm->LookupOrCreate<StagingMap<Ordered>>(
506 cinfo.container(), cinfo.name(), map, create_fn));
507 return Status::OK();
508 }
509
510 template <bool Ordered>
511 class MapStageOp : public OpKernel {
512 public:
MapStageOp(OpKernelConstruction * ctx)513 explicit MapStageOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
514
Compute(OpKernelContext * ctx)515 void Compute(OpKernelContext* ctx) override {
516 StagingMap<Ordered>* map = nullptr;
517 OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
518 core::ScopedUnref scope(map);
519 typename StagingMap<Ordered>::OptionalTuple tuple;
520
521 const Tensor* key_tensor;
522 const Tensor* indices_tensor;
523 OpInputList values_tensor;
524
525 OP_REQUIRES_OK(ctx, ctx->input("key", &key_tensor));
526 OP_REQUIRES_OK(ctx, ctx->input("indices", &indices_tensor));
527 OP_REQUIRES_OK(ctx, ctx->input_list("values", &values_tensor));
528
529 // Create copy for insertion into Staging Area
530 Tensor key(*key_tensor);
531
532 // Create the tuple to store
533 for (std::size_t i = 0; i < values_tensor.size(); ++i) {
534 tuple.push_back(values_tensor[i]);
535 }
536
537 // Store the tuple in the map
538 OP_REQUIRES_OK(ctx, map->put(&key, indices_tensor, &tuple));
539 }
540 };
541
542 REGISTER_KERNEL_BUILDER(Name("MapStage").Device(DEVICE_CPU), MapStageOp<false>);
543 REGISTER_KERNEL_BUILDER(Name("OrderedMapStage").Device(DEVICE_CPU),
544 MapStageOp<true>);
545
546 #if GOOGLE_CUDA
547 REGISTER_KERNEL_BUILDER(
548 Name("MapStage").HostMemory("key").HostMemory("indices").Device(DEVICE_GPU),
549 MapStageOp<false>);
550 REGISTER_KERNEL_BUILDER(Name("OrderedMapStage")
551 .HostMemory("key")
552 .HostMemory("indices")
553 .Device(DEVICE_GPU),
554 MapStageOp<true>);
555 #endif // GOOGLE_CUDA
556
557 #ifdef TENSORFLOW_USE_SYCL
558 REGISTER_KERNEL_BUILDER(Name("MapStage")
559 .HostMemory("key")
560 .HostMemory("indices")
561 .Device(DEVICE_SYCL),
562 MapStageOp<false>);
563 REGISTER_KERNEL_BUILDER(Name("OrderedMapStage")
564 .HostMemory("key")
565 .HostMemory("indices")
566 .Device(DEVICE_SYCL),
567 MapStageOp<true>);
568 #endif // TENSORFLOW_USE_SYCL
569
570 template <bool Ordered>
571 class MapUnstageOp : public OpKernel {
572 public:
MapUnstageOp(OpKernelConstruction * ctx)573 explicit MapUnstageOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
574
575 // Using this op in such a way that it blocks forever
576 // is an error. As such cancellation is not handled.
Compute(OpKernelContext * ctx)577 void Compute(OpKernelContext* ctx) override {
578 StagingMap<Ordered>* map = nullptr;
579 OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
580 core::ScopedUnref scope(map);
581 typename StagingMap<Ordered>::Tuple tuple;
582
583 const Tensor* key_tensor;
584 const Tensor* indices_tensor;
585
586 OP_REQUIRES_OK(ctx, ctx->input("key", &key_tensor));
587 OP_REQUIRES_OK(ctx, ctx->input("indices", &indices_tensor));
588 OP_REQUIRES_OK(ctx, map->pop(key_tensor, indices_tensor, &tuple));
589
590 OP_REQUIRES(
591 ctx, tuple.size() == indices_tensor->NumElements(),
592 errors::InvalidArgument("output/indices size mismatch: ", tuple.size(),
593 " vs. ", indices_tensor->NumElements()));
594
595 for (std::size_t i = 0; i < tuple.size(); ++i) {
596 ctx->set_output(i, tuple[i]);
597 }
598 }
599 };
600
601 REGISTER_KERNEL_BUILDER(Name("MapUnstage").Device(DEVICE_CPU),
602 MapUnstageOp<false>);
603 REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstage").Device(DEVICE_CPU),
604 MapUnstageOp<true>);
605
606 #if GOOGLE_CUDA
607 REGISTER_KERNEL_BUILDER(Name("MapUnstage")
608 .HostMemory("key")
609 .HostMemory("indices")
610 .Device(DEVICE_GPU),
611 MapUnstageOp<false>);
612 REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstage")
613 .HostMemory("key")
614 .HostMemory("indices")
615 .Device(DEVICE_GPU),
616 MapUnstageOp<true>);
617 #endif
618 #ifdef TENSORFLOW_USE_SYCL
619 REGISTER_KERNEL_BUILDER(Name("MapUnstage")
620 .HostMemory("key")
621 .HostMemory("indices")
622 .Device(DEVICE_SYCL),
623 MapUnstageOp<false>);
624 REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstage")
625 .HostMemory("key")
626 .HostMemory("indices")
627 .Device(DEVICE_SYCL),
628 MapUnstageOp<true>);
629 #endif // TENSORFLOW_USE_SYCL
630
631 template <bool Ordered>
632 class MapPeekOp : public OpKernel {
633 public:
MapPeekOp(OpKernelConstruction * ctx)634 explicit MapPeekOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
635
636 // Using this op in such a way that it blocks forever
637 // is an error. As such cancellation is not handled.
Compute(OpKernelContext * ctx)638 void Compute(OpKernelContext* ctx) override {
639 StagingMap<Ordered>* map = nullptr;
640 OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
641 core::ScopedUnref scope(map);
642 typename StagingMap<Ordered>::Tuple tuple;
643
644 const Tensor* key_tensor;
645 const Tensor* indices_tensor;
646
647 OP_REQUIRES_OK(ctx, ctx->input("key", &key_tensor));
648 OP_REQUIRES_OK(ctx, ctx->input("indices", &indices_tensor));
649 OP_REQUIRES_OK(ctx, map->get(key_tensor, indices_tensor, &tuple));
650
651 OP_REQUIRES(
652 ctx, tuple.size() == indices_tensor->NumElements(),
653 errors::InvalidArgument("output/indices size mismatch: ", tuple.size(),
654 " vs. ", indices_tensor->NumElements()));
655
656 for (std::size_t i = 0; i < tuple.size(); ++i) {
657 ctx->set_output(i, tuple[i]);
658 }
659 }
660 };
661
662 REGISTER_KERNEL_BUILDER(Name("MapPeek").Device(DEVICE_CPU), MapPeekOp<false>);
663 REGISTER_KERNEL_BUILDER(Name("OrderedMapPeek").Device(DEVICE_CPU),
664 MapPeekOp<true>);
665
666 #if GOOGLE_CUDA
667 REGISTER_KERNEL_BUILDER(
668 Name("MapPeek").HostMemory("key").HostMemory("indices").Device(DEVICE_GPU),
669 MapPeekOp<false>);
670 REGISTER_KERNEL_BUILDER(Name("OrderedMapPeek")
671 .HostMemory("key")
672 .HostMemory("indices")
673 .Device(DEVICE_GPU),
674 MapPeekOp<true>);
675 #endif
676
677 #ifdef TENSORFLOW_USE_SYCL
678 REGISTER_KERNEL_BUILDER(
679 Name("MapPeek").HostMemory("key").HostMemory("indices").Device(DEVICE_SYCL),
680 MapPeekOp<false>);
681 REGISTER_KERNEL_BUILDER(Name("OrderedMapPeek")
682 .HostMemory("key")
683 .HostMemory("indices")
684 .Device(DEVICE_SYCL),
685 MapPeekOp<true>);
686 #endif // TENSORFLOW_USE_SYCL
687
688 template <bool Ordered>
689 class MapUnstageNoKeyOp : public OpKernel {
690 public:
MapUnstageNoKeyOp(OpKernelConstruction * ctx)691 explicit MapUnstageNoKeyOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
692
693 // Using this op in such a way that it blocks forever
694 // is an error. As such cancellation is not handled.
Compute(OpKernelContext * ctx)695 void Compute(OpKernelContext* ctx) override {
696 StagingMap<Ordered>* map = nullptr;
697 OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
698 core::ScopedUnref scope(map);
699
700 // Pop a random (key, value) off the map
701 typename StagingMap<Ordered>::KeyType key;
702 typename StagingMap<Ordered>::Tuple tuple;
703
704 const Tensor* indices_tensor;
705
706 OP_REQUIRES_OK(ctx, ctx->input("indices", &indices_tensor));
707 OP_REQUIRES_OK(ctx, map->popitem(&key, indices_tensor, &tuple));
708
709 // Allocate a key tensor and assign the key as the first output
710 ctx->set_output(0, key);
711
712 // Set the rest of the outputs to the tuple Tensors
713 OP_REQUIRES(
714 ctx, tuple.size() == indices_tensor->NumElements(),
715 errors::InvalidArgument("output/indices size mismatch: ", tuple.size(),
716 " vs. ", indices_tensor->NumElements()));
717
718 for (std::size_t i = 0; i < tuple.size(); ++i) {
719 ctx->set_output(i + 1, tuple[i]);
720 }
721 }
722 };
723
724 REGISTER_KERNEL_BUILDER(Name("MapUnstageNoKey").Device(DEVICE_CPU),
725 MapUnstageNoKeyOp<false>);
726 REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstageNoKey").Device(DEVICE_CPU),
727 MapUnstageNoKeyOp<true>);
728
729 #if GOOGLE_CUDA
730 REGISTER_KERNEL_BUILDER(Name("MapUnstageNoKey")
731 .HostMemory("key")
732 .HostMemory("indices")
733 .Device(DEVICE_GPU),
734 MapUnstageNoKeyOp<false>);
735 REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstageNoKey")
736 .HostMemory("key")
737 .HostMemory("indices")
738 .Device(DEVICE_GPU),
739 MapUnstageNoKeyOp<true>);
740 #endif
741
742 #ifdef TENSORFLOW_USE_SYCL
743 REGISTER_KERNEL_BUILDER(Name("MapUnstageNoKey")
744 .HostMemory("key")
745 .HostMemory("indices")
746 .Device(DEVICE_SYCL),
747 MapUnstageNoKeyOp<false>);
748 REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstageNoKey")
749 .HostMemory("key")
750 .HostMemory("indices")
751 .Device(DEVICE_SYCL),
752 MapUnstageNoKeyOp<true>);
753 #endif // TENSORFLOW_USE_SYCL
754
755 template <bool Ordered>
756 class MapSizeOp : public OpKernel {
757 public:
MapSizeOp(OpKernelConstruction * ctx)758 explicit MapSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
759
Compute(OpKernelContext * ctx)760 void Compute(OpKernelContext* ctx) override {
761 StagingMap<Ordered>* map = nullptr;
762 OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
763 core::ScopedUnref scope(map);
764
765 // Allocate size output tensor
766 Tensor* size = nullptr;
767 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &size));
768
769 // Set it to the actual size
770 size->scalar<int32>().setConstant(map->size());
771 }
772 };
773
774 REGISTER_KERNEL_BUILDER(Name("MapSize").Device(DEVICE_CPU), MapSizeOp<false>);
775 REGISTER_KERNEL_BUILDER(Name("OrderedMapSize").Device(DEVICE_CPU),
776 MapSizeOp<true>);
777
778 #if GOOGLE_CUDA
779 REGISTER_KERNEL_BUILDER(Name("MapSize").Device(DEVICE_GPU).HostMemory("size"),
780 MapSizeOp<false>);
781 REGISTER_KERNEL_BUILDER(
782 Name("OrderedMapSize").Device(DEVICE_GPU).HostMemory("size"),
783 MapSizeOp<true>);
784 #endif
785 #ifdef TENSORFLOW_USE_SYCL
786 REGISTER_KERNEL_BUILDER(Name("MapSize").Device(DEVICE_SYCL).HostMemory("size"),
787 MapSizeOp<false>);
788 REGISTER_KERNEL_BUILDER(
789 Name("OrderedMapSize").Device(DEVICE_SYCL).HostMemory("size"),
790 MapSizeOp<true>);
791 #endif // TENSORFLOW_USE_SYCL
792
793 template <bool Ordered>
794 class MapIncompleteSizeOp : public OpKernel {
795 public:
MapIncompleteSizeOp(OpKernelConstruction * ctx)796 explicit MapIncompleteSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
797
Compute(OpKernelContext * ctx)798 void Compute(OpKernelContext* ctx) override {
799 StagingMap<Ordered>* map = nullptr;
800 OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
801 core::ScopedUnref scope(map);
802
803 // Allocate size output tensor
804 Tensor* size = nullptr;
805 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &size));
806
807 // Set it to the actual size
808 size->scalar<int32>().setConstant(map->incomplete_size());
809 }
810 };
811
812 REGISTER_KERNEL_BUILDER(Name("MapIncompleteSize").Device(DEVICE_CPU),
813 MapIncompleteSizeOp<false>);
814 REGISTER_KERNEL_BUILDER(Name("OrderedMapIncompleteSize").Device(DEVICE_CPU),
815 MapIncompleteSizeOp<true>);
816
817 #if GOOGLE_CUDA
818 REGISTER_KERNEL_BUILDER(
819 Name("MapIncompleteSize").Device(DEVICE_GPU).HostMemory("size"),
820 MapIncompleteSizeOp<false>);
821 REGISTER_KERNEL_BUILDER(
822 Name("OrderedMapIncompleteSize").Device(DEVICE_GPU).HostMemory("size"),
823 MapIncompleteSizeOp<true>);
824 #endif
825 #ifdef TENSORFLOW_USE_SYCL
826 REGISTER_KERNEL_BUILDER(
827 Name("MapIncompleteSize").Device(DEVICE_SYCL).HostMemory("size"),
828 MapIncompleteSizeOp<false>);
829 REGISTER_KERNEL_BUILDER(
830 Name("OrderedMapIncompleteSize").Device(DEVICE_SYCL).HostMemory("size"),
831 MapIncompleteSizeOp<true>);
832 #endif // TENSORFLOW_USE_SYCL
833
834 template <bool Ordered>
835 class MapClearOp : public OpKernel {
836 public:
MapClearOp(OpKernelConstruction * ctx)837 explicit MapClearOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
838
Compute(OpKernelContext * ctx)839 void Compute(OpKernelContext* ctx) override {
840 StagingMap<Ordered>* map = nullptr;
841 OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
842 core::ScopedUnref scope(map);
843
844 OP_REQUIRES_OK(ctx, map->clear());
845 }
846 };
847
848 REGISTER_KERNEL_BUILDER(Name("MapClear").Device(DEVICE_CPU), MapClearOp<false>);
849 REGISTER_KERNEL_BUILDER(Name("OrderedMapClear").Device(DEVICE_CPU),
850 MapClearOp<true>);
851
852 #if GOOGLE_CUDA
853 REGISTER_KERNEL_BUILDER(Name("MapClear").Device(DEVICE_GPU), MapClearOp<false>);
854 REGISTER_KERNEL_BUILDER(Name("OrderedMapClear").Device(DEVICE_GPU),
855 MapClearOp<true>);
856 #endif
857 #ifdef TENSORFLOW_USE_SYCL
858 REGISTER_KERNEL_BUILDER(Name("MapClear").Device(DEVICE_SYCL),
859 MapClearOp<false>);
860 REGISTER_KERNEL_BUILDER(Name("OrderedMapClear").Device(DEVICE_SYCL),
861 MapClearOp<true>);
862 #endif // TENSORFLOW_USE_SYCL
863
864 } // namespace
865 } // namespace tensorflow
866