1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/concatenate_dataset_op.h"
16 
17 #include "tensorflow/core/framework/partial_tensor_shape.h"
18 #include "tensorflow/core/framework/tensor.h"
19 #include "tensorflow/core/kernels/data/name_utils.h"
20 
21 namespace tensorflow {
22 namespace data {
23 
24 // See documentation in ../../ops/dataset_ops.cc for a high-level
25 // description of the following op.
26 
27 /* static */ constexpr const char* const ConcatenateDatasetOp::kDatasetType;
28 /* static */ constexpr const char* const ConcatenateDatasetOp::kInputDataset;
29 /* static */ constexpr const char* const ConcatenateDatasetOp::kAnotherDataset;
30 /* static */ constexpr const char* const ConcatenateDatasetOp::kOutputTypes;
31 /* static */ constexpr const char* const ConcatenateDatasetOp::kOutputShapes;
32 
33 constexpr char kIndex[] = "i";
34 constexpr char kInputImplUninitialized[] = "input_impl_uninitialized";
35 
36 class ConcatenateDatasetOp::Dataset : public DatasetBase {
37  public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,const DatasetBase * to_concatenate)38   explicit Dataset(OpKernelContext* ctx, const DatasetBase* input,
39                    const DatasetBase* to_concatenate)
40       : DatasetBase(DatasetContext(ctx)),
41         input_(input),
42         to_concatenate_(to_concatenate) {
43     input_->Ref();
44     to_concatenate_->Ref();
45 
46     auto os_input = input->output_shapes();
47     auto os_concatenate = to_concatenate->output_shapes();
48     for (int i = 0; i < os_input.size(); i++) {
49       output_shapes_.push_back(
50           MostSpecificCompatibleShape(os_input[i], os_concatenate[i]));
51     }
52   }
~Dataset()53   ~Dataset() override {
54     input_->Unref();
55     to_concatenate_->Unref();
56   }
57 
MakeIteratorInternal(const string & prefix) const58   std::unique_ptr<IteratorBase> MakeIteratorInternal(
59       const string& prefix) const override {
60     return absl::make_unique<Iterator>(Iterator::Params{
61         this, name_utils::IteratorPrefix(kDatasetType, prefix)});
62   }
63 
output_dtypes() const64   const DataTypeVector& output_dtypes() const override {
65     return input_->output_dtypes();
66   }
67 
output_shapes() const68   const std::vector<PartialTensorShape>& output_shapes() const override {
69     return output_shapes_;
70   }
71 
DebugString() const72   string DebugString() const override {
73     return name_utils::DatasetDebugString(kDatasetType);
74   }
75 
Cardinality() const76   int64 Cardinality() const override {
77     int64 n1 = input_->Cardinality();
78     int64 n2 = to_concatenate_->Cardinality();
79     if (n1 == kInfiniteCardinality || n2 == kInfiniteCardinality) {
80       return kInfiniteCardinality;
81     }
82     if (n1 == kUnknownCardinality || n2 == kUnknownCardinality) {
83       return kUnknownCardinality;
84     }
85     return n1 + n2;
86   }
87 
InputDatasets(std::vector<const DatasetBase * > * inputs) const88   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
89     inputs->push_back(input_);
90     inputs->push_back(to_concatenate_);
91     return Status::OK();
92   }
93 
CheckExternalState() const94   Status CheckExternalState() const override {
95     TF_RETURN_IF_ERROR(input_->CheckExternalState());
96     return to_concatenate_->CheckExternalState();
97   }
98 
99  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const100   Status AsGraphDefInternal(SerializationContext* ctx,
101                             DatasetGraphDefBuilder* b,
102                             Node** output) const override {
103     Node* input_graph = nullptr;
104     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph));
105     Node* to_concatenate_graph = nullptr;
106     TF_RETURN_IF_ERROR(
107         b->AddInputDataset(ctx, to_concatenate_, &to_concatenate_graph));
108     TF_RETURN_IF_ERROR(
109         b->AddDataset(this, {input_graph, to_concatenate_graph}, output));
110     return Status::OK();
111   }
112 
113  private:
114   class Iterator : public DatasetIterator<Dataset> {
115    public:
Iterator(const Params & params)116     explicit Iterator(const Params& params)
117         : DatasetIterator<Dataset>(params), i_(0) {}
118 
Initialize(IteratorContext * ctx)119     Status Initialize(IteratorContext* ctx) override {
120       return dataset()->input_->MakeIterator(
121           ctx, this, strings::StrCat(prefix(), "[0]"), &input_impl_);
122     }
123 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)124     Status GetNextInternal(IteratorContext* ctx,
125                            std::vector<Tensor>* out_tensors,
126                            bool* end_of_sequence) override {
127       mutex_lock l(mu_);
128       if (!input_impl_) {
129         *end_of_sequence = true;
130         return Status::OK();
131       }
132       while (i_ < 2) {
133         TF_RETURN_IF_ERROR(
134             input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
135         if (!*end_of_sequence) {
136           return Status::OK();
137         }
138         if (++i_ < 2) {
139           TF_RETURN_IF_ERROR(dataset()->to_concatenate_->MakeIterator(
140               ctx, this, strings::StrCat(prefix(), "[1]"), &input_impl_));
141         }
142       }
143       *end_of_sequence = true;
144       input_impl_.reset();
145       return Status::OK();
146     }
147 
148    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const149     std::shared_ptr<model::Node> CreateNode(
150         IteratorContext* ctx, model::Node::Args args) const override {
151       return model::MakeKnownRatioNode(std::move(args),
152                                        /*ratio=*/1);
153     }
154 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)155     Status SaveInternal(SerializationContext* ctx,
156                         IteratorStateWriter* writer) override {
157       mutex_lock l(mu_);
158       TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kIndex), i_));
159       if (input_impl_) {
160         TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
161       } else {
162         TF_RETURN_IF_ERROR(
163             writer->WriteScalar(full_name(kInputImplUninitialized), ""));
164       }
165       return Status::OK();
166     }
167 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)168     Status RestoreInternal(IteratorContext* ctx,
169                            IteratorStateReader* reader) override {
170       mutex_lock l(mu_);
171       TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kIndex), &i_));
172       if (reader->Contains(full_name(kInputImplUninitialized))) {
173         input_impl_.reset();
174         return Status::OK();
175       }
176       if (!TF_PREDICT_TRUE(i_ >= 0 && i_ <= 2))
177         return errors::InvalidArgument("i_ must be in range [0, 2].");
178       if (i_ == 1) {
179         TF_RETURN_IF_ERROR(dataset()->to_concatenate_->MakeIterator(
180             ctx, this, strings::StrCat(prefix(), "[1]"), &input_impl_));
181       } else if (i_ == 2) {
182         input_impl_.reset();
183       }
184       if (input_impl_) {
185         TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
186       }
187       return Status::OK();
188     }
189 
190    private:
191     mutex mu_;
192     int64 i_ TF_GUARDED_BY(mu_);
193     std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
194   };
195 
MostSpecificCompatibleShape(const PartialTensorShape & ts1,const PartialTensorShape & ts2)196   static PartialTensorShape MostSpecificCompatibleShape(
197       const PartialTensorShape& ts1, const PartialTensorShape& ts2) {
198     if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank())
199       return PartialTensorShape();
200     PartialTensorShape output_tensorshape({});
201     auto dims1 = ts1.dim_sizes();
202     auto dims2 = ts2.dim_sizes();
203     for (int d = 0; d < ts1.dims(); d++) {
204       if (dims1[d] == dims2[d])
205         output_tensorshape.AddDim(dims1[d]);
206       else
207         output_tensorshape.AddDim(-1);
208     }
209     return output_tensorshape;
210   }
211 
212   const DatasetBase* input_;
213   const DatasetBase* to_concatenate_;
214   std::vector<PartialTensorShape> output_shapes_;
215 };
216 
ConcatenateDatasetOp(OpKernelConstruction * ctx)217 ConcatenateDatasetOp::ConcatenateDatasetOp(OpKernelConstruction* ctx)
218     : BinaryDatasetOpKernel(ctx) {}
219 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase * to_concatenate,DatasetBase ** output)220 void ConcatenateDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
221                                        DatasetBase* to_concatenate,
222                                        DatasetBase** output) {
223   OP_REQUIRES(ctx, input->output_dtypes() == to_concatenate->output_dtypes(),
224               errors::InvalidArgument(
225                   "input dataset and dataset to concatenate"
226                   " have different output_types %s and %s",
227                   (DataTypeVectorString(input->output_dtypes()),
228                    DataTypeVectorString(to_concatenate->output_dtypes()))));
229   *output = new Dataset(ctx, input, to_concatenate);
230 }
231 
232 namespace {
233 REGISTER_KERNEL_BUILDER(Name("ConcatenateDataset").Device(DEVICE_CPU),
234                         ConcatenateDatasetOp);
235 }  // namespace
236 }  // namespace data
237 }  // namespace tensorflow
238