1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/data/standalone.h"
17
18 #include <memory>
19
20 #include "absl/memory/memory.h"
21 #include "tensorflow/core/common_runtime/device_factory.h"
22 #include "tensorflow/core/common_runtime/device_mgr.h"
23 #include "tensorflow/core/common_runtime/function.h"
24 #include "tensorflow/core/common_runtime/graph_constructor.h"
25 #include "tensorflow/core/common_runtime/graph_runner.h"
26 #include "tensorflow/core/common_runtime/process_util.h"
27 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
28 #include "tensorflow/core/framework/dataset.h"
29 #include "tensorflow/core/graph/graph.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/public/version.h"
32 #include "tensorflow/core/util/ptr_util.h"
33
34 namespace tensorflow {
35 namespace data {
36 namespace standalone {
37
GetNext(std::vector<Tensor> * outputs,bool * end_of_input)38 Status Iterator::GetNext(std::vector<Tensor>* outputs, bool* end_of_input) {
39 return iterator_->GetNext(ctx_.get(), outputs, end_of_input);
40 }
41
Iterator(IteratorBase * iterator,IteratorContext * ctx)42 Iterator::Iterator(IteratorBase* iterator, IteratorContext* ctx)
43 : iterator_(iterator), ctx_(ctx) {}
44
FromGraph(Params params,const GraphDef & graph_def,std::unique_ptr<Dataset> * result)45 Status Dataset::FromGraph(Params params, const GraphDef& graph_def,
46 std::unique_ptr<Dataset>* result) {
47 Graph graph(OpRegistry::Global());
48 TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
49
50 // Instantiate enough of the TF runtime to run `graph` on a single CPU device.
51 auto device_mgr = absl::make_unique<StaticDeviceMgr>(DeviceFactory::NewDevice(
52 "CPU", params.session_options, "/job:localhost/replica:0/task:0"));
53 Device* device = device_mgr->ListDevices()[0];
54 // Clone the `FunctionLibraryDefinition` to extend its lifetime extends beyond
55 // the lifetime of `graph`.
56 auto flib_def =
57 absl::make_unique<FunctionLibraryDefinition>(graph.flib_def());
58 auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
59 device_mgr.get(), Env::Default(), /*config=*/nullptr,
60 TF_GRAPH_DEF_VERSION, flib_def.get(), OptimizerOptions{},
61 /*thread_pool=*/nullptr, /*parent=*/nullptr,
62 /*session_metadata=*/nullptr,
63 Rendezvous::Factory{
64 [](const int64, const DeviceMgr* device_mgr, Rendezvous** r) {
65 *r = new IntraProcessRendezvous(device_mgr);
66 return Status::OK();
67 }});
68
69 string fetch_node = "";
70 for (const auto& node : graph_def.node()) {
71 if (node.op() == "_Retval") {
72 fetch_node = node.input(0);
73 }
74 }
75 if (fetch_node.empty()) {
76 return errors::NotFound("Failed to find a _Retval op in the given dataset");
77 }
78
79 // Run graph up to `output_node` and extract the `DatasetBase` stored in the
80 // DT_VARIANT output tensor.
81 data::DatasetBase* dataset;
82 {
83 std::vector<Tensor> outputs;
84 GraphRunner graph_runner(device);
85 TF_RETURN_IF_ERROR(graph_runner.Run(&graph, pflr->GetFLR("/device:CPU:0"),
86 {}, {fetch_node}, &outputs));
87 TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset));
88 // NOTE(mrry): The dataset is currently owned by `outputs[0]`, so acquire an
89 // additional reference.
90 dataset->Ref();
91 }
92
93 std::unique_ptr<thread::ThreadPool> pool(
94 NewThreadPoolFromSessionOptions(params.session_options));
95 *result =
96 WrapUnique(new Dataset(dataset, device_mgr.release(), pflr.release(),
97 flib_def.release(), pool.release()));
98 return Status::OK();
99 } // static
100
MakeIterator(std::unique_ptr<SplitProvider> split_provider,std::unique_ptr<Iterator> * result)101 Status Dataset::MakeIterator(std::unique_ptr<SplitProvider> split_provider,
102 std::unique_ptr<Iterator>* result) {
103 // Create an `IteratorContext`, which bundles together the necessary runtime
104 // support to create and get elements from an iterator.
105 std::unique_ptr<IteratorContext> ctx;
106 {
107 // NOTE(mrry): In the current API, an `IteratorContext` is always initially
108 // created from an `OpKernelContext*`, so we need to create a fake
109 // `OpKernelContext` with the appropriate subset of parameters.
110 OpKernelContext::Params op_params;
111 op_params.function_library = pflr_->GetFLR("/device:CPU:0");
112 op_params.device = device_mgr_->ListDevices()[0];
113 op_params.runner = &runner_;
114 OpKernelContext op_ctx(&op_params, 0);
115 IteratorContext::Params params(&op_ctx);
116 params.function_handle_cache = function_handle_cache_.get();
117 params.resource_mgr = &resource_mgr_;
118 params.cancellation_manager = &cancellation_manager_;
119 params.split_provider = std::move(split_provider);
120
121 ctx = absl::make_unique<IteratorContext>(std::move(params));
122 }
123
124 // Create the iterator from the dataset.
125 std::unique_ptr<IteratorBase> iterator;
126 TF_RETURN_IF_ERROR(dataset_->MakeIterator(ctx.get(), /*parent=*/nullptr,
127 "Iterator", &iterator));
128
129 *result = WrapUnique(new Iterator(iterator.release(), ctx.release()));
130
131 return Status::OK();
132 }
133
MakeIterator(std::unique_ptr<Iterator> * result)134 Status Dataset::MakeIterator(std::unique_ptr<Iterator>* result) {
135 return MakeIterator(/*split_provider=*/nullptr, result);
136 }
137
MakeSplitProvider(std::unique_ptr<SplitProvider> * result)138 Status Dataset::MakeSplitProvider(std::unique_ptr<SplitProvider>* result) {
139 return dataset_->MakeSplitProvider(result);
140 }
141
Get() const142 const DatasetBase* Dataset::Get() const { return dataset_; }
143
Dataset(DatasetBase * dataset,DeviceMgr * device_mgr,ProcessFunctionLibraryRuntime * pflr,FunctionLibraryDefinition * flib_def,thread::ThreadPool * pool)144 Dataset::Dataset(DatasetBase* dataset, DeviceMgr* device_mgr,
145 ProcessFunctionLibraryRuntime* pflr,
146 FunctionLibraryDefinition* flib_def, thread::ThreadPool* pool)
147 : dataset_(dataset),
148 device_mgr_(device_mgr),
149 flib_def_(flib_def),
150 pflr_(pflr),
151 pool_(pool) {
152 runner_ = [this](std::function<void()> c) { pool_->Schedule(std::move(c)); };
153 function_handle_cache_ =
154 absl::make_unique<FunctionHandleCache>(pflr_->GetFLR("/device:CPU:0"));
155 }
156
~Dataset()157 Dataset::~Dataset() { dataset_->Unref(); }
158
159 } // namespace standalone
160 } // namespace data
161 } // namespace tensorflow
162