1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/split_utils.h"
16 
17 namespace tensorflow {
18 namespace data {
19 namespace {
20 constexpr char kNumToSkip[] = "num_to_skip";
21 constexpr char kSplitProvider[] = "split_provider";
22 constexpr char kSlash[] = "/";
23 constexpr char kIndex[] = "index";
24 }  // namespace
25 
IndexSplitProvider(int64 n)26 IndexSplitProvider::IndexSplitProvider(int64 n) : i_(0), n_(n) {}
27 
GetNext(Tensor * split,bool * end_of_splits)28 Status IndexSplitProvider::GetNext(Tensor* split, bool* end_of_splits) {
29   mutex_lock l(mu_);
30   if (i_ >= n_) {
31     *end_of_splits = true;
32     return Status::OK();
33   }
34   *end_of_splits = false;
35   *split = Tensor(DT_INT64, TensorShape{});
36   split->scalar<int64>()() = i_++;
37   return Status::OK();
38 }
39 
Reset()40 Status IndexSplitProvider::Reset() {
41   mutex_lock l(mu_);
42   i_ = 0;
43   return Status::OK();
44 }
45 
Save(std::function<std::string (std::string)> full_name,IteratorStateWriter * writer)46 Status IndexSplitProvider::Save(
47     std::function<std::string(std::string)> full_name,
48     IteratorStateWriter* writer) {
49   mutex_lock l(mu_);
50   return writer->WriteScalar(full_name(kIndex), i_);
51 }
52 
Restore(std::function<std::string (std::string)> full_name,IteratorStateReader * reader)53 Status IndexSplitProvider::Restore(
54     std::function<std::string(std::string)> full_name,
55     IteratorStateReader* reader) {
56   mutex_lock l(mu_);
57   return reader->ReadScalar(full_name(kIndex), &i_);
58 }
59 
ShardingSplitProvider(int64 num_shards,int64 shard_index,std::shared_ptr<SplitProvider> split_provider)60 ShardingSplitProvider::ShardingSplitProvider(
61     int64 num_shards, int64 shard_index,
62     std::shared_ptr<SplitProvider> split_provider)
63     : num_shards_(num_shards),
64       shard_index_(shard_index),
65       split_provider_(split_provider),
66       num_to_skip_(shard_index_) {}
67 
GetNext(Tensor * split,bool * end_of_splits)68 Status ShardingSplitProvider::GetNext(Tensor* split, bool* end_of_splits) {
69   mutex_lock l(mu_);
70   while (num_to_skip_ > 0) {
71     TF_RETURN_IF_ERROR(split_provider_->GetNext(split, end_of_splits));
72     if (*end_of_splits) {
73       return Status::OK();
74     }
75     num_to_skip_--;
76   }
77   num_to_skip_ = num_shards_ - 1;
78   TF_RETURN_IF_ERROR(split_provider_->GetNext(split, end_of_splits));
79   return Status::OK();
80 }
81 
Reset()82 Status ShardingSplitProvider::Reset() {
83   mutex_lock l(mu_);
84   TF_RETURN_IF_ERROR(split_provider_->Reset());
85   num_to_skip_ = shard_index_;
86   return Status::OK();
87 }
88 
Save(std::function<std::string (std::string)> full_name,IteratorStateWriter * writer)89 Status ShardingSplitProvider::Save(
90     std::function<std::string(std::string)> full_name,
91     IteratorStateWriter* writer) {
92   mutex_lock l(mu_);
93   TF_RETURN_IF_ERROR(split_provider_->Save(
94       [&](const std::string& key) {
95         return full_name(absl::StrCat(kSplitProvider, kSlash, key));
96       },
97       writer));
98   return writer->WriteScalar(full_name(kNumToSkip), num_to_skip_);
99 }
100 
Restore(std::function<std::string (std::string)> full_name,IteratorStateReader * reader)101 Status ShardingSplitProvider::Restore(
102     std::function<std::string(std::string)> full_name,
103     IteratorStateReader* reader) {
104   mutex_lock l(mu_);
105   TF_RETURN_IF_ERROR(split_provider_->Restore(
106       [&](const std::string& key) {
107         return full_name(absl::StrCat(kSplitProvider, kSlash, key));
108       },
109       reader));
110   TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kNumToSkip), &num_to_skip_));
111   return Status::OK();
112 }
113 
114 }  // namespace data
115 }  // namespace tensorflow
116