1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/data/dataset_test_base.h"
17
18 namespace tensorflow {
19 namespace data {
20
ExpectEqual(const Tensor & a,const Tensor & b)21 Status DatasetOpsTestBase::ExpectEqual(const Tensor& a, const Tensor& b) {
22 EXPECT_EQ(a.dtype(), b.dtype());
23 switch (a.dtype()) {
24 #define CASE(type) \
25 case DataTypeToEnum<type>::value: \
26 test::ExpectTensorEqual<type>(a, b); \
27 break;
28 TF_CALL_NUMBER_TYPES(CASE);
29 TF_CALL_string(CASE);
30 // TODO(feihugis): figure out how to support variant tensors.
31 #undef CASE
32 default:
33 return errors::Internal("Unsupported dtype", a.dtype());
34 }
35 return Status::OK();
36 }
37
CreateTensorSliceDatasetKernel(StringPiece node_name,const DataTypeVector & dtypes,const std::vector<PartialTensorShape> & shapes,std::unique_ptr<OpKernel> * tensor_slice_dataset_kernel)38 Status DatasetOpsTestBase::CreateTensorSliceDatasetKernel(
39 StringPiece node_name, const DataTypeVector& dtypes,
40 const std::vector<PartialTensorShape>& shapes,
41 std::unique_ptr<OpKernel>* tensor_slice_dataset_kernel) {
42 std::vector<string> components;
43 components.reserve(dtypes.size());
44 for (int i = 0; i < dtypes.size(); ++i) {
45 // Create the placeholder names for the input components of
46 // `TensorSliceDataset`.
47 components.emplace_back(strings::StrCat("component_", i));
48 }
49 NodeDef node_def = test::function::NDef(
50 node_name, "TensorSliceDataset", components,
51 {{"Toutput_types", dtypes}, {"output_shapes", shapes}});
52 TF_RETURN_IF_ERROR(CreateOpKernel(node_def, tensor_slice_dataset_kernel));
53 return Status::OK();
54 }
55
CreateTensorSliceDataset(StringPiece node_name,std::vector<Tensor> * const components,DatasetBase ** tensor_slice_dataset)56 Status DatasetOpsTestBase::CreateTensorSliceDataset(
57 StringPiece node_name, std::vector<Tensor>* const components,
58 DatasetBase** tensor_slice_dataset) {
59 std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
60 DataTypeVector dtypes;
61 dtypes.reserve(components->size());
62 std::vector<PartialTensorShape> shapes;
63 shapes.reserve(components->size());
64 for (const auto& t : *components) {
65 dtypes.push_back(t.dtype());
66 gtl::InlinedVector<int64, 4> partial_dim_sizes;
67 for (int i = 1; i < t.dims(); ++i) {
68 partial_dim_sizes.push_back(t.dim_size(i));
69 }
70 shapes.emplace_back(std::move(partial_dim_sizes));
71 }
72 TF_RETURN_IF_ERROR(CreateTensorSliceDatasetKernel(
73 node_name, dtypes, shapes, &tensor_slice_dataset_kernel));
74 gtl::InlinedVector<TensorValue, 4> inputs;
75 for (auto& tensor : *components) {
76 inputs.emplace_back(&tensor);
77 }
78 TF_RETURN_IF_ERROR(CheckOpKernelInput(*tensor_slice_dataset_kernel, inputs));
79 std::unique_ptr<OpKernelContext> context;
80 TF_RETURN_IF_ERROR(CreateOpKernelContext(tensor_slice_dataset_kernel.get(),
81 &inputs, &context));
82 TF_RETURN_IF_ERROR(
83 RunOpKernel(tensor_slice_dataset_kernel.get(), context.get()));
84 TF_RETURN_IF_ERROR(
85 GetDatasetFromContext(context.get(), 0, tensor_slice_dataset));
86 return Status::OK();
87 }
88
CreateOpKernel(const NodeDef & node_def,std::unique_ptr<OpKernel> * op_kernel)89 Status DatasetOpsTestBase::CreateOpKernel(
90 const NodeDef& node_def, std::unique_ptr<OpKernel>* op_kernel) {
91 OpKernel* kernel;
92 TF_RETURN_IF_ERROR(tensorflow::CreateOpKernel(device_type_, device_.get(),
93 allocator_, flr_, node_def,
94 TF_GRAPH_DEF_VERSION, &kernel));
95 op_kernel->reset(kernel);
96 return Status::OK();
97 }
98
CreateDataset(OpKernel * kernel,OpKernelContext * context,DatasetBase ** const dataset)99 Status DatasetOpsTestBase::CreateDataset(OpKernel* kernel,
100 OpKernelContext* context,
101 DatasetBase** const dataset) {
102 TF_RETURN_IF_ERROR(RunOpKernel(kernel, context));
103 // Assume that DatasetOp has only one output.
104 DCHECK_EQ(context->num_outputs(), 1);
105 TF_RETURN_IF_ERROR(GetDatasetFromContext(context, 0, dataset));
106 return Status::OK();
107 }
108
CreateIteratorContext(OpKernelContext * const op_context,std::unique_ptr<IteratorContext> * iterator_context)109 Status DatasetOpsTestBase::CreateIteratorContext(
110 OpKernelContext* const op_context,
111 std::unique_ptr<IteratorContext>* iterator_context) {
112 IteratorContext::Params params(op_context);
113 function_handle_cache_ = absl::make_unique<FunctionHandleCache>(flr_);
114 params.function_handle_cache = function_handle_cache_.get();
115 *iterator_context = absl::make_unique<IteratorContext>(params);
116 return Status::OK();
117 }
118
GetDatasetFromContext(OpKernelContext * context,int output_index,DatasetBase ** const dataset)119 Status DatasetOpsTestBase::GetDatasetFromContext(OpKernelContext* context,
120 int output_index,
121 DatasetBase** const dataset) {
122 Tensor* output = context->mutable_output(output_index);
123 Status status = GetDatasetFromVariantTensor(*output, dataset);
124 (*dataset)->Ref();
125 return status;
126 }
127
InitThreadPool(int thread_num)128 Status DatasetOpsTestBase::InitThreadPool(int thread_num) {
129 if (thread_num < 1) {
130 return errors::InvalidArgument(
131 "The `thread_num` argument should be positive but got: ", thread_num);
132 }
133 thread_pool_ = absl::make_unique<thread::ThreadPool>(
134 Env::Default(), ThreadOptions(), "inter_op", thread_num);
135 return Status::OK();
136 }
137
InitFunctionLibraryRuntime(const std::vector<FunctionDef> & flib,int cpu_num)138 Status DatasetOpsTestBase::InitFunctionLibraryRuntime(
139 const std::vector<FunctionDef>& flib, int cpu_num) {
140 if (cpu_num < 1) {
141 return errors::InvalidArgument(
142 "The `cpu_num` argument should be positive but got: ", cpu_num);
143 }
144 SessionOptions options;
145 auto* device_count = options.config.mutable_device_count();
146 device_count->insert({"CPU", cpu_num});
147 std::vector<std::unique_ptr<Device>> devices;
148 TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
149 options, "/job:localhost/replica:0/task:0", &devices));
150 device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
151
152 FunctionDefLibrary proto;
153 for (const auto& fdef : flib) *(proto.add_function()) = fdef;
154 lib_def_ =
155 absl::make_unique<FunctionLibraryDefinition>(OpRegistry::Global(), proto);
156
157 OptimizerOptions opts;
158 pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
159 device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
160 opts, thread_pool_.get(), nullptr /* cluster_flr */);
161 flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
162 if (thread_pool_ == nullptr) {
163 runner_ = [](std::function<void()> fn) { fn(); };
164 } else {
165 runner_ = [this](std::function<void()> fn) {
166 thread_pool_->Schedule(std::move(fn));
167 };
168 }
169 return Status::OK();
170 }
171
RunOpKernel(OpKernel * op_kernel,OpKernelContext * context)172 Status DatasetOpsTestBase::RunOpKernel(OpKernel* op_kernel,
173 OpKernelContext* context) {
174 device_->Compute(op_kernel, context);
175 return context->status();
176 }
177
CreateOpKernelContext(OpKernel * kernel,gtl::InlinedVector<TensorValue,4> * inputs,std::unique_ptr<OpKernelContext> * context)178 Status DatasetOpsTestBase::CreateOpKernelContext(
179 OpKernel* kernel, gtl::InlinedVector<TensorValue, 4>* inputs,
180 std::unique_ptr<OpKernelContext>* context) {
181 params_ = absl::make_unique<OpKernelContext::Params>();
182 params_->device = device_.get();
183 params_->resource_manager = device_->resource_manager();
184 params_->frame_iter = FrameAndIter(0, 0);
185 params_->inputs = inputs;
186 params_->op_kernel = kernel;
187 params_->function_library = flr_;
188 params_->runner = &runner_;
189 step_container_ =
190 absl::make_unique<ScopedStepContainer>(0, [](const string&) {});
191 params_->step_container = step_container_.get();
192 checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper;
193 slice_reader_cache_ =
194 absl::make_unique<checkpoint::TensorSliceReaderCacheWrapper>();
195 params_->slice_reader_cache = slice_reader_cache_.get();
196
197 // Set the allocator attributes for the outputs.
198 allocator_attrs_.clear();
199 for (int index = 0; index < params_->op_kernel->num_outputs(); index++) {
200 AllocatorAttributes attr;
201 const bool on_host =
202 (params_->op_kernel->output_memory_types()[index] == HOST_MEMORY);
203 attr.set_on_host(on_host);
204 allocator_attrs_.emplace_back(attr);
205 }
206 params_->output_attr_array = gtl::vector_as_array(&allocator_attrs_);
207
208 *context = absl::make_unique<OpKernelContext>(params_.get());
209 return Status::OK();
210 }
211
CreateSerializationContext(std::unique_ptr<SerializationContext> * context)212 Status DatasetOpsTestBase::CreateSerializationContext(
213 std::unique_ptr<SerializationContext>* context) {
214 SerializationContext::Params params;
215 params.flib_def = lib_def_.get();
216 *context = absl::make_unique<SerializationContext>(params);
217 return Status::OK();
218 }
219
CheckOpKernelInput(const OpKernel & kernel,const gtl::InlinedVector<TensorValue,4> & inputs)220 Status DatasetOpsTestBase::CheckOpKernelInput(
221 const OpKernel& kernel, const gtl::InlinedVector<TensorValue, 4>& inputs) {
222 if (kernel.input_types().size() != inputs.size()) {
223 return errors::Internal("The number of input elements should be ",
224 kernel.input_types().size(),
225 ", but got: ", inputs.size());
226 }
227 return Status::OK();
228 }
229
AddDatasetInput(gtl::InlinedVector<TensorValue,4> * inputs,DataTypeVector input_types,DataType dtype,const TensorShape & shape)230 Status DatasetOpsTestBase::AddDatasetInput(
231 gtl::InlinedVector<TensorValue, 4>* inputs, DataTypeVector input_types,
232 DataType dtype, const TensorShape& shape) {
233 if (input_types.size() < inputs->size()) {
234 return errors::InvalidArgument("Adding more inputs than types: ",
235 inputs->size(), " vs. ", input_types.size());
236 }
237 bool is_ref = IsRefType(input_types[inputs->size()]);
238 std::unique_ptr<Tensor> input =
239 absl::make_unique<Tensor>(allocator_, dtype, shape);
240
241 if (is_ref) {
242 DataType expected_dtype = RemoveRefType(input_types[inputs->size()]);
243 if (expected_dtype != dtype) {
244 return errors::InvalidArgument("The input data type is ", dtype,
245 " , but expected: ", expected_dtype);
246 }
247 inputs->push_back({&lock_for_refs_, input.get()});
248 } else {
249 if (input_types[inputs->size()] != dtype) {
250 return errors::InvalidArgument(
251 "The input data type is ", dtype,
252 " , but expected: ", input_types[inputs->size()]);
253 }
254 inputs->push_back({nullptr, input.get()});
255 }
256
257 // TODO(jsimsa): Figure out how to avoid using a member variable to garbage
258 // collect the inputs.
259 tensors_.push_back(std::move(input));
260
261 return Status::OK();
262 }
263
264 } // namespace data
265 } // namespace tensorflow
266