1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/data/dataset_test_base.h"
17 
18 #include <algorithm>
19 #include <complex>
20 #include <functional>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/string_view.h"
28 #include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint"
29 #include "tensorflow/core/common_runtime/device.h"
30 #include "tensorflow/core/common_runtime/device_factory.h"
31 #include "tensorflow/core/common_runtime/device_mgr.h"
32 #include "tensorflow/core/common_runtime/executor.h"
33 #include "tensorflow/core/common_runtime/graph_constructor.h"
34 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
35 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
36 #include "tensorflow/core/framework/allocator.h"
37 #include "tensorflow/core/framework/cancellation.h"
38 #include "tensorflow/core/framework/control_flow.h"
39 #include "tensorflow/core/framework/dataset.h"
40 #include "tensorflow/core/framework/function.h"
41 #include "tensorflow/core/framework/function.pb.h"
42 #include "tensorflow/core/framework/function_handle_cache.h"
43 #include "tensorflow/core/framework/function_testlib.h"
44 #include "tensorflow/core/framework/node_def.pb.h"
45 #include "tensorflow/core/framework/numeric_types.h"
46 #include "tensorflow/core/framework/op.h"
47 #include "tensorflow/core/framework/op_def.pb.h"
48 #include "tensorflow/core/framework/op_kernel.h"
49 #include "tensorflow/core/framework/register_types.h"
50 #include "tensorflow/core/framework/rendezvous.h"
51 #include "tensorflow/core/framework/resource_mgr.h"
52 #include "tensorflow/core/framework/tensor.h"
53 #include "tensorflow/core/framework/tensor_shape.h"
54 #include "tensorflow/core/framework/types.h"
55 #include "tensorflow/core/framework/types.pb.h"
56 #include "tensorflow/core/framework/variant_tensor_data.h"
57 #include "tensorflow/core/framework/versions.pb.h"
58 #include "tensorflow/core/graph/graph.h"
59 #include "tensorflow/core/kernels/data/batch_dataset_op.h"
60 #include "tensorflow/core/kernels/data/concatenate_dataset_op.h"
61 #include "tensorflow/core/kernels/data/dataset_utils.h"
62 #include "tensorflow/core/kernels/data/map_dataset_op.h"
63 #include "tensorflow/core/kernels/data/name_utils.h"
64 #include "tensorflow/core/kernels/data/range_dataset_op.h"
65 #include "tensorflow/core/kernels/data/split_utils.h"
66 #include "tensorflow/core/kernels/data/take_dataset_op.h"
67 #include "tensorflow/core/kernels/data/tensor_slice_dataset_op.h"
68 #include "tensorflow/core/lib/core/status_test_util.h"
69 #include "tensorflow/core/lib/gtl/inlined_vector.h"
70 #include "tensorflow/core/lib/io/record_writer.h"
71 #include "tensorflow/core/lib/io/zlib_compression_options.h"
72 #include "tensorflow/core/lib/io/zlib_outputbuffer.h"
73 #include "tensorflow/core/platform/bfloat16.h"
74 #include "tensorflow/core/platform/env.h"
75 #include "tensorflow/core/platform/errors.h"
76 #include "tensorflow/core/platform/file_system.h"
77 #include "tensorflow/core/platform/logging.h"
78 #include "tensorflow/core/platform/status.h"
79 #include "tensorflow/core/platform/test.h"
80 #include "tensorflow/core/platform/threadpool.h"
81 #include "tensorflow/core/platform/tstring.h"
82 #include "tensorflow/core/platform/types.h"
83 #include "tensorflow/core/protobuf/config.pb.h"
84 #include "tensorflow/core/public/session_options.h"
85 #include "tensorflow/core/public/version.h"
86 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
87 
88 namespace tensorflow {
89 namespace data {
90 
ToString(CompressionType compression_type)91 string ToString(CompressionType compression_type) {
92   switch (compression_type) {
93     case CompressionType::ZLIB:
94       return "ZLIB";
95     case CompressionType::GZIP:
96       return "GZIP";
97     case CompressionType::RAW:
98       return "RAW";
99     case CompressionType::UNCOMPRESSED:
100       return "";
101   }
102 }
103 
GetZlibCompressionOptions(CompressionType compression_type)104 io::ZlibCompressionOptions GetZlibCompressionOptions(
105     CompressionType compression_type) {
106   switch (compression_type) {
107     case CompressionType::ZLIB:
108       return io::ZlibCompressionOptions::DEFAULT();
109     case CompressionType::GZIP:
110       return io::ZlibCompressionOptions::GZIP();
111     case CompressionType::RAW:
112       return io::ZlibCompressionOptions::RAW();
113     case CompressionType::UNCOMPRESSED:
114       LOG(WARNING) << "ZlibCompressionOptions does not have an option for "
115                    << ToString(compression_type);
116       return io::ZlibCompressionOptions::DEFAULT();
117   }
118 }
119 
WriteDataToFile(const string & filename,const char * data)120 Status WriteDataToFile(const string& filename, const char* data) {
121   return WriteDataToFile(filename, data, CompressionParams());
122 }
123 
WriteDataToFile(const string & filename,const char * data,const CompressionParams & params)124 Status WriteDataToFile(const string& filename, const char* data,
125                        const CompressionParams& params) {
126   Env* env = Env::Default();
127   std::unique_ptr<WritableFile> file_writer;
128   TF_RETURN_IF_ERROR(env->NewWritableFile(filename, &file_writer));
129   if (params.compression_type == CompressionType::UNCOMPRESSED) {
130     TF_RETURN_IF_ERROR(file_writer->Append(data));
131   } else if (params.compression_type == CompressionType::ZLIB ||
132              params.compression_type == CompressionType::GZIP ||
133              params.compression_type == CompressionType::RAW) {
134     auto zlib_compression_options =
135         GetZlibCompressionOptions(params.compression_type);
136     io::ZlibOutputBuffer out(file_writer.get(), params.input_buffer_size,
137                              params.output_buffer_size,
138                              zlib_compression_options);
139     TF_RETURN_IF_ERROR(out.Init());
140     TF_RETURN_IF_ERROR(out.Append(data));
141     TF_RETURN_IF_ERROR(out.Flush());
142     TF_RETURN_IF_ERROR(out.Close());
143   } else {
144     return tensorflow::errors::InvalidArgument(
145         "Unsupported compression_type: ", ToString(params.compression_type));
146   }
147 
148   TF_RETURN_IF_ERROR(file_writer->Flush());
149   TF_RETURN_IF_ERROR(file_writer->Close());
150 
151   return Status::OK();
152 }
153 
WriteDataToTFRecordFile(const string & filename,const std::vector<absl::string_view> & records,const CompressionParams & params)154 Status WriteDataToTFRecordFile(const string& filename,
155                                const std::vector<absl::string_view>& records,
156                                const CompressionParams& params) {
157   Env* env = Env::Default();
158   std::unique_ptr<WritableFile> file_writer;
159   TF_RETURN_IF_ERROR(env->NewWritableFile(filename, &file_writer));
160   auto options = io::RecordWriterOptions::CreateRecordWriterOptions(
161       ToString(params.compression_type));
162   options.zlib_options.input_buffer_size = params.input_buffer_size;
163   io::RecordWriter record_writer(file_writer.get(), options);
164   for (const auto& record : records) {
165     TF_RETURN_IF_ERROR(record_writer.WriteRecord(record));
166   }
167   TF_RETURN_IF_ERROR(record_writer.Flush());
168   TF_RETURN_IF_ERROR(record_writer.Close());
169   TF_RETURN_IF_ERROR(file_writer->Flush());
170   TF_RETURN_IF_ERROR(file_writer->Close());
171   return Status::OK();
172 }
173 
174 template <typename T>
IsEqual(const Tensor & t1,const Tensor & t2)175 Status IsEqual(const Tensor& t1, const Tensor& t2) {
176   if (t1.dtype() != t2.dtype()) {
177     return tensorflow::errors::Internal(
178         "Two tensors have different dtypes: ", DataTypeString(t1.dtype()),
179         " vs. ", DataTypeString(t2.dtype()));
180   }
181   if (!t1.IsSameSize(t2)) {
182     return tensorflow::errors::Internal(
183         "Two tensors have different shapes: ", t1.shape().DebugString(),
184         " vs. ", t2.shape().DebugString());
185   }
186 
187   auto flat_t1 = t1.flat<T>();
188   auto flat_t2 = t2.flat<T>();
189   auto length = flat_t1.size();
190 
191   for (int i = 0; i < length; ++i) {
192     if (flat_t1(i) != flat_t2(i)) {
193       return tensorflow::errors::Internal(
194           "Two tensors have different values "
195           "at [",
196           i, "]: ", flat_t1(i), " vs. ", flat_t2(i));
197     }
198   }
199   return Status::OK();
200 }
201 
DatasetOpsTestBase()202 DatasetOpsTestBase::DatasetOpsTestBase()
203     : device_(DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")),
204       device_type_(DEVICE_CPU),
205       cpu_num_(kDefaultCPUNum),
206       thread_num_(kDefaultThreadNum) {
207   allocator_ = device_->GetAllocator(AllocatorAttributes());
208 }
209 
~DatasetOpsTestBase()210 DatasetOpsTestBase::~DatasetOpsTestBase() {
211   if (dataset_) {
212     dataset_->Unref();
213   }
214 }
215 
ExpectEqual(const Tensor & a,const Tensor & b)216 Status DatasetOpsTestBase::ExpectEqual(const Tensor& a, const Tensor& b) {
217   switch (a.dtype()) {
218 #define CASE(DT)                           \
219   case DataTypeToEnum<DT>::value:          \
220     TF_RETURN_IF_ERROR(IsEqual<DT>(a, b)); \
221     break;
222     TF_CALL_NUMBER_TYPES(CASE);
223     TF_CALL_tstring(CASE);
224     // TODO(feihugis): figure out how to support variant tensors.
225 #undef CASE
226     default:
227       return errors::Internal("Unsupported dtype: ", a.dtype());
228   }
229   return Status::OK();
230 }
231 
232 template <typename T>
compare(const Tensor & t1,const Tensor & t2)233 bool compare(const Tensor& t1, const Tensor& t2) {
234   auto flat_t1 = t1.flat<T>();
235   auto flat_t2 = t2.flat<T>();
236   auto length = std::min(flat_t1.size(), flat_t2.size());
237   for (int i = 0; i < length; ++i) {
238     if (flat_t1(i) < flat_t2(i)) return true;
239     if (flat_t1(i) > flat_t2(i)) return false;
240   }
241   return flat_t1.size() < length;
242 }
243 
ExpectEqual(std::vector<Tensor> produced_tensors,std::vector<Tensor> expected_tensors,bool compare_order)244 Status DatasetOpsTestBase::ExpectEqual(std::vector<Tensor> produced_tensors,
245                                        std::vector<Tensor> expected_tensors,
246                                        bool compare_order) {
247   if (produced_tensors.size() != expected_tensors.size()) {
248     return Status(tensorflow::errors::Internal(
249         "The two tensor vectors have different size (", produced_tensors.size(),
250         " v.s. ", expected_tensors.size(), ")"));
251   }
252 
253   if (produced_tensors.empty()) return Status::OK();
254   if (produced_tensors[0].dtype() != expected_tensors[0].dtype()) {
255     return Status(tensorflow::errors::Internal(
256         "The two tensor vectors have different dtypes (",
257         produced_tensors[0].dtype(), " v.s. ", expected_tensors[0].dtype(),
258         ")"));
259   }
260 
261   if (!compare_order) {
262     const DataType& dtype = produced_tensors[0].dtype();
263     switch (dtype) {
264 #define CASE(DT)                                                \
265   case DT:                                                      \
266     std::sort(produced_tensors.begin(), produced_tensors.end(), \
267               compare<EnumToDataType<DT>::Type>);               \
268     std::sort(expected_tensors.begin(), expected_tensors.end(), \
269               compare<EnumToDataType<DT>::Type>);               \
270     break;
271       CASE(DT_FLOAT);
272       CASE(DT_DOUBLE);
273       CASE(DT_INT32);
274       CASE(DT_UINT8);
275       CASE(DT_INT16);
276       CASE(DT_INT8);
277       CASE(DT_STRING);
278       CASE(DT_INT64);
279       CASE(DT_BOOL);
280       CASE(DT_QINT8);
281       CASE(DT_QUINT8);
282       CASE(DT_QINT32);
283       CASE(DT_QINT16);
284       CASE(DT_QUINT16);
285       CASE(DT_UINT16);
286       CASE(DT_HALF);
287       CASE(DT_UINT32);
288       CASE(DT_UINT64);
289       // TODO(feihugis): support other dtypes.
290 #undef CASE
291       default:
292         return errors::Internal("Unsupported dtype: ", dtype);
293     }
294   }
295 
296   for (int i = 0; i < produced_tensors.size(); ++i) {
297     TF_RETURN_IF_ERROR(DatasetOpsTestBase::ExpectEqual(produced_tensors[i],
298                                                        expected_tensors[i]));
299   }
300   return Status::OK();
301 }
302 
CreateOpKernel(const NodeDef & node_def,std::unique_ptr<OpKernel> * op_kernel)303 Status DatasetOpsTestBase::CreateOpKernel(
304     const NodeDef& node_def, std::unique_ptr<OpKernel>* op_kernel) {
305   OpKernel* kernel;
306   Status s;
307 
308   std::shared_ptr<const NodeProperties> props;
309   TF_RETURN_IF_ERROR(NodeProperties::CreateFromNodeDef(
310       node_def, flr_->GetFunctionLibraryDefinition(), &props));
311   TF_RETURN_IF_ERROR(tensorflow::CreateOpKernel(
312       device_type_, device_.get(), allocator_, flr_,
313       device_->resource_manager(), props, TF_GRAPH_DEF_VERSION, &kernel));
314   op_kernel->reset(kernel);
315   return Status::OK();
316 }
317 
CreateDatasetContext(OpKernel * const dateset_kernel,gtl::InlinedVector<TensorValue,4> * const inputs,std::unique_ptr<OpKernelContext::Params> * dataset_context_params,std::unique_ptr<OpKernelContext> * dataset_context)318 Status DatasetOpsTestBase::CreateDatasetContext(
319     OpKernel* const dateset_kernel,
320     gtl::InlinedVector<TensorValue, 4>* const inputs,
321     std::unique_ptr<OpKernelContext::Params>* dataset_context_params,
322     std::unique_ptr<OpKernelContext>* dataset_context) {
323   Status status = CheckOpKernelInput(*dateset_kernel, *inputs);
324   if (!status.ok()) {
325     VLOG(0) << "WARNING: " << status.ToString();
326   }
327   TF_RETURN_IF_ERROR(CreateOpKernelContext(
328       dateset_kernel, inputs, dataset_context_params, dataset_context));
329   return Status::OK();
330 }
331 
CreateDataset(OpKernel * kernel,OpKernelContext * context,DatasetBase ** const dataset)332 Status DatasetOpsTestBase::CreateDataset(OpKernel* kernel,
333                                          OpKernelContext* context,
334                                          DatasetBase** const dataset) {
335   TF_RETURN_IF_ERROR(RunOpKernel(kernel, context));
336   // Assume that DatasetOp has only one output.
337   DCHECK_EQ(context->num_outputs(), 1);
338   TF_RETURN_IF_ERROR(GetDatasetFromContext(context, 0, dataset));
339   return Status::OK();
340 }
341 
RestoreIterator(IteratorContext * ctx,IteratorStateReader * reader,const string & output_prefix,const DatasetBase & dataset,std::unique_ptr<IteratorBase> * iterator)342 Status DatasetOpsTestBase::RestoreIterator(
343     IteratorContext* ctx, IteratorStateReader* reader,
344     const string& output_prefix, const DatasetBase& dataset,
345     std::unique_ptr<IteratorBase>* iterator) {
346   return dataset.MakeIteratorFromCheckpoint(ctx, output_prefix, reader,
347                                             iterator);
348 }
349 
CreateIteratorContext(OpKernelContext * const op_context,std::unique_ptr<IteratorContext> * iterator_context)350 Status DatasetOpsTestBase::CreateIteratorContext(
351     OpKernelContext* const op_context,
352     std::unique_ptr<IteratorContext>* iterator_context) {
353   IteratorContext::Params params(op_context);
354   params.resource_mgr = op_context->resource_manager();
355   function_handle_cache_ = absl::make_unique<FunctionHandleCache>(flr_);
356   params.function_handle_cache = function_handle_cache_.get();
357   params.cancellation_manager = cancellation_manager_.get();
358   *iterator_context = absl::make_unique<IteratorContext>(params);
359   return Status::OK();
360 }
361 
GetDatasetFromContext(OpKernelContext * context,int output_index,DatasetBase ** const dataset)362 Status DatasetOpsTestBase::GetDatasetFromContext(OpKernelContext* context,
363                                                  int output_index,
364                                                  DatasetBase** const dataset) {
365   Tensor* output = context->mutable_output(output_index);
366   Status status = GetDatasetFromVariantTensor(*output, dataset);
367   (*dataset)->Ref();
368   return status;
369 }
370 
InitThreadPool(int thread_num)371 Status DatasetOpsTestBase::InitThreadPool(int thread_num) {
372   if (thread_num < 1) {
373     return errors::InvalidArgument(
374         "The `thread_num` argument should be positive but got: ", thread_num);
375   }
376   thread_pool_ = absl::make_unique<thread::ThreadPool>(
377       Env::Default(), ThreadOptions(), "test_thread_pool", thread_num);
378   return Status::OK();
379 }
380 
InitFunctionLibraryRuntime(const std::vector<FunctionDef> & flib,int cpu_num)381 Status DatasetOpsTestBase::InitFunctionLibraryRuntime(
382     const std::vector<FunctionDef>& flib, int cpu_num) {
383   if (cpu_num < 1) {
384     return errors::InvalidArgument(
385         "The `cpu_num` argument should be positive but got: ", cpu_num);
386   }
387   SessionOptions options;
388   auto* device_count = options.config.mutable_device_count();
389   device_count->insert({"CPU", cpu_num});
390   std::vector<std::unique_ptr<Device>> devices;
391   TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
392       options, "/job:localhost/replica:0/task:0", &devices));
393   device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
394   resource_mgr_ = absl::make_unique<ResourceMgr>("default_container");
395 
396   FunctionDefLibrary proto;
397   for (const auto& fdef : flib) *(proto.add_function()) = fdef;
398   lib_def_ =
399       absl::make_unique<FunctionLibraryDefinition>(OpRegistry::Global(), proto);
400 
401   OptimizerOptions opts;
402   pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
403       device_mgr_.get(), Env::Default(), /*config=*/nullptr,
404       TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, thread_pool_.get(),
405       /*parent=*/nullptr,
406       /*session_metadata=*/nullptr,
407       Rendezvous::Factory{
408           [](const int64, const DeviceMgr* device_mgr, Rendezvous** r) {
409             *r = new IntraProcessRendezvous(device_mgr);
410             return Status::OK();
411           }});
412   flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
413   if (thread_pool_ == nullptr) {
414     runner_ = [](const std::function<void()>& fn) { fn(); };
415   } else {
416     runner_ = [this](std::function<void()> fn) {
417       thread_pool_->Schedule(std::move(fn));
418     };
419   }
420   return Status::OK();
421 }
422 
RunOpKernel(OpKernel * op_kernel,OpKernelContext * context)423 Status DatasetOpsTestBase::RunOpKernel(OpKernel* op_kernel,
424                                        OpKernelContext* context) {
425   device_->Compute(op_kernel, context);
426   return context->status();
427 }
428 
RunFunction(const FunctionDef & fdef,test::function::Attrs attrs,const std::vector<Tensor> & args,const GraphConstructorOptions & graph_options,std::vector<Tensor * > rets)429 Status DatasetOpsTestBase::RunFunction(
430     const FunctionDef& fdef, test::function::Attrs attrs,
431     const std::vector<Tensor>& args,
432     const GraphConstructorOptions& graph_options, std::vector<Tensor*> rets) {
433   std::unique_ptr<Executor> exec;
434   InstantiationResult result;
435   auto GetOpSig = [](const string& op, const OpDef** sig) {
436     return OpRegistry::Global()->LookUpOpDef(op, sig);
437   };
438   TF_RETURN_IF_ERROR(InstantiateFunction(fdef, attrs, GetOpSig, &result));
439 
440   DataTypeVector arg_types = result.arg_types;
441   DataTypeVector ret_types = result.ret_types;
442 
443   std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
444   TF_RETURN_IF_ERROR(
445       ConvertNodeDefsToGraph(graph_options, result.nodes, g.get()));
446 
447   const int version = g->versions().producer();
448   LocalExecutorParams params;
449   params.function_library = flr_;
450   params.device = device_.get();
451   params.create_kernel = [this, version](
452                              const std::shared_ptr<const NodeProperties>& props,
453                              OpKernel** kernel) {
454     return CreateNonCachedKernel(device_.get(), this->flr_, props, version,
455                                  kernel);
456   };
457   params.delete_kernel = [](OpKernel* kernel) {
458     DeleteNonCachedKernel(kernel);
459   };
460 
461   Executor* cur_exec;
462   TF_RETURN_IF_ERROR(NewLocalExecutor(params, *g, &cur_exec));
463   exec.reset(cur_exec);
464   FunctionCallFrame frame(arg_types, ret_types);
465   TF_RETURN_IF_ERROR(frame.SetArgs(args));
466   Executor::Args exec_args;
467   exec_args.call_frame = &frame;
468   exec_args.runner = runner_;
469   TF_RETURN_IF_ERROR(exec->Run(exec_args));
470   std::vector<Tensor> computed;
471   TF_RETURN_IF_ERROR(frame.GetRetvals(&computed));
472   if (computed.size() != rets.size()) {
473     return errors::InvalidArgument(
474         "The result does not match the expected number of return outpus",
475         ". Expected: ", rets.size(), ". Actual: ", computed.size());
476   }
477   for (int i = 0; i < rets.size(); ++i) {
478     *(rets[i]) = computed[i];
479   }
480   return Status::OK();
481 }
482 
CreateOpKernelContext(OpKernel * kernel,gtl::InlinedVector<TensorValue,4> * inputs,std::unique_ptr<OpKernelContext> * context)483 Status DatasetOpsTestBase::CreateOpKernelContext(
484     OpKernel* kernel, gtl::InlinedVector<TensorValue, 4>* inputs,
485     std::unique_ptr<OpKernelContext>* context) {
486   return CreateOpKernelContext(kernel, inputs, &params_, context);
487 }
488 
CreateOpKernelContext(OpKernel * kernel,gtl::InlinedVector<TensorValue,4> * inputs,std::unique_ptr<OpKernelContext::Params> * context_params,std::unique_ptr<OpKernelContext> * context)489 Status DatasetOpsTestBase::CreateOpKernelContext(
490     OpKernel* kernel, gtl::InlinedVector<TensorValue, 4>* inputs,
491     std::unique_ptr<OpKernelContext::Params>* context_params,
492     std::unique_ptr<OpKernelContext>* context) {
493   auto params = absl::make_unique<OpKernelContext::Params>();
494   cancellation_manager_ = absl::make_unique<CancellationManager>();
495   params->cancellation_manager = cancellation_manager_.get();
496   params->device = device_.get();
497   params->frame_iter = FrameAndIter(0, 0);
498   params->function_library = flr_;
499   params->inputs = inputs;
500   params->op_kernel = kernel;
501   params->resource_manager = resource_mgr_.get();
502   params->runner = &runner_;
503   slice_reader_cache_ =
504       absl::make_unique<checkpoint::TensorSliceReaderCacheWrapper>();
505   params->slice_reader_cache = slice_reader_cache_.get();
506   step_container_ =
507       absl::make_unique<ScopedStepContainer>(0, [](const string&) {});
508   params->step_container = step_container_.get();
509 
510   // Set the allocator attributes for the outputs.
511   allocator_attrs_.clear();
512   for (int index = 0; index < params->op_kernel->num_outputs(); index++) {
513     AllocatorAttributes attr;
514     const bool on_host =
515         (params->op_kernel->output_memory_types()[index] == HOST_MEMORY);
516     attr.set_on_host(on_host);
517     allocator_attrs_.emplace_back(attr);
518   }
519   params->output_attr_array = allocator_attrs_.data();
520 
521   *context = absl::make_unique<OpKernelContext>(params.get());
522   *context_params = std::move(params);
523   return Status::OK();
524 }
525 
CreateSerializationContext(std::unique_ptr<SerializationContext> * context)526 Status DatasetOpsTestBase::CreateSerializationContext(
527     std::unique_ptr<SerializationContext>* context) {
528   *context =
529       absl::make_unique<SerializationContext>(SerializationContext::Params{});
530   return Status::OK();
531 }
532 
CheckOpKernelInput(const OpKernel & kernel,const gtl::InlinedVector<TensorValue,4> & inputs)533 Status DatasetOpsTestBase::CheckOpKernelInput(
534     const OpKernel& kernel, const gtl::InlinedVector<TensorValue, 4>& inputs) {
535   if (kernel.num_inputs() != inputs.size()) {
536     return errors::InvalidArgument("The number of input elements should be ",
537                                    kernel.num_inputs(),
538                                    ", but got: ", inputs.size());
539   }
540   return Status::OK();
541 }
542 
AddDatasetInput(gtl::InlinedVector<TensorValue,4> * inputs,DataTypeVector input_types,DataType dtype,const TensorShape & shape)543 Status DatasetOpsTestBase::AddDatasetInput(
544     gtl::InlinedVector<TensorValue, 4>* inputs, DataTypeVector input_types,
545     DataType dtype, const TensorShape& shape) {
546   if (input_types.size() < inputs->size()) {
547     return errors::InvalidArgument("Adding more inputs than types: ",
548                                    inputs->size(), " vs. ", input_types.size());
549   }
550   bool is_ref = IsRefType(input_types[inputs->size()]);
551   auto input = absl::make_unique<Tensor>(allocator_, dtype, shape);
552 
553   if (is_ref) {
554     DataType expected_dtype = RemoveRefType(input_types[inputs->size()]);
555     if (expected_dtype != dtype) {
556       return errors::InvalidArgument("The input data type is ", dtype,
557                                      " , but expected: ", expected_dtype);
558     }
559     inputs->push_back({&lock_for_refs_, input.get()});
560   } else {
561     if (input_types[inputs->size()] != dtype) {
562       return errors::InvalidArgument(
563           "The input data type is ", dtype,
564           " , but expected: ", input_types[inputs->size()]);
565     }
566     inputs->push_back({nullptr, input.get()});
567   }
568 
569   // TODO(jsimsa): Figure out how to avoid using a member variable to garbage
570   // collect the inputs.
571   tensors_.push_back(std::move(input));
572 
573   return Status::OK();
574 }
575 
CheckIteratorGetNext(const std::vector<Tensor> & expected_outputs,bool compare_order)576 Status DatasetOpsTestBase::CheckIteratorGetNext(
577     const std::vector<Tensor>& expected_outputs, bool compare_order) {
578   return CheckIteratorGetNext(iterator_.get(), iterator_ctx_.get(),
579                               expected_outputs, compare_order);
580 }
581 
CheckIteratorGetNext(TestIterator * iterator,const std::vector<Tensor> & expected_outputs,bool compare_order)582 Status DatasetOpsTestBase::CheckIteratorGetNext(
583     TestIterator* iterator, const std::vector<Tensor>& expected_outputs,
584     bool compare_order) {
585   return CheckIteratorGetNext(iterator->iterator(), iterator->ctx(),
586                               expected_outputs, compare_order);
587 }
588 
CheckIteratorGetNext(IteratorBase * iterator,IteratorContext * ctx,const std::vector<Tensor> & expected_outputs,bool compare_order)589 Status DatasetOpsTestBase::CheckIteratorGetNext(
590     IteratorBase* iterator, IteratorContext* ctx,
591     const std::vector<Tensor>& expected_outputs, bool compare_order) {
592   bool end_of_sequence = false;
593   std::vector<Tensor> out_tensors;
594   while (!end_of_sequence) {
595     std::vector<Tensor> next;
596     TF_RETURN_IF_ERROR(iterator->GetNext(ctx, &next, &end_of_sequence));
597     out_tensors.insert(out_tensors.end(), next.begin(), next.end());
598   }
599   // Call GetNext one more time to make sure it still reports
600   // end_of_sequence = True.
601   std::vector<Tensor> unused;
602   TF_RETURN_IF_ERROR(iterator->GetNext(ctx, &unused, &end_of_sequence));
603   EXPECT_TRUE(end_of_sequence);
604 
605   TF_EXPECT_OK(ExpectEqual(out_tensors, expected_outputs,
606                            /*compare_order=*/compare_order));
607   return Status::OK();
608 }
609 
CheckIteratorSkip(int num_to_skip,int expected_num_skipped,bool get_next,const std::vector<Tensor> & expected_outputs,bool compare_order)610 Status DatasetOpsTestBase::CheckIteratorSkip(
611     int num_to_skip, int expected_num_skipped, bool get_next,
612     const std::vector<Tensor>& expected_outputs, bool compare_order) {
613   IteratorBase* iterator = iterator_.get();
614   IteratorContext* ctx = iterator_ctx_.get();
615 
616   bool end_of_sequence = false;
617   int num_skipped = 0;
618   TF_RETURN_IF_ERROR(
619       iterator->Skip(ctx, num_to_skip, &end_of_sequence, &num_skipped));
620   EXPECT_TRUE(num_skipped == expected_num_skipped);
621   if (get_next) {
622     EXPECT_TRUE(!end_of_sequence);
623     std::vector<Tensor> out_tensors;
624     TF_RETURN_IF_ERROR(iterator->GetNext(ctx, &out_tensors, &end_of_sequence));
625     TF_EXPECT_OK(ExpectEqual(out_tensors, expected_outputs,
626                              /*compare_order=*/compare_order));
627   }
628   return Status::OK();
629 }
630 
CheckSplitProviderFullIteration(const DatasetParams & params,const std::vector<Tensor> & expected_outputs)631 Status DatasetOpsTestBase::CheckSplitProviderFullIteration(
632     const DatasetParams& params, const std::vector<Tensor>& expected_outputs) {
633   std::unique_ptr<TestDataset> dataset;
634   TF_RETURN_IF_ERROR(MakeDataset(params, &dataset));
635   std::unique_ptr<SplitProvider> split_provider;
636   TF_RETURN_IF_ERROR(dataset->dataset()->MakeSplitProvider(&split_provider));
637   std::unique_ptr<TestIterator> iterator;
638   TF_RETURN_IF_ERROR(
639       MakeIterator(params, *dataset, std::move(split_provider), &iterator));
640   TF_RETURN_IF_ERROR(CheckIteratorGetNext(iterator.get(), expected_outputs,
641                                           /*compare_order=*/true));
642   return Status::OK();
643 }
644 
CheckSplitProviderShardedIteration(const DatasetParams & params,int64 num_shards,int64 shard_index,const std::vector<Tensor> & expected_outputs)645 Status DatasetOpsTestBase::CheckSplitProviderShardedIteration(
646     const DatasetParams& params, int64 num_shards, int64 shard_index,
647     const std::vector<Tensor>& expected_outputs) {
648   std::unique_ptr<TestDataset> dataset;
649   TF_RETURN_IF_ERROR(MakeDataset(params, &dataset));
650   std::unique_ptr<SplitProvider> split_provider;
651   TF_RETURN_IF_ERROR(dataset->dataset()->MakeSplitProvider(&split_provider));
652   split_provider = absl::make_unique<ShardingSplitProvider>(
653       num_shards, shard_index, std::move(split_provider));
654   std::unique_ptr<IteratorContext> iterator_ctx;
655   TF_RETURN_IF_ERROR(
656       CreateIteratorContext(dataset->op_kernel_context(), &iterator_ctx));
657   IteratorContext::Params iterator_params(iterator_ctx.get());
658   iterator_params.split_provider = std::move(split_provider);
659   iterator_ctx = absl::make_unique<IteratorContext>(iterator_params);
660   int mid_breakpoint = expected_outputs.size() / 2;
661   int near_end_breakpoint = expected_outputs.size() - 1;
662   int end_breakpoint = expected_outputs.size();
663   TF_RETURN_IF_ERROR(CheckIteratorSaveAndRestore(
664       dataset->dataset(), iterator_ctx.get(), params.iterator_prefix(),
665       expected_outputs,
666       /*breakpoints=*/
667       {0, mid_breakpoint, near_end_breakpoint, end_breakpoint},
668       /*compare_order=*/true));
669   return Status::OK();
670 }
671 
CheckDatasetNodeName(const string & expected_dataset_node_name)672 Status DatasetOpsTestBase::CheckDatasetNodeName(
673     const string& expected_dataset_node_name) {
674   EXPECT_EQ(dataset_->node_name(), expected_dataset_node_name);
675   return Status::OK();
676 }
677 
CheckDatasetTypeString(const string & expected_type_str)678 Status DatasetOpsTestBase::CheckDatasetTypeString(
679     const string& expected_type_str) {
680   EXPECT_EQ(dataset_->type_string(), expected_type_str);
681   return Status::OK();
682 }
683 
CheckDatasetOutputDtypes(const DataTypeVector & expected_output_dtypes)684 Status DatasetOpsTestBase::CheckDatasetOutputDtypes(
685     const DataTypeVector& expected_output_dtypes) {
686   TF_EXPECT_OK(
687       VerifyTypesMatch(dataset_->output_dtypes(), expected_output_dtypes));
688   return Status::OK();
689 }
690 
CheckDatasetOutputShapes(const std::vector<PartialTensorShape> & expected_output_shapes)691 Status DatasetOpsTestBase::CheckDatasetOutputShapes(
692     const std::vector<PartialTensorShape>& expected_output_shapes) {
693   TF_EXPECT_OK(VerifyShapesCompatible(dataset_->output_shapes(),
694                                       expected_output_shapes));
695   return Status::OK();
696 }
697 
CheckDatasetCardinality(int expected_cardinality)698 Status DatasetOpsTestBase::CheckDatasetCardinality(int expected_cardinality) {
699   EXPECT_EQ(dataset_->Cardinality(), expected_cardinality);
700   return Status::OK();
701 }
702 
CheckIteratorOutputDtypes(const DataTypeVector & expected_output_dtypes)703 Status DatasetOpsTestBase::CheckIteratorOutputDtypes(
704     const DataTypeVector& expected_output_dtypes) {
705   TF_EXPECT_OK(
706       VerifyTypesMatch(iterator_->output_dtypes(), expected_output_dtypes));
707   return Status::OK();
708 }
709 
CheckIteratorOutputShapes(const std::vector<PartialTensorShape> & expected_output_shapes)710 Status DatasetOpsTestBase::CheckIteratorOutputShapes(
711     const std::vector<PartialTensorShape>& expected_output_shapes) {
712   TF_EXPECT_OK(VerifyShapesCompatible(iterator_->output_shapes(),
713                                       expected_output_shapes));
714   return Status::OK();
715 }
716 
CheckIteratorPrefix(const string & expected_iterator_prefix)717 Status DatasetOpsTestBase::CheckIteratorPrefix(
718     const string& expected_iterator_prefix) {
719   EXPECT_EQ(iterator_->prefix(), expected_iterator_prefix);
720   return Status::OK();
721 }
722 
CheckIteratorSaveAndRestore(DatasetBase * dataset,IteratorContext * iterator_ctx,const std::string & iterator_prefix,const std::vector<Tensor> & expected_outputs,const std::vector<int> & breakpoints,bool compare_order)723 Status DatasetOpsTestBase::CheckIteratorSaveAndRestore(
724     DatasetBase* dataset, IteratorContext* iterator_ctx,
725     const std::string& iterator_prefix,
726     const std::vector<Tensor>& expected_outputs,
727     const std::vector<int>& breakpoints, bool compare_order) {
728   std::unique_ptr<IteratorBase> iterator;
729   TF_RETURN_IF_ERROR(dataset->MakeIterator(iterator_ctx, /*parent=*/nullptr,
730                                            iterator_prefix, &iterator));
731   std::unique_ptr<SerializationContext> serialization_ctx;
732   TF_RETURN_IF_ERROR(CreateSerializationContext(&serialization_ctx));
733   bool end_of_sequence = false;
734   int cur_iteration = 0;
735   std::vector<Tensor> out_tensors;
736   for (int breakpoint : breakpoints) {
737     VariantTensorDataWriter writer;
738     TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
739     std::vector<const VariantTensorData*> data;
740     writer.GetData(&data);
741     VariantTensorDataReader reader(data);
742     TF_EXPECT_OK(RestoreIterator(iterator_ctx, &reader, iterator_prefix,
743                                  *dataset, &iterator));
744 
745     while (cur_iteration <= breakpoint) {
746       std::vector<Tensor> next;
747       TF_RETURN_IF_ERROR(
748           iterator->GetNext(iterator_ctx, &next, &end_of_sequence));
749       out_tensors.insert(out_tensors.end(), next.begin(), next.end());
750       cur_iteration++;
751     }
752   }
753   TF_EXPECT_OK(ExpectEqual(out_tensors, expected_outputs,
754                            /*compare_order=*/compare_order));
755   return Status::OK();
756 }
757 
CheckIteratorSaveAndRestore(const std::string & iterator_prefix,const std::vector<Tensor> & expected_outputs,const std::vector<int> & breakpoints,bool compare_order)758 Status DatasetOpsTestBase::CheckIteratorSaveAndRestore(
759     const std::string& iterator_prefix,
760     const std::vector<Tensor>& expected_outputs,
761     const std::vector<int>& breakpoints, bool compare_order) {
762   return CheckIteratorSaveAndRestore(dataset_, iterator_ctx_.get(),
763                                      iterator_prefix, expected_outputs,
764                                      breakpoints, compare_order);
765 }
766 
Initialize(const DatasetParams & dataset_params)767 Status DatasetOpsTestBase::Initialize(const DatasetParams& dataset_params) {
768   if (initialized_) {
769     return errors::Internal(
770         "The fields (e.g. dataset_kernel_, dataset_ctx_, dataset_, "
771         "iterator_ctx_, iterator_) have already been initialized.");
772   }
773   TF_RETURN_IF_ERROR(InitializeRuntime(dataset_params));
774   TF_RETURN_IF_ERROR(MakeDataset(dataset_params, &dataset_kernel_, &params_,
775                                  &dataset_ctx_, &tensors_, &dataset_));
776   TF_RETURN_IF_ERROR(CreateIteratorContext(dataset_ctx_.get(), &iterator_ctx_));
777   TF_RETURN_IF_ERROR(
778       dataset_->MakeIterator(iterator_ctx_.get(), /*parent=*/nullptr,
779                              dataset_params.iterator_prefix(), &iterator_));
780   initialized_ = true;
781   return Status::OK();
782 }
783 
InitializeRuntime(const DatasetParams & dataset_params)784 Status DatasetOpsTestBase::InitializeRuntime(
785     const DatasetParams& dataset_params) {
786   TF_RETURN_IF_ERROR(InitThreadPool(thread_num_));
787   TF_RETURN_IF_ERROR(
788       InitFunctionLibraryRuntime(dataset_params.func_lib(), cpu_num_));
789   return Status::OK();
790 }
791 
MakeDataset(const DatasetParams & dataset_params,std::unique_ptr<TestDataset> * dataset)792 Status DatasetOpsTestBase::MakeDataset(const DatasetParams& dataset_params,
793                                        std::unique_ptr<TestDataset>* dataset) {
794   DatasetBase* dataset_base;
795   std::unique_ptr<OpKernel> dataset_kernel;
796   std::unique_ptr<OpKernelContext::Params> dataset_ctx_params;
797   std::unique_ptr<OpKernelContext> dataset_ctx;
798   std::vector<std::unique_ptr<Tensor>> created_tensors;
799   TF_RETURN_IF_ERROR(MakeDataset(dataset_params, &dataset_kernel,
800                                  &dataset_ctx_params, &dataset_ctx,
801                                  &created_tensors, &dataset_base));
802   *dataset = std::make_unique<TestDataset>(
803       std::move(dataset_kernel), std::move(dataset_ctx_params),
804       std::move(dataset_ctx), std::move(created_tensors), dataset_base);
805   return Status::OK();
806 }
807 
RunDatasetOp(const DatasetParams & dataset_params,std::unique_ptr<OpKernel> * dataset_kernel,std::unique_ptr<OpKernelContext::Params> * dataset_ctx_params,std::vector<std::unique_ptr<Tensor>> * created_tensors,std::unique_ptr<OpKernelContext> * dataset_ctx)808 Status DatasetOpsTestBase::RunDatasetOp(
809     const DatasetParams& dataset_params,
810     std::unique_ptr<OpKernel>* dataset_kernel,
811     std::unique_ptr<OpKernelContext::Params>* dataset_ctx_params,
812     std::vector<std::unique_ptr<Tensor>>* created_tensors,
813     std::unique_ptr<OpKernelContext>* dataset_ctx) {
814   std::vector<Tensor*> input_datasets;
815   for (auto& input : dataset_params.input_dataset_params()) {
816     std::unique_ptr<Tensor> t;
817     TF_RETURN_IF_ERROR(MakeDatasetTensor(*input, created_tensors, &t));
818     input_datasets.push_back(t.get());
819     created_tensors->push_back(std::move(t));
820   }
821   gtl::InlinedVector<TensorValue, 4> inputs;
822   for (auto input_dataset : input_datasets) {
823     inputs.emplace_back(TensorValue(input_dataset));
824   }
825 
826   // Copy the input tensors, storing them in the `inputs` vectors, and storing
827   // owned references to the copies in `created_tensors`.
828   for (auto& input : dataset_params.GetInputTensors()) {
829     auto copy = absl::make_unique<Tensor>(input);
830     inputs.push_back(TensorValue(copy.get()));
831     created_tensors->push_back(std::move(copy));
832   }
833 
834   TF_RETURN_IF_ERROR(MakeDatasetOpKernel(dataset_params, dataset_kernel));
835   TF_RETURN_IF_ERROR(CreateDatasetContext(dataset_kernel->get(), &inputs,
836                                           dataset_ctx_params, dataset_ctx));
837   TF_RETURN_IF_ERROR(RunOpKernel(dataset_kernel->get(), dataset_ctx->get()));
838   return Status::OK();
839 }
840 
MakeDataset(const DatasetParams & dataset_params,std::unique_ptr<OpKernel> * dataset_kernel,std::unique_ptr<OpKernelContext::Params> * dataset_ctx_params,std::unique_ptr<OpKernelContext> * dataset_ctx,std::vector<std::unique_ptr<Tensor>> * created_tensors,DatasetBase ** dataset)841 Status DatasetOpsTestBase::MakeDataset(
842     const DatasetParams& dataset_params,
843     std::unique_ptr<OpKernel>* dataset_kernel,
844     std::unique_ptr<OpKernelContext::Params>* dataset_ctx_params,
845     std::unique_ptr<OpKernelContext>* dataset_ctx,
846     std::vector<std::unique_ptr<Tensor>>* created_tensors,
847     DatasetBase** dataset) {
848   TF_RETURN_IF_ERROR(RunDatasetOp(dataset_params, dataset_kernel,
849                                   dataset_ctx_params, created_tensors,
850                                   dataset_ctx));
851   // Assume that DatasetOp has only one output.
852   DCHECK_EQ((*dataset_ctx)->num_outputs(), 1);
853   TF_RETURN_IF_ERROR(GetDatasetFromContext(dataset_ctx->get(), 0, dataset));
854   return Status::OK();
855 }
856 
MakeIterator(const DatasetParams & dataset_params,const TestDataset & dataset,std::unique_ptr<SplitProvider> split_provider,std::unique_ptr<TestIterator> * iterator)857 Status DatasetOpsTestBase::MakeIterator(
858     const DatasetParams& dataset_params, const TestDataset& dataset,
859     std::unique_ptr<SplitProvider> split_provider,
860     std::unique_ptr<TestIterator>* iterator) {
861   std::unique_ptr<IteratorContext> iterator_ctx;
862   TF_RETURN_IF_ERROR(
863       CreateIteratorContext(dataset.op_kernel_context(), &iterator_ctx));
864   IteratorContext::Params iterator_params(iterator_ctx.get());
865   iterator_params.split_provider = std::move(split_provider);
866   iterator_ctx = absl::make_unique<IteratorContext>(iterator_params);
867   std::unique_ptr<IteratorBase> iterator_base;
868   TF_RETURN_IF_ERROR(dataset.dataset()->MakeIterator(
869       iterator_ctx.get(), /*parent=*/nullptr, dataset_params.iterator_prefix(),
870       &iterator_base));
871   *iterator = std::make_unique<TestIterator>(std::move(iterator_ctx),
872                                              std::move(iterator_base));
873   return Status::OK();
874 }
875 
MakeIterator(const DatasetParams & dataset_params,const TestDataset & dataset,std::unique_ptr<TestIterator> * iterator)876 Status DatasetOpsTestBase::MakeIterator(
877     const DatasetParams& dataset_params, const TestDataset& dataset,
878     std::unique_ptr<TestIterator>* iterator) {
879   return MakeIterator(dataset_params, dataset, /*split_provider=*/nullptr,
880                       iterator);
881 }
882 
RunDatasetOp(const DatasetParams & dataset_params,std::vector<Tensor> * outputs)883 Status DatasetOpsTestBase::RunDatasetOp(const DatasetParams& dataset_params,
884                                         std::vector<Tensor>* outputs) {
885   TF_RETURN_IF_ERROR(RunDatasetOp(dataset_params, &dataset_kernel_, &params_,
886                                   &tensors_, &dataset_ctx_));
887   for (int i = 0; i < dataset_ctx_->num_outputs(); ++i) {
888     outputs->emplace_back(*dataset_ctx_->mutable_output(i));
889   }
890   return Status::OK();
891 }
892 
MakeDatasetOpKernel(const DatasetParams & dataset_params,std::unique_ptr<OpKernel> * dataset_kernel)893 Status DatasetOpsTestBase::MakeDatasetOpKernel(
894     const DatasetParams& dataset_params,
895     std::unique_ptr<OpKernel>* dataset_kernel) {
896   name_utils::OpNameParams params;
897   params.op_version = dataset_params.op_version();
898   std::vector<string> input_names;
899   TF_RETURN_IF_ERROR(dataset_params.GetInputNames(&input_names));
900   AttributeVector attributes;
901   TF_RETURN_IF_ERROR(dataset_params.GetAttributes(&attributes));
902   NodeDef node_def = test::function::NDef(
903       dataset_params.node_name(),
904       name_utils::OpName(dataset_params.dataset_type(), params), input_names,
905       attributes);
906   TF_RETURN_IF_ERROR(CreateOpKernel(node_def, dataset_kernel));
907   return Status::OK();
908 }
909 
MakeDatasetTensor(const DatasetParams & dataset_params,std::vector<std::unique_ptr<Tensor>> * created_tensors,std::unique_ptr<Tensor> * dataset)910 Status DatasetOpsTestBase::MakeDatasetTensor(
911     const DatasetParams& dataset_params,
912     std::vector<std::unique_ptr<Tensor>>* created_tensors,
913     std::unique_ptr<Tensor>* dataset) {
914   // Make sure all the input dataset tensors have been populated.
915   std::vector<Tensor*> input_datasets;
916   for (auto& input : dataset_params.input_dataset_params()) {
917     std::unique_ptr<Tensor> t;
918     TF_RETURN_IF_ERROR(MakeDatasetTensor(*input, created_tensors, &t));
919     input_datasets.push_back(t.get());
920     created_tensors->push_back(std::move(t));
921   }
922 
923   AttributeVector attributes;
924   TF_RETURN_IF_ERROR(dataset_params.GetAttributes(&attributes));
925 
926   gtl::InlinedVector<TensorValue, 4> inputs;
927   for (auto input_dataset : input_datasets) {
928     inputs.emplace_back(TensorValue(input_dataset));
929   }
930   auto input_tensors = dataset_params.GetInputTensors();
931   for (auto& input_tensor : input_tensors) {
932     inputs.emplace_back(TensorValue(&input_tensor));
933   }
934 
935   DatasetBase* dataset_base;
936   std::unique_ptr<OpKernel> dataset_kernel;
937   std::unique_ptr<OpKernelContext::Params> dataset_ctx_params;
938   std::unique_ptr<OpKernelContext> dataset_ctx;
939   TF_RETURN_IF_ERROR(MakeDatasetOpKernel(dataset_params, &dataset_kernel));
940   TF_RETURN_IF_ERROR(CreateDatasetContext(dataset_kernel.get(), &inputs,
941                                           &dataset_ctx_params, &dataset_ctx));
942   TF_RETURN_IF_ERROR(
943       CreateDataset(dataset_kernel.get(), dataset_ctx.get(), &dataset_base));
944   Tensor dataset_tensor(DT_VARIANT, TensorShape({}));
945   TF_RETURN_IF_ERROR(
946       StoreDatasetInVariantTensor(dataset_base, &dataset_tensor));
947   *dataset = absl::make_unique<Tensor>(dataset_tensor);
948   return Status::OK();
949 }
950 
DatasetParams(DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,string node_name)951 DatasetParams::DatasetParams(DataTypeVector output_dtypes,
952                              std::vector<PartialTensorShape> output_shapes,
953                              string node_name)
954     : output_dtypes_(std::move(output_dtypes)),
955       output_shapes_(std::move(output_shapes)),
956       node_name_(std::move(node_name)) {}
957 
IsDatasetTensor(const Tensor & tensor)958 bool DatasetParams::IsDatasetTensor(const Tensor& tensor) {
959   return tensor.dtype() == DT_VARIANT &&
960          TensorShapeUtils::IsScalar(tensor.shape());
961 }
962 
RangeDatasetParams(int64 start,int64 stop,int64 step,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,string node_name)963 RangeDatasetParams::RangeDatasetParams(
964     int64 start, int64 stop, int64 step, DataTypeVector output_dtypes,
965     std::vector<PartialTensorShape> output_shapes, string node_name)
966     : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
967                     std::move(node_name)),
968       start_(start),
969       stop_(stop),
970       step_(step) {}
971 
RangeDatasetParams(int64 start,int64 stop,int64 step)972 RangeDatasetParams::RangeDatasetParams(int64 start, int64 stop, int64 step)
973     : DatasetParams({DT_INT64}, {PartialTensorShape({})}, "range_dataset"),
974       start_(start),
975       stop_(stop),
976       step_(step) {}
977 
RangeDatasetParams(int64 start,int64 stop,int64 step,DataTypeVector output_dtypes)978 RangeDatasetParams::RangeDatasetParams(int64 start, int64 stop, int64 step,
979                                        DataTypeVector output_dtypes)
980     : DatasetParams(std::move(output_dtypes), {PartialTensorShape({})},
981                     "range_dataset"),
982       start_(start),
983       stop_(stop),
984       step_(step) {}
985 
GetInputTensors() const986 std::vector<Tensor> RangeDatasetParams::GetInputTensors() const {
987   Tensor start_tensor = CreateTensor<int64>(TensorShape({}), {start_});
988   Tensor stop_tensor = CreateTensor<int64>(TensorShape({}), {stop_});
989   Tensor step_tensor = CreateTensor<int64>(TensorShape({}), {step_});
990   return {start_tensor, stop_tensor, step_tensor};
991 }
992 
GetInputNames(std::vector<string> * input_names) const993 Status RangeDatasetParams::GetInputNames(
994     std::vector<string>* input_names) const {
995   *input_names = {RangeDatasetOp::kStart, RangeDatasetOp::kStop,
996                   RangeDatasetOp::kStep};
997   return Status::OK();
998 }
999 
GetAttributes(AttributeVector * attr_vector) const1000 Status RangeDatasetParams::GetAttributes(AttributeVector* attr_vector) const {
1001   *attr_vector = {{RangeDatasetOp::kOutputTypes, output_dtypes_},
1002                   {RangeDatasetOp::kOutputShapes, output_shapes_}};
1003   return Status::OK();
1004 }
1005 
dataset_type() const1006 string RangeDatasetParams::dataset_type() const {
1007   return RangeDatasetOp::kDatasetType;
1008 }
1009 
GetInputTensors() const1010 std::vector<Tensor> BatchDatasetParams::GetInputTensors() const {
1011   Tensor batch_size = CreateTensor<int64>(TensorShape({}), {batch_size_});
1012   Tensor drop_remainder =
1013       CreateTensor<bool>(TensorShape({}), {drop_remainder_});
1014   return {batch_size, drop_remainder};
1015 }
1016 
GetInputNames(std::vector<string> * input_names) const1017 Status BatchDatasetParams::GetInputNames(
1018     std::vector<string>* input_names) const {
1019   *input_names = {BatchDatasetOp::kInputDataset, BatchDatasetOp::kBatchSize,
1020                   BatchDatasetOp::kDropRemainder};
1021   return Status::OK();
1022 }
1023 
GetAttributes(AttributeVector * attr_vector) const1024 Status BatchDatasetParams::GetAttributes(AttributeVector* attr_vector) const {
1025   *attr_vector = {{BatchDatasetOp::kParallelCopy, parallel_copy_},
1026                   {BatchDatasetOp::kOutputTypes, output_dtypes_},
1027                   {BatchDatasetOp::kOutputShapes, output_shapes_}};
1028   return Status::OK();
1029 }
1030 
dataset_type() const1031 string BatchDatasetParams::dataset_type() const {
1032   return BatchDatasetOp::kDatasetType;
1033 }
1034 
GetInputTensors() const1035 std::vector<Tensor> MapDatasetParams::GetInputTensors() const {
1036   return other_arguments_;
1037 }
1038 
GetInputNames(std::vector<string> * input_names) const1039 Status MapDatasetParams::GetInputNames(std::vector<string>* input_names) const {
1040   input_names->emplace_back(MapDatasetOp::kInputDataset);
1041   for (int i = 0; i < other_arguments_.size(); ++i) {
1042     input_names->emplace_back(
1043         absl::StrCat(MapDatasetOp::kOtherArguments, "_", i));
1044   }
1045   return Status::OK();
1046 }
1047 
GetAttributes(AttributeVector * attr_vector) const1048 Status MapDatasetParams::GetAttributes(AttributeVector* attr_vector) const {
1049   *attr_vector = {
1050       {MapDatasetOp::kFunc, func_},
1051       {MapDatasetOp::kTarguments, type_arguments_},
1052       {MapDatasetOp::kOutputShapes, output_shapes_},
1053       {MapDatasetOp::kOutputTypes, output_dtypes_},
1054       {MapDatasetOp::kUseInterOpParallelism, use_inter_op_parallelism_},
1055       {MapDatasetOp::kPreserveCardinality, preserve_cardinality_}};
1056   return Status::OK();
1057 }
1058 
dataset_type() const1059 string MapDatasetParams::dataset_type() const {
1060   return MapDatasetOp::kDatasetType;
1061 }
1062 
func_lib() const1063 std::vector<FunctionDef> MapDatasetParams::func_lib() const {
1064   return func_lib_;
1065 }
1066 
TensorSliceDatasetParams(std::vector<Tensor> components,string node_name)1067 TensorSliceDatasetParams::TensorSliceDatasetParams(
1068     std::vector<Tensor> components, string node_name)
1069     : DatasetParams(TensorSliceDtypes(components),
1070                     TensorSliceShapes(components), std::move(node_name)),
1071       components_(std::move(components)) {}
1072 
GetInputTensors() const1073 std::vector<Tensor> TensorSliceDatasetParams::GetInputTensors() const {
1074   return components_;
1075 }
1076 
GetInputNames(std::vector<string> * input_names) const1077 Status TensorSliceDatasetParams::GetInputNames(
1078     std::vector<string>* input_names) const {
1079   input_names->reserve(components_.size());
1080   for (int i = 0; i < components_.size(); ++i) {
1081     input_names->emplace_back(
1082         absl::StrCat(TensorSliceDatasetOp::kComponents, "_", i));
1083   }
1084   return Status::OK();
1085 }
1086 
GetAttributes(AttributeVector * attr_vector) const1087 Status TensorSliceDatasetParams::GetAttributes(
1088     AttributeVector* attr_vector) const {
1089   *attr_vector = {{TensorSliceDatasetOp::kToutputTypes, output_dtypes_},
1090                   {TensorSliceDatasetOp::kOutputShapes, output_shapes_}};
1091   return Status::OK();
1092 }
1093 
TensorSliceDtypes(const std::vector<Tensor> & input_components)1094 DataTypeVector TensorSliceDatasetParams::TensorSliceDtypes(
1095     const std::vector<Tensor>& input_components) {
1096   DataTypeVector dtypes;
1097   for (const auto& component : input_components) {
1098     dtypes.emplace_back(component.dtype());
1099   }
1100   return dtypes;
1101 }
1102 
TensorSliceShapes(const std::vector<Tensor> & input_components)1103 std::vector<PartialTensorShape> TensorSliceDatasetParams::TensorSliceShapes(
1104     const std::vector<Tensor>& input_components) {
1105   std::vector<PartialTensorShape> shapes;
1106   for (const auto& component : input_components) {
1107     gtl::InlinedVector<int64, 4> partial_dim_sizes;
1108     for (int i = 1; i < component.dims(); ++i) {
1109       partial_dim_sizes.push_back(component.dim_size(i));
1110     }
1111     shapes.emplace_back(std::move(partial_dim_sizes));
1112   }
1113   return shapes;
1114 }
1115 
dataset_type() const1116 string TensorSliceDatasetParams::dataset_type() const {
1117   return TensorSliceDatasetOp::kDatasetType;
1118 }
1119 
GetInputTensors() const1120 std::vector<Tensor> TakeDatasetParams::GetInputTensors() const {
1121   return {CreateTensor<int64>(TensorShape({}), {count_})};
1122 }
1123 
GetInputNames(std::vector<string> * input_names) const1124 Status TakeDatasetParams::GetInputNames(
1125     std::vector<string>* input_names) const {
1126   *input_names = {TakeDatasetOp::kInputDataset, TakeDatasetOp::kCount};
1127   return Status::OK();
1128 }
1129 
GetAttributes(AttributeVector * attr_vector) const1130 Status TakeDatasetParams::GetAttributes(AttributeVector* attr_vector) const {
1131   *attr_vector = {{TakeDatasetOp::kOutputShapes, output_shapes_},
1132                   {TakeDatasetOp::kOutputTypes, output_dtypes_}};
1133   return Status::OK();
1134 }
1135 
dataset_type() const1136 string TakeDatasetParams::dataset_type() const {
1137   return TakeDatasetOp::kDatasetType;
1138 }
1139 
GetInputTensors() const1140 std::vector<Tensor> ConcatenateDatasetParams::GetInputTensors() const {
1141   return {};
1142 }
1143 
GetInputNames(std::vector<string> * input_names) const1144 Status ConcatenateDatasetParams::GetInputNames(
1145     std::vector<string>* input_names) const {
1146   *input_names = {ConcatenateDatasetOp::kInputDataset,
1147                   ConcatenateDatasetOp::kAnotherDataset};
1148   return Status::OK();
1149 }
1150 
GetAttributes(AttributeVector * attr_vector) const1151 Status ConcatenateDatasetParams::GetAttributes(
1152     AttributeVector* attr_vector) const {
1153   *attr_vector = {{ConcatenateDatasetOp::kOutputTypes, output_dtypes_},
1154                   {ConcatenateDatasetOp::kOutputShapes, output_shapes_}};
1155   return Status::OK();
1156 }
1157 
dataset_type() const1158 string ConcatenateDatasetParams::dataset_type() const {
1159   return ConcatenateDatasetOp::kDatasetType;
1160 }
1161 
1162 }  // namespace data
1163 }  // namespace tensorflow
1164