1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_DATA_DATASET_TEST_BASE_H_ 17 #define TENSORFLOW_CORE_KERNELS_DATA_DATASET_TEST_BASE_H_ 18 19 #include <vector> 20 21 #include "tensorflow/core/framework/dataset.h" 22 #include "tensorflow/core/framework/function.h" 23 #include "tensorflow/core/framework/function_handle_cache.h" 24 #include "tensorflow/core/framework/function_testlib.h" 25 #include "tensorflow/core/framework/node_def_builder.h" 26 #include "tensorflow/core/framework/partial_tensor_shape.h" 27 #include "tensorflow/core/framework/variant.h" 28 #include "tensorflow/core/framework/variant_tensor_data.h" 29 #include "tensorflow/core/kernels/data/dataset_utils.h" 30 #include "tensorflow/core/kernels/data/iterator_ops.h" 31 #include "tensorflow/core/kernels/ops_testutil.h" 32 #include "tensorflow/core/platform/test.h" 33 #include "tensorflow/core/platform/types.h" 34 #include "tensorflow/core/util/ptr_util.h" 35 36 namespace tensorflow { 37 namespace data { 38 39 // Helpful functions to test Dataset op kernels. 40 class DatasetOpsTestBase : public ::testing::Test { 41 public: DatasetOpsTestBase()42 DatasetOpsTestBase() 43 : device_(DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")), 44 device_type_(DEVICE_CPU) { 45 allocator_ = device_->GetAllocator(AllocatorAttributes()); 46 } 47 ~DatasetOpsTestBase()48 ~DatasetOpsTestBase() {} 49 50 // The method validates whether the two tensors have the same shape, dtype, 51 // and value. 52 static Status ExpectEqual(const Tensor& a, const Tensor& b); 53 54 // Creates a tensor with the specified dtype, shape, and value. 55 template <typename T> CreateTensor(TensorShape input_shape,const gtl::ArraySlice<T> & input_data)56 static Tensor CreateTensor(TensorShape input_shape, 57 const gtl::ArraySlice<T>& input_data) { 58 Tensor tensor(DataTypeToEnum<T>::value, input_shape); 59 test::FillValues<T>(&tensor, input_data); 60 return tensor; 61 } 62 63 // Creates a new op kernel based on the node definition. 64 Status CreateOpKernel(const NodeDef& node_def, 65 std::unique_ptr<OpKernel>* op_kernel); 66 67 // Creates a new dataset. 68 Status CreateDataset(OpKernel* kernel, OpKernelContext* context, 69 DatasetBase** const dataset); 70 71 // Creates a new RangeDataset op kernel. `T` specifies the output dtype of the 72 // op kernel. 73 template <typename T> CreateRangeDatasetOpKernel(StringPiece node_name,std::unique_ptr<OpKernel> * range_op_kernel)74 Status CreateRangeDatasetOpKernel( 75 StringPiece node_name, std::unique_ptr<OpKernel>* range_op_kernel) { 76 DataTypeVector dtypes({tensorflow::DataTypeToEnum<T>::value}); 77 std::vector<PartialTensorShape> shapes({{}}); 78 NodeDef node_def = test::function::NDef( 79 node_name, "RangeDataset", {"start", "stop", "step"}, 80 {{"output_types", dtypes}, {"output_shapes", shapes}}); 81 82 TF_RETURN_IF_ERROR(CreateOpKernel(node_def, range_op_kernel)); 83 return Status::OK(); 84 } 85 86 // Creates a new RangeDataset dataset. `T` specifies the output dtype of the 87 // RangeDataset op kernel. 88 template <typename T> CreateRangeDataset(int64 start,int64 end,int64 step,StringPiece node_name,DatasetBase ** range_dataset)89 Status CreateRangeDataset(int64 start, int64 end, int64 step, 90 StringPiece node_name, 91 DatasetBase** range_dataset) { 92 std::unique_ptr<OpKernel> range_kernel; 93 TF_RETURN_IF_ERROR(CreateRangeDatasetOpKernel<T>(node_name, &range_kernel)); 94 gtl::InlinedVector<TensorValue, 4> range_inputs; 95 TF_RETURN_IF_ERROR(AddDatasetInputFromArray<int64>( 96 &range_inputs, range_kernel->input_types(), TensorShape({}), {start})); 97 TF_RETURN_IF_ERROR(AddDatasetInputFromArray<int64>( 98 &range_inputs, range_kernel->input_types(), TensorShape({}), {end})); 99 TF_RETURN_IF_ERROR(AddDatasetInputFromArray<int64>( 100 &range_inputs, range_kernel->input_types(), TensorShape({}), {step})); 101 std::unique_ptr<OpKernelContext> range_context; 102 TF_RETURN_IF_ERROR(CreateOpKernelContext(range_kernel.get(), &range_inputs, 103 &range_context)); 104 TF_RETURN_IF_ERROR(CheckOpKernelInput(*range_kernel, range_inputs)); 105 TF_RETURN_IF_ERROR(RunOpKernel(range_kernel.get(), range_context.get())); 106 TF_RETURN_IF_ERROR( 107 GetDatasetFromContext(range_context.get(), 0, range_dataset)); 108 return Status::OK(); 109 } 110 111 // Creates a new TensorSliceDataset op kernel. 112 Status CreateTensorSliceDatasetKernel( 113 StringPiece node_name, const DataTypeVector& dtypes, 114 const std::vector<PartialTensorShape>& shapes, 115 std::unique_ptr<OpKernel>* tensor_slice_dataset_kernel); 116 117 // Creates a new TensorSliceDataset. 118 Status CreateTensorSliceDataset(StringPiece node_name, 119 std::vector<Tensor>* const components, 120 DatasetBase** tensor_slice_dataset); 121 122 // Fetches the dataset from the operation context. 123 Status GetDatasetFromContext(OpKernelContext* context, int output_index, 124 DatasetBase** const dataset); 125 126 protected: 127 // Creates a thread pool for parallel tasks. 128 Status InitThreadPool(int thread_num); 129 130 // Initializes the runtime for computing the dataset operation and registers 131 // the input function definitions. `InitThreadPool()' needs to be called 132 // before this method if we want to run the tasks in parallel. 133 Status InitFunctionLibraryRuntime(const std::vector<FunctionDef>& flib, 134 int cpu_num); 135 136 // Runs an operation producing outputs. 137 Status RunOpKernel(OpKernel* op_kernel, OpKernelContext* context); 138 139 // Checks that the size of `inputs` matches the requirement of the op kernel. 140 Status CheckOpKernelInput(const OpKernel& kernel, 141 const gtl::InlinedVector<TensorValue, 4>& inputs); 142 143 // Creates a new context for running the dataset operation. 144 Status CreateOpKernelContext(OpKernel* kernel, 145 gtl::InlinedVector<TensorValue, 4>* inputs, 146 std::unique_ptr<OpKernelContext>* context); 147 148 // Creates a new iterator context for iterating the dataset. 149 Status CreateIteratorContext( 150 OpKernelContext* const op_context, 151 std::unique_ptr<IteratorContext>* iterator_context); 152 153 // Creates a new serialization context for serializing the dataset and 154 // iterator. 155 Status CreateSerializationContext( 156 std::unique_ptr<SerializationContext>* context); 157 158 // Adds an arrayslice of data into the input vector. `input_types` describes 159 // the required data type for each input tensor. `shape` and `data` describes 160 // the shape and values of the current input tensor. `T` specifies the dtype 161 // of the input data. 162 template <typename T> AddDatasetInputFromArray(gtl::InlinedVector<TensorValue,4> * inputs,DataTypeVector input_types,const TensorShape & shape,const gtl::ArraySlice<T> & data)163 Status AddDatasetInputFromArray(gtl::InlinedVector<TensorValue, 4>* inputs, 164 DataTypeVector input_types, 165 const TensorShape& shape, 166 const gtl::ArraySlice<T>& data) { 167 TF_RETURN_IF_ERROR( 168 AddDatasetInput(inputs, input_types, DataTypeToEnum<T>::v(), shape)); 169 test::FillValues<T>(inputs->back().tensor, data); 170 return Status::OK(); 171 } 172 173 private: 174 // Adds an empty tensor with the specified dtype and shape to the input 175 // vector. 176 Status AddDatasetInput(gtl::InlinedVector<TensorValue, 4>* inputs, 177 DataTypeVector input_types, DataType dtype, 178 const TensorShape& shape); 179 180 protected: 181 std::unique_ptr<Device> device_; 182 DeviceType device_type_; 183 Allocator* allocator_; // Owned by `AllocatorFactoryRegistry`. 184 std::vector<AllocatorAttributes> allocator_attrs_; 185 std::unique_ptr<ScopedStepContainer> step_container_; 186 187 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_; 188 FunctionLibraryRuntime* flr_; // Owned by `pflr_`. 189 std::unique_ptr<FunctionHandleCache> function_handle_cache_; 190 std::function<void(std::function<void()>)> runner_; 191 std::unique_ptr<DeviceMgr> device_mgr_; 192 std::unique_ptr<FunctionLibraryDefinition> lib_def_; 193 std::unique_ptr<OpKernelContext::Params> params_; 194 std::unique_ptr<checkpoint::TensorSliceReaderCacheWrapper> 195 slice_reader_cache_; 196 std::unique_ptr<thread::ThreadPool> thread_pool_; 197 std::vector<std::unique_ptr<Tensor>> tensors_; // Owns tensors. 198 mutex lock_for_refs_; // Used as the Mutex for inputs added as refs. 199 }; 200 201 } // namespace data 202 } // namespace tensorflow 203 204 #endif // TENSORFLOW_CORE_KERNELS_DATA_DATASET_TEST_BASE_H_ 205