1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/cluster_function_library_runtime.h"
16 
17 #include <map>
18 
19 #include "tensorflow/core/common_runtime/function.h"
20 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
21 #include "tensorflow/core/distributed_runtime/worker_session.h"
22 #include "tensorflow/core/framework/function.h"
23 #include "tensorflow/core/framework/graph_def_util.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/node_def_builder.h"
26 #include "tensorflow/core/framework/tensor.pb.h"
27 #include "tensorflow/core/graph/node_builder.h"
28 #include "tensorflow/core/lib/gtl/cleanup.h"
29 #include "tensorflow/core/lib/random/random.h"
30 #include "tensorflow/core/protobuf/named_tensor.pb.h"
31 #include "tensorflow/core/protobuf/worker.pb.h"
32 
33 namespace tensorflow {
34 
35 /* static */
ConstructFunctionGraph(const OpDef & sig,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,const FunctionLibraryDefinition & flib_def,GraphDef * gdef,std::vector<string> * send_keys,std::vector<string> * recv_keys)36 Status ClusterFunctionLibraryRuntime::ConstructFunctionGraph(
37     const OpDef& sig, AttrSlice attrs,
38     const FunctionLibraryRuntime::InstantiateOptions& options,
39     const FunctionLibraryDefinition& flib_def, GraphDef* gdef,
40     std::vector<string>* send_keys, std::vector<string>* recv_keys) {
41   const string& target = options.target;
42   const string& func_name = sig.name();
43   const FunctionDef* func_def = flib_def.Find(sig.name());
44   if (func_def == nullptr) {
45     return errors::InvalidArgument("Function ", func_name,
46                                    " not found in flib_def.");
47   }
48 
49   // Build a smaller flib_def containing only the functions used by the given
50   // function, plus that function itself.
51   FunctionLibraryDefinition pruned_flib_def =
52       flib_def.ReachableDefinitions(*func_def);
53   TF_RETURN_IF_ERROR(pruned_flib_def.CopyFunctionDefFrom(func_name, flib_def));
54 
55   Graph g(pruned_flib_def);
56 
57   std::vector<Node*> input_nodes;
58   input_nodes.reserve(sig.input_arg_size());
59 
60   // Construct recv nodes for each input argument.
61   int i = 0;
62   for (const auto& in : sig.input_arg()) {
63     // Resolve the input type.
64     bool is_type_list;
65     DataTypeVector dtypes;
66     TF_RETURN_IF_ERROR(ArgNumType(attrs, in, &is_type_list, &dtypes));
67     // TODO(rohanj): Handle list and variadic number of attrs. Here and below.
68     if (is_type_list || dtypes.size() > 1) {
69       return errors::Unimplemented("Input arg: ", in.name(),
70                                    " has a list type or variadic number of "
71                                    "attrs. Currently unsupported.");
72     }
73 
74     auto input_node_builder =
75         NodeDefBuilder(strings::StrCat("_recv_", in.name(), "_", i), "_Recv")
76             .Attr("tensor_type", dtypes[0])
77             .Attr("tensor_name", in.name())
78             .Attr("send_device", target)
79             .Attr("recv_device", target)
80             .Attr("send_device_incarnation", 1)
81             .Attr("client_terminated", true)
82             .Device(target);
83 
84     Node* input_node;
85     TF_RETURN_IF_ERROR(
86         NodeBuilder(input_node_builder).Finalize(&g, &input_node));
87     input_nodes.push_back(input_node);
88 
89     // src_incarnation = 1 works because the transfer is across the same device.
90     // TODO(rohanj): Find the src_incarnation for the remote device and set it.
91     const string& key = Rendezvous::CreateKey(
92         target, 1 /* src_incarnation */, target, in.name(), FrameAndIter(0, 0));
93     send_keys->push_back(key);
94     ++i;
95   }
96 
97   NodeDef function_node_def;
98   function_node_def.set_name(func_name);
99   function_node_def.set_op(func_name);
100   i = 0;
101   function_node_def.set_device(target);
102   for (const auto& p : attrs) {
103     (*function_node_def.mutable_attr())[p.first] = p.second;
104   }
105   Status status;
106   Node* function_node = g.AddNode(std::move(function_node_def), &status);
107   TF_RETURN_IF_ERROR(status);
108   for (size_t i = 0; i < input_nodes.size(); ++i) {
109     g.AddEdge(input_nodes[i], 0, function_node, i);
110   }
111 
112   // Construct output nodes for each output.
113   i = 0;
114   for (const auto& out : sig.output_arg()) {
115     // Resolve the output type.
116     bool is_type_list;
117     DataTypeVector dtypes;
118     TF_RETURN_IF_ERROR(ArgNumType(attrs, out, &is_type_list, &dtypes));
119     // TODO(rohanj): Handle list and variadic number of attrs. Here and below.
120     if (is_type_list || dtypes.size() > 1) {
121       return errors::Unimplemented("Output arg: ", out.name(),
122                                    " has a list type or variadic number of "
123                                    "attrs. Currently unsupported.");
124     }
125 
126     auto output_node_builder =
127         NodeDefBuilder(strings::StrCat("_send_", out.name(), "_", i), "_Send")
128             .Input(func_name, i, dtypes[0])
129             .Attr("tensor_name", out.name())
130             .Attr("send_device", target)
131             .Attr("recv_device", target)
132             .Attr("send_device_incarnation", 1)
133             .Attr("client_terminated", true)
134             .Device(target);
135 
136     Node* output_node;
137     TF_RETURN_IF_ERROR(
138         NodeBuilder(output_node_builder).Finalize(&g, &output_node));
139 
140     g.AddEdge(function_node, i, output_node, 0);
141 
142     const string& key =
143         Rendezvous::CreateKey(target, 1 /* src_incarnation */, target,
144                               out.name(), FrameAndIter(0, 0));
145     recv_keys->push_back(key);
146     ++i;
147   }
148 
149   // Inline function node into the graph.
150   InlineFunctionBodyOptions inline_options;
151   inline_options.inlined_function_body_placer =
152       InlinedFunctionBodyPlacer::SingleDevice();
153   // When the remote call is a partition of a multi-device function, and the
154   // Send/Recv nodes depend on the frame names in the original graph, we must
155   // retain the original frame names. Since the graph contains a single function
156   // call, we do not need to add a unique prefix to frame names inside the
157   // inlined graph.
158   inline_options.uniquify_frame_names = false;
159   std::unique_ptr<FunctionBody> function_body;
160   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*func_def, attrs, &pruned_flib_def,
161                                              &function_body));
162   TF_RETURN_IF_ERROR(InlineFunctionBody(pruned_flib_def, &g, function_node,
163                                         function_body.get(), inline_options));
164 
165   g.ToGraphDef(gdef);
166 
167   // Since we have inlined `function_node`, we can prune its function definition
168   // from the library.
169   *(gdef->mutable_library()) = flib_def.ReachableDefinitions(*gdef).ToProto();
170 
171   return Status::OK();
172 }
173 
~ClusterFunctionLibraryRuntime()174 ClusterFunctionLibraryRuntime::~ClusterFunctionLibraryRuntime() {
175   for (auto& function_data : function_data_) {
176     worker_session_->worker_cache()->ReleaseWorker(function_data.target,
177                                                    function_data.wi);
178   }
179 }
180 
Instantiate(const string & function_name,const FunctionLibraryDefinition & lib_def,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::LocalHandle * handle,FunctionLibraryRuntime::DoneCallback done)181 void ClusterFunctionLibraryRuntime::Instantiate(
182     const string& function_name, const FunctionLibraryDefinition& lib_def,
183     AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options,
184     FunctionLibraryRuntime::LocalHandle* handle,
185     FunctionLibraryRuntime::DoneCallback done) {
186   auto target = options.target;
187   VLOG(1) << "CFLR::Instantiate: " << function_name << " on " << target
188           << " (this: " << this << ")";
189   std::shared_ptr<WorkerCacheInterface> worker_cache =
190       worker_session_->GetSharedWorkerCache();
191   WorkerInterface* wi = worker_cache->GetOrCreateWorker(target);
192 
193   if (wi == nullptr) {
194     std::vector<string> workers;
195     worker_session_->worker_cache()->ListWorkers(&workers);
196     done(errors::InvalidArgument(
197         "Could not find worker with target: ", target,
198         " Available workers: ", absl::StrJoin(workers, ", ")));
199     return;
200   }
201 
202   // Make RPC and obtain a graph handle.
203   GraphDef gdef;
204   auto* send_keys = new std::vector<string>;
205   auto* recv_keys = new std::vector<string>;
206   auto construct_graph_fn = [&](const FunctionLibraryDefinition* lib_def) {
207     const FunctionDef* fdef = lib_def->Find(function_name);
208     const OpDef& sig = fdef->signature();
209     TF_RETURN_IF_ERROR(ConstructFunctionGraph(sig, attrs, options, *lib_def,
210                                               &gdef, send_keys, recv_keys));
211     return Status::OK();
212   };
213   Status s;
214   if (options.lib_def) {
215     s = construct_graph_fn(options.lib_def);
216   } else {
217     s = construct_graph_fn(&lib_def);
218   }
219   if (!s.ok()) {
220     done(s);
221     return;
222   }
223 
224   auto* req = new RegisterGraphRequest;
225   req->set_session_handle(worker_session_->session_name());
226   req->set_create_worker_session_called(create_worker_session_called_);
227   *req->mutable_graph_def() = std::move(gdef);
228   StripDefaultAttributes(*OpRegistry::Global(),
229                          req->mutable_graph_def()->mutable_node());
230   req->mutable_graph_options()
231       ->mutable_optimizer_options()
232       ->set_do_function_inlining(true);
233   auto* resp = new RegisterGraphResponse;
234 
235   wi->RegisterGraphAsync(
236       req, resp,
237       [this, handle, req, resp, worker_cache, wi, function_name, target,
238        send_keys, recv_keys, done](const Status& status) {
239         if (status.ok()) {
240           mutex_lock l(mu_);
241           *handle = function_data_.size();
242           function_data_.push_back(FunctionData(resp->graph_handle(), target,
243                                                 worker_cache, wi, *send_keys,
244                                                 *recv_keys));
245           VLOG(1) << "CFLR::Instantiate: [Success] " << function_name << " on "
246                   << target << " (this: " << this << ")"
247                   << " with handle: " << *handle;
248         }
249         done(status);
250         delete recv_keys;
251         delete send_keys;
252         delete req;
253         delete resp;
254       });
255 }
256 
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::LocalHandle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,FunctionLibraryRuntime::DoneCallback done)257 void ClusterFunctionLibraryRuntime::Run(
258     const FunctionLibraryRuntime::Options& opts,
259     FunctionLibraryRuntime::LocalHandle handle, gtl::ArraySlice<Tensor> args,
260     std::vector<Tensor>* rets, FunctionLibraryRuntime::DoneCallback done) {
261   FunctionData* function_data = nullptr;
262   {
263     mutex_lock l(mu_);
264     CHECK_LE(handle, function_data_.size());
265     function_data = &function_data_[handle];
266   }
267 
268   WorkerInterface* wi = function_data->wi;
269 
270   if (wi == nullptr) {
271     done(errors::Internal("Could not find worker"));
272     return;
273   }
274 
275   RunGraphRequest* req = new RunGraphRequest;
276   req->set_session_handle(worker_session_->session_name());
277   req->set_create_worker_session_called(create_worker_session_called_);
278   req->set_graph_handle(function_data->graph_handle);
279   req->set_step_id(opts.step_id);
280   int i = 0;
281   for (const auto& send_key : function_data->send_keys) {
282     NamedTensorProto* send = req->add_send();
283     send->set_name(send_key);
284     args[i].AsProtoTensorContent(send->mutable_tensor());
285     i++;
286   }
287   const std::vector<string>& recv_keys = function_data->recv_keys;
288   for (const auto& recv_key : recv_keys) {
289     req->add_recv_key(recv_key);
290   }
291 
292   RunGraphResponse* resp = new RunGraphResponse();
293   CallOptions* call_options = new CallOptions();
294   wi->RunGraphAsync(
295       call_options, req, resp,
296       [call_options, req, resp, rets, recv_keys, done](const Status& status) {
297         Status* local_status = new Status(status);
298         auto cleanup =
299             gtl::MakeCleanup([call_options, req, resp, local_status, done] {
300               done(*local_status);
301               delete call_options;
302               delete req;
303               delete resp;
304               delete local_status;
305             });
306         if (!local_status->ok()) {
307           return;
308         }
309         std::map<string, TensorProto*> mapped_recvs;
310         for (auto& recv : *resp->mutable_recv()) {
311           mapped_recvs[recv.name()] = recv.mutable_tensor();
312         }
313 
314         for (const auto& recv_key : recv_keys) {
315           TensorProto* tp = mapped_recvs[recv_key];
316           if (tp == nullptr) {
317             local_status->Update(
318                 errors::Internal("Could not find key: ", recv_key));
319             return;
320           }
321           Tensor t;
322           if (t.FromProto(*tp)) {
323             rets->push_back(t);
324           } else {
325             local_status->Update(errors::Internal(
326                 "Could not convert tensor proto: ", tp->DebugString()));
327             return;
328           }
329         }
330       });
331 }
332 
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::LocalHandle handle,gtl::ArraySlice<FunctionArg> args,std::vector<FunctionRet> * rets,FunctionLibraryRuntime::DoneCallback done)333 void ClusterFunctionLibraryRuntime::Run(
334     const FunctionLibraryRuntime::Options& opts,
335     FunctionLibraryRuntime::LocalHandle handle,
336     gtl::ArraySlice<FunctionArg> args, std::vector<FunctionRet>* rets,
337     FunctionLibraryRuntime::DoneCallback done) {
338   std::vector<Tensor> tensors;
339   for (const auto& arg : args) {
340     if (arg.index() == 0) {
341       tensors.push_back(absl::get<Tensor>(arg));
342     } else {
343       done(
344           errors::Internal("ClusterFunctionLibraryRuntime doesn't support "
345                            "eager::RemoteTensorHandle."));
346       return;
347     }
348   }
349   std::vector<Tensor>* ret_tensors = new std::vector<Tensor>;
350   return Run(opts, handle, tensors, ret_tensors,
351              [rets, ret_tensors, done = std::move(done)](const Status& s) {
352                if (s.ok()) {
353                  for (const auto& t : *ret_tensors) {
354                    rets->push_back(t);
355                  }
356                }
357                delete ret_tensors;
358                done(s);
359              });
360 }
361 
CleanUp(uint64 step_id,FunctionLibraryRuntime::LocalHandle handle,FunctionLibraryRuntime::DoneCallback done)362 void ClusterFunctionLibraryRuntime::CleanUp(
363     uint64 step_id, FunctionLibraryRuntime::LocalHandle handle,
364     FunctionLibraryRuntime::DoneCallback done) {
365   FunctionData* function_data = nullptr;
366   {
367     mutex_lock l(mu_);
368     DCHECK_LE(handle, function_data_.size());
369     function_data = &function_data_[handle];
370   }
371 
372   WorkerInterface* wi = function_data->wi;
373 
374   if (wi == nullptr) {
375     done(errors::Internal("Could not find worker"));
376     return;
377   }
378   CleanupGraphRequest* cleanup_req = new CleanupGraphRequest;
379   cleanup_req->set_step_id(step_id);
380   CleanupGraphResponse* cleanup_resp = new CleanupGraphResponse;
381   wi->CleanupGraphAsync(
382       cleanup_req, cleanup_resp,
383       [cleanup_req, cleanup_resp, done](const Status& cleanup_status) {
384         done(cleanup_status);
385         delete cleanup_req;
386         delete cleanup_resp;
387       });
388 }
389 
390 }  // namespace tensorflow
391