1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/data/standalone.h"
17 
18 #include <memory>
19 
20 #include "absl/memory/memory.h"
21 #include "tensorflow/core/common_runtime/device_factory.h"
22 #include "tensorflow/core/common_runtime/device_mgr.h"
23 #include "tensorflow/core/common_runtime/function.h"
24 #include "tensorflow/core/common_runtime/graph_constructor.h"
25 #include "tensorflow/core/common_runtime/graph_runner.h"
26 #include "tensorflow/core/common_runtime/process_util.h"
27 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
28 #include "tensorflow/core/framework/dataset.h"
29 #include "tensorflow/core/graph/graph.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/public/version.h"
32 #include "tensorflow/core/util/ptr_util.h"
33 
34 namespace tensorflow {
35 namespace data {
36 namespace standalone {
37 
GetNext(std::vector<Tensor> * outputs,bool * end_of_input)38 Status Iterator::GetNext(std::vector<Tensor>* outputs, bool* end_of_input) {
39   return iterator_->GetNext(ctx_.get(), outputs, end_of_input);
40 }
41 
Iterator(IteratorBase * iterator,IteratorContext * ctx)42 Iterator::Iterator(IteratorBase* iterator, IteratorContext* ctx)
43     : iterator_(iterator), ctx_(ctx) {}
44 
FromGraph(Params params,const GraphDef & graph_def,std::unique_ptr<Dataset> * result)45 Status Dataset::FromGraph(Params params, const GraphDef& graph_def,
46                           std::unique_ptr<Dataset>* result) {
47   Graph graph(OpRegistry::Global());
48   TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
49 
50   // Instantiate enough of the TF runtime to run `graph` on a single CPU device.
51   auto device_mgr = absl::make_unique<StaticDeviceMgr>(DeviceFactory::NewDevice(
52       "CPU", params.session_options, "/job:localhost/replica:0/task:0"));
53   Device* device = device_mgr->ListDevices()[0];
54   // Clone the `FunctionLibraryDefinition` to extend its lifetime extends beyond
55   // the lifetime of `graph`.
56   auto flib_def =
57       absl::make_unique<FunctionLibraryDefinition>(graph.flib_def());
58   auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
59       device_mgr.get(), Env::Default(), /*config=*/nullptr,
60       TF_GRAPH_DEF_VERSION, flib_def.get(), OptimizerOptions{},
61       /*thread_pool=*/nullptr, /*parent=*/nullptr,
62       /*session_metadata=*/nullptr,
63       Rendezvous::Factory{
64           [](const int64, const DeviceMgr* device_mgr, Rendezvous** r) {
65             *r = new IntraProcessRendezvous(device_mgr);
66             return Status::OK();
67           }});
68 
69   string fetch_node = "";
70   for (const auto& node : graph_def.node()) {
71     if (node.op() == "_Retval") {
72       fetch_node = node.input(0);
73     }
74   }
75   if (fetch_node.empty()) {
76     return errors::NotFound("Failed to find a _Retval op in the given dataset");
77   }
78 
79   // Run graph up to `output_node` and extract the `DatasetBase` stored in the
80   // DT_VARIANT output tensor.
81   data::DatasetBase* dataset;
82   {
83     std::vector<Tensor> outputs;
84     GraphRunner graph_runner(device);
85     TF_RETURN_IF_ERROR(graph_runner.Run(&graph, pflr->GetFLR("/device:CPU:0"),
86                                         {}, {fetch_node}, &outputs));
87     TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset));
88     // NOTE(mrry): The dataset is currently owned by `outputs[0]`, so acquire an
89     // additional reference.
90     dataset->Ref();
91   }
92 
93   std::unique_ptr<thread::ThreadPool> pool(
94       NewThreadPoolFromSessionOptions(params.session_options));
95   *result =
96       WrapUnique(new Dataset(dataset, device_mgr.release(), pflr.release(),
97                              flib_def.release(), pool.release()));
98   return Status::OK();
99 }  // static
100 
MakeIterator(std::unique_ptr<SplitProvider> split_provider,std::unique_ptr<Iterator> * result)101 Status Dataset::MakeIterator(std::unique_ptr<SplitProvider> split_provider,
102                              std::unique_ptr<Iterator>* result) {
103   // Create an `IteratorContext`, which bundles together the necessary runtime
104   // support to create and get elements from an iterator.
105   std::unique_ptr<IteratorContext> ctx;
106   {
107     // NOTE(mrry): In the current API, an `IteratorContext` is always initially
108     // created from an `OpKernelContext*`, so we need to create a fake
109     // `OpKernelContext` with the appropriate subset of parameters.
110     OpKernelContext::Params op_params;
111     op_params.function_library = pflr_->GetFLR("/device:CPU:0");
112     op_params.device = device_mgr_->ListDevices()[0];
113     op_params.runner = &runner_;
114     OpKernelContext op_ctx(&op_params, 0);
115     IteratorContext::Params params(&op_ctx);
116     params.function_handle_cache = function_handle_cache_.get();
117     params.resource_mgr = &resource_mgr_;
118     params.cancellation_manager = &cancellation_manager_;
119     params.split_provider = std::move(split_provider);
120 
121     ctx = absl::make_unique<IteratorContext>(std::move(params));
122   }
123 
124   // Create the iterator from the dataset.
125   std::unique_ptr<IteratorBase> iterator;
126   TF_RETURN_IF_ERROR(dataset_->MakeIterator(ctx.get(), /*parent=*/nullptr,
127                                             "Iterator", &iterator));
128 
129   *result = WrapUnique(new Iterator(iterator.release(), ctx.release()));
130 
131   return Status::OK();
132 }
133 
MakeIterator(std::unique_ptr<Iterator> * result)134 Status Dataset::MakeIterator(std::unique_ptr<Iterator>* result) {
135   return MakeIterator(/*split_provider=*/nullptr, result);
136 }
137 
MakeSplitProvider(std::unique_ptr<SplitProvider> * result)138 Status Dataset::MakeSplitProvider(std::unique_ptr<SplitProvider>* result) {
139   return dataset_->MakeSplitProvider(result);
140 }
141 
Get() const142 const DatasetBase* Dataset::Get() const { return dataset_; }
143 
Dataset(DatasetBase * dataset,DeviceMgr * device_mgr,ProcessFunctionLibraryRuntime * pflr,FunctionLibraryDefinition * flib_def,thread::ThreadPool * pool)144 Dataset::Dataset(DatasetBase* dataset, DeviceMgr* device_mgr,
145                  ProcessFunctionLibraryRuntime* pflr,
146                  FunctionLibraryDefinition* flib_def, thread::ThreadPool* pool)
147     : dataset_(dataset),
148       device_mgr_(device_mgr),
149       flib_def_(flib_def),
150       pflr_(pflr),
151       pool_(pool) {
152   runner_ = [this](std::function<void()> c) { pool_->Schedule(std::move(c)); };
153   function_handle_cache_ =
154       absl::make_unique<FunctionHandleCache>(pflr_->GetFLR("/device:CPU:0"));
155 }
156 
~Dataset()157 Dataset::~Dataset() { dataset_->Unref(); }
158 
159 }  // namespace standalone
160 }  // namespace data
161 }  // namespace tensorflow
162