1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_DATA_DATASET_TEST_BASE_H_
17 #define TENSORFLOW_CORE_KERNELS_DATA_DATASET_TEST_BASE_H_
18 
19 #include <vector>
20 
21 #include "tensorflow/core/framework/dataset.h"
22 #include "tensorflow/core/framework/function.h"
23 #include "tensorflow/core/framework/function_handle_cache.h"
24 #include "tensorflow/core/framework/function_testlib.h"
25 #include "tensorflow/core/framework/node_def_builder.h"
26 #include "tensorflow/core/framework/partial_tensor_shape.h"
27 #include "tensorflow/core/framework/variant.h"
28 #include "tensorflow/core/framework/variant_tensor_data.h"
29 #include "tensorflow/core/kernels/data/dataset_utils.h"
30 #include "tensorflow/core/kernels/data/iterator_ops.h"
31 #include "tensorflow/core/kernels/ops_testutil.h"
32 #include "tensorflow/core/platform/test.h"
33 #include "tensorflow/core/platform/types.h"
34 #include "tensorflow/core/util/ptr_util.h"
35 
36 namespace tensorflow {
37 namespace data {
38 
39 // Helpful functions to test Dataset op kernels.
40 class DatasetOpsTestBase : public ::testing::Test {
41  public:
DatasetOpsTestBase()42   DatasetOpsTestBase()
43       : device_(DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")),
44         device_type_(DEVICE_CPU) {
45     allocator_ = device_->GetAllocator(AllocatorAttributes());
46   }
47 
~DatasetOpsTestBase()48   ~DatasetOpsTestBase() {}
49 
50   // The method validates whether the two tensors have the same shape, dtype,
51   // and value.
52   static Status ExpectEqual(const Tensor& a, const Tensor& b);
53 
54   // Creates a tensor with the specified dtype, shape, and value.
55   template <typename T>
CreateTensor(TensorShape input_shape,const gtl::ArraySlice<T> & input_data)56   static Tensor CreateTensor(TensorShape input_shape,
57                              const gtl::ArraySlice<T>& input_data) {
58     Tensor tensor(DataTypeToEnum<T>::value, input_shape);
59     test::FillValues<T>(&tensor, input_data);
60     return tensor;
61   }
62 
63   // Creates a new op kernel based on the node definition.
64   Status CreateOpKernel(const NodeDef& node_def,
65                         std::unique_ptr<OpKernel>* op_kernel);
66 
67   // Creates a new dataset.
68   Status CreateDataset(OpKernel* kernel, OpKernelContext* context,
69                        DatasetBase** const dataset);
70 
71   // Creates a new RangeDataset op kernel. `T` specifies the output dtype of the
72   // op kernel.
73   template <typename T>
CreateRangeDatasetOpKernel(StringPiece node_name,std::unique_ptr<OpKernel> * range_op_kernel)74   Status CreateRangeDatasetOpKernel(
75       StringPiece node_name, std::unique_ptr<OpKernel>* range_op_kernel) {
76     DataTypeVector dtypes({tensorflow::DataTypeToEnum<T>::value});
77     std::vector<PartialTensorShape> shapes({{}});
78     NodeDef node_def = test::function::NDef(
79         node_name, "RangeDataset", {"start", "stop", "step"},
80         {{"output_types", dtypes}, {"output_shapes", shapes}});
81 
82     TF_RETURN_IF_ERROR(CreateOpKernel(node_def, range_op_kernel));
83     return Status::OK();
84   }
85 
86   // Creates a new RangeDataset dataset. `T` specifies the output dtype of the
87   // RangeDataset op kernel.
88   template <typename T>
CreateRangeDataset(int64 start,int64 end,int64 step,StringPiece node_name,DatasetBase ** range_dataset)89   Status CreateRangeDataset(int64 start, int64 end, int64 step,
90                             StringPiece node_name,
91                             DatasetBase** range_dataset) {
92     std::unique_ptr<OpKernel> range_kernel;
93     TF_RETURN_IF_ERROR(CreateRangeDatasetOpKernel<T>(node_name, &range_kernel));
94     gtl::InlinedVector<TensorValue, 4> range_inputs;
95     TF_RETURN_IF_ERROR(AddDatasetInputFromArray<int64>(
96         &range_inputs, range_kernel->input_types(), TensorShape({}), {start}));
97     TF_RETURN_IF_ERROR(AddDatasetInputFromArray<int64>(
98         &range_inputs, range_kernel->input_types(), TensorShape({}), {end}));
99     TF_RETURN_IF_ERROR(AddDatasetInputFromArray<int64>(
100         &range_inputs, range_kernel->input_types(), TensorShape({}), {step}));
101     std::unique_ptr<OpKernelContext> range_context;
102     TF_RETURN_IF_ERROR(CreateOpKernelContext(range_kernel.get(), &range_inputs,
103                                              &range_context));
104     TF_RETURN_IF_ERROR(CheckOpKernelInput(*range_kernel, range_inputs));
105     TF_RETURN_IF_ERROR(RunOpKernel(range_kernel.get(), range_context.get()));
106     TF_RETURN_IF_ERROR(
107         GetDatasetFromContext(range_context.get(), 0, range_dataset));
108     return Status::OK();
109   }
110 
111   // Creates a new TensorSliceDataset op kernel.
112   Status CreateTensorSliceDatasetKernel(
113       StringPiece node_name, const DataTypeVector& dtypes,
114       const std::vector<PartialTensorShape>& shapes,
115       std::unique_ptr<OpKernel>* tensor_slice_dataset_kernel);
116 
117   // Creates a new TensorSliceDataset.
118   Status CreateTensorSliceDataset(StringPiece node_name,
119                                   std::vector<Tensor>* const components,
120                                   DatasetBase** tensor_slice_dataset);
121 
122   // Fetches the dataset from the operation context.
123   Status GetDatasetFromContext(OpKernelContext* context, int output_index,
124                                DatasetBase** const dataset);
125 
126  protected:
127   // Creates a thread pool for parallel tasks.
128   Status InitThreadPool(int thread_num);
129 
130   // Initializes the runtime for computing the dataset operation and registers
131   // the input function definitions. `InitThreadPool()' needs to be called
132   // before this method if we want to run the tasks in parallel.
133   Status InitFunctionLibraryRuntime(const std::vector<FunctionDef>& flib,
134                                     int cpu_num);
135 
136   // Runs an operation producing outputs.
137   Status RunOpKernel(OpKernel* op_kernel, OpKernelContext* context);
138 
139   // Checks that the size of `inputs` matches the requirement of the op kernel.
140   Status CheckOpKernelInput(const OpKernel& kernel,
141                             const gtl::InlinedVector<TensorValue, 4>& inputs);
142 
143   // Creates a new context for running the dataset operation.
144   Status CreateOpKernelContext(OpKernel* kernel,
145                                gtl::InlinedVector<TensorValue, 4>* inputs,
146                                std::unique_ptr<OpKernelContext>* context);
147 
148   // Creates a new iterator context for iterating the dataset.
149   Status CreateIteratorContext(
150       OpKernelContext* const op_context,
151       std::unique_ptr<IteratorContext>* iterator_context);
152 
153   // Creates a new serialization context for serializing the dataset and
154   // iterator.
155   Status CreateSerializationContext(
156       std::unique_ptr<SerializationContext>* context);
157 
158   // Adds an arrayslice of data into the input vector. `input_types` describes
159   // the required data type for each input tensor. `shape` and `data` describes
160   // the shape and values of the current input tensor. `T` specifies the dtype
161   // of the input data.
162   template <typename T>
AddDatasetInputFromArray(gtl::InlinedVector<TensorValue,4> * inputs,DataTypeVector input_types,const TensorShape & shape,const gtl::ArraySlice<T> & data)163   Status AddDatasetInputFromArray(gtl::InlinedVector<TensorValue, 4>* inputs,
164                                   DataTypeVector input_types,
165                                   const TensorShape& shape,
166                                   const gtl::ArraySlice<T>& data) {
167     TF_RETURN_IF_ERROR(
168         AddDatasetInput(inputs, input_types, DataTypeToEnum<T>::v(), shape));
169     test::FillValues<T>(inputs->back().tensor, data);
170     return Status::OK();
171   }
172 
173  private:
174   // Adds an empty tensor with the specified dtype and shape to the input
175   // vector.
176   Status AddDatasetInput(gtl::InlinedVector<TensorValue, 4>* inputs,
177                          DataTypeVector input_types, DataType dtype,
178                          const TensorShape& shape);
179 
180  protected:
181   std::unique_ptr<Device> device_;
182   DeviceType device_type_;
183   Allocator* allocator_;  // Owned by `AllocatorFactoryRegistry`.
184   std::vector<AllocatorAttributes> allocator_attrs_;
185   std::unique_ptr<ScopedStepContainer> step_container_;
186 
187   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
188   FunctionLibraryRuntime* flr_;  // Owned by `pflr_`.
189   std::unique_ptr<FunctionHandleCache> function_handle_cache_;
190   std::function<void(std::function<void()>)> runner_;
191   std::unique_ptr<DeviceMgr> device_mgr_;
192   std::unique_ptr<FunctionLibraryDefinition> lib_def_;
193   std::unique_ptr<OpKernelContext::Params> params_;
194   std::unique_ptr<checkpoint::TensorSliceReaderCacheWrapper>
195       slice_reader_cache_;
196   std::unique_ptr<thread::ThreadPool> thread_pool_;
197   std::vector<std::unique_ptr<Tensor>> tensors_;  // Owns tensors.
198   mutex lock_for_refs_;  // Used as the Mutex for inputs added as refs.
199 };
200 
201 }  // namespace data
202 }  // namespace tensorflow
203 
204 #endif  // TENSORFLOW_CORE_KERNELS_DATA_DATASET_TEST_BASE_H_
205