1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
16 
17 #include <utility>
18 
19 #include "absl/strings/str_join.h"
20 #include "tensorflow/core/common_runtime/device_set.h"
21 #include "tensorflow/core/common_runtime/function.h"
22 #include "tensorflow/core/common_runtime/optimization_registry.h"
23 #include "tensorflow/core/common_runtime/partitioning_utils.h"
24 #include "tensorflow/core/common_runtime/placer.h"
25 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
26 #include "tensorflow/core/common_runtime/rendezvous_util.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/framework/graph_to_functiondef.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/graph/graph.h"
33 #include "tensorflow/core/graph/graph_constructor.h"
34 #include "tensorflow/core/graph/graph_partition.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/lib/gtl/map_util.h"
37 #include "tensorflow/core/util/device_name_utils.h"
38 #include "tensorflow/core/util/ptr_util.h"
39 #include "tensorflow/core/util/reffed_status_callback.h"
40 
41 namespace tensorflow {
42 
43 const char ProcessFunctionLibraryRuntime::kDefaultFLRDevice[] = "null";
44 
DistributedInit(DistributedFunctionLibraryRuntime * parent,const string & function_name,const FunctionLibraryDefinition & lib_def,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options)45 Status ProcessFunctionLibraryRuntime::FunctionData::DistributedInit(
46     DistributedFunctionLibraryRuntime* parent, const string& function_name,
47     const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
48     const FunctionLibraryRuntime::InstantiateOptions& options) {
49   mutex_lock l(mu_);
50   if (!init_started_) {
51     init_started_ = true;
52     init_result_ = parent->Instantiate(function_name, lib_def, attrs, options,
53                                        &local_handle_);
54   }
55   return init_result_;
56 }
57 
ProcessFunctionLibraryRuntime(const DeviceMgr * device_mgr,Env * env,int graph_def_version,const FunctionLibraryDefinition * lib_def,const OptimizerOptions & optimizer_options,thread::ThreadPool * default_thread_pool,DistributedFunctionLibraryRuntime * parent)58 ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
59     const DeviceMgr* device_mgr, Env* env, int graph_def_version,
60     const FunctionLibraryDefinition* lib_def,
61     const OptimizerOptions& optimizer_options,
62     thread::ThreadPool* default_thread_pool,
63     DistributedFunctionLibraryRuntime* parent)
64     : env_(env),
65       device_mgr_(device_mgr),
66       lib_def_(lib_def),
67       default_thread_pool_(default_thread_pool),
68       next_handle_(0),
69       parent_(parent) {
70   if (device_mgr == nullptr) {
71     flr_map_[nullptr] = NewFunctionLibraryRuntime(
72         nullptr, env, nullptr, graph_def_version, lib_def_, default_thread_pool,
73         optimizer_options, this);
74     return;
75   }
76   for (Device* d : device_mgr->ListDevices()) {
77     flr_map_[d] = NewFunctionLibraryRuntime(
78         device_mgr, env, d, graph_def_version, lib_def_, default_thread_pool,
79         optimizer_options, this);
80   }
81 }
82 
ProcessFunctionLibraryRuntime(const DeviceMgr * device_mgr,Env * env,int graph_def_version,const FunctionLibraryDefinition * lib_def,const OptimizerOptions & optimizer_options,CustomKernelCreator custom_kernel_creator,thread::ThreadPool * default_thread_pool,DistributedFunctionLibraryRuntime * parent)83 ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
84     const DeviceMgr* device_mgr, Env* env, int graph_def_version,
85     const FunctionLibraryDefinition* lib_def,
86     const OptimizerOptions& optimizer_options,
87     CustomKernelCreator custom_kernel_creator,
88     thread::ThreadPool* default_thread_pool,
89     DistributedFunctionLibraryRuntime* parent)
90     : env_(env),
91       device_mgr_(device_mgr),
92       lib_def_(lib_def),
93       default_thread_pool_(default_thread_pool),
94       next_handle_(0),
95       parent_(parent) {
96   if (device_mgr == nullptr) {
97     flr_map_[nullptr] = NewFunctionLibraryRuntime(
98         nullptr, env, nullptr, graph_def_version, lib_def_, default_thread_pool,
99         optimizer_options, std::move(custom_kernel_creator), this);
100     return;
101   }
102   for (Device* d : device_mgr->ListDevices()) {
103     flr_map_[d] = NewFunctionLibraryRuntime(
104         device_mgr, env, d, graph_def_version, lib_def_, default_thread_pool,
105         optimizer_options, custom_kernel_creator, this);
106   }
107 }
108 
109 /* static */
SendTensors(const string & source_device,const string & target_device,const string & key_prefix,int64 src_incarnation,gtl::ArraySlice<Tensor> tensors_to_send,DeviceContext * device_context,const std::vector<AllocatorAttributes> & alloc_attrs,Rendezvous * rendezvous)110 Status ProcessFunctionLibraryRuntime::SendTensors(
111     const string& source_device, const string& target_device,
112     const string& key_prefix, int64 src_incarnation,
113     gtl::ArraySlice<Tensor> tensors_to_send, DeviceContext* device_context,
114     const std::vector<AllocatorAttributes>& alloc_attrs,
115     Rendezvous* rendezvous) {
116   std::vector<string> keys;
117   for (int i = 0; i < tensors_to_send.size(); ++i) {
118     string name = strings::StrCat(key_prefix, i);
119     string key = Rendezvous::CreateKey(source_device, src_incarnation,
120                                        target_device, name, FrameAndIter(0, 0));
121     keys.push_back(key);
122   }
123   TF_RETURN_IF_ERROR(SendTensorsToRendezvous(
124       rendezvous, device_context, alloc_attrs, keys, tensors_to_send));
125   return Status::OK();
126 }
127 
128 /* static */
ReceiveTensorsAsync(const string & source_device,const string & target_device,const string & key_prefix,int64 src_incarnation,int64 num_tensors,DeviceContext * device_context,const std::vector<AllocatorAttributes> & alloc_attrs,Rendezvous * rendezvous,std::vector<Tensor> * received_tensors,StatusCallback done)129 void ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
130     const string& source_device, const string& target_device,
131     const string& key_prefix, int64 src_incarnation, int64 num_tensors,
132     DeviceContext* device_context,
133     const std::vector<AllocatorAttributes>& alloc_attrs, Rendezvous* rendezvous,
134     std::vector<Tensor>* received_tensors, StatusCallback done) {
135   std::vector<string> keys;
136   for (int64 i = 0; i < num_tensors; ++i) {
137     string name = strings::StrCat(key_prefix, i);
138     string key = Rendezvous::CreateKey(source_device, src_incarnation,
139                                        target_device, name, FrameAndIter(0, 0));
140     keys.push_back(key);
141   }
142   RecvOutputsFromRendezvousAsync(rendezvous, device_context, alloc_attrs, keys,
143                                  received_tensors, std::move(done));
144 }
145 
GetDeviceIncarnation(const string & device_name,int64 * incarnation) const146 Status ProcessFunctionLibraryRuntime::GetDeviceIncarnation(
147     const string& device_name, int64* incarnation) const {
148   FunctionLibraryRuntime* flr = GetFLR(device_name);
149   if (flr == nullptr) {
150     return errors::InvalidArgument("Device name: ", device_name, " not found");
151   }
152   *incarnation = flr->device()->attributes().incarnation();
153   return Status::OK();
154 }
155 
GetDeviceContext(const string & device_name,DeviceContext ** device_context) const156 Status ProcessFunctionLibraryRuntime::GetDeviceContext(
157     const string& device_name, DeviceContext** device_context) const {
158   *device_context = nullptr;
159   FunctionLibraryRuntime* flr = GetFLR(device_name);
160   if (flr == nullptr) {
161     return errors::InvalidArgument("Device name: ", device_name, " not found.");
162   }
163   Device* device = flr->device();
164   string device_type = device->parsed_name().type;
165   if (device_type == "CPU" || device_type == "TPU_SYSTEM") {
166     // "TPU_SYSTEM" indicates that `device` is a CPU.
167     return Status::OK();
168   }
169   if (device_type == "GPU" || device_type == "TPU") {
170     auto* dev_info = flr->device()->tensorflow_gpu_device_info();
171     if (dev_info) {
172       *device_context = dev_info->default_context;
173       return Status::OK();
174     }
175   }
176   return errors::Internal("Device type: ", device_type,
177                           " is currently unsupported for remote ",
178                           "function executions");
179 }
180 
GetFLR(const string & device_name) const181 FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR(
182     const string& device_name) const {
183   Device* device = nullptr;
184   if (device_name != kDefaultFLRDevice) {
185     if (!device_mgr_->LookupDevice(device_name, &device).ok()) {
186       VLOG(1) << "Could not find device: " << device_name;
187       return nullptr;
188     }
189   }
190   const auto& iter = flr_map_.find(device);
191   if (iter == flr_map_.end()) {
192     LOG(ERROR) << "Could not find device: " << device_name;
193     return nullptr;
194   }
195   return iter->second.get();
196 }
197 
AddHandle(const string & function_key,const string & device_name,FunctionLibraryRuntime::LocalHandle local_handle)198 FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandle(
199     const string& function_key, const string& device_name,
200     FunctionLibraryRuntime::LocalHandle local_handle) {
201   mutex_lock l(mu_);
202   return AddHandleLocked(function_key, device_name, local_handle);
203 }
204 
AddHandleLocked(const string & function_key,const string & device_name,FunctionLibraryRuntime::LocalHandle local_handle)205 FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandleLocked(
206     const string& function_key, const string& device_name,
207     FunctionLibraryRuntime::LocalHandle local_handle) {
208   auto h = next_handle_;
209   function_data_[h] =
210       MakeUnique<FunctionData>(device_name, local_handle, function_key);
211   table_[function_key] = h;
212   next_handle_++;
213   return h;
214 }
215 
216 FunctionLibraryRuntime::Handle
AddMultiDeviceHandle(std::unique_ptr<MultiDeviceFunctionData> data,const string & function_key)217 ProcessFunctionLibraryRuntime::AddMultiDeviceHandle(
218     std::unique_ptr<MultiDeviceFunctionData> data, const string& function_key) {
219   mutex_lock l(mu_);
220   auto h = next_handle_;
221   mdevice_data_[h] = std::move(data);
222   table_[function_key] = h;
223   next_handle_++;
224   return h;
225 }
226 
GetHandle(const string & function_key) const227 FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::GetHandle(
228     const string& function_key) const {
229   tf_shared_lock l(mu_);
230   return gtl::FindWithDefault(table_, function_key, kInvalidHandle);
231 }
232 
IsInstantiatedOnDevice(const string & device_name,FunctionLibraryRuntime::Handle handle) const233 bool ProcessFunctionLibraryRuntime::IsInstantiatedOnDevice(
234     const string& device_name, FunctionLibraryRuntime::Handle handle) const {
235   return GetHandleOnDevice(device_name, handle) != kInvalidHandle;
236 }
237 
238 FunctionLibraryRuntime::LocalHandle
GetHandleOnDevice(const string & device_name,FunctionLibraryRuntime::Handle handle) const239 ProcessFunctionLibraryRuntime::GetHandleOnDevice(
240     const string& device_name, FunctionLibraryRuntime::Handle handle) const {
241   tf_shared_lock l(mu_);
242 
243   auto miter = mdevice_data_.find(handle);
244   if (miter != mdevice_data_.end()) {
245     return kInvalidLocalHandle;
246   }
247 
248   auto iter = function_data_.find(handle);
249   if (iter == function_data_.end()) {
250     return kInvalidLocalHandle;
251   }
252   FunctionData* function_data = iter->second.get();
253   if (function_data->target_device() != device_name) {
254     return kInvalidLocalHandle;
255   }
256   return function_data->local_handle();
257 }
258 
GetDeviceName(FunctionLibraryRuntime::Handle handle) const259 string ProcessFunctionLibraryRuntime::GetDeviceName(
260     FunctionLibraryRuntime::Handle handle) const {
261   tf_shared_lock l(mu_);
262   auto iter = function_data_.find(handle);
263   CHECK(iter != function_data_.end());
264   FunctionData* function_data = iter->second.get();
265   return function_data->target_device();
266 }
267 
268 ProcessFunctionLibraryRuntime::MultiDeviceFunctionData*
IsMultiDevice(FunctionLibraryRuntime::Handle handle) const269 ProcessFunctionLibraryRuntime::IsMultiDevice(
270     FunctionLibraryRuntime::Handle handle) const {
271   tf_shared_lock l(mu_);
272   const auto& it = mdevice_data_.find(handle);
273   if (it != mdevice_data_.end()) {
274     return it->second.get();
275   }
276   return nullptr;
277 }
278 
279 namespace {
280 // Sets `group` to the first colocation group specified in `node`. If no
281 // group is specified, does not touch `group`.
GetColocationGroup(const Node * node,string * group)282 void GetColocationGroup(const Node* node, string* group) {
283   // We hoist the conversion from C-style string literal to string here,
284   // so that we can avoid the many repeated calls to strlen().
285   static const StringPiece kColocationAttrNameStringPiece(kColocationAttrName);
286   const AttrValue* attr_value =
287       node->attrs().Find(kColocationAttrNameStringPiece);
288   if (attr_value != nullptr && attr_value->has_list() &&
289       attr_value->list().s_size() > 0) {
290     *group = attr_value->list().s(0);
291   }
292 }
293 
AssignedOrRequestedDeviceName(const Node & node)294 const string* AssignedOrRequestedDeviceName(const Node& node) {
295   if (node.has_assigned_device_name()) {
296     return &node.assigned_device_name();
297   }
298   return &node.requested_device();
299 }
300 
301 }  // anonymous namespace
302 
PinArgsAndRets(const std::vector<string> & input_devices,const std::vector<string> & output_devices,const DeviceSet & device_set,Graph * graph) const303 Status ProcessFunctionLibraryRuntime::PinArgsAndRets(
304     const std::vector<string>& input_devices,
305     const std::vector<string>& output_devices, const DeviceSet& device_set,
306     Graph* graph) const {
307   // If output_devices are not specified, we want to set the output device
308   // based on the device of the output producing node. The output producing
309   // node can be an arg node because functions can simply return their
310   // arguments. To make sure that the output producing nodes have assigned
311   // devices, we assign them to arguments first.
312   for (Node* node : graph->op_nodes()) {
313     if (node->IsArg()) {
314       const AttrValue* attr_value;
315       TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
316       int64 index = attr_value->i();
317       node->set_assigned_device_name(input_devices[index]);
318     }
319   }
320 
321   for (Node* node : graph->op_nodes()) {
322     if (node->IsRetval()) {
323       if (output_devices.empty()) {
324         VLOG(3) << "Trying to determine device for node " << node->name();
325         // If output_devices are empty, the node producing retval
326         // must have explicitly assigned device or a colocation constraint
327         // to a node with explicitly assigned device.
328         for (const auto& it : node->in_edges()) {
329           if (!it->IsControlEdge()) {
330             Node* src_node = it->src();
331             const string* src_device = AssignedOrRequestedDeviceName(*src_node);
332             string colocation_group = "";
333             GetColocationGroup(src_node, &colocation_group);
334             VLOG(3) << "Considering src: " << src_node->name()
335                     << " src_device: " << *src_device
336                     << " colo group: " << colocation_group;
337             while (src_device->empty() && colocation_group.empty() &&
338                    src_node->IsIdentity()) {
339               src_node = *src_node->in_nodes().begin();
340               src_device = AssignedOrRequestedDeviceName(*src_node);
341               GetColocationGroup(src_node, &colocation_group);
342               VLOG(3) << "Considering src: " << src_node->name()
343                       << " src_device: " << *src_device
344                       << " colo group: " << colocation_group;
345             }
346 
347             if (!colocation_group.empty()) {
348               AttrValue::ListValue colo_attr;
349               colo_attr.add_s(colocation_group);
350               std::vector<string> colo_slice = {colocation_group};
351               node->AddAttr(kColocationAttrName, colo_slice);
352             } else if (!src_device->empty()) {
353               // src_device can be a partially specified device. Find the
354               // matching device in the device_set.
355               DeviceNameUtils::ParsedName parsed;
356               if (!DeviceNameUtils::ParseFullName(*src_device, &parsed)) {
357                 return errors::InvalidArgument(
358                     "Failed to parse explicit device specification ",
359                     *src_device);
360               }
361               std::vector<Device*> matching_devices;
362               device_set.FindMatchingDevices(parsed, &matching_devices);
363               if (matching_devices.empty()) {
364                 return errors::InvalidArgument(
365                     "Unable to find any devices for spec ", *src_device);
366               } else if (matching_devices.size() != 1) {
367                 // Convert a vector of devices to a string.
368                 // Using absl::StrJoin did not work in Android builds.
369                 string devices = "[";
370                 for (Device* device : matching_devices) {
371                   devices.append(device->name());
372                   devices.append(", ");
373                 }
374                 if (devices.size() > 2) {
375                   devices.resize(devices.size() - 2);
376                 }
377                 devices.append("]");
378 
379                 return errors::InvalidArgument(
380                     "When FunctionLibraryRuntime::Options.output_devices are "
381                     "not specified for a multi-device function, the device "
382                     "specification on the output node must match exactly one "
383                     "device. Matched devices are ",
384                     devices);
385               }
386               VLOG(3) << "Setting output device to "
387                       << matching_devices[0]->name() << " for node "
388                       << node->DebugString();
389               node->set_assigned_device_name(matching_devices[0]->name());
390             }
391           }
392         }
393       } else {
394         const AttrValue* attr_value;
395         TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
396         int64 index = attr_value->i();
397         // output_devices size is checked in InstantiateMultiDevice
398         DCHECK_GT(output_devices.size(), index);
399         VLOG(3) << "Setting output device to " << output_devices[index]
400                 << " for return at index " << index;
401         node->set_assigned_device_name(output_devices[index]);
402       }
403     }
404   }
405   return Status::OK();
406 }
407 
408 namespace {
409 
ValidateNoListArguments(const protobuf::RepeatedPtrField<OpDef::ArgDef> & args,const char * arg_type,const string & function_name)410 Status ValidateNoListArguments(
411     const protobuf::RepeatedPtrField<OpDef::ArgDef>& args, const char* arg_type,
412     const string& function_name) {
413   for (const OpDef::ArgDef& arg : args) {
414     if (!arg.number_attr().empty() || !arg.type_list_attr().empty()) {
415       return errors::InvalidArgument(
416           "Function ", function_name, " has an ", arg_type, " named \"",
417           arg.name(),
418           "\" that is a list of tensors."
419           " Multi-device functions support only single-tensor inputs "
420           " and outputs");
421     }
422   }
423   return Status::OK();
424 }
425 
ValidateMultiDeviceOptions(const FunctionDef & fdef,const FunctionLibraryRuntime::InstantiateOptions & options)426 Status ValidateMultiDeviceOptions(
427     const FunctionDef& fdef,
428     const FunctionLibraryRuntime::InstantiateOptions& options) {
429   const OpDef& signature = fdef.signature();
430   // Multi-device functions don't currently support list inputs or outputs
431   TF_RETURN_IF_ERROR(ValidateNoListArguments(signature.input_arg(), "input",
432                                              signature.name()));
433   TF_RETURN_IF_ERROR(ValidateNoListArguments(signature.output_arg(), "output",
434                                              signature.name()));
435 
436   if (fdef.attr().count(FunctionLibraryDefinition::kIntsOnDeviceAttr) != 0 &&
437       fdef.attr().at(FunctionLibraryDefinition::kIntsOnDeviceAttr).b()) {
438     return errors::Unimplemented(
439         "Function '", signature.name(), "' has `",
440         FunctionLibraryDefinition::kIntsOnDeviceAttr,
441         "` attribute set. This attribute is not currently supported by "
442         "multi-device functions.");
443   }
444 
445   if (options.input_devices.size() != signature.input_arg_size()) {
446     return errors::InvalidArgument(
447         "InstantiateOptions.input_devices must have the same length "
448         "as the number of arguments: input_devices length = ",
449         options.input_devices.size(),
450         " number of arguments = ", signature.input_arg_size());
451   }
452   if (!options.output_devices.empty() &&
453       options.output_devices.size() != signature.output_arg_size()) {
454     return errors::InvalidArgument(
455         "InstantiateOptions.output_devices must either be empty or have "
456         "the same length as the number of arguments: output_devices length "
457         "= ",
458         options.output_devices.size(),
459         " number of arguments = ", signature.output_arg_size());
460   }
461 
462   if (!options.state_handle.empty()) {
463     return errors::Unimplemented(
464         "InstantiateOptions.state_handle is not supported for multi-device "
465         "functions. Function: ",
466         signature.name());
467   }
468   if (options.create_kernels_eagerly) {
469     return errors::Unimplemented(
470         "InstantiateOptions.create_kernels_eagerly is not supported for "
471         "multi-device functions. Function: ",
472         signature.name());
473   }
474 
475   return Status::OK();
476 }
477 
GetGraphAndRets(const string & function_name,AttrSlice attrs,const FunctionDef * fdef,const FunctionLibraryDefinition * lib_def,std::unique_ptr<Graph> * graph,std::vector<string> * ret_node_names,std::vector<string> * control_ret_node_names)478 Status GetGraphAndRets(const string& function_name, AttrSlice attrs,
479                        const FunctionDef* fdef,
480                        const FunctionLibraryDefinition* lib_def,
481                        std::unique_ptr<Graph>* graph,
482                        std::vector<string>* ret_node_names,
483                        std::vector<string>* control_ret_node_names) {
484   auto get_func_sig = [lib_def](const string& op, const OpDef** sig) {
485     return lib_def->LookUpOpDef(op, sig);
486   };
487   FunctionBody* tmp_fbody;
488   // TODO(iga): FunctionDefToBodyHelper copies fdef. Avoid this copy.
489   TF_RETURN_IF_ERROR(
490       FunctionDefToBodyHelper(*fdef, attrs, lib_def, get_func_sig, &tmp_fbody));
491   if (tmp_fbody == nullptr) {
492     LOG(ERROR) << "Failed to get FunctionBody for \"" << function_name << "\"";
493     return errors::Internal("Failed to construct FunctionBody for ",
494                             function_name);
495   }
496   std::unique_ptr<FunctionBody> fbody(tmp_fbody);
497   *graph = std::unique_ptr<Graph>(fbody->graph);
498   fbody->graph = nullptr;
499   ret_node_names->reserve(fbody->ret_nodes.size());
500   for (const Node* node : fbody->ret_nodes) {
501     ret_node_names->push_back(node->name());
502   }
503   control_ret_node_names->reserve(fbody->control_ret_nodes.size());
504   for (const Node* node : fbody->control_ret_nodes) {
505     control_ret_node_names->push_back(node->name());
506   }
507   return Status::OK();
508 }
509 
510 }  // anonymous namespace
511 
InstantiateMultiDevice(const string & function_name,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::Handle * handle)512 Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
513     const string& function_name, AttrSlice attrs,
514     const FunctionLibraryRuntime::InstantiateOptions& options,
515     FunctionLibraryRuntime::Handle* handle) {
516   // Check if this function has already been instantiated.
517   const string& function_key = Canonicalize(function_name, attrs, options);
518 
519   {
520     mutex_lock l(mu_);
521     const auto& it = table_.find(function_key);
522     if (it != table_.end()) {
523       *handle = it->second;
524       ++mdevice_data_[*handle]->instantiation_counter_;
525       return Status::OK();
526     }
527   }
528 
529   VLOG(1) << "Instantiating MultiDevice function \"" << function_name
530           << "\" on default device \"" << options.target << "\"";
531   if (VLOG_IS_ON(3)) {
532     VLOG(3) << "Requested input devices:";
533     for (const string& device : options.input_devices) {
534       VLOG(3) << "    " << device;
535     }
536     VLOG(3) << "Requested output devices:";
537     for (const string& device : options.output_devices) {
538       VLOG(3) << "    " << device;
539     }
540   }
541 
542   const FunctionLibraryDefinition* lib_def =
543       options.overlay_lib == nullptr ? lib_def_ : options.overlay_lib;
544 
545   const FunctionDef* fdef = lib_def->Find(function_name);
546   if (fdef == nullptr) {
547     return errors::InvalidArgument("Failed to find function \"", function_name,
548                                    "\" in function library: ", lib_def);
549   }
550 
551   TF_RETURN_IF_ERROR(ValidateMultiDeviceOptions(*fdef, options));
552 
553   std::unique_ptr<Graph> graph;
554   std::vector<string> ret_node_names;
555   std::vector<string> control_ret_node_names;
556 
557   TF_RETURN_IF_ERROR(GetGraphAndRets(function_name, attrs, fdef, lib_def,
558                                      &graph, &ret_node_names,
559                                      &control_ret_node_names));
560 
561   if (options.graph_collector != nullptr) {
562     GraphDef def;
563     graph->ToGraphDef(&def);
564     *def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto();
565     options.graph_collector->CollectRawGraph(def);
566   }
567 
568   DeviceSet device_set;
569   for (auto d : device_mgr_->ListDevices()) {
570     device_set.AddDevice(d);
571   }
572 
573   TF_RETURN_IF_ERROR(PinArgsAndRets(
574       options.input_devices, options.output_devices, device_set, graph.get()));
575 
576   std::unique_ptr<MultiDeviceFunctionData> data =
577       MakeUnique<MultiDeviceFunctionData>(function_name, function_key,
578                                           ret_node_names.size(),
579                                           lib_def->ReachableDefinitions(*fdef));
580 
581   GraphOptimizationPassOptions optimization_options;
582   // TODO(iga): Thread other relevant options from SessionOptions.
583   SessionOptions session_options;
584   session_options.env = env_;
585   session_options.config = options.config_proto;
586   optimization_options.session_options = &session_options;
587   optimization_options.graph = &graph;
588   optimization_options.flib_def = &data->overlay_lib_;
589   optimization_options.device_set = &device_set;
590 
591   DumpGraph("Before running PRE_PLACEMENT passes", graph.get());
592   TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
593       OptimizationPassRegistry::PRE_PLACEMENT, optimization_options));
594 
595   DumpGraph("Before calling Placer", graph.get());
596   // Make the FunctionLibraryRuntime's device the default device if
597   // nothing else is hard coded. This allows the same function definition
598   // to be specialized to different devices depending on the
599   // PartitionedCallOp's device.
600   Device* default_device = nullptr;
601   if (!options.target.empty()) {
602     FunctionLibraryRuntime* flr = GetFLR(options.target);
603     if (flr == nullptr) {
604       return errors::InvalidArgument(
605           "Cannot instantiate multi-device function with target device ",
606           options.target);
607     }
608     default_device = flr->device();
609   }
610 
611   // TODO(b/124993244): Smartly merge options in nested defuns, and raise
612   // exceptions/warnings in case where nested function call options are ignored.
613   Placer placer(graph.get(), &device_set, default_device,
614                 options.config_proto.allow_soft_placement(),
615                 options.config_proto.log_device_placement());
616   TF_RETURN_IF_ERROR(placer.Run());
617 
618   DumpGraph("Before running POST_PLACEMENT passes", graph.get());
619   TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
620       OptimizationPassRegistry::POST_PLACEMENT, optimization_options));
621 
622   Device* cpu_device;
623   TF_RETURN_IF_ERROR(device_mgr_->LookupDevice("CPU:0", &cpu_device));
624 
625   if (options.optimize_graph_fn) {
626     DumpGraph("Before running graph optimization fn", graph.get());
627     Status status = options.optimize_graph_fn(
628         std::move(ret_node_names), std::move(control_ret_node_names),
629         &data->overlay_lib_, device_set, cpu_device, &graph);
630     if (!status.ok()) {
631       LOG(WARNING) << "Ignoring multi-device function optimization failure: "
632                    << status.ToString();
633     }
634     DumpGraph("After optimization", graph.get());
635   }
636 
637   DumpGraph("Before running POST_REWRITE_FOR_EXEC passes", graph.get());
638   TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
639       OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
640   DumpGraph("After all optimization passes", graph.get());
641 
642   if (options.graph_collector != nullptr) {
643     GraphDef def;
644     graph->ToGraphDef(&def);
645     *def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto();
646     options.graph_collector->CollectOptimizedGraph(def);
647   }
648 
649   std::unordered_map<string, std::unique_ptr<Graph>> subgraphs;
650   TF_RETURN_IF_ERROR(
651       PartitionFunctionGraph(device_set, std::move(graph), &subgraphs));
652 
653   if (options.graph_collector != nullptr) {
654     for (const auto& pair : subgraphs) {
655       GraphDef def;
656       pair.second->ToGraphDef(&def);
657       *def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto();
658       options.graph_collector->CollectPartitionedGraph(def);
659     }
660   }
661 
662   int i = 0;
663   FunctionNameGenerator name_generator(&data->overlay_lib_, function_name);
664   for (const auto& pair : subgraphs) {
665     i += 1;
666     // TODO(iga): Fail gracefully if the set of devices corresponds
667     // to more than one address space.
668     const string& target = pair.first;
669     Graph* subgraph = pair.second.get();
670 
671     ComponentFunctionData* comp_data = &data->glue_[target];
672     TF_RETURN_IF_ERROR(UpdateArgAndRetvalMetadata(
673         subgraph, &comp_data->arg_indices_, &comp_data->ret_indices_,
674         &comp_data->arg_alloc_attrs_, &comp_data->ret_alloc_attrs_));
675     FunctionDef shard;
676     string unique_name = name_generator.GetName();
677     TF_RETURN_IF_ERROR(GraphToFunctionDef(*subgraph, unique_name, &shard));
678     FunctionLibraryRuntime* target_flr = GetFLR(target);
679     TF_RETURN_IF_ERROR(data->overlay_lib_.AddFunctionDef(shard));
680     FunctionLibraryRuntime::InstantiateOptions opts;
681     opts.executor_type = options.executor_type;
682     opts.target = target;
683     opts.overlay_lib = &data->overlay_lib_;
684     FunctionLibraryRuntime::Handle component_handle;
685 
686     TF_RETURN_IF_ERROR(target_flr->Instantiate(
687         unique_name, AttrSlice(&shard.attr()), opts, &component_handle));
688     VLOG(1) << "Instantiated component function " << unique_name
689             << " on device " << target << " with component handle "
690             << component_handle;
691     VLOG(2) << DebugString(shard);
692     comp_data->handle_ = component_handle;
693   }
694 
695   *handle = AddMultiDeviceHandle(std::move(data), function_key);
696   VLOG(2) << "Instantiated MultiDevice function \"" << function_name
697           << "\" with handle " << *handle;
698   return Status::OK();
699 }
700 
GetOutputDevices(FunctionLibraryRuntime::Handle handle,std::vector<Device * > * output_devices) const701 Status ProcessFunctionLibraryRuntime::GetOutputDevices(
702     FunctionLibraryRuntime::Handle handle,
703     std::vector<Device*>* output_devices) const {
704   const MultiDeviceFunctionData* data = IsMultiDevice(handle);
705   if (data == nullptr) {
706     return errors::InvalidArgument(
707         "Failed for find multi-device function handle ", handle);
708   }
709 
710   for (const auto& pair : data->glue_) {
711     const ComponentFunctionData& comp_data = pair.second;
712     DCHECK(comp_data.ret_alloc_attrs_.size() == comp_data.ret_indices_.size());
713 
714     const string& target = pair.first;
715     FunctionLibraryRuntime* target_flr = GetFLR(target);
716     Device* target_device = target_flr->device();
717     const FunctionBody* fbody = target_flr->GetFunctionBody(comp_data.handle_);
718     DCHECK(fbody != nullptr);
719 
720     output_devices->resize(data->num_outputs_);
721     for (int j = 0; j < comp_data.ret_indices_.size(); ++j) {
722       int ret_index = comp_data.ret_indices_[j];
723       if (fbody->ret_types[j] == DT_RESOURCE) {
724         (*output_devices)[ret_index] = target_device;
725       } else {
726         (*output_devices)[ret_index] =
727             comp_data.ret_alloc_attrs_[j].on_host() ? nullptr : target_device;
728       }
729     }
730   }
731 
732   return Status::OK();
733 }
734 
RunMultiDevice(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,FunctionLibraryRuntime::DoneCallback done) const735 void ProcessFunctionLibraryRuntime::RunMultiDevice(
736     const FunctionLibraryRuntime::Options& opts,
737     FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
738     std::vector<Tensor>* rets,
739     FunctionLibraryRuntime::DoneCallback done) const {
740   if (opts.create_rendezvous) {
741     // FLR->Run() is the default entry point. It checks for cancellation,
742     // creates rendezvous, etc.
743     // Letting create_rendezvous through will do the wrong thing - each
744     // component function will get a separate rendezvous created by its FLR.
745     done(
746         errors::Internal("Cannot call ProcessFunctionLibraryRuntime::Run with "
747                          "create_rendezvous=true. Please run the function "
748                          "using FunctionLibraryRuntime::Run"));
749     return;
750   }
751 
752   const MultiDeviceFunctionData* data = IsMultiDevice(handle);
753   if (data == nullptr) {
754     done(
755         errors::InvalidArgument("Failed for find multi-device function handle ",
756                                 handle, ". Was the function instantiated?"));
757     return;
758   }
759 
760   if (data->glue_.empty()) {
761     // Trivial case where the function body is empty.
762     done(Status::OK());
763     return;
764   }
765 
766   auto* refcounted_done = new ReffedStatusCallback(std::move(done));
767   for (int i = 0; i < data->glue_.size(); ++i) {
768     refcounted_done->Ref();
769   }
770 
771   FunctionLibraryRuntime::Options opts_copy = opts;
772   for (const auto& pair : data->glue_) {
773     const string& target = pair.first;
774     const ComponentFunctionData& comp_data = pair.second;
775     FunctionLibraryRuntime::Handle handle = pair.second.handle_;
776     VLOG(1) << "Running function shard on device " << target << " with handle "
777             << handle;
778 
779     opts_copy.args_alloc_attrs = comp_data.arg_alloc_attrs_;
780     opts_copy.rets_alloc_attrs = comp_data.ret_alloc_attrs_;
781     opts_copy.remote_execution = false;
782     std::vector<Tensor> comp_args =
783         GetArgsForIndices(comp_data.arg_indices_, args);
784     std::vector<Tensor>* comp_rets = new std::vector<Tensor>;
785     rets->resize(data->num_outputs_);
786     GetFLR(target)->Run(
787         opts_copy, handle, comp_args, comp_rets,
788         [comp_rets, rets, comp_data, refcounted_done](const Status& status) {
789           if (!status.ok()) {
790             LOG(ERROR) << "Component function execution failed: " << status;
791             refcounted_done->UpdateStatus(status);
792           } else {
793             for (int i = 0; i < comp_rets->size(); ++i) {
794               (*rets)[comp_data.ret_indices_[i]] = (*comp_rets)[i];
795             }
796           }
797           delete comp_rets;
798           // refcounted_done is thread-safe
799           refcounted_done->Unref();
800         });
801   }
802   refcounted_done->Unref();
803 }
804 
Instantiate(const string & function_name,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::Handle * handle)805 Status ProcessFunctionLibraryRuntime::Instantiate(
806     const string& function_name, AttrSlice attrs,
807     const FunctionLibraryRuntime::InstantiateOptions& options,
808     FunctionLibraryRuntime::Handle* handle) {
809   if (options.is_multi_device_function) {
810     return InstantiateMultiDevice(function_name, attrs, options, handle);
811   }
812 
813   *handle = kInvalidHandle;
814   FunctionLibraryRuntime* flr = GetFLR(options.target);
815   if (flr != nullptr) {
816     return flr->Instantiate(function_name, attrs, options, handle);
817   }
818   if (parent_ == nullptr) {
819     return errors::Internal(
820         "Currently don't support instantiating functions on device: ",
821         options.target);
822   }
823   VLOG(1) << "ProcessFLR Instantiate: " << function_name
824           << " on: " << options.target;
825   string function_key = Canonicalize(function_name, attrs, options);
826   FunctionData* f;
827   {
828     mutex_lock l(mu_);
829     FunctionLibraryRuntime::Handle h =
830         gtl::FindWithDefault(table_, function_key, kInvalidHandle);
831     if (h == kInvalidHandle || function_data_.count(h) == 0) {
832       h = AddHandleLocked(function_key, options.target, kInvalidHandle);
833     }
834     f = function_data_[h].get();
835     *handle = h;
836   }
837   TF_RETURN_IF_ERROR(
838       f->DistributedInit(parent_, function_name, *lib_def_, attrs, options));
839   VLOG(1) << "ProcessFLR Instantiate [success]: " << function_name
840           << " on: " << options.target << " with handle: " << *handle
841           << " (this: " << this << ")";
842   return Status::OK();
843 }
844 
RemoveHandle(FunctionLibraryRuntime::Handle handle)845 Status ProcessFunctionLibraryRuntime::RemoveHandle(
846     FunctionLibraryRuntime::Handle handle) {
847   mutex_lock l(mu_);
848   table_.erase(function_data_[handle]->function_key());
849   function_data_.erase(handle);
850   return Status::OK();
851 }
852 
ReleaseMultiDeviceHandle(FunctionLibraryRuntime::Handle handle)853 Status ProcessFunctionLibraryRuntime::ReleaseMultiDeviceHandle(
854     FunctionLibraryRuntime::Handle handle) {
855   std::unique_ptr<MultiDeviceFunctionData> mdata;
856   {
857     mutex_lock l(mu_);
858     auto it = mdevice_data_.find(handle);
859     --it->second->instantiation_counter_;
860     if (it->second->instantiation_counter_ != 0) {
861       return Status::OK();
862     }
863     mdata = std::move(it->second);
864     table_.erase(mdata->function_key_);
865     mdevice_data_.erase(it);
866   }
867 
868   // If we are here we are releasing the last instantiation of `handle`.
869   // Release all component function handles.
870   Status overall_status;
871   for (const auto& it : mdata->glue_) {
872     const string& device = it.first;
873     FunctionLibraryRuntime::Handle flr_handle = it.second.handle_;
874     FunctionLibraryRuntime* flr = GetFLR(device);
875     if (flr == nullptr) {
876       return errors::InvalidArgument(
877           "Failed to find FunctionLibraryRuntime for device ", device,
878           " when releasing multi-device function handle ", handle);
879     }
880     Status status = flr->ReleaseHandle(flr_handle);
881     if (!status.ok()) {
882       overall_status = status;
883     }
884   }
885 
886   return overall_status;
887 }
888 
ReleaseHandle(FunctionLibraryRuntime::Handle handle)889 Status ProcessFunctionLibraryRuntime::ReleaseHandle(
890     FunctionLibraryRuntime::Handle handle) {
891   if (IsMultiDevice(handle)) {
892     return ReleaseMultiDeviceHandle(handle);
893   }
894 
895   FunctionLibraryRuntime* flr = nullptr;
896   string target_device;
897   {
898     mutex_lock l(mu_);
899     CHECK_EQ(1, function_data_.count(handle)) << " handle: " << handle;
900     target_device = function_data_[handle]->target_device();
901   }
902   flr = GetFLR(target_device);
903   if (flr != nullptr) {
904     return flr->ReleaseHandle(handle);
905   }
906   return errors::InvalidArgument("Handle not found: ", handle);
907 }
908 
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,FunctionLibraryRuntime::DoneCallback done) const909 void ProcessFunctionLibraryRuntime::Run(
910     const FunctionLibraryRuntime::Options& opts,
911     FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
912     std::vector<Tensor>* rets,
913     FunctionLibraryRuntime::DoneCallback done) const {
914   bool multi_device;
915   {
916     tf_shared_lock l(mu_);
917     multi_device = mdevice_data_.find(handle) != mdevice_data_.end();
918   }
919   if (multi_device) {
920     return RunMultiDevice(opts, handle, args, rets, done);
921   }
922 
923   FunctionLibraryRuntime* flr = nullptr;
924   string target_device;
925   FunctionLibraryRuntime::LocalHandle local_handle;
926   {
927     tf_shared_lock l(mu_);
928     auto iter = function_data_.find(handle);
929     if (iter == function_data_.end()) {
930       done(errors::NotFound("Handle: ", handle, " not found."));
931       return;
932     }
933     FunctionData* function_data = iter->second.get();
934     target_device = function_data->target_device();
935     local_handle = function_data->local_handle();
936   }
937 
938   if (!opts.remote_execution) {
939     done(
940         errors::InvalidArgument("ProcessFunctionLibraryRuntime::Run should "
941                                 "only be called for multi-device functions or "
942                                 "for remote execution."));
943     return;
944   }
945 
946   flr = GetFLR(target_device);
947   if (flr != nullptr) {
948     auto rendezvous = opts.rendezvous;
949     string source_device = opts.source_device;
950     DeviceContext* device_context;
951     Status s = GetDeviceContext(source_device, &device_context);
952     if (!s.ok()) {
953       done(s);
954       return;
955     }
956     int64 src_incarnation, target_incarnation;
957     s = GetDeviceIncarnation(source_device, &src_incarnation);
958     s.Update(GetDeviceIncarnation(target_device, &target_incarnation));
959     if (!s.ok()) {
960       done(s);
961       return;
962     }
963 
964     // Send the args over to the target device.
965     s = SendTensors(source_device, target_device, "arg_", src_incarnation, args,
966                     device_context, opts.args_alloc_attrs, rendezvous);
967     if (!s.ok()) {
968       done(s);
969       return;
970     }
971     const std::vector<AllocatorAttributes>& rets_alloc_attrs =
972         opts.rets_alloc_attrs;
973     std::vector<Tensor>* remote_rets = new std::vector<Tensor>;
974     flr->Run(opts, handle, args, remote_rets,
975              std::bind(
976                  [source_device, target_device, target_incarnation, rendezvous,
977                   device_context, rets_alloc_attrs, remote_rets,
978                   rets](const Status& status,
979                         FunctionLibraryRuntime::DoneCallback& done) {
980                    if (!status.ok()) {
981                      delete remote_rets;
982                      done(status);
983                      return;
984                    }
985                    int64 num_returns = remote_rets->size();
986                    delete remote_rets;
987                    // Now receive the return values from the target.
988                    ReceiveTensorsAsync(target_device, source_device, "ret_",
989                                        target_incarnation, num_returns,
990                                        device_context, rets_alloc_attrs,
991                                        rendezvous, rets, std::move(done));
992                  },
993                  std::placeholders::_1, std::move(done)));
994     return;
995   }
996   if (parent_ != nullptr) {
997     parent_->Run(opts, local_handle, args, rets, std::move(done));
998     return;
999   }
1000   done(errors::Internal("Could not find device"));
1001 }
1002 
Clone(Env * env,int graph_def_version,const OptimizerOptions & optimizer_options,CustomKernelCreator custom_kernel_creator,std::unique_ptr<FunctionLibraryDefinition> * out_lib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> * out_pflr) const1003 Status ProcessFunctionLibraryRuntime::Clone(
1004     Env* env, int graph_def_version, const OptimizerOptions& optimizer_options,
1005     CustomKernelCreator custom_kernel_creator,
1006     std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
1007     std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr) const {
1008   out_lib_def->reset(new FunctionLibraryDefinition(*lib_def_));
1009   out_pflr->reset(new ProcessFunctionLibraryRuntime(
1010       device_mgr_, env, graph_def_version, out_lib_def->get(),
1011       optimizer_options, std::move(custom_kernel_creator), default_thread_pool_,
1012       parent_));
1013   return Status::OK();
1014 }
1015 
1016 }  // namespace tensorflow
1017