1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/take_dataset_op.h"
16 
17 #include "tensorflow/core/framework/partial_tensor_shape.h"
18 #include "tensorflow/core/framework/tensor.h"
19 #include "tensorflow/core/kernels/data/name_utils.h"
20 
21 namespace tensorflow {
22 namespace data {
23 
24 /* static */ constexpr const char* const TakeDatasetOp::kDatasetType;
25 /* static */ constexpr const char* const TakeDatasetOp::kInputDataset;
26 /* static */ constexpr const char* const TakeDatasetOp::kCount;
27 /* static */ constexpr const char* const TakeDatasetOp::kOutputTypes;
28 /* static */ constexpr const char* const TakeDatasetOp::kOutputShapes;
29 
30 constexpr char kCurIndex[] = "i";
31 constexpr char kInputImplEmpty[] = "input_impl_empty";
32 constexpr char kEmptyTake[] = "EmptyTake";
33 constexpr char kFiniteTake[] = "FiniteTake";
34 
TakeDataset(OpKernelContext * ctx,int64 count,const DatasetBase * input)35 TakeDataset::TakeDataset(OpKernelContext* ctx, int64 count,
36                          const DatasetBase* input)
37     : DatasetBase(DatasetContext(ctx)), count_(count), input_(input) {
38   input_->Ref();
39 }
40 
TakeDataset(DatasetContext::Params params,int64 count,const DatasetBase * input)41 TakeDataset::TakeDataset(DatasetContext::Params params, int64 count,
42                          const DatasetBase* input)
43     : DatasetBase(DatasetContext(std::move(params))),
44       count_(count),
45       input_(input) {
46   input_->Ref();
47 }
48 
~TakeDataset()49 TakeDataset::~TakeDataset() { input_->Unref(); }
50 
output_dtypes() const51 const DataTypeVector& TakeDataset::output_dtypes() const {
52   return input_->output_dtypes();
53 }
54 
output_shapes() const55 const std::vector<PartialTensorShape>& TakeDataset::output_shapes() const {
56   return input_->output_shapes();
57 }
58 
DebugString() const59 string TakeDataset::DebugString() const {
60   return name_utils::DatasetDebugString(TakeDatasetOp::kDatasetType);
61 }
62 
Cardinality() const63 int64 TakeDataset::Cardinality() const {
64   int64 n = input_->Cardinality();
65   if (n == kUnknownCardinality) {
66     return kUnknownCardinality;
67   }
68   if (n == kInfiniteCardinality) {
69     return count_;
70   } else if (count_ == kInfiniteCardinality) {
71     return n;
72   }
73 
74   return std::min(n, count_);
75 }
76 
InputDatasets(std::vector<const DatasetBase * > * inputs) const77 Status TakeDataset::InputDatasets(
78     std::vector<const DatasetBase*>* inputs) const {
79   inputs->push_back(input_);
80   return Status::OK();
81 }
82 
CheckExternalState() const83 Status TakeDataset::CheckExternalState() const {
84   return input_->CheckExternalState();
85 }
86 
87 class TakeDataset::EmptyIterator : public DatasetIterator<TakeDataset> {
88  public:
EmptyIterator(const Params & params)89   explicit EmptyIterator(const Params& params)
90       : DatasetIterator<TakeDataset>(params) {}
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)91   Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
92                          bool* end_of_sequence) override {
93     *end_of_sequence = true;
94     return Status::OK();
95   }
96 
97  protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const98   std::shared_ptr<model::Node> CreateNode(
99       IteratorContext* ctx, model::Node::Args args) const override {
100     return model::MakeKnownRatioNode(std::move(args),
101                                      /*ratio=*/1);
102   }
103 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)104   Status SaveInternal(SerializationContext* ctx,
105                       IteratorStateWriter* writer) override {
106     return Status::OK();
107   }
108 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)109   Status RestoreInternal(IteratorContext* ctx,
110                          IteratorStateReader* reader) override {
111     return Status::OK();
112   }
113 };
114 
115 class TakeDataset::FiniteIterator : public DatasetIterator<TakeDataset> {
116  public:
FiniteIterator(const Params & params)117   explicit FiniteIterator(const Params& params)
118       : DatasetIterator<TakeDataset>(params), i_(0) {}
119 
Initialize(IteratorContext * ctx)120   Status Initialize(IteratorContext* ctx) override {
121     return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
122   }
123 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)124   Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
125                          bool* end_of_sequence) override {
126     mutex_lock l(mu_);  // TODO(mrry): Make locking less conservative.
127     if (!input_impl_) {
128       *end_of_sequence = true;
129       return Status::OK();
130     }
131     while (dataset()->count_ < 0 || i_ < dataset()->count_) {
132       TF_RETURN_IF_ERROR(
133           input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
134       if (!*end_of_sequence) {
135         ++i_;
136         return Status::OK();
137       }
138       break;
139     }
140     *end_of_sequence = true;
141     input_impl_.reset();
142     return Status::OK();
143   }
144 
145  protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const146   std::shared_ptr<model::Node> CreateNode(
147       IteratorContext* ctx, model::Node::Args args) const override {
148     return model::MakeKnownRatioNode(std::move(args),
149                                      /*ratio=*/1);
150   }
151 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)152   Status SaveInternal(SerializationContext* ctx,
153                       IteratorStateWriter* writer) override {
154     mutex_lock l(mu_);
155     TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurIndex), i_));
156     if (input_impl_) {
157       TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
158     } else {
159       TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInputImplEmpty), ""));
160     }
161     return Status::OK();
162   }
163 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)164   Status RestoreInternal(IteratorContext* ctx,
165                          IteratorStateReader* reader) override {
166     mutex_lock l(mu_);
167     TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCurIndex), &i_));
168     if (!reader->Contains(full_name(kInputImplEmpty))) {
169       TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
170     } else {
171       input_impl_.reset();
172     }
173     return Status::OK();
174   }
175 
176  private:
177   mutex mu_;
178   int64 i_ TF_GUARDED_BY(mu_);
179   std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
180 };
181 
182 // See documentation in ../../ops/dataset_ops.cc for a high-level
183 // description of the following op.
MakeIteratorInternal(const string & prefix) const184 std::unique_ptr<IteratorBase> TakeDataset::MakeIteratorInternal(
185     const string& prefix) const {
186   if (count_ == 0) {
187     return absl::make_unique<EmptyIterator>(EmptyIterator::Params{
188         this, name_utils::IteratorPrefix(kEmptyTake, prefix)});
189   } else {
190     return absl::make_unique<FiniteIterator>(FiniteIterator::Params{
191         this, name_utils::IteratorPrefix(kFiniteTake, prefix)});
192   }
193 }
194 
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const195 Status TakeDataset::AsGraphDefInternal(SerializationContext* ctx,
196                                        DatasetGraphDefBuilder* b,
197                                        Node** output) const {
198   Node* input_graph_node = nullptr;
199   TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
200   Node* count = nullptr;
201   TF_RETURN_IF_ERROR(b->AddScalar(count_, &count));
202   TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node, count}, output));
203   return Status::OK();
204 }
205 
TakeDatasetOp(OpKernelConstruction * ctx)206 TakeDatasetOp::TakeDatasetOp(OpKernelConstruction* ctx)
207     : UnaryDatasetOpKernel(ctx) {}
208 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)209 void TakeDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
210                                 DatasetBase** output) {
211   // Create a new TakeDatasetOp::Dataset, and return it as the output.
212   int64 count;
213   OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kCount, &count));
214   *output = new TakeDataset(ctx, count, input);
215 }
216 
217 namespace {
218 REGISTER_KERNEL_BUILDER(Name("TakeDataset").Device(DEVICE_CPU), TakeDatasetOp);
219 }  // namespace
220 }  // namespace data
221 }  // namespace tensorflow
222