1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/data/dataset_test_base.h"
17
18 namespace tensorflow {
19 namespace data {
20 namespace {
21
22 constexpr char kNodeName[] = "zip_dataset";
23 constexpr char kOpName[] = "ZipDataset";
24
25 struct RangeDatasetParam {
26 int64 start;
27 int64 end;
28 int64 step;
29 };
30
31 class ZipDatasetOpTest : public DatasetOpsTestBase {
32 protected:
33 // Creates `RangeDataset` variant tensors from the input vector of
34 // `RangeDatasetParam`.
CreateRangeDatasetTensors(const std::vector<RangeDatasetParam> & params,std::vector<Tensor> * const dataset_tensors)35 Status CreateRangeDatasetTensors(const std::vector<RangeDatasetParam> ¶ms,
36 std::vector<Tensor> *const dataset_tensors) {
37 for (int i = 0; i < params.size(); ++i) {
38 DatasetBase *range_dataset;
39 TF_RETURN_IF_ERROR(CreateRangeDataset<int64>(
40 params[i].start, params[i].end, params[i].step,
41 strings::StrCat("range_", i), &range_dataset));
42 Tensor dataset_tensor(DT_VARIANT, TensorShape({}));
43 TF_RETURN_IF_ERROR(
44 StoreDatasetInVariantTensor(range_dataset, &dataset_tensor));
45 dataset_tensors->emplace_back(std::move(dataset_tensor));
46 }
47 return Status::OK();
48 }
49
50 // Creates a new ZipDataset op kernel.
CreateZipDatasetKernel(const DataTypeVector & dtypes,const std::vector<PartialTensorShape> & output_shapes,int n,std::unique_ptr<OpKernel> * op_kernel)51 Status CreateZipDatasetKernel(
52 const DataTypeVector &dtypes,
53 const std::vector<PartialTensorShape> &output_shapes, int n,
54 std::unique_ptr<OpKernel> *op_kernel) {
55 std::vector<string> input_datasets;
56 input_datasets.reserve(n);
57 for (int i = 0; i < n; ++i) {
58 // Create the placeholder names for the input components of `ZipDataset`.
59 input_datasets.emplace_back(strings::StrCat("input_dataset_", i));
60 }
61 node_def_ = test::function::NDef(
62 kNodeName, kOpName, input_datasets,
63 {{"output_types", dtypes}, {"output_shapes", output_shapes}, {"N", n}});
64 TF_RETURN_IF_ERROR(CreateOpKernel(node_def_, op_kernel));
65 return Status::OK();
66 }
67
68 // Creates a new ZipDataset op kernel context.
CreateZipDatasetContext(OpKernel * const op_kernel,gtl::InlinedVector<TensorValue,4> * const inputs,std::unique_ptr<OpKernelContext> * context)69 Status CreateZipDatasetContext(
70 OpKernel *const op_kernel,
71 gtl::InlinedVector<TensorValue, 4> *const inputs,
72 std::unique_ptr<OpKernelContext> *context) {
73 TF_RETURN_IF_ERROR(CheckOpKernelInput(*op_kernel, *inputs));
74 TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context));
75 return Status::OK();
76 }
77
78 private:
79 NodeDef node_def_;
80 };
81
82 struct TestParam {
83 std::vector<RangeDatasetParam> input_range_dataset_params;
84 std::vector<Tensor> expected_outputs;
85 std::vector<int> breakpoints;
86 };
87
TestCase1()88 TestParam TestCase1() {
89 // Test case 1: the input datasets with same number of outputs.
90 return {/*input_range_dataset_params*/
91 {RangeDatasetParam{0, 3, 1}, RangeDatasetParam{10, 13, 1}},
92 /*expected_outputs*/
93 {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {0}),
94 DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {10}),
95 DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {1}),
96 DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {11}),
97 DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {2}),
98 DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {12})},
99 /*breakpoints*/ {0, 1, 4}};
100 }
101
TestCase2()102 TestParam TestCase2() {
103 // Test case 2: the input datasets with different number of outputs.
104 return {/*input_range_dataset_params*/
105 {RangeDatasetParam{0, 3, 1}, RangeDatasetParam{10, 15, 1}},
106 /*expected_outputs*/
107 {DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {0}),
108 DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {10}),
109 DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {1}),
110 DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {11}),
111 DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {2}),
112 DatasetOpsTestBase::CreateTensor<int64>(TensorShape{}, {12})},
113 /*breakpoints*/ {0, 1, 4}};
114 }
115
116 class ZipDatasetOpTestHelper : public ZipDatasetOpTest {
117 public:
~ZipDatasetOpTestHelper()118 ~ZipDatasetOpTestHelper() override {
119 if (dataset_) dataset_->Unref();
120 }
121
122 protected:
CreateDatasetFromTestCase(const TestParam & test_case)123 Status CreateDatasetFromTestCase(const TestParam &test_case) {
124 std::vector<Tensor> range_dataset_tensors;
125 range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
126 TF_RETURN_IF_ERROR(CreateRangeDatasetTensors(
127 test_case.input_range_dataset_params, &range_dataset_tensors));
128 gtl::InlinedVector<TensorValue, 4> inputs;
129 inputs.reserve(range_dataset_tensors.size());
130 for (auto &tensor : range_dataset_tensors) {
131 inputs.emplace_back(&tensor);
132 }
133 int num_tensors_per_slice = test_case.input_range_dataset_params.size();
134 TF_RETURN_IF_ERROR(CreateZipDatasetKernel({DT_INT64},
135 {{num_tensors_per_slice}},
136 inputs.size(), &dataset_kernel_));
137 TF_RETURN_IF_ERROR(CreateZipDatasetContext(dataset_kernel_.get(), &inputs,
138 &dataset_kernel_ctx_));
139 TF_RETURN_IF_ERROR(CreateDataset(dataset_kernel_.get(),
140 dataset_kernel_ctx_.get(), &dataset_));
141 return Status::OK();
142 }
143
CreateIteratorFromTestCase(const TestParam & test_case)144 Status CreateIteratorFromTestCase(const TestParam &test_case) {
145 TF_RETURN_IF_ERROR(CreateDatasetFromTestCase(test_case));
146 TF_RETURN_IF_ERROR(
147 CreateIteratorContext(dataset_kernel_ctx_.get(), &iterator_ctx_));
148 TF_RETURN_IF_ERROR(
149 dataset_->MakeIterator(iterator_ctx_.get(), "Iterator", &iterator_));
150 return Status::OK();
151 }
152
153 std::unique_ptr<OpKernel> dataset_kernel_;
154 std::unique_ptr<OpKernelContext> dataset_kernel_ctx_;
155 DatasetBase *dataset_ = nullptr; // owned by this class.
156 std::unique_ptr<IteratorContext> iterator_ctx_;
157 std::unique_ptr<IteratorBase> iterator_;
158 };
159
160 class ParameterizedDatasetTest
161 : public ZipDatasetOpTestHelper,
162 public ::testing::WithParamInterface<TestParam> {};
163
TEST_P(ParameterizedDatasetTest,GetNext)164 TEST_P(ParameterizedDatasetTest, GetNext) {
165 int thread_num = 2, cpu_num = 2;
166 TF_ASSERT_OK(InitThreadPool(thread_num));
167 TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
168 const TestParam &test_case = GetParam();
169 TF_ASSERT_OK(CreateIteratorFromTestCase(test_case));
170
171 auto expected_outputs_it = test_case.expected_outputs.begin();
172 bool end_of_sequence = false;
173 std::vector<Tensor> out_tensors;
174 while (!end_of_sequence) {
175 TF_EXPECT_OK(iterator_->GetNext(iterator_ctx_.get(), &out_tensors,
176 &end_of_sequence));
177 if (!end_of_sequence) {
178 for (const auto &tensor : out_tensors) {
179 EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end());
180 TF_EXPECT_OK(ExpectEqual(tensor, *expected_outputs_it));
181 expected_outputs_it++;
182 }
183 }
184 }
185 EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end());
186 }
187
TEST_F(ZipDatasetOpTestHelper,DatasetName)188 TEST_F(ZipDatasetOpTestHelper, DatasetName) {
189 int thread_num = 2, cpu_num = 2;
190 TF_ASSERT_OK(InitThreadPool(thread_num));
191 TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
192 TF_ASSERT_OK(CreateDatasetFromTestCase(TestCase1()));
193
194 EXPECT_EQ(dataset_->type_string(), kOpName);
195 }
196
TEST_P(ParameterizedDatasetTest,DatasetOutputDtypes)197 TEST_P(ParameterizedDatasetTest, DatasetOutputDtypes) {
198 int thread_num = 2, cpu_num = 2;
199 TF_ASSERT_OK(InitThreadPool(thread_num));
200 TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
201 const TestParam &test_case = GetParam();
202 int num_tensors_per_slice = test_case.input_range_dataset_params.size();
203 TF_ASSERT_OK(CreateDatasetFromTestCase(test_case));
204
205 DataTypeVector expected_output_dtypes;
206 expected_output_dtypes.reserve(num_tensors_per_slice);
207 for (int i = 0; i < num_tensors_per_slice; ++i) {
208 expected_output_dtypes.emplace_back(test_case.expected_outputs[i].dtype());
209 }
210
211 TF_EXPECT_OK(
212 VerifyTypesMatch(dataset_->output_dtypes(), expected_output_dtypes));
213 }
214
TEST_P(ParameterizedDatasetTest,DatasetOutputShapes)215 TEST_P(ParameterizedDatasetTest, DatasetOutputShapes) {
216 int thread_num = 2, cpu_num = 2;
217 TF_ASSERT_OK(InitThreadPool(thread_num));
218 TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
219 const TestParam &test_case = GetParam();
220 int num_tensors_per_slice = test_case.input_range_dataset_params.size();
221 TF_ASSERT_OK(CreateDatasetFromTestCase(test_case));
222
223 std::vector<PartialTensorShape> expected_output_shapes;
224 expected_output_shapes.reserve(num_tensors_per_slice);
225 for (int i = 0; i < num_tensors_per_slice; ++i) {
226 expected_output_shapes.emplace_back(test_case.expected_outputs[i].shape());
227 }
228
229 TF_EXPECT_OK(VerifyShapesCompatible(dataset_->output_shapes(),
230 expected_output_shapes));
231 }
232
TEST_P(ParameterizedDatasetTest,Cardinality)233 TEST_P(ParameterizedDatasetTest, Cardinality) {
234 int thread_num = 2, cpu_num = 2;
235 TF_ASSERT_OK(InitThreadPool(thread_num));
236 TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
237 const TestParam &test_case = GetParam();
238 int num_tensors_per_slice = test_case.input_range_dataset_params.size();
239 TF_ASSERT_OK(CreateDatasetFromTestCase(test_case));
240
241 EXPECT_EQ(dataset_->Cardinality(),
242 test_case.expected_outputs.size() / num_tensors_per_slice);
243 }
244
TEST_F(ZipDatasetOpTestHelper,DatasetSave)245 TEST_F(ZipDatasetOpTestHelper, DatasetSave) {
246 int thread_num = 2, cpu_num = 2;
247 TF_ASSERT_OK(InitThreadPool(thread_num));
248 TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
249 TF_ASSERT_OK(CreateDatasetFromTestCase(TestCase1()));
250
251 std::unique_ptr<SerializationContext> serialization_ctx;
252 TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
253 VariantTensorData data;
254 VariantTensorDataWriter writer(&data);
255 TF_ASSERT_OK(dataset_->Save(serialization_ctx.get(), &writer));
256 TF_ASSERT_OK(writer.Flush());
257 }
258
TEST_P(ParameterizedDatasetTest,IteratorOutputDtypes)259 TEST_P(ParameterizedDatasetTest, IteratorOutputDtypes) {
260 int thread_num = 2, cpu_num = 2;
261 TF_ASSERT_OK(InitThreadPool(thread_num));
262 TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
263 const TestParam &test_case = GetParam();
264 int num_tensors_per_slice = test_case.input_range_dataset_params.size();
265 TF_ASSERT_OK(CreateIteratorFromTestCase(test_case));
266
267 DataTypeVector expected_output_dtypes;
268 expected_output_dtypes.reserve(num_tensors_per_slice);
269 for (int i = 0; i < num_tensors_per_slice; ++i) {
270 expected_output_dtypes.emplace_back(test_case.expected_outputs[i].dtype());
271 }
272
273 TF_EXPECT_OK(
274 VerifyTypesMatch(iterator_->output_dtypes(), expected_output_dtypes));
275 }
276
TEST_P(ParameterizedDatasetTest,IteratorOutputShapes)277 TEST_P(ParameterizedDatasetTest, IteratorOutputShapes) {
278 int thread_num = 2, cpu_num = 2;
279 TF_ASSERT_OK(InitThreadPool(thread_num));
280 TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
281 const TestParam &test_case = GetParam();
282 int num_tensors_per_slice = test_case.input_range_dataset_params.size();
283 TF_ASSERT_OK(CreateIteratorFromTestCase(test_case));
284
285 std::vector<PartialTensorShape> expected_output_shapes;
286 expected_output_shapes.reserve(num_tensors_per_slice);
287 for (int i = 0; i < num_tensors_per_slice; ++i) {
288 expected_output_shapes.emplace_back(test_case.expected_outputs[i].shape());
289 }
290
291 TF_EXPECT_OK(VerifyShapesCompatible(iterator_->output_shapes(),
292 expected_output_shapes));
293 }
294
TEST_F(ZipDatasetOpTestHelper,IteratorOutputPrefix)295 TEST_F(ZipDatasetOpTestHelper, IteratorOutputPrefix) {
296 int thread_num = 2, cpu_num = 2;
297 TF_ASSERT_OK(InitThreadPool(thread_num));
298 TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
299 TF_ASSERT_OK(CreateIteratorFromTestCase(TestCase1()));
300 EXPECT_EQ(iterator_->prefix(), "Iterator::Zip");
301 }
302
TEST_P(ParameterizedDatasetTest,Roundtrip)303 TEST_P(ParameterizedDatasetTest, Roundtrip) {
304 int thread_num = 2, cpu_num = 2;
305 TF_ASSERT_OK(InitThreadPool(thread_num));
306 TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
307 const TestParam &test_case = GetParam();
308 auto expected_outputs_it = test_case.expected_outputs.begin();
309 TF_ASSERT_OK(CreateIteratorFromTestCase(test_case));
310
311 std::unique_ptr<SerializationContext> serialization_ctx;
312 TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
313
314 bool end_of_sequence = false;
315 std::vector<Tensor> out_tensors;
316 int cur_iteration = 0;
317 for (int breakpoint : test_case.breakpoints) {
318 VariantTensorData data;
319 VariantTensorDataWriter writer(&data);
320 TF_EXPECT_OK(iterator_->Save(serialization_ctx.get(), &writer));
321 TF_EXPECT_OK(writer.Flush());
322 VariantTensorDataReader reader(&data);
323 TF_EXPECT_OK(iterator_->Restore(iterator_ctx_.get(), &reader));
324
325 while (cur_iteration < breakpoint) {
326 TF_EXPECT_OK(iterator_->GetNext(iterator_ctx_.get(), &out_tensors,
327 &end_of_sequence));
328 if (!end_of_sequence) {
329 for (auto &tensor : out_tensors) {
330 EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end());
331 TF_EXPECT_OK(ExpectEqual(tensor, *expected_outputs_it));
332 expected_outputs_it++;
333 }
334 }
335 cur_iteration++;
336 }
337
338 if (breakpoint >= dataset_->Cardinality()) {
339 EXPECT_TRUE(end_of_sequence);
340 EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end());
341 } else {
342 EXPECT_FALSE(end_of_sequence);
343 }
344 }
345 }
346
347 INSTANTIATE_TEST_SUITE_P(
348 ZipDatasetOpTest, ParameterizedDatasetTest,
349 ::testing::ValuesIn(std::vector<TestParam>({TestCase1(), TestCase2()})));
350
351 } // namespace
352 } // namespace data
353 } // namespace tensorflow
354