1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/cluster_function_library_runtime.h"
16
17 #include <map>
18
19 #include "tensorflow/core/common_runtime/function.h"
20 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
21 #include "tensorflow/core/distributed_runtime/worker_session.h"
22 #include "tensorflow/core/framework/function.h"
23 #include "tensorflow/core/framework/graph_def_util.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/node_def_builder.h"
26 #include "tensorflow/core/framework/tensor.pb.h"
27 #include "tensorflow/core/graph/node_builder.h"
28 #include "tensorflow/core/lib/gtl/cleanup.h"
29 #include "tensorflow/core/lib/random/random.h"
30 #include "tensorflow/core/protobuf/named_tensor.pb.h"
31 #include "tensorflow/core/protobuf/worker.pb.h"
32
33 namespace tensorflow {
34
35 /* static */
ConstructFunctionGraph(const OpDef & sig,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,const FunctionLibraryDefinition & flib_def,GraphDef * gdef,std::vector<string> * send_keys,std::vector<string> * recv_keys)36 Status ClusterFunctionLibraryRuntime::ConstructFunctionGraph(
37 const OpDef& sig, AttrSlice attrs,
38 const FunctionLibraryRuntime::InstantiateOptions& options,
39 const FunctionLibraryDefinition& flib_def, GraphDef* gdef,
40 std::vector<string>* send_keys, std::vector<string>* recv_keys) {
41 const string& target = options.target;
42 const string& func_name = sig.name();
43 const FunctionDef* func_def = flib_def.Find(sig.name());
44 if (func_def == nullptr) {
45 return errors::InvalidArgument("Function ", func_name,
46 " not found in flib_def.");
47 }
48
49 // Build a smaller flib_def containing only the functions used by the given
50 // function, plus that function itself.
51 FunctionLibraryDefinition pruned_flib_def =
52 flib_def.ReachableDefinitions(*func_def);
53 TF_RETURN_IF_ERROR(pruned_flib_def.CopyFunctionDefFrom(func_name, flib_def));
54
55 Graph g(pruned_flib_def);
56
57 std::vector<Node*> input_nodes;
58 input_nodes.reserve(sig.input_arg_size());
59
60 // Construct recv nodes for each input argument.
61 int i = 0;
62 for (const auto& in : sig.input_arg()) {
63 // Resolve the input type.
64 bool is_type_list;
65 DataTypeVector dtypes;
66 TF_RETURN_IF_ERROR(ArgNumType(attrs, in, &is_type_list, &dtypes));
67 // TODO(rohanj): Handle list and variadic number of attrs. Here and below.
68 if (is_type_list || dtypes.size() > 1) {
69 return errors::Unimplemented("Input arg: ", in.name(),
70 " has a list type or variadic number of "
71 "attrs. Currently unsupported.");
72 }
73
74 auto input_node_builder =
75 NodeDefBuilder(strings::StrCat("_recv_", in.name(), "_", i), "_Recv")
76 .Attr("tensor_type", dtypes[0])
77 .Attr("tensor_name", in.name())
78 .Attr("send_device", target)
79 .Attr("recv_device", target)
80 .Attr("send_device_incarnation", 1)
81 .Attr("client_terminated", true)
82 .Device(target);
83
84 Node* input_node;
85 TF_RETURN_IF_ERROR(
86 NodeBuilder(input_node_builder).Finalize(&g, &input_node));
87 input_nodes.push_back(input_node);
88
89 // src_incarnation = 1 works because the transfer is across the same device.
90 // TODO(rohanj): Find the src_incarnation for the remote device and set it.
91 const string& key = Rendezvous::CreateKey(
92 target, 1 /* src_incarnation */, target, in.name(), FrameAndIter(0, 0));
93 send_keys->push_back(key);
94 ++i;
95 }
96
97 NodeDef function_node_def;
98 function_node_def.set_name(func_name);
99 function_node_def.set_op(func_name);
100 i = 0;
101 function_node_def.set_device(target);
102 for (const auto& p : attrs) {
103 (*function_node_def.mutable_attr())[p.first] = p.second;
104 }
105 Status status;
106 Node* function_node = g.AddNode(std::move(function_node_def), &status);
107 TF_RETURN_IF_ERROR(status);
108 for (size_t i = 0; i < input_nodes.size(); ++i) {
109 g.AddEdge(input_nodes[i], 0, function_node, i);
110 }
111
112 // Construct output nodes for each output.
113 i = 0;
114 for (const auto& out : sig.output_arg()) {
115 // Resolve the output type.
116 bool is_type_list;
117 DataTypeVector dtypes;
118 TF_RETURN_IF_ERROR(ArgNumType(attrs, out, &is_type_list, &dtypes));
119 // TODO(rohanj): Handle list and variadic number of attrs. Here and below.
120 if (is_type_list || dtypes.size() > 1) {
121 return errors::Unimplemented("Output arg: ", out.name(),
122 " has a list type or variadic number of "
123 "attrs. Currently unsupported.");
124 }
125
126 auto output_node_builder =
127 NodeDefBuilder(strings::StrCat("_send_", out.name(), "_", i), "_Send")
128 .Input(func_name, i, dtypes[0])
129 .Attr("tensor_name", out.name())
130 .Attr("send_device", target)
131 .Attr("recv_device", target)
132 .Attr("send_device_incarnation", 1)
133 .Attr("client_terminated", true)
134 .Device(target);
135
136 Node* output_node;
137 TF_RETURN_IF_ERROR(
138 NodeBuilder(output_node_builder).Finalize(&g, &output_node));
139
140 g.AddEdge(function_node, i, output_node, 0);
141
142 const string& key =
143 Rendezvous::CreateKey(target, 1 /* src_incarnation */, target,
144 out.name(), FrameAndIter(0, 0));
145 recv_keys->push_back(key);
146 ++i;
147 }
148
149 // Inline function node into the graph.
150 InlineFunctionBodyOptions inline_options;
151 inline_options.inlined_function_body_placer =
152 InlinedFunctionBodyPlacer::SingleDevice();
153 // When the remote call is a partition of a multi-device function, and the
154 // Send/Recv nodes depend on the frame names in the original graph, we must
155 // retain the original frame names. Since the graph contains a single function
156 // call, we do not need to add a unique prefix to frame names inside the
157 // inlined graph.
158 inline_options.uniquify_frame_names = false;
159 std::unique_ptr<FunctionBody> function_body;
160 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*func_def, attrs, &pruned_flib_def,
161 &function_body));
162 TF_RETURN_IF_ERROR(InlineFunctionBody(pruned_flib_def, &g, function_node,
163 function_body.get(), inline_options));
164
165 g.ToGraphDef(gdef);
166
167 // Since we have inlined `function_node`, we can prune its function definition
168 // from the library.
169 *(gdef->mutable_library()) = flib_def.ReachableDefinitions(*gdef).ToProto();
170
171 return Status::OK();
172 }
173
~ClusterFunctionLibraryRuntime()174 ClusterFunctionLibraryRuntime::~ClusterFunctionLibraryRuntime() {
175 for (auto& function_data : function_data_) {
176 worker_session_->worker_cache()->ReleaseWorker(function_data.target,
177 function_data.wi);
178 }
179 }
180
Instantiate(const string & function_name,const FunctionLibraryDefinition & lib_def,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::LocalHandle * handle,FunctionLibraryRuntime::DoneCallback done)181 void ClusterFunctionLibraryRuntime::Instantiate(
182 const string& function_name, const FunctionLibraryDefinition& lib_def,
183 AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options,
184 FunctionLibraryRuntime::LocalHandle* handle,
185 FunctionLibraryRuntime::DoneCallback done) {
186 auto target = options.target;
187 VLOG(1) << "CFLR::Instantiate: " << function_name << " on " << target
188 << " (this: " << this << ")";
189 std::shared_ptr<WorkerCacheInterface> worker_cache =
190 worker_session_->GetSharedWorkerCache();
191 WorkerInterface* wi = worker_cache->GetOrCreateWorker(target);
192
193 if (wi == nullptr) {
194 std::vector<string> workers;
195 worker_session_->worker_cache()->ListWorkers(&workers);
196 done(errors::InvalidArgument(
197 "Could not find worker with target: ", target,
198 " Available workers: ", absl::StrJoin(workers, ", ")));
199 return;
200 }
201
202 // Make RPC and obtain a graph handle.
203 GraphDef gdef;
204 auto* send_keys = new std::vector<string>;
205 auto* recv_keys = new std::vector<string>;
206 auto construct_graph_fn = [&](const FunctionLibraryDefinition* lib_def) {
207 const FunctionDef* fdef = lib_def->Find(function_name);
208 const OpDef& sig = fdef->signature();
209 TF_RETURN_IF_ERROR(ConstructFunctionGraph(sig, attrs, options, *lib_def,
210 &gdef, send_keys, recv_keys));
211 return Status::OK();
212 };
213 Status s;
214 if (options.lib_def) {
215 s = construct_graph_fn(options.lib_def);
216 } else {
217 s = construct_graph_fn(&lib_def);
218 }
219 if (!s.ok()) {
220 done(s);
221 return;
222 }
223
224 auto* req = new RegisterGraphRequest;
225 req->set_session_handle(worker_session_->session_name());
226 req->set_create_worker_session_called(create_worker_session_called_);
227 *req->mutable_graph_def() = std::move(gdef);
228 StripDefaultAttributes(*OpRegistry::Global(),
229 req->mutable_graph_def()->mutable_node());
230 req->mutable_graph_options()
231 ->mutable_optimizer_options()
232 ->set_do_function_inlining(true);
233 auto* resp = new RegisterGraphResponse;
234
235 wi->RegisterGraphAsync(
236 req, resp,
237 [this, handle, req, resp, worker_cache, wi, function_name, target,
238 send_keys, recv_keys, done](const Status& status) {
239 if (status.ok()) {
240 mutex_lock l(mu_);
241 *handle = function_data_.size();
242 function_data_.push_back(FunctionData(resp->graph_handle(), target,
243 worker_cache, wi, *send_keys,
244 *recv_keys));
245 VLOG(1) << "CFLR::Instantiate: [Success] " << function_name << " on "
246 << target << " (this: " << this << ")"
247 << " with handle: " << *handle;
248 }
249 done(status);
250 delete recv_keys;
251 delete send_keys;
252 delete req;
253 delete resp;
254 });
255 }
256
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::LocalHandle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,FunctionLibraryRuntime::DoneCallback done)257 void ClusterFunctionLibraryRuntime::Run(
258 const FunctionLibraryRuntime::Options& opts,
259 FunctionLibraryRuntime::LocalHandle handle, gtl::ArraySlice<Tensor> args,
260 std::vector<Tensor>* rets, FunctionLibraryRuntime::DoneCallback done) {
261 FunctionData* function_data = nullptr;
262 {
263 mutex_lock l(mu_);
264 CHECK_LE(handle, function_data_.size());
265 function_data = &function_data_[handle];
266 }
267
268 WorkerInterface* wi = function_data->wi;
269
270 if (wi == nullptr) {
271 done(errors::Internal("Could not find worker"));
272 return;
273 }
274
275 RunGraphRequest* req = new RunGraphRequest;
276 req->set_session_handle(worker_session_->session_name());
277 req->set_create_worker_session_called(create_worker_session_called_);
278 req->set_graph_handle(function_data->graph_handle);
279 req->set_step_id(opts.step_id);
280 int i = 0;
281 for (const auto& send_key : function_data->send_keys) {
282 NamedTensorProto* send = req->add_send();
283 send->set_name(send_key);
284 args[i].AsProtoTensorContent(send->mutable_tensor());
285 i++;
286 }
287 const std::vector<string>& recv_keys = function_data->recv_keys;
288 for (const auto& recv_key : recv_keys) {
289 req->add_recv_key(recv_key);
290 }
291
292 RunGraphResponse* resp = new RunGraphResponse();
293 CallOptions* call_options = new CallOptions();
294 wi->RunGraphAsync(
295 call_options, req, resp,
296 [call_options, req, resp, rets, recv_keys, done](const Status& status) {
297 Status* local_status = new Status(status);
298 auto cleanup =
299 gtl::MakeCleanup([call_options, req, resp, local_status, done] {
300 done(*local_status);
301 delete call_options;
302 delete req;
303 delete resp;
304 delete local_status;
305 });
306 if (!local_status->ok()) {
307 return;
308 }
309 std::map<string, TensorProto*> mapped_recvs;
310 for (auto& recv : *resp->mutable_recv()) {
311 mapped_recvs[recv.name()] = recv.mutable_tensor();
312 }
313
314 for (const auto& recv_key : recv_keys) {
315 TensorProto* tp = mapped_recvs[recv_key];
316 if (tp == nullptr) {
317 local_status->Update(
318 errors::Internal("Could not find key: ", recv_key));
319 return;
320 }
321 Tensor t;
322 if (t.FromProto(*tp)) {
323 rets->push_back(t);
324 } else {
325 local_status->Update(errors::Internal(
326 "Could not convert tensor proto: ", tp->DebugString()));
327 return;
328 }
329 }
330 });
331 }
332
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::LocalHandle handle,gtl::ArraySlice<FunctionArg> args,std::vector<FunctionRet> * rets,FunctionLibraryRuntime::DoneCallback done)333 void ClusterFunctionLibraryRuntime::Run(
334 const FunctionLibraryRuntime::Options& opts,
335 FunctionLibraryRuntime::LocalHandle handle,
336 gtl::ArraySlice<FunctionArg> args, std::vector<FunctionRet>* rets,
337 FunctionLibraryRuntime::DoneCallback done) {
338 std::vector<Tensor> tensors;
339 for (const auto& arg : args) {
340 if (arg.index() == 0) {
341 tensors.push_back(absl::get<Tensor>(arg));
342 } else {
343 done(
344 errors::Internal("ClusterFunctionLibraryRuntime doesn't support "
345 "eager::RemoteTensorHandle."));
346 return;
347 }
348 }
349 std::vector<Tensor>* ret_tensors = new std::vector<Tensor>;
350 return Run(opts, handle, tensors, ret_tensors,
351 [rets, ret_tensors, done = std::move(done)](const Status& s) {
352 if (s.ok()) {
353 for (const auto& t : *ret_tensors) {
354 rets->push_back(t);
355 }
356 }
357 delete ret_tensors;
358 done(s);
359 });
360 }
361
CleanUp(uint64 step_id,FunctionLibraryRuntime::LocalHandle handle,FunctionLibraryRuntime::DoneCallback done)362 void ClusterFunctionLibraryRuntime::CleanUp(
363 uint64 step_id, FunctionLibraryRuntime::LocalHandle handle,
364 FunctionLibraryRuntime::DoneCallback done) {
365 FunctionData* function_data = nullptr;
366 {
367 mutex_lock l(mu_);
368 DCHECK_LE(handle, function_data_.size());
369 function_data = &function_data_[handle];
370 }
371
372 WorkerInterface* wi = function_data->wi;
373
374 if (wi == nullptr) {
375 done(errors::Internal("Could not find worker"));
376 return;
377 }
378 CleanupGraphRequest* cleanup_req = new CleanupGraphRequest;
379 cleanup_req->set_step_id(step_id);
380 CleanupGraphResponse* cleanup_resp = new CleanupGraphResponse;
381 wi->CleanupGraphAsync(
382 cleanup_req, cleanup_resp,
383 [cleanup_req, cleanup_resp, done](const Status& cleanup_status) {
384 done(cleanup_status);
385 delete cleanup_req;
386 delete cleanup_resp;
387 });
388 }
389
390 } // namespace tensorflow
391