1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstddef>
17 #include <functional>
18 #include <map>
19 #include <mutex>
20 #include <numeric>
21 #include <unordered_map>
22 #include <vector>
23 
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/resource_mgr.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/lib/gtl/optional.h"
29 #include "tensorflow/core/lib/strings/strcat.h"
30 #include "tensorflow/core/platform/env.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/thread_annotations.h"
33 
34 namespace tensorflow {
35 namespace {
36 
37 // Partial Ordering Comparator for Tensor keys containing scalar int64's
38 struct KeyTensorLess {
operator ()tensorflow::__anon07039e130111::KeyTensorLess39   bool operator()(const Tensor& lhs, const Tensor& rhs) const {
40     return std::less<int64>{}(lhs.scalar<int64>()(), rhs.scalar<int64>()());
41   }
42 };
43 
44 // Key Equality operator for Tensor keys containing scalar int64's
45 struct KeyTensorEqual {
operator ()tensorflow::__anon07039e130111::KeyTensorEqual46   bool operator()(const Tensor& lhs, const Tensor& rhs) const {
47     return std::equal_to<int64>{}(lhs.scalar<int64>()(), rhs.scalar<int64>()());
48   }
49 };
50 
51 // Hash for Tensor keys containing scalar int64's
52 struct KeyTensorHash {
operator ()tensorflow::__anon07039e130111::KeyTensorHash53   std::size_t operator()(const Tensor& key) const {
54     return std::hash<int64>{}(key.scalar<int64>()());
55   }
56 };
57 
58 // Primary template.
59 template <bool Ordered, typename Data>
60 struct MapTraits;
61 
62 // Partial specialization for ordered.
63 template <typename Data>
64 struct MapTraits<true, Data> {
65   using KeyType = Tensor;
66   using DataType = Data;
67   using MapType = std::map<KeyType, Data, KeyTensorLess>;
68 };
69 
70 // Partial specialization for unordered.
71 template <typename Data>
72 struct MapTraits<false, Data> {
73   using KeyType = Tensor;
74   using DataType = Data;
75   using MapType =
76       std::unordered_map<KeyType, Data, KeyTensorHash, KeyTensorEqual>;
77 };
78 
79 // Wrapper around map/unordered_map.
80 template <bool Ordered>
81 class StagingMap : public ResourceBase {
82  public:
83   // Public typedefs
84   using Tuple = std::vector<Tensor>;
85   using OptionalTensor = gtl::optional<Tensor>;
86   using OptionalTuple = std::vector<OptionalTensor>;
87 
88   using MapType = typename MapTraits<Ordered, OptionalTuple>::MapType;
89   using KeyType = typename MapTraits<Ordered, OptionalTuple>::KeyType;
90 
91   using IncompleteType = typename MapTraits<false, OptionalTuple>::MapType;
92 
93  private:
94   // Private variables
95   DataTypeVector dtypes_ GUARDED_BY(mu_);
96   std::size_t capacity_ GUARDED_BY(mu_);
97   std::size_t memory_limit_ GUARDED_BY(mu_);
98   std::size_t current_bytes_ GUARDED_BY(mu_);
99   tensorflow::mutex mu_;
100   tensorflow::condition_variable not_empty_;
101   tensorflow::condition_variable full_;
102   IncompleteType incomplete_ GUARDED_BY(mu_);
103   MapType map_ GUARDED_BY(mu_);
104 
105  private:
106   // private methods
107 
108   // If map is configured for bounded capacity, notify
109   // waiting inserters that space is now available
notify_inserters_if_bounded()110   void notify_inserters_if_bounded() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
111     if (has_capacity() || has_memory_limit()) {
112       // Notify all inserters. The removal of an element
113       // may make memory available for many inserters
114       // to insert new elements
115       full_.notify_all();
116     }
117   }
118 
119   // Notify all removers waiting to extract values
120   // that data is now available
notify_removers()121   void notify_removers() {
122     // Notify all removers. This is because they are
123     // waiting for specific keys to appear in the map
124     // so we don't know which one to wake up.
125     not_empty_.notify_all();
126   }
127 
has_capacity() const128   bool has_capacity() const EXCLUSIVE_LOCKS_REQUIRED(mu_) {
129     return capacity_ > 0;
130   }
131 
has_memory_limit() const132   bool has_memory_limit() const EXCLUSIVE_LOCKS_REQUIRED(mu_) {
133     return memory_limit_ > 0;
134   }
135 
would_exceed_memory_limit(std::size_t bytes) const136   bool would_exceed_memory_limit(std::size_t bytes) const
137       EXCLUSIVE_LOCKS_REQUIRED(mu_) {
138     return has_memory_limit() && bytes + current_bytes_ > memory_limit_;
139   }
140 
is_capacity_full() const141   bool is_capacity_full() const EXCLUSIVE_LOCKS_REQUIRED(mu_) {
142     return has_capacity() && map_.size() >= capacity_;
143   }
144 
145   // Get number of bytes in the tuple
get_tuple_bytes(const Tuple & tuple)146   std::size_t get_tuple_bytes(const Tuple& tuple) {
147     return std::accumulate(tuple.begin(), tuple.end(),
148                            static_cast<std::size_t>(0),
149                            [](const std::size_t& lhs, const Tensor& rhs) {
150                              return lhs + rhs.TotalBytes();
151                            });
152   }
153 
154   // Get number of bytes in the incomplete tuple
get_tuple_bytes(const OptionalTuple & tuple)155   std::size_t get_tuple_bytes(const OptionalTuple& tuple) {
156     return std::accumulate(
157         tuple.begin(), tuple.end(), static_cast<std::size_t>(0),
158         [](const std::size_t& lhs, const OptionalTensor& rhs) {
159           return (lhs + rhs.has_value()) ? rhs.value().TotalBytes() : 0;
160         });
161   }
162 
163   // Check that the index is within bounds
check_index(const Tensor & key,std::size_t index)164   Status check_index(const Tensor& key, std::size_t index)
165       EXCLUSIVE_LOCKS_REQUIRED(mu_) {
166     if (index >= dtypes_.size()) {
167       return Status(errors::InvalidArgument(
168           "Index '", index, "' for key '", key.scalar<int64>()(),
169           "' was out of bounds '", dtypes_.size(), "'."));
170     }
171 
172     return Status::OK();
173   }
174 
copy_or_move_tensors(OptionalTuple * map_tuple,const Tensor & key,const Tensor & indices,Tuple * output,bool copy=false)175   Status copy_or_move_tensors(OptionalTuple* map_tuple, const Tensor& key,
176                               const Tensor& indices, Tuple* output,
177                               bool copy = false) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
178     auto findices = indices.flat<int>();
179 
180     // Return values at specified indices
181     for (std::size_t i = 0; i < findices.dimension(0); ++i) {
182       std::size_t index = findices(i);
183 
184       TF_RETURN_IF_ERROR(check_index(key, index));
185 
186       // Insist on a value present at the specified index
187       if (!(*map_tuple)[index].has_value()) {
188         return Status(errors::InvalidArgument(
189             "Tensor at index '", index, "' for key '", key.scalar<int64>()(),
190             "' has already been removed."));
191       }
192 
193       // Copy the contained tensor and
194       // remove from the OptionalTuple
195       output->push_back((*map_tuple)[index].value());
196 
197       // Clear out the entry if we're not copying (moving)
198       if (!copy) {
199         (*map_tuple)[index].reset();
200       }
201     }
202 
203     return Status::OK();
204   }
205 
206   // Check that the optional value at the specified index
207   // is uninitialized
check_index_uninitialized(const Tensor & key,std::size_t index,const OptionalTuple & tuple)208   Status check_index_uninitialized(const Tensor& key, std::size_t index,
209                                    const OptionalTuple& tuple)
210       EXCLUSIVE_LOCKS_REQUIRED(mu_) {
211     if (tuple[index].has_value()) {
212       return Status(errors::InvalidArgument(
213           "The tensor for index '", index, "' for key '", key.scalar<int64>()(),
214           "' was already initialized '", dtypes_.size(), "'."));
215     }
216 
217     return Status::OK();
218   }
219 
220   // Check that the indices are strictly ordered
check_index_ordering(const Tensor & indices)221   Status check_index_ordering(const Tensor& indices) {
222     auto findices = indices.flat<int>();
223 
224     for (std::size_t i = 0; i < findices.dimension(0) - 1; ++i) {
225       if (findices(i) < findices(i + 1)) {
226         continue;
227       }
228 
229       return Status(
230           errors::InvalidArgument("Indices are not strictly ordered"));
231     }
232 
233     return Status::OK();
234   }
235 
236   // Check bytes are within memory limits memory limits
check_memory_limit(std::size_t bytes)237   Status check_memory_limit(std::size_t bytes) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
238     if (has_memory_limit() && bytes > memory_limit_) {
239       return Status(errors::ResourceExhausted(
240           "Attempted to insert tensors with combined size of '", bytes,
241           "' bytes into Staging Area with a memory limit of '", memory_limit_,
242           "'."));
243     }
244 
245     return Status::OK();
246   }
247 
248   // Insert incomplete data into the Barrier
put_incomplete(const KeyType & key,const Tensor & indices,OptionalTuple * tuple,tensorflow::mutex_lock * lock)249   Status put_incomplete(const KeyType& key, const Tensor& indices,
250                         OptionalTuple* tuple, tensorflow::mutex_lock* lock)
251       EXCLUSIVE_LOCKS_REQUIRED(mu_) {
252     auto findices = indices.flat<int>();
253 
254     // Search for the key in our incomplete set
255     auto it = incomplete_.find(key);
256 
257     // Check that the tuple fits within the memory limit
258     std::size_t tuple_bytes = get_tuple_bytes(*tuple);
259     TF_RETURN_IF_ERROR(check_memory_limit(tuple_bytes));
260 
261     // Wait until we don't exceed the memory limit
262     while (would_exceed_memory_limit(tuple_bytes)) {
263       full_.wait(*lock);
264     }
265 
266     // This key isn't present in the incomplete set
267     // Create OptionalTuple and insert
268     if (it == incomplete_.end()) {
269       OptionalTuple empty(dtypes_.size());
270 
271       // Initialize empty tuple with given dta
272       for (std::size_t i = 0; i < findices.dimension(0); ++i) {
273         std::size_t index = findices(i);
274         TF_RETURN_IF_ERROR(check_index(key, index));
275 
276         // Assign tuple at this index
277         empty[index] = std::move((*tuple)[i]);
278       }
279 
280       // Insert into incomplete map
281       incomplete_.insert({key, std::move(empty)});
282 
283       // Increment size
284       current_bytes_ += tuple_bytes;
285     }
286     // Found an entry in the incomplete index
287     // Update with given data and insert complete entries
288     // into the main map
289     else {
290       // Reference existing incomplete tuple
291       OptionalTuple& present = it->second;
292 
293       // Assign given data
294       for (std::size_t i = 0; i < findices.dimension(0); ++i) {
295         std::size_t index = findices(i);
296         TF_RETURN_IF_ERROR(check_index(key, index));
297         TF_RETURN_IF_ERROR(check_index_uninitialized(key, index, present));
298 
299         // Assign tuple at this index
300         present[index] = std::move((*tuple)[i]);
301       }
302 
303       // Increment size
304       current_bytes_ += tuple_bytes;
305 
306       // Do we have values at all tuple elements?
307       bool complete =
308           std::all_of(present.begin(), present.end(),
309                       [](const OptionalTensor& v) { return v.has_value(); });
310 
311       // If so, put the tuple in the actual map
312       if (complete) {
313         OptionalTuple insert_tuple = std::move(it->second);
314 
315         // Remove from incomplete
316         incomplete_.erase(it);
317 
318         TF_RETURN_IF_ERROR(put_complete(key, &insert_tuple));
319       }
320     }
321 
322     return Status::OK();
323   }
324 
325   // Does the insertion into the actual staging area
put_complete(const KeyType & key,OptionalTuple * tuple)326   Status put_complete(const KeyType& key, OptionalTuple* tuple)
327       EXCLUSIVE_LOCKS_REQUIRED(mu_) {
328     // Insert key and tuples into the map
329     map_.insert({key, std::move(*tuple)});
330 
331     notify_removers();
332 
333     return Status::OK();
334   }
335 
336  public:
337   // public methods
StagingMap(const DataTypeVector & dtypes,std::size_t capacity,std::size_t memory_limit)338   explicit StagingMap(const DataTypeVector& dtypes, std::size_t capacity,
339                       std::size_t memory_limit)
340       : dtypes_(dtypes),
341         capacity_(capacity),
342         memory_limit_(memory_limit),
343         current_bytes_(0) {}
344 
put(KeyType * key,const Tensor * indices,OptionalTuple * tuple)345   Status put(KeyType* key, const Tensor* indices, OptionalTuple* tuple) {
346     tensorflow::mutex_lock lock(mu_);
347 
348     // Sanity check the indices
349     TF_RETURN_IF_ERROR(check_index_ordering(*indices));
350 
351     // Handle incomplete inserts
352     if (indices->NumElements() != dtypes_.size()) {
353       return put_incomplete(*key, *indices, tuple, &lock);
354     }
355 
356     std::size_t tuple_bytes = get_tuple_bytes(*tuple);
357     // Check that tuple_bytes fits within the memory limit
358     TF_RETURN_IF_ERROR(check_memory_limit(tuple_bytes));
359 
360     // Wait until there's space for insertion.
361     while (would_exceed_memory_limit(tuple_bytes) || is_capacity_full()) {
362       full_.wait(lock);
363     }
364 
365     // Do the put operation
366     TF_RETURN_IF_ERROR(put_complete(*key, tuple));
367 
368     // Update the current size
369     current_bytes_ += tuple_bytes;
370 
371     return Status::OK();
372   }
373 
get(const KeyType * key,const Tensor * indices,Tuple * tuple)374   Status get(const KeyType* key, const Tensor* indices, Tuple* tuple) {
375     tensorflow::mutex_lock lock(mu_);
376 
377     // Sanity check the indices
378     TF_RETURN_IF_ERROR(check_index_ordering(*indices));
379 
380     typename MapType::iterator it;
381 
382     // Wait until the element with the requested key is present
383     while ((it = map_.find(*key)) == map_.end()) {
384       not_empty_.wait(lock);
385     }
386 
387     TF_RETURN_IF_ERROR(
388         copy_or_move_tensors(&it->second, *key, *indices, tuple, true));
389 
390     // Update bytes in the Staging Area
391     current_bytes_ -= get_tuple_bytes(*tuple);
392 
393     return Status::OK();
394   }
395 
pop(const KeyType * key,const Tensor * indices,Tuple * tuple)396   Status pop(const KeyType* key, const Tensor* indices, Tuple* tuple) {
397     tensorflow::mutex_lock lock(mu_);
398 
399     // Sanity check the indices
400     TF_RETURN_IF_ERROR(check_index_ordering(*indices));
401 
402     typename MapType::iterator it;
403 
404     // Wait until the element with the requested key is present
405     while ((it = map_.find(*key)) == map_.end()) {
406       not_empty_.wait(lock);
407     }
408 
409     TF_RETURN_IF_ERROR(
410         copy_or_move_tensors(&it->second, *key, *indices, tuple));
411 
412     // Remove entry if all the values have been consumed
413     if (!std::any_of(
414             it->second.begin(), it->second.end(),
415             [](const OptionalTensor& tensor) { return tensor.has_value(); })) {
416       map_.erase(it);
417     }
418 
419     // Update bytes in the Staging Area
420     current_bytes_ -= get_tuple_bytes(*tuple);
421 
422     notify_inserters_if_bounded();
423 
424     return Status::OK();
425   }
426 
popitem(KeyType * key,const Tensor * indices,Tuple * tuple)427   Status popitem(KeyType* key, const Tensor* indices, Tuple* tuple) {
428     tensorflow::mutex_lock lock(mu_);
429 
430     // Sanity check the indices
431     TF_RETURN_IF_ERROR(check_index_ordering(*indices));
432 
433     // Wait until map is not empty
434     while (this->map_.empty()) {
435       not_empty_.wait(lock);
436     }
437 
438     // Move from the first element and erase it
439 
440     auto it = map_.begin();
441 
442     TF_RETURN_IF_ERROR(
443         copy_or_move_tensors(&it->second, *key, *indices, tuple));
444 
445     *key = it->first;
446 
447     // Remove entry if all the values have been consumed
448     if (!std::any_of(
449             it->second.begin(), it->second.end(),
450             [](const OptionalTensor& tensor) { return tensor.has_value(); })) {
451       map_.erase(it);
452     }
453 
454     // Update bytes in the Staging Area
455     current_bytes_ -= get_tuple_bytes(*tuple);
456 
457     notify_inserters_if_bounded();
458 
459     return Status::OK();
460   }
461 
clear()462   Status clear() {
463     tensorflow::mutex_lock lock(mu_);
464     map_.clear();
465     incomplete_.clear();
466     current_bytes_ = 0;
467 
468     notify_inserters_if_bounded();
469 
470     return Status::OK();
471   }
472 
incomplete_size()473   std::size_t incomplete_size() {
474     tensorflow::mutex_lock lock(mu_);
475     return incomplete_.size();
476   }
477 
size()478   std::size_t size() {
479     tensorflow::mutex_lock lock(mu_);
480     return map_.size();
481   }
482 
DebugString() const483   string DebugString() const override { return "StagingMap"; }
484 };
485 
486 template <bool Ordered>
GetStagingMap(OpKernelContext * ctx,const NodeDef & ndef,StagingMap<Ordered> ** map)487 Status GetStagingMap(OpKernelContext* ctx, const NodeDef& ndef,
488                      StagingMap<Ordered>** map) {
489   auto rm = ctx->resource_manager();
490   ContainerInfo cinfo;
491 
492   // Lambda for creating the Staging Area
493   auto create_fn = [&ndef](StagingMap<Ordered>** ret) -> Status {
494     DataTypeVector dtypes;
495     int64 capacity;
496     int64 memory_limit;
497     TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "dtypes", &dtypes));
498     TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "capacity", &capacity));
499     TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "memory_limit", &memory_limit));
500     *ret = new StagingMap<Ordered>(dtypes, capacity, memory_limit);
501     return Status::OK();
502   };
503 
504   TF_RETURN_IF_ERROR(cinfo.Init(rm, ndef, true /* use name() */));
505   TF_RETURN_IF_ERROR(rm->LookupOrCreate<StagingMap<Ordered>>(
506       cinfo.container(), cinfo.name(), map, create_fn));
507   return Status::OK();
508 }
509 
510 template <bool Ordered>
511 class MapStageOp : public OpKernel {
512  public:
MapStageOp(OpKernelConstruction * ctx)513   explicit MapStageOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
514 
Compute(OpKernelContext * ctx)515   void Compute(OpKernelContext* ctx) override {
516     StagingMap<Ordered>* map = nullptr;
517     OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
518     core::ScopedUnref scope(map);
519     typename StagingMap<Ordered>::OptionalTuple tuple;
520 
521     const Tensor* key_tensor;
522     const Tensor* indices_tensor;
523     OpInputList values_tensor;
524 
525     OP_REQUIRES_OK(ctx, ctx->input("key", &key_tensor));
526     OP_REQUIRES_OK(ctx, ctx->input("indices", &indices_tensor));
527     OP_REQUIRES_OK(ctx, ctx->input_list("values", &values_tensor));
528 
529     // Create copy for insertion into Staging Area
530     Tensor key(*key_tensor);
531 
532     // Create the tuple to store
533     for (std::size_t i = 0; i < values_tensor.size(); ++i) {
534       tuple.push_back(values_tensor[i]);
535     }
536 
537     // Store the tuple in the map
538     OP_REQUIRES_OK(ctx, map->put(&key, indices_tensor, &tuple));
539   }
540 };
541 
542 REGISTER_KERNEL_BUILDER(Name("MapStage").Device(DEVICE_CPU), MapStageOp<false>);
543 REGISTER_KERNEL_BUILDER(Name("OrderedMapStage").Device(DEVICE_CPU),
544                         MapStageOp<true>);
545 
546 #if GOOGLE_CUDA
547 REGISTER_KERNEL_BUILDER(
548     Name("MapStage").HostMemory("key").HostMemory("indices").Device(DEVICE_GPU),
549     MapStageOp<false>);
550 REGISTER_KERNEL_BUILDER(Name("OrderedMapStage")
551                             .HostMemory("key")
552                             .HostMemory("indices")
553                             .Device(DEVICE_GPU),
554                         MapStageOp<true>);
555 #endif  // GOOGLE_CUDA
556 
557 #ifdef TENSORFLOW_USE_SYCL
558 REGISTER_KERNEL_BUILDER(Name("MapStage")
559                             .HostMemory("key")
560                             .HostMemory("indices")
561                             .Device(DEVICE_SYCL),
562                         MapStageOp<false>);
563 REGISTER_KERNEL_BUILDER(Name("OrderedMapStage")
564                             .HostMemory("key")
565                             .HostMemory("indices")
566                             .Device(DEVICE_SYCL),
567                         MapStageOp<true>);
568 #endif  // TENSORFLOW_USE_SYCL
569 
570 template <bool Ordered>
571 class MapUnstageOp : public OpKernel {
572  public:
MapUnstageOp(OpKernelConstruction * ctx)573   explicit MapUnstageOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
574 
575   // Using this op in such a way that it blocks forever
576   // is an error.  As such cancellation is not handled.
Compute(OpKernelContext * ctx)577   void Compute(OpKernelContext* ctx) override {
578     StagingMap<Ordered>* map = nullptr;
579     OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
580     core::ScopedUnref scope(map);
581     typename StagingMap<Ordered>::Tuple tuple;
582 
583     const Tensor* key_tensor;
584     const Tensor* indices_tensor;
585 
586     OP_REQUIRES_OK(ctx, ctx->input("key", &key_tensor));
587     OP_REQUIRES_OK(ctx, ctx->input("indices", &indices_tensor));
588     OP_REQUIRES_OK(ctx, map->pop(key_tensor, indices_tensor, &tuple));
589 
590     OP_REQUIRES(
591         ctx, tuple.size() == indices_tensor->NumElements(),
592         errors::InvalidArgument("output/indices size mismatch: ", tuple.size(),
593                                 " vs. ", indices_tensor->NumElements()));
594 
595     for (std::size_t i = 0; i < tuple.size(); ++i) {
596       ctx->set_output(i, tuple[i]);
597     }
598   }
599 };
600 
601 REGISTER_KERNEL_BUILDER(Name("MapUnstage").Device(DEVICE_CPU),
602                         MapUnstageOp<false>);
603 REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstage").Device(DEVICE_CPU),
604                         MapUnstageOp<true>);
605 
606 #if GOOGLE_CUDA
607 REGISTER_KERNEL_BUILDER(Name("MapUnstage")
608                             .HostMemory("key")
609                             .HostMemory("indices")
610                             .Device(DEVICE_GPU),
611                         MapUnstageOp<false>);
612 REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstage")
613                             .HostMemory("key")
614                             .HostMemory("indices")
615                             .Device(DEVICE_GPU),
616                         MapUnstageOp<true>);
617 #endif
618 #ifdef TENSORFLOW_USE_SYCL
619 REGISTER_KERNEL_BUILDER(Name("MapUnstage")
620                             .HostMemory("key")
621                             .HostMemory("indices")
622                             .Device(DEVICE_SYCL),
623                         MapUnstageOp<false>);
624 REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstage")
625                             .HostMemory("key")
626                             .HostMemory("indices")
627                             .Device(DEVICE_SYCL),
628                         MapUnstageOp<true>);
629 #endif  // TENSORFLOW_USE_SYCL
630 
631 template <bool Ordered>
632 class MapPeekOp : public OpKernel {
633  public:
MapPeekOp(OpKernelConstruction * ctx)634   explicit MapPeekOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
635 
636   // Using this op in such a way that it blocks forever
637   // is an error.  As such cancellation is not handled.
Compute(OpKernelContext * ctx)638   void Compute(OpKernelContext* ctx) override {
639     StagingMap<Ordered>* map = nullptr;
640     OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
641     core::ScopedUnref scope(map);
642     typename StagingMap<Ordered>::Tuple tuple;
643 
644     const Tensor* key_tensor;
645     const Tensor* indices_tensor;
646 
647     OP_REQUIRES_OK(ctx, ctx->input("key", &key_tensor));
648     OP_REQUIRES_OK(ctx, ctx->input("indices", &indices_tensor));
649     OP_REQUIRES_OK(ctx, map->get(key_tensor, indices_tensor, &tuple));
650 
651     OP_REQUIRES(
652         ctx, tuple.size() == indices_tensor->NumElements(),
653         errors::InvalidArgument("output/indices size mismatch: ", tuple.size(),
654                                 " vs. ", indices_tensor->NumElements()));
655 
656     for (std::size_t i = 0; i < tuple.size(); ++i) {
657       ctx->set_output(i, tuple[i]);
658     }
659   }
660 };
661 
662 REGISTER_KERNEL_BUILDER(Name("MapPeek").Device(DEVICE_CPU), MapPeekOp<false>);
663 REGISTER_KERNEL_BUILDER(Name("OrderedMapPeek").Device(DEVICE_CPU),
664                         MapPeekOp<true>);
665 
666 #if GOOGLE_CUDA
667 REGISTER_KERNEL_BUILDER(
668     Name("MapPeek").HostMemory("key").HostMemory("indices").Device(DEVICE_GPU),
669     MapPeekOp<false>);
670 REGISTER_KERNEL_BUILDER(Name("OrderedMapPeek")
671                             .HostMemory("key")
672                             .HostMemory("indices")
673                             .Device(DEVICE_GPU),
674                         MapPeekOp<true>);
675 #endif
676 
677 #ifdef TENSORFLOW_USE_SYCL
678 REGISTER_KERNEL_BUILDER(
679     Name("MapPeek").HostMemory("key").HostMemory("indices").Device(DEVICE_SYCL),
680     MapPeekOp<false>);
681 REGISTER_KERNEL_BUILDER(Name("OrderedMapPeek")
682                             .HostMemory("key")
683                             .HostMemory("indices")
684                             .Device(DEVICE_SYCL),
685                         MapPeekOp<true>);
686 #endif  // TENSORFLOW_USE_SYCL
687 
688 template <bool Ordered>
689 class MapUnstageNoKeyOp : public OpKernel {
690  public:
MapUnstageNoKeyOp(OpKernelConstruction * ctx)691   explicit MapUnstageNoKeyOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
692 
693   // Using this op in such a way that it blocks forever
694   // is an error.  As such cancellation is not handled.
Compute(OpKernelContext * ctx)695   void Compute(OpKernelContext* ctx) override {
696     StagingMap<Ordered>* map = nullptr;
697     OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
698     core::ScopedUnref scope(map);
699 
700     // Pop a random (key, value) off the map
701     typename StagingMap<Ordered>::KeyType key;
702     typename StagingMap<Ordered>::Tuple tuple;
703 
704     const Tensor* indices_tensor;
705 
706     OP_REQUIRES_OK(ctx, ctx->input("indices", &indices_tensor));
707     OP_REQUIRES_OK(ctx, map->popitem(&key, indices_tensor, &tuple));
708 
709     // Allocate a key tensor and assign the key as the first output
710     ctx->set_output(0, key);
711 
712     // Set the rest of the outputs to the tuple Tensors
713     OP_REQUIRES(
714         ctx, tuple.size() == indices_tensor->NumElements(),
715         errors::InvalidArgument("output/indices size mismatch: ", tuple.size(),
716                                 " vs. ", indices_tensor->NumElements()));
717 
718     for (std::size_t i = 0; i < tuple.size(); ++i) {
719       ctx->set_output(i + 1, tuple[i]);
720     }
721   }
722 };
723 
724 REGISTER_KERNEL_BUILDER(Name("MapUnstageNoKey").Device(DEVICE_CPU),
725                         MapUnstageNoKeyOp<false>);
726 REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstageNoKey").Device(DEVICE_CPU),
727                         MapUnstageNoKeyOp<true>);
728 
729 #if GOOGLE_CUDA
730 REGISTER_KERNEL_BUILDER(Name("MapUnstageNoKey")
731                             .HostMemory("key")
732                             .HostMemory("indices")
733                             .Device(DEVICE_GPU),
734                         MapUnstageNoKeyOp<false>);
735 REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstageNoKey")
736                             .HostMemory("key")
737                             .HostMemory("indices")
738                             .Device(DEVICE_GPU),
739                         MapUnstageNoKeyOp<true>);
740 #endif
741 
742 #ifdef TENSORFLOW_USE_SYCL
743 REGISTER_KERNEL_BUILDER(Name("MapUnstageNoKey")
744                             .HostMemory("key")
745                             .HostMemory("indices")
746                             .Device(DEVICE_SYCL),
747                         MapUnstageNoKeyOp<false>);
748 REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstageNoKey")
749                             .HostMemory("key")
750                             .HostMemory("indices")
751                             .Device(DEVICE_SYCL),
752                         MapUnstageNoKeyOp<true>);
753 #endif  // TENSORFLOW_USE_SYCL
754 
755 template <bool Ordered>
756 class MapSizeOp : public OpKernel {
757  public:
MapSizeOp(OpKernelConstruction * ctx)758   explicit MapSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
759 
Compute(OpKernelContext * ctx)760   void Compute(OpKernelContext* ctx) override {
761     StagingMap<Ordered>* map = nullptr;
762     OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
763     core::ScopedUnref scope(map);
764 
765     // Allocate size output tensor
766     Tensor* size = nullptr;
767     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &size));
768 
769     // Set it to the actual size
770     size->scalar<int32>().setConstant(map->size());
771   }
772 };
773 
774 REGISTER_KERNEL_BUILDER(Name("MapSize").Device(DEVICE_CPU), MapSizeOp<false>);
775 REGISTER_KERNEL_BUILDER(Name("OrderedMapSize").Device(DEVICE_CPU),
776                         MapSizeOp<true>);
777 
778 #if GOOGLE_CUDA
779 REGISTER_KERNEL_BUILDER(Name("MapSize").Device(DEVICE_GPU).HostMemory("size"),
780                         MapSizeOp<false>);
781 REGISTER_KERNEL_BUILDER(
782     Name("OrderedMapSize").Device(DEVICE_GPU).HostMemory("size"),
783     MapSizeOp<true>);
784 #endif
785 #ifdef TENSORFLOW_USE_SYCL
786 REGISTER_KERNEL_BUILDER(Name("MapSize").Device(DEVICE_SYCL).HostMemory("size"),
787                         MapSizeOp<false>);
788 REGISTER_KERNEL_BUILDER(
789     Name("OrderedMapSize").Device(DEVICE_SYCL).HostMemory("size"),
790     MapSizeOp<true>);
791 #endif  // TENSORFLOW_USE_SYCL
792 
793 template <bool Ordered>
794 class MapIncompleteSizeOp : public OpKernel {
795  public:
MapIncompleteSizeOp(OpKernelConstruction * ctx)796   explicit MapIncompleteSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
797 
Compute(OpKernelContext * ctx)798   void Compute(OpKernelContext* ctx) override {
799     StagingMap<Ordered>* map = nullptr;
800     OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
801     core::ScopedUnref scope(map);
802 
803     // Allocate size output tensor
804     Tensor* size = nullptr;
805     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &size));
806 
807     // Set it to the actual size
808     size->scalar<int32>().setConstant(map->incomplete_size());
809   }
810 };
811 
812 REGISTER_KERNEL_BUILDER(Name("MapIncompleteSize").Device(DEVICE_CPU),
813                         MapIncompleteSizeOp<false>);
814 REGISTER_KERNEL_BUILDER(Name("OrderedMapIncompleteSize").Device(DEVICE_CPU),
815                         MapIncompleteSizeOp<true>);
816 
817 #if GOOGLE_CUDA
818 REGISTER_KERNEL_BUILDER(
819     Name("MapIncompleteSize").Device(DEVICE_GPU).HostMemory("size"),
820     MapIncompleteSizeOp<false>);
821 REGISTER_KERNEL_BUILDER(
822     Name("OrderedMapIncompleteSize").Device(DEVICE_GPU).HostMemory("size"),
823     MapIncompleteSizeOp<true>);
824 #endif
825 #ifdef TENSORFLOW_USE_SYCL
826 REGISTER_KERNEL_BUILDER(
827     Name("MapIncompleteSize").Device(DEVICE_SYCL).HostMemory("size"),
828     MapIncompleteSizeOp<false>);
829 REGISTER_KERNEL_BUILDER(
830     Name("OrderedMapIncompleteSize").Device(DEVICE_SYCL).HostMemory("size"),
831     MapIncompleteSizeOp<true>);
832 #endif  // TENSORFLOW_USE_SYCL
833 
834 template <bool Ordered>
835 class MapClearOp : public OpKernel {
836  public:
MapClearOp(OpKernelConstruction * ctx)837   explicit MapClearOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
838 
Compute(OpKernelContext * ctx)839   void Compute(OpKernelContext* ctx) override {
840     StagingMap<Ordered>* map = nullptr;
841     OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
842     core::ScopedUnref scope(map);
843 
844     OP_REQUIRES_OK(ctx, map->clear());
845   }
846 };
847 
848 REGISTER_KERNEL_BUILDER(Name("MapClear").Device(DEVICE_CPU), MapClearOp<false>);
849 REGISTER_KERNEL_BUILDER(Name("OrderedMapClear").Device(DEVICE_CPU),
850                         MapClearOp<true>);
851 
852 #if GOOGLE_CUDA
853 REGISTER_KERNEL_BUILDER(Name("MapClear").Device(DEVICE_GPU), MapClearOp<false>);
854 REGISTER_KERNEL_BUILDER(Name("OrderedMapClear").Device(DEVICE_GPU),
855                         MapClearOp<true>);
856 #endif
857 #ifdef TENSORFLOW_USE_SYCL
858 REGISTER_KERNEL_BUILDER(Name("MapClear").Device(DEVICE_SYCL),
859                         MapClearOp<false>);
860 REGISTER_KERNEL_BUILDER(Name("OrderedMapClear").Device(DEVICE_SYCL),
861                         MapClearOp<true>);
862 #endif  // TENSORFLOW_USE_SYCL
863 
864 }  // namespace
865 }  // namespace tensorflow
866