1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/data/dataset_test_base.h"
17 
18 namespace tensorflow {
19 namespace data {
20 
ExpectEqual(const Tensor & a,const Tensor & b)21 Status DatasetOpsTestBase::ExpectEqual(const Tensor& a, const Tensor& b) {
22   EXPECT_EQ(a.dtype(), b.dtype());
23   switch (a.dtype()) {
24 #define CASE(type)                       \
25   case DataTypeToEnum<type>::value:      \
26     test::ExpectTensorEqual<type>(a, b); \
27     break;
28     TF_CALL_NUMBER_TYPES(CASE);
29     TF_CALL_string(CASE);
30     // TODO(feihugis): figure out how to support variant tensors.
31 #undef CASE
32     default:
33       return errors::Internal("Unsupported dtype", a.dtype());
34   }
35   return Status::OK();
36 }
37 
CreateTensorSliceDatasetKernel(StringPiece node_name,const DataTypeVector & dtypes,const std::vector<PartialTensorShape> & shapes,std::unique_ptr<OpKernel> * tensor_slice_dataset_kernel)38 Status DatasetOpsTestBase::CreateTensorSliceDatasetKernel(
39     StringPiece node_name, const DataTypeVector& dtypes,
40     const std::vector<PartialTensorShape>& shapes,
41     std::unique_ptr<OpKernel>* tensor_slice_dataset_kernel) {
42   std::vector<string> components;
43   components.reserve(dtypes.size());
44   for (int i = 0; i < dtypes.size(); ++i) {
45     // Create the placeholder names for the input components of
46     // `TensorSliceDataset`.
47     components.emplace_back(strings::StrCat("component_", i));
48   }
49   NodeDef node_def = test::function::NDef(
50       node_name, "TensorSliceDataset", components,
51       {{"Toutput_types", dtypes}, {"output_shapes", shapes}});
52   TF_RETURN_IF_ERROR(CreateOpKernel(node_def, tensor_slice_dataset_kernel));
53   return Status::OK();
54 }
55 
CreateTensorSliceDataset(StringPiece node_name,std::vector<Tensor> * const components,DatasetBase ** tensor_slice_dataset)56 Status DatasetOpsTestBase::CreateTensorSliceDataset(
57     StringPiece node_name, std::vector<Tensor>* const components,
58     DatasetBase** tensor_slice_dataset) {
59   std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
60   DataTypeVector dtypes;
61   dtypes.reserve(components->size());
62   std::vector<PartialTensorShape> shapes;
63   shapes.reserve(components->size());
64   for (const auto& t : *components) {
65     dtypes.push_back(t.dtype());
66     gtl::InlinedVector<int64, 4> partial_dim_sizes;
67     for (int i = 1; i < t.dims(); ++i) {
68       partial_dim_sizes.push_back(t.dim_size(i));
69     }
70     shapes.emplace_back(std::move(partial_dim_sizes));
71   }
72   TF_RETURN_IF_ERROR(CreateTensorSliceDatasetKernel(
73       node_name, dtypes, shapes, &tensor_slice_dataset_kernel));
74   gtl::InlinedVector<TensorValue, 4> inputs;
75   for (auto& tensor : *components) {
76     inputs.emplace_back(&tensor);
77   }
78   TF_RETURN_IF_ERROR(CheckOpKernelInput(*tensor_slice_dataset_kernel, inputs));
79   std::unique_ptr<OpKernelContext> context;
80   TF_RETURN_IF_ERROR(CreateOpKernelContext(tensor_slice_dataset_kernel.get(),
81                                            &inputs, &context));
82   TF_RETURN_IF_ERROR(
83       RunOpKernel(tensor_slice_dataset_kernel.get(), context.get()));
84   TF_RETURN_IF_ERROR(
85       GetDatasetFromContext(context.get(), 0, tensor_slice_dataset));
86   return Status::OK();
87 }
88 
CreateOpKernel(const NodeDef & node_def,std::unique_ptr<OpKernel> * op_kernel)89 Status DatasetOpsTestBase::CreateOpKernel(
90     const NodeDef& node_def, std::unique_ptr<OpKernel>* op_kernel) {
91   OpKernel* kernel;
92   TF_RETURN_IF_ERROR(tensorflow::CreateOpKernel(device_type_, device_.get(),
93                                                 allocator_, flr_, node_def,
94                                                 TF_GRAPH_DEF_VERSION, &kernel));
95   op_kernel->reset(kernel);
96   return Status::OK();
97 }
98 
CreateDataset(OpKernel * kernel,OpKernelContext * context,DatasetBase ** const dataset)99 Status DatasetOpsTestBase::CreateDataset(OpKernel* kernel,
100                                          OpKernelContext* context,
101                                          DatasetBase** const dataset) {
102   TF_RETURN_IF_ERROR(RunOpKernel(kernel, context));
103   // Assume that DatasetOp has only one output.
104   DCHECK_EQ(context->num_outputs(), 1);
105   TF_RETURN_IF_ERROR(GetDatasetFromContext(context, 0, dataset));
106   return Status::OK();
107 }
108 
CreateIteratorContext(OpKernelContext * const op_context,std::unique_ptr<IteratorContext> * iterator_context)109 Status DatasetOpsTestBase::CreateIteratorContext(
110     OpKernelContext* const op_context,
111     std::unique_ptr<IteratorContext>* iterator_context) {
112   IteratorContext::Params params(op_context);
113   function_handle_cache_ = absl::make_unique<FunctionHandleCache>(flr_);
114   params.function_handle_cache = function_handle_cache_.get();
115   *iterator_context = absl::make_unique<IteratorContext>(params);
116   return Status::OK();
117 }
118 
GetDatasetFromContext(OpKernelContext * context,int output_index,DatasetBase ** const dataset)119 Status DatasetOpsTestBase::GetDatasetFromContext(OpKernelContext* context,
120                                                  int output_index,
121                                                  DatasetBase** const dataset) {
122   Tensor* output = context->mutable_output(output_index);
123   Status status = GetDatasetFromVariantTensor(*output, dataset);
124   (*dataset)->Ref();
125   return status;
126 }
127 
InitThreadPool(int thread_num)128 Status DatasetOpsTestBase::InitThreadPool(int thread_num) {
129   if (thread_num < 1) {
130     return errors::InvalidArgument(
131         "The `thread_num` argument should be positive but got: ", thread_num);
132   }
133   thread_pool_ = absl::make_unique<thread::ThreadPool>(
134       Env::Default(), ThreadOptions(), "inter_op", thread_num);
135   return Status::OK();
136 }
137 
InitFunctionLibraryRuntime(const std::vector<FunctionDef> & flib,int cpu_num)138 Status DatasetOpsTestBase::InitFunctionLibraryRuntime(
139     const std::vector<FunctionDef>& flib, int cpu_num) {
140   if (cpu_num < 1) {
141     return errors::InvalidArgument(
142         "The `cpu_num` argument should be positive but got: ", cpu_num);
143   }
144   SessionOptions options;
145   auto* device_count = options.config.mutable_device_count();
146   device_count->insert({"CPU", cpu_num});
147   std::vector<std::unique_ptr<Device>> devices;
148   TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
149       options, "/job:localhost/replica:0/task:0", &devices));
150   device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
151 
152   FunctionDefLibrary proto;
153   for (const auto& fdef : flib) *(proto.add_function()) = fdef;
154   lib_def_ =
155       absl::make_unique<FunctionLibraryDefinition>(OpRegistry::Global(), proto);
156 
157   OptimizerOptions opts;
158   pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
159       device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
160       opts, thread_pool_.get(), nullptr /* cluster_flr */);
161   flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
162   if (thread_pool_ == nullptr) {
163     runner_ = [](std::function<void()> fn) { fn(); };
164   } else {
165     runner_ = [this](std::function<void()> fn) {
166       thread_pool_->Schedule(std::move(fn));
167     };
168   }
169   return Status::OK();
170 }
171 
RunOpKernel(OpKernel * op_kernel,OpKernelContext * context)172 Status DatasetOpsTestBase::RunOpKernel(OpKernel* op_kernel,
173                                        OpKernelContext* context) {
174   device_->Compute(op_kernel, context);
175   return context->status();
176 }
177 
CreateOpKernelContext(OpKernel * kernel,gtl::InlinedVector<TensorValue,4> * inputs,std::unique_ptr<OpKernelContext> * context)178 Status DatasetOpsTestBase::CreateOpKernelContext(
179     OpKernel* kernel, gtl::InlinedVector<TensorValue, 4>* inputs,
180     std::unique_ptr<OpKernelContext>* context) {
181   params_ = absl::make_unique<OpKernelContext::Params>();
182   params_->device = device_.get();
183   params_->resource_manager = device_->resource_manager();
184   params_->frame_iter = FrameAndIter(0, 0);
185   params_->inputs = inputs;
186   params_->op_kernel = kernel;
187   params_->function_library = flr_;
188   params_->runner = &runner_;
189   step_container_ =
190       absl::make_unique<ScopedStepContainer>(0, [](const string&) {});
191   params_->step_container = step_container_.get();
192   checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper;
193   slice_reader_cache_ =
194       absl::make_unique<checkpoint::TensorSliceReaderCacheWrapper>();
195   params_->slice_reader_cache = slice_reader_cache_.get();
196 
197   // Set the allocator attributes for the outputs.
198   allocator_attrs_.clear();
199   for (int index = 0; index < params_->op_kernel->num_outputs(); index++) {
200     AllocatorAttributes attr;
201     const bool on_host =
202         (params_->op_kernel->output_memory_types()[index] == HOST_MEMORY);
203     attr.set_on_host(on_host);
204     allocator_attrs_.emplace_back(attr);
205   }
206   params_->output_attr_array = gtl::vector_as_array(&allocator_attrs_);
207 
208   *context = absl::make_unique<OpKernelContext>(params_.get());
209   return Status::OK();
210 }
211 
CreateSerializationContext(std::unique_ptr<SerializationContext> * context)212 Status DatasetOpsTestBase::CreateSerializationContext(
213     std::unique_ptr<SerializationContext>* context) {
214   SerializationContext::Params params;
215   params.flib_def = lib_def_.get();
216   *context = absl::make_unique<SerializationContext>(params);
217   return Status::OK();
218 }
219 
CheckOpKernelInput(const OpKernel & kernel,const gtl::InlinedVector<TensorValue,4> & inputs)220 Status DatasetOpsTestBase::CheckOpKernelInput(
221     const OpKernel& kernel, const gtl::InlinedVector<TensorValue, 4>& inputs) {
222   if (kernel.input_types().size() != inputs.size()) {
223     return errors::Internal("The number of input elements should be ",
224                             kernel.input_types().size(),
225                             ", but got: ", inputs.size());
226   }
227   return Status::OK();
228 }
229 
AddDatasetInput(gtl::InlinedVector<TensorValue,4> * inputs,DataTypeVector input_types,DataType dtype,const TensorShape & shape)230 Status DatasetOpsTestBase::AddDatasetInput(
231     gtl::InlinedVector<TensorValue, 4>* inputs, DataTypeVector input_types,
232     DataType dtype, const TensorShape& shape) {
233   if (input_types.size() < inputs->size()) {
234     return errors::InvalidArgument("Adding more inputs than types: ",
235                                    inputs->size(), " vs. ", input_types.size());
236   }
237   bool is_ref = IsRefType(input_types[inputs->size()]);
238   std::unique_ptr<Tensor> input =
239       absl::make_unique<Tensor>(allocator_, dtype, shape);
240 
241   if (is_ref) {
242     DataType expected_dtype = RemoveRefType(input_types[inputs->size()]);
243     if (expected_dtype != dtype) {
244       return errors::InvalidArgument("The input data type is ", dtype,
245                                      " , but expected: ", expected_dtype);
246     }
247     inputs->push_back({&lock_for_refs_, input.get()});
248   } else {
249     if (input_types[inputs->size()] != dtype) {
250       return errors::InvalidArgument(
251           "The input data type is ", dtype,
252           " , but expected: ", input_types[inputs->size()]);
253     }
254     inputs->push_back({nullptr, input.get()});
255   }
256 
257   // TODO(jsimsa): Figure out how to avoid using a member variable to garbage
258   // collect the inputs.
259   tensors_.push_back(std::move(input));
260 
261   return Status::OK();
262 }
263 
264 }  // namespace data
265 }  // namespace tensorflow
266