1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/range_dataset_op.h"
16 
17 #include "absl/memory/memory.h"
18 #include "tensorflow/core/framework/dataset.h"
19 #include "tensorflow/core/framework/partial_tensor_shape.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/kernels/data/name_utils.h"
22 
23 namespace tensorflow {
24 namespace data {
25 
26 // See documentation in ../../ops/dataset_ops.cc for a high-level
27 // description of the following op.
28 
29 /* static */ constexpr const char* const RangeDatasetOp::kDatasetType;
30 /* static */ constexpr const char* const RangeDatasetOp::kStart;
31 /* static */ constexpr const char* const RangeDatasetOp::kStop;
32 /* static */ constexpr const char* const RangeDatasetOp::kStep;
33 /* static */ constexpr const char* const RangeDatasetOp::kOutputTypes;
34 /* static */ constexpr const char* const RangeDatasetOp::kOutputShapes;
35 
36 namespace {
37 constexpr char kNext[] = "next";
38 constexpr char kHasSplitProvider[] = "has_split_provider";
39 constexpr char kSlash[] = "/";
40 constexpr char kSplitProvider[] = "split_provider";
41 
42 // Class which produces the elements of `range(start, stop, step)`. Threadsafe.
43 class RangeCounter {
44  public:
RangeCounter(int64 start,int64 stop,int64 step)45   RangeCounter(int64 start, int64 stop, int64 step)
46       : start_(start), stop_(stop), step_(step), next_(start) {}
47 
48   // Returns the next value for the counter. Sets `*end_of_counter` to indicate
49   // whether the end of the counter was reached.
GetNext(bool * end_of_counter)50   int64 GetNext(bool* end_of_counter) {
51     mutex_lock l(mu_);
52     if ((step_ > 0 && next_ >= stop_) || (step_ < 0 && next_ <= stop_)) {
53       *end_of_counter = true;
54       return -1;
55     }
56     *end_of_counter = false;
57     int result = next_;
58     next_ += step_;
59     return result;
60   }
61 
Peek() const62   int64 Peek() const {
63     mutex_lock l(mu_);
64     return next_;
65   }
66 
Reset()67   void Reset() {
68     mutex_lock l(mu_);
69     next_ = start_;
70   }
71 
SetNext(int64 value)72   void SetNext(int64 value) {
73     mutex_lock l(mu_);
74     next_ = value;
75   }
76 
77  private:
78   const int64 start_;
79   const int64 stop_;
80   const int64 step_;
81   mutable mutex mu_;
82   int64 next_ TF_GUARDED_BY(mu_);
83 };
84 }  // namespace
85 
86 // Split provider where splits are individual outputs from RangeDataset.
87 // For example, the "splits" of range(0, 10, 2) will be {0, 2, 4, 6, 8}.
88 // The split tensors are scalars of type DT_INT64.
89 class RangeDatasetOp::RangeSplitProvider : public SplitProvider {
90  public:
RangeSplitProvider(int64 start,int64 stop,int64 step)91   RangeSplitProvider(int64 start, int64 stop, int64 step)
92       : counter_(start, stop, step) {}
93 
GetNext(Tensor * split,bool * end_of_splits)94   Status GetNext(Tensor* split, bool* end_of_splits) override {
95     int64 next = counter_.GetNext(end_of_splits);
96     if (*end_of_splits) {
97       return Status::OK();
98     }
99     *split = Tensor(DT_INT64, TensorShape{});
100     split->scalar<int64>()() = next;
101     return Status::OK();
102   }
103 
Reset()104   Status Reset() override {
105     counter_.Reset();
106     return Status::OK();
107   }
108 
Save(std::function<std::string (std::string)> key_name_fn,IteratorStateWriter * writer)109   Status Save(std::function<std::string(std::string)> key_name_fn,
110               IteratorStateWriter* writer) override {
111     TF_RETURN_IF_ERROR(
112         writer->WriteScalar(key_name_fn(kNext), counter_.Peek()));
113     return Status::OK();
114   }
115 
Restore(std::function<std::string (std::string)> key_name_fn,IteratorStateReader * reader)116   Status Restore(std::function<std::string(std::string)> key_name_fn,
117                  IteratorStateReader* reader) override {
118     int64 next;
119     TF_RETURN_IF_ERROR(reader->ReadScalar(key_name_fn(kNext), &next));
120     counter_.SetNext(next);
121     return Status::OK();
122   }
123 
124  private:
125   RangeCounter counter_;
126 };
127 
128 class RangeDatasetOp::Dataset : public DatasetBase {
129  public:
Dataset(OpKernelContext * ctx,int64 start,int64 stop,int64 step,DataTypeVector output_dtypes)130   Dataset(OpKernelContext* ctx, int64 start, int64 stop, int64 step,
131           DataTypeVector output_dtypes)
132       : DatasetBase(DatasetContext(ctx)),
133         start_(start),
134         stop_(stop),
135         step_(step),
136         output_dtypes_(output_dtypes) {}
137 
MakeIteratorInternal(const string & prefix) const138   std::unique_ptr<IteratorBase> MakeIteratorInternal(
139       const string& prefix) const override {
140     return absl::make_unique<Iterator>(Iterator::Params{
141         this, name_utils::IteratorPrefix(kDatasetType, prefix)});
142   }
143 
output_dtypes() const144   const DataTypeVector& output_dtypes() const override {
145     return output_dtypes_;
146   }
147 
output_shapes() const148   const std::vector<PartialTensorShape>& output_shapes() const override {
149     static std::vector<PartialTensorShape>* shapes =
150         new std::vector<PartialTensorShape>({PartialTensorShape({})});
151     return *shapes;
152   }
153 
DebugString() const154   string DebugString() const override {
155     name_utils::DatasetDebugStringParams params;
156     params.set_args(start_, stop_, step_);
157     return name_utils::DatasetDebugString(kDatasetType, params);
158   }
159 
Cardinality() const160   int64 Cardinality() const override {
161     if (step_ > 0) {
162       return std::max(int64{0}, (stop_ - start_ - 1) / step_ + 1);
163     } else {
164       return std::max(int64{0}, (start_ - stop_ - 1) / -step_ + 1);
165     }
166   }
167 
MakeSplitProvider(std::unique_ptr<SplitProvider> * split_provider) const168   Status MakeSplitProvider(
169       std::unique_ptr<SplitProvider>* split_provider) const override {
170     *split_provider =
171         absl::make_unique<RangeSplitProvider>(start_, stop_, step_);
172     return Status::OK();
173   }
174 
InputDatasets(std::vector<const DatasetBase * > * inputs) const175   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
176     inputs->clear();
177     return Status::OK();
178   }
179 
CheckExternalState() const180   Status CheckExternalState() const override { return Status::OK(); }
181 
182  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const183   Status AsGraphDefInternal(SerializationContext* ctx,
184                             DatasetGraphDefBuilder* b,
185                             Node** output) const override {
186     Node* start = nullptr;
187     Node* stop = nullptr;
188     Node* step = nullptr;
189     TF_RETURN_IF_ERROR(b->AddScalar(start_, &start));
190     TF_RETURN_IF_ERROR(b->AddScalar(stop_, &stop));
191     TF_RETURN_IF_ERROR(b->AddScalar(step_, &step));
192     TF_RETURN_IF_ERROR(b->AddDataset(this, {start, stop, step}, output));
193     return Status::OK();
194   }
195 
196  private:
197   class Iterator : public DatasetIterator<Dataset> {
198    public:
Iterator(const Params & params)199     explicit Iterator(const Params& params)
200         : DatasetIterator<Dataset>(params) {}
201 
Initialize(IteratorContext * ctx)202     Status Initialize(IteratorContext* ctx) override {
203       split_provider_ = ctx->split_provider();
204       if (!split_provider_) {
205         counter_ = absl::make_unique<RangeCounter>(
206             dataset()->start_, dataset()->stop_, dataset()->step_);
207       }
208       return Status::OK();
209     }
210 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)211     Status GetNextInternal(IteratorContext* ctx,
212                            std::vector<Tensor>* out_tensors,
213                            bool* end_of_sequence) override {
214       int64 value;
215       if (split_provider_ != nullptr) {
216         Tensor split;
217         TF_RETURN_IF_ERROR(split_provider_->GetNext(&split, end_of_sequence));
218         if (*end_of_sequence) {
219           return Status::OK();
220         }
221         value = split.scalar<int64>()();
222       } else {
223         value = counter_->GetNext(end_of_sequence);
224         if (*end_of_sequence) {
225           return Status::OK();
226         }
227       }
228       out_tensors->reserve(1);
229       switch (dataset()->output_dtypes()[0]) {
230 #define HANDLE_TYPE(type)                                \
231   case DataTypeToEnum<type>::value: {                    \
232     out_tensors->emplace_back(static_cast<type>(value)); \
233     break;                                               \
234   }
235         TF_CALL_NUMBER_TYPES(HANDLE_TYPE);
236 #undef HANDLE_TYPE
237         default:
238           return errors::InvalidArgument(
239               "Unsupported data type: ",
240               DataTypeString(dataset()->output_dtypes()[0]));
241       }
242       return Status::OK();
243     }
244 
245    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const246     std::shared_ptr<model::Node> CreateNode(
247         IteratorContext* ctx, model::Node::Args args) const override {
248       return model::MakeSourceNode(std::move(args));
249     }
250 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)251     Status SaveInternal(SerializationContext* ctx,
252                         IteratorStateWriter* writer) override {
253       if (split_provider_) {
254         TF_RETURN_IF_ERROR(
255             writer->WriteScalar(full_name(kHasSplitProvider), true));
256         TF_RETURN_IF_ERROR(split_provider_->Save(
257             [this](const std::string& key) {
258               return SplitProviderKeyNameFn(key);
259             },
260             writer));
261       } else {
262         TF_RETURN_IF_ERROR(
263             writer->WriteScalar(full_name(kNext), counter_->Peek()));
264       }
265       return Status::OK();
266     }
267 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)268     Status RestoreInternal(IteratorContext* ctx,
269                            IteratorStateReader* reader) override {
270       if (reader->Contains(full_name(kHasSplitProvider))) {
271         TF_RETURN_IF_ERROR(split_provider_->Restore(
272             [this](const std::string& key) {
273               return SplitProviderKeyNameFn(key);
274             },
275             reader));
276       } else {
277         int64 next;
278         TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kNext), &next));
279         counter_->SetNext(next);
280       }
281       return Status::OK();
282     }
283 
SplitProviderKeyNameFn(const std::string & key)284     std::string SplitProviderKeyNameFn(const std::string& key) {
285       return full_name(absl::StrCat(kSplitProvider, kSlash, key));
286     }
287 
288    private:
289     std::unique_ptr<RangeCounter> counter_;
290     std::shared_ptr<SplitProvider> split_provider_;
291   };
292 
293   const int64 start_;
294   const int64 stop_;
295   const int64 step_;
296   const DataTypeVector output_dtypes_;
297 };
298 
RangeDatasetOp(OpKernelConstruction * ctx)299 RangeDatasetOp::RangeDatasetOp(OpKernelConstruction* ctx)
300     : DatasetOpKernel(ctx) {
301   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
302 }
303 
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)304 void RangeDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase** output) {
305   int64 start;
306   OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kStart, &start));
307 
308   int64 stop;
309   OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kStop, &stop));
310 
311   int64 step;
312   OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kStep, &step));
313   OP_REQUIRES(ctx, step != 0,
314               errors::InvalidArgument("step must be a non-zero integer."));
315 
316   *output = new Dataset(ctx, start, stop, step, output_types_);
317 }
318 
319 namespace {
320 REGISTER_KERNEL_BUILDER(Name("RangeDataset").Device(DEVICE_CPU),
321                         RangeDatasetOp);
322 }  // namespace
323 
324 }  // namespace data
325 }  // namespace tensorflow
326