1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/tensor_slice_dataset_op.h"
16 
17 #include "tensorflow/core/framework/partial_tensor_shape.h"
18 #include "tensorflow/core/framework/tensor.h"
19 #include "tensorflow/core/graph/graph.h"
20 #include "tensorflow/core/kernels/data/dataset_utils.h"
21 #include "tensorflow/core/kernels/data/name_utils.h"
22 #include "tensorflow/core/kernels/data/split_utils.h"
23 #include "tensorflow/core/util/batch_util.h"
24 
25 namespace tensorflow {
26 namespace data {
27 
28 // See documentation in ../../ops/dataset_ops.cc for a high-level
29 // description of the following op.
30 
31 /* static */ constexpr const char* const TensorSliceDatasetOp::kDatasetType;
32 /* static */ constexpr const char* const TensorSliceDatasetOp::kComponents;
33 /* static */ constexpr const char* const TensorSliceDatasetOp::kToutputTypes;
34 /* static */ constexpr const char* const TensorSliceDatasetOp::kOutputShapes;
35 
36 constexpr char kCurIndex[] = "i";
37 
38 class TensorSliceDatasetOp::Dataset : public DatasetBase {
39  public:
Dataset(OpKernelContext * ctx,std::vector<Tensor> tensors)40   explicit Dataset(OpKernelContext* ctx, std::vector<Tensor> tensors)
41       : DatasetBase(DatasetContext(ctx)), tensors_(std::move(tensors)) {
42     for (const Tensor& t : tensors_) {
43       dtypes_.push_back(t.dtype());
44       gtl::InlinedVector<int64, 4> element_dim_sizes;
45       // Handle scalar here. Check that everyone matches here? Or fail
46       // at runtime?
47       for (int i = 1; i < t.dims(); ++i) {
48         element_dim_sizes.push_back(t.dim_size(i));
49       }
50       partial_shapes_.emplace_back(element_dim_sizes);
51       shapes_.emplace_back(std::move(element_dim_sizes));
52     }
53   }
54 
MakeIteratorInternal(const string & prefix) const55   std::unique_ptr<IteratorBase> MakeIteratorInternal(
56       const string& prefix) const override {
57     return absl::make_unique<Iterator>(Iterator::Params{
58         this, name_utils::IteratorPrefix(kDatasetType, prefix)});
59   }
60 
MakeSplitProvider(std::unique_ptr<SplitProvider> * split_provider) const61   Status MakeSplitProvider(
62       std::unique_ptr<SplitProvider>* split_provider) const override {
63     *split_provider =
64         absl::make_unique<IndexSplitProvider>(tensors_[0].dim_size(0));
65     return Status::OK();
66   }
67 
output_dtypes() const68   const DataTypeVector& output_dtypes() const override { return dtypes_; }
69 
output_shapes() const70   const std::vector<PartialTensorShape>& output_shapes() const override {
71     return partial_shapes_;
72   }
73 
DebugString() const74   string DebugString() const override {
75     return name_utils::DatasetDebugString(kDatasetType);
76   }
77 
Cardinality() const78   int64 Cardinality() const override { return tensors_[0].dim_size(0); }
79 
InputDatasets(std::vector<const DatasetBase * > * inputs) const80   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
81     return Status::OK();
82   }
83 
CheckExternalState() const84   Status CheckExternalState() const override { return Status::OK(); }
85 
86  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const87   Status AsGraphDefInternal(SerializationContext* ctx,
88                             DatasetGraphDefBuilder* b,
89                             Node** output) const override {
90     std::vector<Node*> components;
91     components.reserve(tensors_.size());
92     for (const Tensor& t : tensors_) {
93       Node* node;
94       if (ctx->serialize_data_tensors()) {
95         TF_RETURN_IF_ERROR(b->AddDatasetOrTensor(ctx, t, &node));
96       } else {
97         TF_RETURN_IF_ERROR(b->AddPlaceholder(t, &node));
98         DCHECK_NE(ctx->input_list(), nullptr);
99         ctx->input_list()->emplace_back(node->name(), t);
100       }
101       components.emplace_back(node);
102     }
103     AttrValue dtypes;
104     b->BuildAttrValue(dtypes_, &dtypes);
105     TF_RETURN_IF_ERROR(b->AddDataset(this, {}, {{0, components}},
106                                      {{kToutputTypes, dtypes}}, output));
107     return Status::OK();
108   }
109 
110  private:
111   class Iterator : public DatasetIterator<Dataset> {
112    public:
Iterator(const Params & params)113     explicit Iterator(const Params& params)
114         : DatasetIterator<Dataset>(params) {}
115 
Initialize(IteratorContext * ctx)116     Status Initialize(IteratorContext* ctx) override {
117       split_provider_ = ctx->split_provider();
118       if (split_provider_ == nullptr) {
119         split_provider_ = std::make_shared<IndexSplitProvider>(
120             dataset()->tensors_[0].dim_size(0));
121       }
122       return Status::OK();
123     }
124 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)125     Status GetNextInternal(IteratorContext* ctx,
126                            std::vector<Tensor>* out_tensors,
127                            bool* end_of_sequence) override {
128       Tensor split;
129       TF_RETURN_IF_ERROR(split_provider_->GetNext(&split, end_of_sequence));
130       if (*end_of_sequence) {
131         return Status::OK();
132       }
133       int64 index = split.scalar<int64>()();
134       out_tensors->clear();
135       out_tensors->reserve(dataset()->tensors_.size());
136       for (size_t i = 0; i < dataset()->tensors_.size(); ++i) {
137         const Tensor& t = dataset()->tensors_[i];
138         out_tensors->emplace_back(ctx->allocator({}), t.dtype(),
139                                   dataset()->shapes_[i]);
140         TF_RETURN_IF_ERROR(
141             batch_util::CopySliceToElement(t, &out_tensors->back(), index));
142       }
143       *end_of_sequence = false;
144       return Status::OK();
145     }
146 
147    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const148     std::shared_ptr<model::Node> CreateNode(
149         IteratorContext* ctx, model::Node::Args args) const override {
150       return model::MakeSourceNode(std::move(args));
151     }
152 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)153     Status SaveInternal(SerializationContext* ctx,
154                         IteratorStateWriter* writer) override {
155       return split_provider_->Save(
156           [this](const std::string& key) { return full_name(key); }, writer);
157     }
158 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)159     Status RestoreInternal(IteratorContext* ctx,
160                            IteratorStateReader* reader) override {
161       return split_provider_->Restore(
162           [this](const std::string& key) { return full_name(key); }, reader);
163     }
164 
165    private:
166     std::shared_ptr<SplitProvider> split_provider_;
167   };
168 
169   const std::vector<Tensor> tensors_;
170   DataTypeVector dtypes_;
171   std::vector<TensorShape> shapes_;
172   std::vector<PartialTensorShape> partial_shapes_;
173 };
174 
TensorSliceDatasetOp(OpKernelConstruction * ctx)175 TensorSliceDatasetOp::TensorSliceDatasetOp(OpKernelConstruction* ctx)
176     : DatasetOpKernel(ctx) {
177   OP_REQUIRES_OK(ctx, ctx->GetAttr(kToutputTypes, &output_types_));
178   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
179 }
180 
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)181 void TensorSliceDatasetOp::MakeDataset(OpKernelContext* ctx,
182                                        DatasetBase** output) {
183   OpInputList inputs;
184   OP_REQUIRES_OK(ctx, ctx->input_list(kComponents, &inputs));
185   std::vector<Tensor> components;
186   components.reserve(inputs.size());
187   OP_REQUIRES(
188       ctx, inputs[0].dims() > 0,
189       errors::InvalidArgument("All components must be at least 1-dimensional"));
190   const int64 num_slices = inputs[0].dim_size(0);
191   for (const Tensor& t : inputs) {
192     components.push_back(t);
193     OP_REQUIRES(ctx, t.dims() > 0,
194                 errors::InvalidArgument(
195                     "All components must be at least 1-dimensional"));
196     OP_REQUIRES(
197         ctx, t.dim_size(0) == num_slices,
198         errors::InvalidArgument(
199             "All components must have the same size in the 0th dimension"));
200   }
201   *output = new Dataset(ctx, std::move(components));
202   OP_REQUIRES_OK(ctx,
203                  VerifyTypesMatch((*output)->output_dtypes(), output_types_));
204   OP_REQUIRES_OK(
205       ctx, VerifyShapesCompatible((*output)->output_shapes(), output_shapes_));
206 }
207 
208 namespace {
209 REGISTER_KERNEL_BUILDER(Name("TensorSliceDataset").Device(DEVICE_CPU),
210                         TensorSliceDatasetOp);
211 }  // namespace
212 }  // namespace data
213 }  // namespace tensorflow
214