1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/data/dataset_test_base.h"
17 
18 namespace tensorflow {
19 namespace data {
20 namespace {
21 
22 constexpr char kNodeName[] = "zip_dataset";
23 constexpr char kOpName[] = "ZipDataset";
24 
25 struct RangeDatasetParam {
26   int64 start;
27   int64 end;
28   int64 step;
29 };
30 
31 class ZipDatasetOpTest : public DatasetOpsTestBase {
32  protected:
33   // Creates `RangeDataset` variant tensors from the input vector of
34   // `RangeDatasetParam`.
CreateRangeDatasetTensors(const std::vector<RangeDatasetParam> & params,std::vector<Tensor> * const dataset_tensors)35   Status CreateRangeDatasetTensors(const std::vector<RangeDatasetParam> &params,
36                                    std::vector<Tensor> *const dataset_tensors) {
37     for (int i = 0; i < params.size(); ++i) {
38       DatasetBase *range_dataset;
39       TF_RETURN_IF_ERROR(CreateRangeDataset<int64>(
40           params[i].start, params[i].end, params[i].step,
41           strings::StrCat("range_", i), &range_dataset));
42       Tensor dataset_tensor(DT_VARIANT, TensorShape({}));
43       TF_RETURN_IF_ERROR(
44           StoreDatasetInVariantTensor(range_dataset, &dataset_tensor));
45       dataset_tensors->emplace_back(std::move(dataset_tensor));
46     }
47     return Status::OK();
48   }
49 
50   // Creates a new ZipDataset op kernel.
CreateZipDatasetKernel(const DataTypeVector & dtypes,const std::vector<PartialTensorShape> & output_shapes,int n,std::unique_ptr<OpKernel> * op_kernel)51   Status CreateZipDatasetKernel(
52       const DataTypeVector &dtypes,
53       const std::vector<PartialTensorShape> &output_shapes, int n,
54       std::unique_ptr<OpKernel> *op_kernel) {
55     std::vector<string> input_datasets;
56     input_datasets.reserve(n);
57     for (int i = 0; i < n; ++i) {
58       // Create the placeholder names for the input components of `ZipDataset`.
59       input_datasets.emplace_back(strings::StrCat("input_dataset_", i));
60     }
61     node_def_ = test::function::NDef(
62         kNodeName, kOpName, input_datasets,
63         {{"output_types", dtypes}, {"output_shapes", output_shapes}, {"N", n}});
64     TF_RETURN_IF_ERROR(CreateOpKernel(node_def_, op_kernel));
65     return Status::OK();
66   }
67 
68   // Creates a new ZipDataset op kernel context.
CreateZipDatasetContext(OpKernel * const op_kernel,gtl::InlinedVector<TensorValue,4> * const inputs,std::unique_ptr<OpKernelContext> * context)69   Status CreateZipDatasetContext(
70       OpKernel *const op_kernel,
71       gtl::InlinedVector<TensorValue, 4> *const inputs,
72       std::unique_ptr<OpKernelContext> *context) {
73     TF_RETURN_IF_ERROR(CheckOpKernelInput(*op_kernel, *inputs));
74     TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context));
75     return Status::OK();
76   }
77 
78  private:
79   NodeDef node_def_;
80 };
81 
82 struct TestParam {
83   std::vector<RangeDatasetParam> input_range_dataset_params;
84   std::vector<Tensor> expected_outputs;
85   std::vector<int> breakpoints;
86 };
87 
TestCase1()88 TestParam TestCase1() {
89   // Test case 1: the input datasets with same number of outputs.
90   return {/*input_range_dataset_params*/
91           {RangeDatasetParam{0, 3, 1}, RangeDatasetParam{10, 13, 1}},
92           /*expected_outputs*/
93           {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {0}),
94            DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {10}),
95            DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {1}),
96            DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {11}),
97            DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {2}),
98            DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {12})},
99           /*breakpoints*/ {0, 1, 4}};
100 }
101 
TestCase2()102 TestParam TestCase2() {
103   // Test case 2: the input datasets with different number of outputs.
104   return {/*input_range_dataset_params*/
105           {RangeDatasetParam{0, 3, 1}, RangeDatasetParam{10, 15, 1}},
106           /*expected_outputs*/
107           {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {0}),
108            DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {10}),
109            DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {1}),
110            DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {11}),
111            DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {2}),
112            DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {12})},
113           /*breakpoints*/ {0, 1, 4}};
114 }
115 
116 class ZipDatasetOpTestHelper : public ZipDatasetOpTest {
117  public:
~ZipDatasetOpTestHelper()118   ~ZipDatasetOpTestHelper() override {
119     if (dataset_) dataset_->Unref();
120   }
121 
122  protected:
CreateDatasetFromTestCase(const TestParam & test_case)123   Status CreateDatasetFromTestCase(const TestParam &test_case) {
124     std::vector<Tensor> range_dataset_tensors;
125     range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
126     TF_RETURN_IF_ERROR(CreateRangeDatasetTensors(
127         test_case.input_range_dataset_params, &range_dataset_tensors));
128     gtl::InlinedVector<TensorValue, 4> inputs;
129     inputs.reserve(range_dataset_tensors.size());
130     for (auto &tensor : range_dataset_tensors) {
131       inputs.emplace_back(&tensor);
132     }
133     int num_tensors_per_slice = test_case.input_range_dataset_params.size();
134     TF_RETURN_IF_ERROR(CreateZipDatasetKernel({DT_INT64},
135                                               {{num_tensors_per_slice}},
136                                               inputs.size(), &dataset_kernel_));
137     TF_RETURN_IF_ERROR(CreateZipDatasetContext(dataset_kernel_.get(), &inputs,
138                                                &dataset_kernel_ctx_));
139     TF_RETURN_IF_ERROR(CreateDataset(dataset_kernel_.get(),
140                                      dataset_kernel_ctx_.get(), &dataset_));
141     return Status::OK();
142   }
143 
CreateIteratorFromTestCase(const TestParam & test_case)144   Status CreateIteratorFromTestCase(const TestParam &test_case) {
145     TF_RETURN_IF_ERROR(CreateDatasetFromTestCase(test_case));
146     TF_RETURN_IF_ERROR(
147         CreateIteratorContext(dataset_kernel_ctx_.get(), &iterator_ctx_));
148     TF_RETURN_IF_ERROR(
149         dataset_->MakeIterator(iterator_ctx_.get(), "Iterator", &iterator_));
150     return Status::OK();
151   }
152 
153   std::unique_ptr<OpKernel> dataset_kernel_;
154   std::unique_ptr<OpKernelContext> dataset_kernel_ctx_;
155   DatasetBase *dataset_ = nullptr;  // owned by this class.
156   std::unique_ptr<IteratorContext> iterator_ctx_;
157   std::unique_ptr<IteratorBase> iterator_;
158 };
159 
160 class ParameterizedDatasetTest
161     : public ZipDatasetOpTestHelper,
162       public ::testing::WithParamInterface<TestParam> {};
163 
TEST_P(ParameterizedDatasetTest,GetNext)164 TEST_P(ParameterizedDatasetTest, GetNext) {
165   int thread_num = 2, cpu_num = 2;
166   TF_ASSERT_OK(InitThreadPool(thread_num));
167   TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
168   const TestParam &test_case = GetParam();
169   TF_ASSERT_OK(CreateIteratorFromTestCase(test_case));
170 
171   auto expected_outputs_it = test_case.expected_outputs.begin();
172   bool end_of_sequence = false;
173   std::vector<Tensor> out_tensors;
174   while (!end_of_sequence) {
175     TF_EXPECT_OK(iterator_->GetNext(iterator_ctx_.get(), &out_tensors,
176                                     &end_of_sequence));
177     if (!end_of_sequence) {
178       for (const auto &tensor : out_tensors) {
179         EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end());
180         TF_EXPECT_OK(ExpectEqual(tensor, *expected_outputs_it));
181         expected_outputs_it++;
182       }
183     }
184   }
185   EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end());
186 }
187 
TEST_F(ZipDatasetOpTestHelper,DatasetName)188 TEST_F(ZipDatasetOpTestHelper, DatasetName) {
189   int thread_num = 2, cpu_num = 2;
190   TF_ASSERT_OK(InitThreadPool(thread_num));
191   TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
192   TF_ASSERT_OK(CreateDatasetFromTestCase(TestCase1()));
193 
194   EXPECT_EQ(dataset_->type_string(), kOpName);
195 }
196 
TEST_P(ParameterizedDatasetTest,DatasetOutputDtypes)197 TEST_P(ParameterizedDatasetTest, DatasetOutputDtypes) {
198   int thread_num = 2, cpu_num = 2;
199   TF_ASSERT_OK(InitThreadPool(thread_num));
200   TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
201   const TestParam &test_case = GetParam();
202   int num_tensors_per_slice = test_case.input_range_dataset_params.size();
203   TF_ASSERT_OK(CreateDatasetFromTestCase(test_case));
204 
205   DataTypeVector expected_output_dtypes;
206   expected_output_dtypes.reserve(num_tensors_per_slice);
207   for (int i = 0; i < num_tensors_per_slice; ++i) {
208     expected_output_dtypes.emplace_back(test_case.expected_outputs[i].dtype());
209   }
210 
211   TF_EXPECT_OK(
212       VerifyTypesMatch(dataset_->output_dtypes(), expected_output_dtypes));
213 }
214 
TEST_P(ParameterizedDatasetTest,DatasetOutputShapes)215 TEST_P(ParameterizedDatasetTest, DatasetOutputShapes) {
216   int thread_num = 2, cpu_num = 2;
217   TF_ASSERT_OK(InitThreadPool(thread_num));
218   TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
219   const TestParam &test_case = GetParam();
220   int num_tensors_per_slice = test_case.input_range_dataset_params.size();
221   TF_ASSERT_OK(CreateDatasetFromTestCase(test_case));
222 
223   std::vector<PartialTensorShape> expected_output_shapes;
224   expected_output_shapes.reserve(num_tensors_per_slice);
225   for (int i = 0; i < num_tensors_per_slice; ++i) {
226     expected_output_shapes.emplace_back(test_case.expected_outputs[i].shape());
227   }
228 
229   TF_EXPECT_OK(VerifyShapesCompatible(dataset_->output_shapes(),
230                                       expected_output_shapes));
231 }
232 
TEST_P(ParameterizedDatasetTest,Cardinality)233 TEST_P(ParameterizedDatasetTest, Cardinality) {
234   int thread_num = 2, cpu_num = 2;
235   TF_ASSERT_OK(InitThreadPool(thread_num));
236   TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
237   const TestParam &test_case = GetParam();
238   int num_tensors_per_slice = test_case.input_range_dataset_params.size();
239   TF_ASSERT_OK(CreateDatasetFromTestCase(test_case));
240 
241   EXPECT_EQ(dataset_->Cardinality(),
242             test_case.expected_outputs.size() / num_tensors_per_slice);
243 }
244 
TEST_F(ZipDatasetOpTestHelper,DatasetSave)245 TEST_F(ZipDatasetOpTestHelper, DatasetSave) {
246   int thread_num = 2, cpu_num = 2;
247   TF_ASSERT_OK(InitThreadPool(thread_num));
248   TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
249   TF_ASSERT_OK(CreateDatasetFromTestCase(TestCase1()));
250 
251   std::unique_ptr<SerializationContext> serialization_ctx;
252   TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
253   VariantTensorData data;
254   VariantTensorDataWriter writer(&data);
255   TF_ASSERT_OK(dataset_->Save(serialization_ctx.get(), &writer));
256   TF_ASSERT_OK(writer.Flush());
257 }
258 
TEST_P(ParameterizedDatasetTest,IteratorOutputDtypes)259 TEST_P(ParameterizedDatasetTest, IteratorOutputDtypes) {
260   int thread_num = 2, cpu_num = 2;
261   TF_ASSERT_OK(InitThreadPool(thread_num));
262   TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
263   const TestParam &test_case = GetParam();
264   int num_tensors_per_slice = test_case.input_range_dataset_params.size();
265   TF_ASSERT_OK(CreateIteratorFromTestCase(test_case));
266 
267   DataTypeVector expected_output_dtypes;
268   expected_output_dtypes.reserve(num_tensors_per_slice);
269   for (int i = 0; i < num_tensors_per_slice; ++i) {
270     expected_output_dtypes.emplace_back(test_case.expected_outputs[i].dtype());
271   }
272 
273   TF_EXPECT_OK(
274       VerifyTypesMatch(iterator_->output_dtypes(), expected_output_dtypes));
275 }
276 
TEST_P(ParameterizedDatasetTest,IteratorOutputShapes)277 TEST_P(ParameterizedDatasetTest, IteratorOutputShapes) {
278   int thread_num = 2, cpu_num = 2;
279   TF_ASSERT_OK(InitThreadPool(thread_num));
280   TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
281   const TestParam &test_case = GetParam();
282   int num_tensors_per_slice = test_case.input_range_dataset_params.size();
283   TF_ASSERT_OK(CreateIteratorFromTestCase(test_case));
284 
285   std::vector<PartialTensorShape> expected_output_shapes;
286   expected_output_shapes.reserve(num_tensors_per_slice);
287   for (int i = 0; i < num_tensors_per_slice; ++i) {
288     expected_output_shapes.emplace_back(test_case.expected_outputs[i].shape());
289   }
290 
291   TF_EXPECT_OK(VerifyShapesCompatible(iterator_->output_shapes(),
292                                       expected_output_shapes));
293 }
294 
TEST_F(ZipDatasetOpTestHelper,IteratorOutputPrefix)295 TEST_F(ZipDatasetOpTestHelper, IteratorOutputPrefix) {
296   int thread_num = 2, cpu_num = 2;
297   TF_ASSERT_OK(InitThreadPool(thread_num));
298   TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
299   TF_ASSERT_OK(CreateIteratorFromTestCase(TestCase1()));
300   EXPECT_EQ(iterator_->prefix(), "Iterator::Zip");
301 }
302 
TEST_P(ParameterizedDatasetTest,Roundtrip)303 TEST_P(ParameterizedDatasetTest, Roundtrip) {
304   int thread_num = 2, cpu_num = 2;
305   TF_ASSERT_OK(InitThreadPool(thread_num));
306   TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
307   const TestParam &test_case = GetParam();
308   auto expected_outputs_it = test_case.expected_outputs.begin();
309   TF_ASSERT_OK(CreateIteratorFromTestCase(test_case));
310 
311   std::unique_ptr<SerializationContext> serialization_ctx;
312   TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
313 
314   bool end_of_sequence = false;
315   std::vector<Tensor> out_tensors;
316   int cur_iteration = 0;
317   for (int breakpoint : test_case.breakpoints) {
318     VariantTensorData data;
319     VariantTensorDataWriter writer(&data);
320     TF_EXPECT_OK(iterator_->Save(serialization_ctx.get(), &writer));
321     TF_EXPECT_OK(writer.Flush());
322     VariantTensorDataReader reader(&data);
323     TF_EXPECT_OK(iterator_->Restore(iterator_ctx_.get(), &reader));
324 
325     while (cur_iteration < breakpoint) {
326       TF_EXPECT_OK(iterator_->GetNext(iterator_ctx_.get(), &out_tensors,
327                                       &end_of_sequence));
328       if (!end_of_sequence) {
329         for (auto &tensor : out_tensors) {
330           EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end());
331           TF_EXPECT_OK(ExpectEqual(tensor, *expected_outputs_it));
332           expected_outputs_it++;
333         }
334       }
335       cur_iteration++;
336     }
337 
338     if (breakpoint >= dataset_->Cardinality()) {
339       EXPECT_TRUE(end_of_sequence);
340       EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end());
341     } else {
342       EXPECT_FALSE(end_of_sequence);
343     }
344   }
345 }
346 
347 INSTANTIATE_TEST_SUITE_P(
348     ZipDatasetOpTest, ParameterizedDatasetTest,
349     ::testing::ValuesIn(std::vector<TestParam>({TestCase1(), TestCase2()})));
350 
351 }  // namespace
352 }  // namespace data
353 }  // namespace tensorflow
354