1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
16 
17 #include <iterator>
18 #include <utility>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_join.h"
23 #include "tensorflow/core/common_runtime/device_set.h"
24 #include "tensorflow/core/common_runtime/function.h"
25 #include "tensorflow/core/common_runtime/function_optimization_registry.h"
26 #include "tensorflow/core/common_runtime/graph_constructor.h"
27 #include "tensorflow/core/common_runtime/optimization_registry.h"
28 #include "tensorflow/core/common_runtime/partitioning_utils.h"
29 #include "tensorflow/core/common_runtime/placer.h"
30 #include "tensorflow/core/common_runtime/process_util.h"
31 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
32 #include "tensorflow/core/common_runtime/rendezvous_util.h"
33 #include "tensorflow/core/common_runtime/replicate_per_replica_nodes.h"
34 #include "tensorflow/core/framework/cancellation.h"
35 #include "tensorflow/core/framework/function.h"
36 #include "tensorflow/core/framework/graph_to_functiondef.h"
37 #include "tensorflow/core/framework/op_kernel.h"
38 #include "tensorflow/core/framework/tensor.h"
39 #include "tensorflow/core/framework/types.h"
40 #include "tensorflow/core/framework/types.pb.h"
41 #include "tensorflow/core/graph/graph.h"
42 #include "tensorflow/core/graph/graph_node_util.h"
43 #include "tensorflow/core/graph/graph_partition.h"
44 #include "tensorflow/core/lib/core/blocking_counter.h"
45 #include "tensorflow/core/lib/core/errors.h"
46 #include "tensorflow/core/lib/gtl/cleanup.h"
47 #include "tensorflow/core/lib/gtl/inlined_vector.h"
48 #include "tensorflow/core/lib/gtl/map_util.h"
49 #include "tensorflow/core/lib/random/random.h"
50 #include "tensorflow/core/platform/notification.h"
51 #include "tensorflow/core/util/device_name_utils.h"
52 #include "tensorflow/core/util/dump_graph.h"
53 #include "tensorflow/core/util/ptr_util.h"
54 #include "tensorflow/core/util/reffed_status_callback.h"
55 #if !defined(IS_MOBILE_PLATFORM)
56 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
57 #endif  // IS_MOBILE_PLATFORM
58 
59 namespace tensorflow {
60 
61 const char ProcessFunctionLibraryRuntime::kDefaultFLRDevice[] = "null";
62 
DistributedInit(DistributedFunctionLibraryRuntime * parent,const string & function_name,const FunctionLibraryDefinition & lib_def,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::DoneCallback done)63 void ProcessFunctionLibraryRuntime::FunctionData::DistributedInit(
64     DistributedFunctionLibraryRuntime* parent, const string& function_name,
65     const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
66     const FunctionLibraryRuntime::InstantiateOptions& options,
67     FunctionLibraryRuntime::DoneCallback done) {
68   {
69     mutex_lock l(mu_);
70     is_cross_process_ = true;
71     if (init_started_) {
72       init_done_.WaitForNotification();
73       done(init_result_);
74       return;
75     }
76     init_started_ = true;
77   }
78   parent->Instantiate(function_name, lib_def, attrs, options, &local_handle_,
79                       [this, done](const Status& s) {
80                         init_done_.Notify();
81                         done(s);
82                       });
83 }
84 
ProcessFunctionLibraryRuntime(const DeviceMgr * device_mgr,Env * env,const ConfigProto * config,int graph_def_version,const FunctionLibraryDefinition * lib_def,const OptimizerOptions & optimizer_options,thread::ThreadPool * default_thread_pool,DistributedFunctionLibraryRuntime * parent,const SessionMetadata * session_metadata,Rendezvous::Factory rendezvous_factory)85 ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
86     const DeviceMgr* device_mgr, Env* env, const ConfigProto* config,
87     int graph_def_version, const FunctionLibraryDefinition* lib_def,
88     const OptimizerOptions& optimizer_options,
89     thread::ThreadPool* default_thread_pool,
90     DistributedFunctionLibraryRuntime* parent,
91     const SessionMetadata* session_metadata,
92     Rendezvous::Factory rendezvous_factory)
93     : parent_(parent),
94       env_(env),
95       config_(config ? absl::make_optional(*config) : absl::nullopt),
96       device_mgr_(device_mgr),
97       lib_def_(lib_def),
98       default_thread_pool_(default_thread_pool),
99       flr_map_(new std::unordered_map<Device*,
100                                       std::unique_ptr<FunctionLibraryRuntime>>),
101       next_handle_(0),
102       session_metadata_(session_metadata),
103       rendezvous_factory_(std::move(rendezvous_factory)),
104       optimizer_options_(optimizer_options),
105       graph_def_version_(graph_def_version) {
106   if (device_mgr == nullptr) {
107     (*flr_map_)[nullptr] = NewFunctionLibraryRuntime(
108         nullptr, env, config_ ? &(*config_) : nullptr, nullptr,
109         graph_def_version, lib_def_, default_thread_pool, optimizer_options,
110         session_metadata_, this);
111     return;
112   }
113   InitializeDeviceAndFlr();
114 }
115 
116 /* static */
SendTensors(const string & source_device,const string & target_device,const string & key_prefix,int64 src_incarnation,gtl::ArraySlice<Tensor> tensors_to_send,DeviceContext * device_context,const std::vector<AllocatorAttributes> & alloc_attrs,RendezvousInterface * rendezvous)117 Status ProcessFunctionLibraryRuntime::SendTensors(
118     const string& source_device, const string& target_device,
119     const string& key_prefix, int64 src_incarnation,
120     gtl::ArraySlice<Tensor> tensors_to_send, DeviceContext* device_context,
121     const std::vector<AllocatorAttributes>& alloc_attrs,
122     RendezvousInterface* rendezvous) {
123   std::vector<string> keys;
124   for (int i = 0; i < tensors_to_send.size(); ++i) {
125     string name = strings::StrCat(key_prefix, i);
126     string key = Rendezvous::CreateKey(source_device, src_incarnation,
127                                        target_device, name, FrameAndIter(0, 0));
128     keys.push_back(key);
129   }
130   TF_RETURN_IF_ERROR(SendTensorsToRendezvous(
131       rendezvous, device_context, alloc_attrs, keys, tensors_to_send));
132   return Status::OK();
133 }
134 
135 /* static */
ReceiveTensorsAsync(const string & source_device,const string & target_device,const string & key_prefix,int64 src_incarnation,int64 num_tensors,DeviceContext * device_context,const std::vector<AllocatorAttributes> & alloc_attrs,RendezvousInterface * rendezvous,std::vector<Tensor> * received_tensors,StatusCallback done)136 void ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
137     const string& source_device, const string& target_device,
138     const string& key_prefix, int64 src_incarnation, int64 num_tensors,
139     DeviceContext* device_context,
140     const std::vector<AllocatorAttributes>& alloc_attrs,
141     RendezvousInterface* rendezvous, std::vector<Tensor>* received_tensors,
142     StatusCallback done) {
143   std::vector<string> keys;
144   for (int64 i = 0; i < num_tensors; ++i) {
145     string name = strings::StrCat(key_prefix, i);
146     string key = Rendezvous::CreateKey(source_device, src_incarnation,
147                                        target_device, name, FrameAndIter(0, 0));
148     keys.push_back(key);
149   }
150   RecvOutputsFromRendezvousAsync(rendezvous, device_context, alloc_attrs, keys,
151                                  received_tensors, std::move(done));
152 }
153 
GetRetTypes(FunctionLibraryRuntime::Handle h,DataTypeVector * ret_types)154 Status ProcessFunctionLibraryRuntime::GetRetTypes(
155     FunctionLibraryRuntime::Handle h, DataTypeVector* ret_types) {
156   FunctionLibraryRuntime* flr = nullptr;
157   {
158     tf_shared_lock l(mu_);
159     auto miter = mdevice_data_.find(h);
160     if (miter != mdevice_data_.end()) {
161       *ret_types = miter->second->ret_types_;
162       return Status::OK();
163     }
164     auto fiter = function_data_.find(h);
165     if (fiter != function_data_.end()) {
166       flr = GetFLR(fiter->second->target_device());
167     }
168   }
169   if (flr != nullptr) {
170     return flr->GetRetTypes(h, ret_types);
171   }
172   return errors::InvalidArgument("Handle ", h, " not found.");
173 }
174 
GetDeviceIncarnation(const string & device_name,int64 * incarnation) const175 Status ProcessFunctionLibraryRuntime::GetDeviceIncarnation(
176     const string& device_name, int64* incarnation) const {
177   FunctionLibraryRuntime* flr = GetFLR(device_name);
178   if (flr == nullptr) {
179     return errors::InvalidArgument("Device name: ", device_name, " not found.");
180   }
181   *incarnation = flr->device()->attributes().incarnation();
182   return Status::OK();
183 }
184 
GetDeviceContext(const string & device_name,DeviceContext ** device_context) const185 Status ProcessFunctionLibraryRuntime::GetDeviceContext(
186     const string& device_name, DeviceContext** device_context) const {
187   *device_context = nullptr;
188   FunctionLibraryRuntime* flr = GetFLR(device_name);
189   if (flr == nullptr) {
190     return errors::InvalidArgument("Device name: ", device_name, " not found.");
191   }
192   Device* device = flr->device();
193   string device_type = device->parsed_name().type;
194   if (device_type == "CPU" || device_type == "TPU_SYSTEM") {
195     // "TPU_SYSTEM" indicates that `device` is a CPU.
196     return Status::OK();
197   }
198 
199   if (device->IsRemoteCallAllowed()) {
200     auto* dev_info = flr->device()->tensorflow_gpu_device_info();
201     if (dev_info) {
202       *device_context = dev_info->default_context;
203       return Status::OK();
204     }
205   }
206 
207   return errors::Internal("Device type: ", device_type,
208                           " is currently unsupported for remote ",
209                           "function executions");
210 }
211 
InitializeDeviceAndFlr()212 void ProcessFunctionLibraryRuntime::InitializeDeviceAndFlr() {
213   DeviceMgr const* all_devices = device_mgr_;
214   if (parent_ != nullptr && parent_->remote_device_mgr() != nullptr) {
215     all_devices = parent_->remote_device_mgr();
216   }
217 
218   mutex_lock l(mu_);
219   device_set_ = std::make_shared<DeviceSet>();
220   for (auto d : all_devices->ListDevices()) {
221     device_set_->AddDevice(d);
222   }
223   for (Device* d : device_mgr_->ListDevices()) {
224     if ((*flr_map_)[d] == nullptr) {
225       (*flr_map_)[d] = NewFunctionLibraryRuntime(
226           device_mgr_, env_, config_ ? &(*config_) : nullptr, d,
227           graph_def_version_, lib_def_, default_thread_pool_,
228           optimizer_options_, session_metadata_, this);
229     }
230   }
231 }
232 
GetFLR(const string & device_name) const233 FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR(
234     const string& device_name) const {
235   Device* device = nullptr;
236   if (device_name != kDefaultFLRDevice) {
237     if (!device_mgr_->LookupDevice(device_name, &device).ok()) {
238       VLOG(4) << "Could not find device: " << device_name;
239       return nullptr;
240     }
241   }
242   const auto& iter = flr_map_->find(device);
243   if (iter == flr_map_->end()) {
244     VLOG(1) << "Could not find device: " << device_name
245             << "in the local process.";
246     return nullptr;
247   }
248   return iter->second.get();
249 }
250 
AddHandle(const string & function_key,const string & device_name,FunctionLibraryRuntime::LocalHandle local_handle)251 FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandle(
252     const string& function_key, const string& device_name,
253     FunctionLibraryRuntime::LocalHandle local_handle) {
254   mutex_lock l(mu_);
255   return AddHandleLocked(function_key, device_name, local_handle);
256 }
257 
AddHandleLocked(const string & function_key,const string & device_name,FunctionLibraryRuntime::LocalHandle local_handle)258 FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandleLocked(
259     const string& function_key, const string& device_name,
260     FunctionLibraryRuntime::LocalHandle local_handle) {
261   auto h = next_handle_;
262   function_data_[h] =
263       absl::make_unique<FunctionData>(device_name, local_handle, function_key);
264   table_[function_key] = h;
265   next_handle_++;
266   return h;
267 }
268 
269 FunctionLibraryRuntime::Handle
AddMultiDeviceHandle(std::unique_ptr<MultiDeviceFunctionData> data,const string & function_key)270 ProcessFunctionLibraryRuntime::AddMultiDeviceHandle(
271     std::unique_ptr<MultiDeviceFunctionData> data, const string& function_key) {
272   mutex_lock l(mu_);
273   auto h = next_handle_;
274   mdevice_data_[h] = std::move(data);
275   table_[function_key] = h;
276   next_handle_++;
277   return h;
278 }
279 
GetHandle(const string & function_key) const280 FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::GetHandle(
281     const string& function_key) const {
282   tf_shared_lock l(mu_);
283   return gtl::FindWithDefault(table_, function_key, kInvalidHandle);
284 }
285 
IsInstantiatedOnDevice(const string & device_name,FunctionLibraryRuntime::Handle handle) const286 bool ProcessFunctionLibraryRuntime::IsInstantiatedOnDevice(
287     const string& device_name, FunctionLibraryRuntime::Handle handle) const {
288   return GetHandleOnDevice(device_name, handle) != kInvalidHandle;
289 }
290 
291 FunctionLibraryRuntime::LocalHandle
GetHandleOnDevice(const string & device_name,FunctionLibraryRuntime::Handle handle,bool include_multi_device) const292 ProcessFunctionLibraryRuntime::GetHandleOnDevice(
293     const string& device_name, FunctionLibraryRuntime::Handle handle,
294     bool include_multi_device) const {
295   tf_shared_lock l(mu_);
296 
297   auto miter = mdevice_data_.find(handle);
298   if (miter != mdevice_data_.end()) {
299     if (!include_multi_device) return kInvalidLocalHandle;
300 
301     const MultiDeviceFunctionData& data = *miter->second;
302     if (data.glue_.size() != 1) return kInvalidLocalHandle;
303 
304     const auto& pair = *data.glue_.begin();
305     const string& func_device_name = pair.first;
306     const ComponentFunctionData& component_data = pair.second;
307     if (func_device_name != device_name) return kInvalidLocalHandle;
308 
309     // Replace the given handle with the handle for the single component
310     // function.
311     handle = component_data.handle;
312   }
313 
314   auto iter = function_data_.find(handle);
315   if (iter == function_data_.end()) {
316     return kInvalidLocalHandle;
317   }
318   FunctionData* function_data = iter->second.get();
319   if (function_data->target_device() != device_name) {
320     return kInvalidLocalHandle;
321   }
322   return function_data->local_handle();
323 }
324 
GetDeviceName(FunctionLibraryRuntime::Handle handle) const325 string ProcessFunctionLibraryRuntime::GetDeviceName(
326     FunctionLibraryRuntime::Handle handle) const {
327   tf_shared_lock l(mu_);
328   auto iter = function_data_.find(handle);
329   CHECK(iter != function_data_.end());
330   FunctionData* function_data = iter->second.get();
331   return function_data->target_device();
332 }
333 
334 ProcessFunctionLibraryRuntime::MultiDeviceFunctionData*
IsMultiDevice(FunctionLibraryRuntime::Handle handle) const335 ProcessFunctionLibraryRuntime::IsMultiDevice(
336     FunctionLibraryRuntime::Handle handle) const {
337   tf_shared_lock l(mu_);
338   const auto& it = mdevice_data_.find(handle);
339   if (it != mdevice_data_.end()) {
340     return it->second.get();
341   }
342   return nullptr;
343 }
344 
345 namespace {
346 // Sets `group` to the first colocation group specified in `node`. If no
347 // group is specified, does not touch `group`.
GetColocationGroup(const Node * node,string * group)348 void GetColocationGroup(const Node* node, string* group) {
349   // We hoist the conversion from C-style string literal to string here,
350   // so that we can avoid the many repeated calls to strlen().
351   static const StringPiece kColocationAttrNameStringPiece(kColocationAttrName);
352   const AttrValue* attr_value =
353       node->attrs().Find(kColocationAttrNameStringPiece);
354   if (attr_value != nullptr && attr_value->has_list() &&
355       attr_value->list().s_size() > 0) {
356     *group = attr_value->list().s(0);
357   }
358 }
359 
AssignedOrRequestedDeviceName(const Node & node)360 const string* AssignedOrRequestedDeviceName(const Node& node) {
361   if (node.has_assigned_device_name()) {
362     return &node.assigned_device_name();
363   }
364   return &node.requested_device();
365 }
366 
SetArgShape(const std::unordered_map<int,DtypeAndPartialTensorShape> & input_resource_dtypes_and_shapes,const std::vector<Node * > & arg_nodes)367 Status SetArgShape(
368     const std::unordered_map<int, DtypeAndPartialTensorShape>&
369         input_resource_dtypes_and_shapes,
370     const std::vector<Node*>& arg_nodes) {
371   for (Node* n : arg_nodes) {
372     int index;
373     TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
374     DataType dtype;
375     TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &dtype));
376     if (dtype == DT_RESOURCE) {
377       auto dtype_and_shape_iter = input_resource_dtypes_and_shapes.find(index);
378       if (dtype_and_shape_iter != input_resource_dtypes_and_shapes.end()) {
379         AttrValue dtype_attr_value;
380         dtype_attr_value.mutable_list()->add_type(
381             dtype_and_shape_iter->second.dtype);
382         n->AddAttr("_handle_dtypes", dtype_attr_value);
383         TensorShapeProto shape_proto;
384         dtype_and_shape_iter->second.shape.AsProto(&shape_proto);
385         AttrValue shape_attr_value;
386         *shape_attr_value.mutable_list()->add_shape() = shape_proto;
387         n->AddAttr("_handle_shapes", shape_attr_value);
388       }
389     }
390   }
391   return Status::OK();
392 }
393 
394 // Returns the local tensors referred by `args`.
GetLocalArgs(gtl::ArraySlice<FunctionArg> args)395 std::vector<Tensor> GetLocalArgs(gtl::ArraySlice<FunctionArg> args) {
396   std::vector<Tensor> tensors;
397   for (const auto& arg : args) {
398     if (arg.index() == 0) {
399       tensors.push_back(absl::get<Tensor>(arg));
400     }
401   }
402   return tensors;
403 }
404 
405 // Update the done callback to push Tensors in `tensors` into `rets`.
TensorsToFunctionRetsDoneCallback(std::vector<FunctionRet> * rets,std::vector<Tensor> * tensors,FunctionLibraryRuntime::DoneCallback done)406 FunctionLibraryRuntime::DoneCallback TensorsToFunctionRetsDoneCallback(
407     std::vector<FunctionRet>* rets, std::vector<Tensor>* tensors,
408     FunctionLibraryRuntime::DoneCallback done) {
409   return [rets, tensors, done = std::move(done)](const Status& s) {
410     if (s.ok()) {
411       for (const auto& t : *tensors) {
412         rets->push_back(t);
413       }
414     }
415     delete tensors;
416     done(s);
417   };
418 }
419 
420 }  // anonymous namespace
421 
PinArgsAndRets(const std::vector<string> & input_devices,const std::vector<string> & output_devices,const DeviceSet & device_set,const std::vector<Node * > & arg_nodes,const std::vector<Node * > & ret_nodes,Device * default_device) const422 Status ProcessFunctionLibraryRuntime::PinArgsAndRets(
423     const std::vector<string>& input_devices,
424     const std::vector<string>& output_devices, const DeviceSet& device_set,
425     const std::vector<Node*>& arg_nodes, const std::vector<Node*>& ret_nodes,
426     Device* default_device) const {
427   // If output_devices are not specified, we want to set the output device
428   // based on the device of the output producing node. The output producing
429   // node can be an arg node because functions can simply return their
430   // arguments. To make sure that the output producing nodes have assigned
431   // devices, we assign them to arguments first.
432   for (Node* node : arg_nodes) {
433     const AttrValue* attr_value;
434     TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
435     int64 index = attr_value->i();
436     node->set_assigned_device_name(input_devices[index]);
437   }
438 
439   for (Node* node : ret_nodes) {
440     if (output_devices.empty()) {
441       DataType dtype;
442       TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &dtype));
443 
444       VLOG(3) << "Trying to determine device for node " << node->name()
445               << "[T=" << DataTypeString(dtype) << "]";
446 
447       // If output_devices are empty, the node producing retval
448       // must have explicitly assigned device or a colocation constraint
449       // to a node with explicitly assigned device.
450       for (const auto& it : node->in_edges()) {
451         if (it->IsControlEdge()) continue;
452 
453         Node* src_node = it->src();
454         const string* src_device = AssignedOrRequestedDeviceName(*src_node);
455         string colocation_group = "";
456         GetColocationGroup(src_node, &colocation_group);
457         VLOG(3) << "Considering src: " << src_node->name()
458                 << " src_device: " << *src_device
459                 << " colo group: " << colocation_group;
460         while (src_device->empty() && colocation_group.empty() &&
461                src_node->IsIdentity()) {
462           // Only follows the real data input of Identity, not control edges.
463           Node* input_node;
464           TF_RETURN_IF_ERROR(src_node->input_node(0, &input_node));
465           src_node = input_node;
466 
467           src_device = AssignedOrRequestedDeviceName(*src_node);
468           GetColocationGroup(src_node, &colocation_group);
469           VLOG(3) << "Considering src: " << src_node->name()
470                   << " src_device: " << *src_device
471                   << " colo group: " << colocation_group;
472         }
473 
474         // If resource is produced by a function call node, we can't trust
475         // source node device assignment, because multi-device functions can
476         // return resource placed on multiple devices. In such case we leave
477         // retval device assignment empty, and rely on placer to infer correct
478         // assignment based on actual output device.
479         const bool can_use_src_node_device =
480             !(dtype == DT_RESOURCE && IsFunctionCall(*lib_def_, *src_node));
481 
482         if (!colocation_group.empty()) {
483           AttrValue::ListValue colo_attr;
484           colo_attr.add_s(colocation_group);
485           std::vector<string> colo_slice = {colocation_group};
486           node->AddAttr(kColocationAttrName, colo_slice);
487         } else if (!src_device->empty() && can_use_src_node_device) {
488           // src_device can be a partially specified device. Find the
489           // matching device in the device_set.
490           DeviceNameUtils::ParsedName parsed;
491           if (!DeviceNameUtils::ParseFullName(*src_device, &parsed)) {
492             return errors::InvalidArgument(
493                 "Failed to parse explicit device specification ", *src_device);
494           }
495           std::vector<Device*> matching_devices;
496           device_set.FindMatchingDevices(parsed, &matching_devices);
497           if (matching_devices.empty()) {
498             if (default_device != nullptr) {
499               matching_devices.push_back(default_device);
500             } else {
501               return errors::InvalidArgument(
502                   "Unable to find any devices for spec ", *src_device);
503             }
504           } else if (matching_devices.size() != 1) {
505             bool on_same_task = true;
506             for (int i = 1; i < matching_devices.size(); ++i) {
507               if (!DeviceNameUtils::IsSameAddressSpace(
508                       matching_devices.at(0)->parsed_name(),
509                       matching_devices.at(i)->parsed_name())) {
510                 on_same_task = false;
511                 break;
512               }
513             }
514             // If the src node of an output is assigned to a address space (e.g.
515             // py_func), rely on placer to assign a device to the output.
516             if (on_same_task) {
517               continue;
518             }
519             // Compare with default_device if it has a narrower scope matching
520             // requested device.
521             int colocated_on_default_device = 0;
522             for (int i = 0; i < matching_devices.size(); ++i) {
523               if (DeviceNameUtils::IsSameAddressSpace(
524                       default_device->parsed_name(),
525                       matching_devices.at(i)->parsed_name())) {
526                 colocated_on_default_device++;
527               }
528             }
529             // Continue to raise error if multiple colocated devices are
530             // found.
531             if (colocated_on_default_device == 1) {
532               continue;
533             }
534 
535             // Convert a vector of devices to a string.
536             // Using absl::StrJoin did not work in Android builds.
537             string devices = "[";
538             for (Device* device : matching_devices) {
539               devices.append(device->name());
540               devices.append(", ");
541             }
542             if (devices.size() > 2) {
543               devices.resize(devices.size() - 2);
544             }
545             devices.append("]");
546 
547             return errors::InvalidArgument(
548                 *src_device,
549                 "When FunctionLibraryRuntime::Options.output_devices are "
550                 "not specified for a multi-device function, the device "
551                 "specification on the output node must match exactly one "
552                 "device. Matched devices are ",
553                 devices);
554           }
555           VLOG(3) << "Setting output device to " << matching_devices[0]->name()
556                   << " for node " << SummarizeNode(*node);
557           node->set_assigned_device_name(matching_devices[0]->name());
558         } else if (!src_device->empty() && !can_use_src_node_device) {
559           VLOG(3) << "Did not set device for a resource output node "
560                   << SummarizeNode(*node);
561         }
562       }
563     } else {
564       const AttrValue* attr_value;
565       TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
566       int64 index = attr_value->i();
567       // output_devices size is checked in InstantiateMultiDevice
568       DCHECK_GT(output_devices.size(), index);
569       VLOG(3) << "Setting output device to " << output_devices[index]
570               << " for return at index " << index;
571       node->set_assigned_device_name(output_devices[index]);
572     }
573   }
574   return Status::OK();
575 }
576 
577 namespace {
578 
ValidateNoListArguments(const protobuf::RepeatedPtrField<OpDef::ArgDef> & args,const char * arg_type,const string & function_name)579 Status ValidateNoListArguments(
580     const protobuf::RepeatedPtrField<OpDef::ArgDef>& args, const char* arg_type,
581     const string& function_name) {
582   for (const OpDef::ArgDef& arg : args) {
583     if (!arg.number_attr().empty() || !arg.type_list_attr().empty()) {
584       return errors::InvalidArgument(
585           "Function ", function_name, " has an ", arg_type, " named \"",
586           arg.name(),
587           "\" that is a list of tensors."
588           " Multi-device functions support only single-tensor inputs "
589           " and outputs");
590     }
591   }
592   return Status::OK();
593 }
594 
ValidateMultiDeviceOptions(const FunctionDef & fdef,const FunctionLibraryRuntime::InstantiateOptions & options)595 Status ValidateMultiDeviceOptions(
596     const FunctionDef& fdef,
597     const FunctionLibraryRuntime::InstantiateOptions& options) {
598   const OpDef& signature = fdef.signature();
599   // Multi-device functions currently do not support list inputs or outputs.
600   TF_RETURN_IF_ERROR(ValidateNoListArguments(signature.input_arg(), "input",
601                                              signature.name()));
602   TF_RETURN_IF_ERROR(ValidateNoListArguments(signature.output_arg(), "output",
603                                              signature.name()));
604   if (fdef.attr().count(FunctionLibraryDefinition::kIntsOnDeviceAttr) != 0 &&
605       fdef.attr().at(FunctionLibraryDefinition::kIntsOnDeviceAttr).b()) {
606     return errors::Unimplemented(
607         "Function '", signature.name(), "' has `",
608         FunctionLibraryDefinition::kIntsOnDeviceAttr,
609         "` attribute set. This attribute is not currently supported by "
610         "multi-device functions.");
611   }
612   if (options.input_devices.size() != signature.input_arg_size()) {
613     return errors::InvalidArgument(
614         "InstantiateOptions.input_devices must have the same length "
615         "as the number of arguments: input_devices length = ",
616         options.input_devices.size(),
617         " number of arguments = ", signature.input_arg_size());
618   }
619   if (!options.output_devices.empty() &&
620       options.output_devices.size() != signature.output_arg_size()) {
621     return errors::InvalidArgument(
622         "InstantiateOptions.output_devices must either be empty or have the "
623         "same length as the number of arguments: output_devices length = ",
624         options.output_devices.size(),
625         " number of arguments = ", signature.output_arg_size());
626   }
627   return Status::OK();
628 }
629 
630 }  // anonymous namespace
631 
GetGraphAndArgRets(const string & function_name,AttrSlice attrs,const FunctionDef * fdef,const FunctionLibraryDefinition * lib_def,std::unique_ptr<Graph> * graph,std::vector<Node * > * arg_nodes,std::vector<Node * > * ret_nodes,std::vector<string> * ret_node_names,DataTypeVector * ret_types,std::vector<string> * control_ret_node_names)632 Status GetGraphAndArgRets(
633     const string& function_name, AttrSlice attrs, const FunctionDef* fdef,
634     const FunctionLibraryDefinition* lib_def, std::unique_ptr<Graph>* graph,
635     std::vector<Node*>* arg_nodes, std::vector<Node*>* ret_nodes,
636     std::vector<string>* ret_node_names, DataTypeVector* ret_types,
637     std::vector<string>* control_ret_node_names) {
638   std::unique_ptr<FunctionBody> fbody;
639   // TODO(iga): FunctionDefToBodyHelper copies fdef. Avoid this copy.
640   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, attrs, lib_def, &fbody));
641   if (!fbody) {
642     LOG(ERROR) << "Failed to get FunctionBody for \"" << function_name << "\"";
643     return errors::Internal("Failed to construct FunctionBody for ",
644                             function_name);
645   }
646   *graph = std::unique_ptr<Graph>(fbody->graph);
647   arg_nodes->reserve(fbody->arg_nodes.size());
648   std::copy(fbody->arg_nodes.begin(), fbody->arg_nodes.end(),
649             std::back_inserter(*arg_nodes));
650   ret_nodes->reserve(fbody->ret_nodes.size());
651   std::copy(fbody->ret_nodes.begin(), fbody->ret_nodes.end(),
652             std::back_inserter(*ret_nodes));
653   fbody->graph = nullptr;
654   ret_node_names->reserve(fbody->ret_nodes.size());
655   for (const Node* node : fbody->ret_nodes) {
656     ret_node_names->push_back(node->name());
657   }
658   for (const auto& ret_type : fbody->ret_types) {
659     ret_types->push_back(ret_type);
660   }
661   control_ret_node_names->reserve(fbody->control_ret_nodes.size());
662   for (const Node* node : fbody->control_ret_nodes) {
663     control_ret_node_names->push_back(node->name());
664   }
665   return Status::OK();
666 }
667 
InstantiateMultiDevice(const string & function_name,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::Handle * handle)668 Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
669     const string& function_name, AttrSlice attrs,
670     const FunctionLibraryRuntime::InstantiateOptions& options,
671     FunctionLibraryRuntime::Handle* handle) {
672   // Check if this function has already been instantiated.
673   const string& function_key = Canonicalize(function_name, attrs, options);
674 
675   {
676     mutex_lock l(mu_);
677     const auto& it = table_.find(function_key);
678     if (it != table_.end()) {
679       *handle = it->second;
680       ++mdevice_data_[*handle]->instantiation_counter_;
681       return Status::OK();
682     }
683   }
684 
685   VLOG(1) << "Instantiating MultiDevice function \"" << function_name
686           << "\" on default device \"" << options.target << "\"";
687   if (VLOG_IS_ON(3)) {
688     int index = 0;
689     VLOG(3) << "Requested input devices:";
690     for (const string& device : options.input_devices) {
691       VLOG(3) << "    [input " << index++ << "] " << device;
692     }
693     index = 0;
694     VLOG(3) << "Requested output devices:";
695     for (const string& device : options.output_devices) {
696       VLOG(3) << "    [output " << index++ << "] " << device;
697     }
698   }
699 
700   const FunctionLibraryDefinition* lib_def =
701       options.lib_def == nullptr ? lib_def_ : options.lib_def;
702 
703   const FunctionDef* fdef = lib_def->Find(function_name);
704   if (fdef == nullptr) {
705     return errors::InvalidArgument("Failed to find function \"", function_name,
706                                    "\" in function library: ", lib_def);
707   }
708 
709   TF_RETURN_IF_ERROR(ValidateMultiDeviceOptions(*fdef, options));
710 
711   std::unique_ptr<Graph> graph;
712   std::vector<Node*> arg_nodes, ret_nodes;
713   std::vector<string> ret_node_names;
714   DataTypeVector ret_types;
715   std::vector<string> control_ret_node_names;
716 
717   TF_RETURN_IF_ERROR(GetGraphAndArgRets(
718       function_name, attrs, fdef, lib_def, &graph, &arg_nodes, &ret_nodes,
719       &ret_node_names, &ret_types, &control_ret_node_names));
720 
721   if (options.graph_collector != nullptr) {
722     GraphDef def;
723     graph->ToGraphDef(&def);
724     *def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto();
725     options.graph_collector->CollectRawGraph(def);
726   }
727 
728   Device* default_device = nullptr;
729   if (options.default_device_to_target && !options.target.empty()) {
730     // Make the `target` device the default device if nothing else is hard
731     // coded. This allows the same function definition to be specialized to
732     // different devices depending on the `PartitionedCallOp` device.
733     FunctionLibraryRuntime* flr = GetFLR(options.target);
734     if (flr == nullptr) {
735       return errors::InvalidArgument(
736           "Cannot instantiate multi-device function with target device ",
737           options.target);
738     }
739     default_device = flr->device();
740   }
741   const std::shared_ptr<DeviceSet> dev_set = device_set();
742 
743   TF_RETURN_IF_ERROR(
744       SetArgShape(options.input_resource_dtypes_and_shapes, arg_nodes));
745   TF_RETURN_IF_ERROR(PinArgsAndRets(
746       options.input_devices, options.output_devices, *dev_set, arg_nodes,
747       ret_nodes,
748       options.config_proto.allow_soft_placement() ? default_device : nullptr));
749 
750   auto data = absl::make_unique<MultiDeviceFunctionData>(
751       function_name, function_key, ret_node_names.size(),
752       lib_def->ReachableDefinitions(*fdef), std::move(ret_types));
753 
754   // Do not run function/graph optimization passes for component functions,
755   // since they have already processed the main function.
756   const bool should_run_optimization_passes = !options.is_component_function;
757   if (!should_run_optimization_passes) {
758     VLOG(1) << "Skipping function/graph optimization passes when instantiating "
759                "component function "
760             << function_name;
761   }
762 
763   // Mapping from a function body node name to the control output name.
764   std::unordered_map<string, string> node_name_to_control_ret;
765 
766   bool control_rets_updated = false;
767   if (should_run_optimization_passes) {
768     TF_RETURN_IF_ERROR(FunctionOptimizationPassRegistry::Global().Run(
769         *dev_set, options.config_proto, &graph, &data->lib_def_,
770         &control_ret_node_names, &control_rets_updated));
771   }
772 
773   if (control_rets_updated) {
774     // Function graph pass may have resulted in different nodes/node names for
775     // control rets.
776     for (const auto& control_ret : control_ret_node_names) {
777       node_name_to_control_ret.emplace(control_ret, control_ret);
778     }
779   } else {
780     for (const auto& control_ret : fdef->control_ret()) {
781       node_name_to_control_ret.emplace(control_ret.second, control_ret.first);
782     }
783   }
784 
785   GraphOptimizationPassOptions optimization_options;
786   // TODO(iga): Thread other relevant options from SessionOptions.
787   SessionOptions session_options;
788   session_options.env = env_;
789   session_options.config = options.config_proto;
790   optimization_options.session_options = &session_options;
791   optimization_options.graph = &graph;
792   optimization_options.flib_def = &data->lib_def_;
793   optimization_options.device_set = dev_set.get();
794   optimization_options.is_function_graph = true;
795 
796   DumpGraph("Before running PRE_PLACEMENT passes", graph.get());
797   if (should_run_optimization_passes) {
798     TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
799         OptimizationPassRegistry::PRE_PLACEMENT, optimization_options));
800   }
801 
802   // TODO(b/124993244): Smartly merge options in nested defuns, and raise
803   // exceptions/warnings in case where nested function call options are ignored.
804   DumpGraph("Before calling Placer", graph.get());
805   Placer placer(graph.get(), function_name, optimization_options.flib_def,
806                 dev_set.get(), default_device,
807                 options.config_proto.allow_soft_placement(),
808                 options.config_proto.log_device_placement());
809   TF_RETURN_IF_ERROR(placer.Run());
810 
811   DumpGraph("Before running POST_PLACEMENT passes", graph.get());
812   if (should_run_optimization_passes) {
813     TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
814         OptimizationPassRegistry::POST_PLACEMENT, optimization_options));
815   }
816 
817   Device* cpu_device;
818   TF_RETURN_IF_ERROR(device_mgr_->LookupDevice("CPU:0", &cpu_device));
819 
820   if (options.optimize_graph_fn) {
821     DumpGraph("Before running graph optimization fn", graph.get());
822     Status status = options.optimize_graph_fn(
823         std::move(ret_node_names), std::move(control_ret_node_names),
824         &data->lib_def_, *dev_set, cpu_device, &graph);
825     if (!status.ok()) {
826       LOG(WARNING) << "Ignoring multi-device function optimization failure: "
827                    << status.ToString();
828     }
829     DumpGraph("After optimization", graph.get());
830   }
831 
832   DumpGraph("Before running POST_REWRITE_FOR_EXEC passes", graph.get());
833   if (should_run_optimization_passes) {
834     TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
835         OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
836   }
837 
838   // Expand the nodes assigned to a CompositeDevice before graph partition to
839   // avoid generating a subgraph on a virtual device for execution.
840   // This transformation should happen as late as possible, in order to run as
841   // more graph optimization passes (e.g. PRE_PLACEMENT, PLACER,
842   // POST_PLACEMENT, POST_REWRITE_FOR_EXEC) on a smaller graph as possible.
843   TF_RETURN_IF_ERROR(ReplicatePerReplicaNodesInFunctionGraph(
844       options.composite_devices, graph.get()));
845 
846   if (options.graph_collector != nullptr) {
847     GraphDef def;
848     graph->ToGraphDef(&def);
849     *def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto();
850     options.graph_collector->CollectOptimizedGraph(def);
851   }
852 
853   VLOG(4) << "Main function graph to be partitioned:";
854   VLOG(4) << DebugString(graph->ToGraphDefDebug());
855 
856   std::unordered_map<string, std::unique_ptr<Graph>> subgraphs;
857   TF_RETURN_IF_ERROR(
858       PartitionFunctionGraph(*dev_set, std::move(graph), &subgraphs));
859 
860   for (const auto& pair : subgraphs) {
861     DumpGraph(strings::StrCat("Before running POST_PARTITIONING passes (",
862                               pair.first, ")"),
863               pair.second.get());
864   }
865   optimization_options.graph = nullptr;
866   optimization_options.device_set = nullptr;
867   optimization_options.partition_graphs = &subgraphs;
868   // Normally POST_PARTITIONING passes are run by distributed workers.
869   // Distributed workers are currently not supported in this code path, so we
870   // run the passes here.
871   if (should_run_optimization_passes) {
872     TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
873         OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
874   }
875   for (const auto& pair : subgraphs) {
876     const auto* optimized_subgraph = pair.second.get();
877     DumpGraph(
878         strings::StrCat("After all optimization passes (", pair.first, ")"),
879         optimized_subgraph);
880     if (VLOG_IS_ON(1)) {
881       DumpGraphDefToFile(
882           strings::StrCat("pflr_after_all_optimization_passes_",
883                           reinterpret_cast<uintptr_t>(optimized_subgraph), "_",
884                           pair.first),
885           optimized_subgraph->ToGraphDefDebug());
886     }
887   }
888 
889   if (options.graph_collector != nullptr) {
890     for (const auto& pair : subgraphs) {
891       GraphDef def;
892       pair.second->ToGraphDef(&def);
893       *def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto();
894       options.graph_collector->CollectPartitionedGraph(def);
895     }
896   }
897 
898   // We must preserve control returns in each of the function components,
899   // otherwise after function inlining we might prune side-effectful nodes.
900   const auto control_ret =
901       [&node_name_to_control_ret](const Node* n) -> absl::optional<string> {
902     const auto it = node_name_to_control_ret.find(n->name());
903     return it != node_name_to_control_ret.end()
904                ? absl::make_optional<string>(it->second)
905                : absl::nullopt;
906   };
907 
908   int i = 0;
909   // Generate a random function_name to avoid one function reuse the partition
910   // function instantiated by another function.
911   FunctionLibraryDefinition* data_lib_def = &data->lib_def_;
912   FunctionNameGenerator name_generator(
913       data_lib_def, absl::StrCat(function_name, "_", random::New64()));
914   auto subgraph_size = subgraphs.size();
915   gtl::InlinedVector<Status, 4> instantiate_status(subgraph_size);
916   BlockingCounter counter(static_cast<int>(subgraph_size));
917   auto runner = [this, subgraph_size](std::function<void()> fn) {
918     // NOTE: Only use thread pool to instantiate sub-function when there are
919     // more than 8 sub-functions. We want to avoid cost of switching thread when
920     // there are only a few sub-functions.
921     if (default_thread_pool_ != nullptr && subgraph_size > 8) {
922       default_thread_pool_->Schedule(fn);
923     } else {
924       fn();
925     }
926   };
927   for (const auto& pair : subgraphs) {
928     Status* status = &instantiate_status[i];
929     string unique_name = name_generator.GetName();
930     ComponentFunctionData* comp_data = &data->glue_[pair.first];
931     runner([this, &pair, dev_set, comp_data, unique_name, data_lib_def,
932             &control_ret, &options, status, &counter, &data] {
933       const string& target = pair.first;
934 
935       const string& device_type =
936           dev_set->FindDeviceByName(target)->device_type();
937       Graph* subgraph = pair.second.get();
938 
939       status->Update(UpdateArgAndRetvalMetadata(
940           subgraph, device_type, &comp_data->arg_indices,
941           &comp_data->ret_indices, &comp_data->arg_alloc_attrs,
942           &comp_data->ret_alloc_attrs));
943       if (!status->ok()) {
944         counter.DecrementCount();
945         return;
946       }
947       FunctionDef shard;
948       status->Update(
949           GraphToFunctionDef(*subgraph, unique_name, control_ret, &shard));
950       if (!status->ok()) {
951         counter.DecrementCount();
952         return;
953       }
954       status->Update(data_lib_def->AddFunctionDef(shard));
955       if (!status->ok()) {
956         counter.DecrementCount();
957         return;
958       }
959       FunctionLibraryRuntime::InstantiateOptions opts;
960       opts.executor_type = options.executor_type;
961       opts.target = target;
962       opts.lib_def = data_lib_def;
963       opts.create_kernels_eagerly = options.create_kernels_eagerly;
964       opts.state_handle = options.state_handle;
965       auto attrs = AttrSlice(&shard.attr());
966       VLOG(1) << "Start instantiating component function " << unique_name
967               << " on device " << target;
968       VLOG(4) << DebugString(shard);
969 
970       auto* component_handle = new FunctionLibraryRuntime::Handle;
971       auto done = [this, status, unique_name, comp_data, component_handle,
972                    &data, &counter](const Status& s) {
973         status->Update(s);
974 
975         VLOG(1) << "Finished instantiating component function " << unique_name
976                 << " with handle " << *component_handle << " status: " << s;
977         if (status->ok()) {
978           {
979             mutex_lock l(mu_);
980             if (function_data_[*component_handle]->is_cross_process()) {
981               data->is_cross_process_ = true;
982             }
983           }
984           comp_data->handle = *component_handle;
985         }
986         delete component_handle;
987         counter.DecrementCount();
988       };
989 
990       FunctionLibraryRuntime* flr = GetFLR(opts.target);
991       if (flr != nullptr) {
992         // Initialize local function synchronously.
993         Status s = flr->Instantiate(unique_name, attrs, opts, component_handle);
994         done(s);
995       } else {
996         opts.ret_indices = comp_data->ret_indices;
997         // Initialize remote function asynchronously.
998         InstantiateRemote(unique_name, attrs, opts, component_handle, done);
999       }
1000     });
1001     i += 1;
1002   }
1003   counter.Wait();
1004   StatusGroup group;
1005   for (auto& status : instantiate_status) {
1006     group.Update(status);
1007   }
1008   TF_RETURN_IF_ERROR(group.as_summary_status());
1009 
1010   *handle = AddMultiDeviceHandle(std::move(data), function_key);
1011   VLOG(2) << "Instantiated MultiDevice function \"" << function_name
1012           << "\" with handle " << *handle;
1013   return Status::OK();
1014 }
1015 
GetOutputDevices(FunctionLibraryRuntime::Handle handle,std::vector<Device * > * output_devices) const1016 Status ProcessFunctionLibraryRuntime::GetOutputDevices(
1017     FunctionLibraryRuntime::Handle handle,
1018     std::vector<Device*>* output_devices) const {
1019   MultiDeviceFunctionData* data = IsMultiDevice(handle);
1020   if (data == nullptr) {
1021     return errors::InvalidArgument(
1022         "Failed for find multi-device function handle ", handle);
1023   }
1024 
1025   for (const auto& pair : data->glue_) {
1026     const ComponentFunctionData& comp_data = pair.second;
1027     DCHECK(comp_data.ret_alloc_attrs.size() == comp_data.ret_indices.size());
1028     if (comp_data.ret_indices.empty()) {
1029       continue;
1030     }
1031 
1032     const string& target = pair.first;
1033     FunctionLibraryRuntime* target_flr = GetFLR(target);
1034     Device* target_device = nullptr;
1035     Device* host = nullptr;
1036     if (target_flr == nullptr) {
1037       if (!data->has_remote_outputs) {
1038         data->has_remote_outputs = true;
1039       }
1040       target_device = device_set()->FindDeviceByName(target);
1041       string remote_host;
1042       TF_RETURN_IF_ERROR(
1043           DeviceNameUtils::DeviceNameToCpuDeviceName(target, &remote_host));
1044       host = device_set()->FindDeviceByName(remote_host);
1045     } else {
1046       target_device = target_flr->device();
1047     }
1048     output_devices->resize(data->num_outputs_);
1049     for (int j = 0; j < comp_data.ret_indices.size(); ++j) {
1050       int ret_index = comp_data.ret_indices[j];
1051       if (data->ret_types_[ret_index] == DT_RESOURCE) {
1052         (*output_devices)[ret_index] = target_device;
1053       } else {
1054         (*output_devices)[ret_index] =
1055             comp_data.ret_alloc_attrs[j].on_host() ? host : target_device;
1056       }
1057     }
1058   }
1059 
1060   return Status::OK();
1061 }
1062 
RunMultiDevice(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,std::vector<FunctionRet> * rets,std::vector<std::unique_ptr<CleanUpItem>> * cleanup_items,FunctionLibraryRuntime::DoneCallback done,std::function<Status (const ComponentFunctionData & comp_data,InternalArgs * args)> get_component_args) const1063 void ProcessFunctionLibraryRuntime::RunMultiDevice(
1064     const FunctionLibraryRuntime::Options& opts,
1065     FunctionLibraryRuntime::Handle handle, std::vector<FunctionRet>* rets,
1066     std::vector<std::unique_ptr<CleanUpItem>>* cleanup_items,
1067     FunctionLibraryRuntime::DoneCallback done,
1068     std::function<Status(const ComponentFunctionData& comp_data,
1069                          InternalArgs* args)>
1070         get_component_args) const {
1071   if (opts.create_rendezvous) {
1072     // FLR->Run() is the default entry point. It checks for cancellation,
1073     // creates rendezvous, etc.
1074     // Letting create_rendezvous through will do the wrong thing - each
1075     // component function will get a separate rendezvous created by its FLR.
1076     done(
1077         errors::Internal("Cannot call ProcessFunctionLibraryRuntime::Run with "
1078                          "create_rendezvous=true. Please run the function "
1079                          "using FunctionLibraryRuntime::Run"));
1080     return;
1081   }
1082 
1083   const MultiDeviceFunctionData* data = IsMultiDevice(handle);
1084   if (data == nullptr) {
1085     done(errors::NotFound("Multi-device function handle ", handle,
1086                           "not found. Was the function instantiated?"));
1087     return;
1088   }
1089 
1090   VLOG(1) << "Running multi-device function " << data->function_name_;
1091   VLOG(4) << "    with " << opts.DebugString();
1092 
1093   if (data->glue_.empty()) {
1094     // Trivial case where the function body is empty.
1095     done(Status::OK());
1096     return;
1097   }
1098 
1099   // Check whether we have the right rendezvous.
1100   if (opts.rendezvous && data->is_cross_process_ &&
1101       !opts.rendezvous->is_cross_process()) {
1102     done(errors::InvalidArgument(
1103         "Running a cross process function ", data->function_name_,
1104         " without an appropriate cross process Rendezvous."));
1105     return;
1106   }
1107 
1108   // A locally created cancellation manager, used only when the caller does not
1109   // provide one in argument.
1110   std::shared_ptr<CancellationManager> local_cm;
1111   CancellationManager* cm = opts.cancellation_manager;
1112   if (cm == nullptr) {
1113     local_cm = std::make_shared<CancellationManager>();
1114     cm = local_cm.get();
1115   }
1116 
1117   auto* refcounted_done = new ReffedStatusCallback(std::move(done));
1118   for (int i = 0; i < data->glue_.size(); ++i) {
1119     refcounted_done->Ref();
1120   }
1121 
1122   FunctionLibraryRuntime::Options opts_copy = opts;
1123   for (const auto& pair : data->glue_) {
1124     const string& target = pair.first;
1125     const ComponentFunctionData& comp_data = pair.second;
1126     FunctionLibraryRuntime::Handle handle = pair.second.handle;
1127 
1128     opts_copy.args_alloc_attrs = comp_data.arg_alloc_attrs;
1129     opts_copy.rets_alloc_attrs = comp_data.ret_alloc_attrs;
1130     opts_copy.cancellation_manager = cm;
1131 
1132     InternalArgs comp_args;
1133     Status s = get_component_args(comp_data, &comp_args);
1134     if (!s.ok()) {
1135       VLOG(2) << "Failed to get component function arguments: " << s;
1136       refcounted_done->UpdateStatus(s);
1137       refcounted_done->Unref();
1138       cm->StartCancel();
1139       continue;
1140     }
1141     std::vector<FunctionRet>* comp_rets = new std::vector<FunctionRet>;
1142     rets->resize(data->num_outputs_);
1143 
1144     auto component_fn_callback = [comp_rets, rets, comp_data, refcounted_done,
1145                                   cm, local_cm, data, handle,
1146                                   target](const Status& status) {
1147       if (!status.ok()) {
1148         VLOG(2) << "Component function execution on target " << target
1149                 << " from " << data->function_name_ << " with handle " << handle
1150                 << " failed: " << status;
1151         const string function_and_msg = strings::StrCat(
1152             errors::FormatFunctionForError(data->function_name_), " ",
1153             status.error_message());
1154         refcounted_done->UpdateStatus(Status(status.code(), function_and_msg));
1155         // Cancel the execution of other component functions.
1156         cm->StartCancel();
1157       } else {
1158         VLOG(2) << "Component function execution on target " << target
1159                 << " from " << data->function_name_ << " with handle " << handle
1160                 << " succeeded.";
1161         for (int i = 0; i < comp_rets->size(); ++i) {
1162           (*rets)[comp_data.ret_indices[i]] = (*comp_rets)[i];
1163         }
1164       }
1165       delete comp_rets;
1166       // refcounted_done is thread-safe
1167       refcounted_done->Unref();
1168     };
1169 
1170     FunctionLibraryRuntime* flr = GetFLR(target);
1171     if (flr != nullptr) {
1172       opts_copy.remote_execution = false;
1173       // When target device has private thread pool, use the target device
1174       // runner
1175       thread::ThreadPool* pool = flr->device()->tensorflow_device_thread_pool();
1176       opts_copy.runner = (pool == nullptr) ? opts_copy.runner : flr->runner();
1177 
1178       VLOG(1) << "Running component function on device " << target << " from "
1179               << data->function_name_ << " with handle " << handle;
1180       VLOG(4) << "    with " << opts_copy.DebugString();
1181 
1182       std::vector<Tensor>* comp_tensor_rets = new std::vector<Tensor>;
1183       flr->Run(
1184           opts_copy, handle, GetLocalArgs(comp_args.args), comp_tensor_rets,
1185           TensorsToFunctionRetsDoneCallback(comp_rets, comp_tensor_rets,
1186                                             std::move(component_fn_callback)));
1187     } else {
1188       opts_copy.remote_execution = true;
1189 
1190       VLOG(1) << "Running component function on device " << target << " from "
1191               << data->function_name_ << " with handle " << handle;
1192       VLOG(4) << "    with " << opts_copy.DebugString();
1193 
1194       RunInternal(opts_copy, handle, comp_args.args, comp_rets, cleanup_items,
1195                   std::move(component_fn_callback));
1196     }
1197   }
1198   refcounted_done->Unref();
1199 }
1200 
Instantiate(const string & function_name,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::Handle * handle)1201 Status ProcessFunctionLibraryRuntime::Instantiate(
1202     const string& function_name, AttrSlice attrs,
1203     const FunctionLibraryRuntime::InstantiateOptions& options,
1204     FunctionLibraryRuntime::Handle* handle) {
1205   if (options.is_multi_device_function) {
1206     return InstantiateMultiDevice(function_name, attrs, options, handle);
1207   }
1208 
1209   *handle = kInvalidHandle;
1210   FunctionLibraryRuntime* flr = GetFLR(options.target);
1211   if (flr != nullptr) {
1212     return flr->Instantiate(function_name, attrs, options, handle);
1213   }
1214 
1215   Status status;
1216   Notification notification;
1217   InstantiateRemote(function_name, attrs, options, handle,
1218                     [&status, &notification](const Status& s) {
1219                       status = s;
1220                       notification.Notify();
1221                     });
1222   notification.WaitForNotification();
1223   return status;
1224 }
1225 
IsCrossProcess(FunctionLibraryRuntime::Handle handle,bool * is_cross_process) const1226 Status ProcessFunctionLibraryRuntime::IsCrossProcess(
1227     FunctionLibraryRuntime::Handle handle, bool* is_cross_process) const {
1228   tf_shared_lock l(mu_);
1229   const auto& mdevice_it = mdevice_data_.find(handle);
1230   if (mdevice_it != mdevice_data_.end()) {
1231     *is_cross_process = mdevice_it->second->is_cross_process_;
1232     return Status::OK();
1233   }
1234   const auto& it = function_data_.find(handle);
1235   if (it != function_data_.end()) {
1236     *is_cross_process = it->second->is_cross_process();
1237     return Status::OK();
1238   }
1239   return errors::InvalidArgument("Handle ", handle, " not found.");
1240 }
1241 
InstantiateRemote(const string & function_name,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::Handle * handle,FunctionLibraryRuntime::DoneCallback done)1242 void ProcessFunctionLibraryRuntime::InstantiateRemote(
1243     const string& function_name, AttrSlice attrs,
1244     const FunctionLibraryRuntime::InstantiateOptions& options,
1245     FunctionLibraryRuntime::Handle* handle,
1246     FunctionLibraryRuntime::DoneCallback done) {
1247   if (parent_ == nullptr) {
1248     done(errors::Internal(
1249         "Currently don't support instantiating functions on device: ",
1250         options.target));
1251     return;
1252   }
1253   auto target = options.target;
1254   VLOG(1) << "ProcessFLR Instantiate: " << function_name << " on: " << target;
1255   string function_key = Canonicalize(function_name, attrs, options);
1256   FunctionData* f;
1257   {
1258     mutex_lock l(mu_);
1259     FunctionLibraryRuntime::Handle h =
1260         gtl::FindWithDefault(table_, function_key, kInvalidHandle);
1261     if (h == kInvalidHandle || function_data_.count(h) == 0) {
1262       h = AddHandleLocked(function_key, target, kInvalidHandle);
1263     }
1264     f = function_data_[h].get();
1265     *handle = h;
1266   }
1267   f->DistributedInit(
1268       parent_, function_name,
1269       options.lib_def == nullptr ? *lib_def_ : *options.lib_def, attrs, options,
1270       [this, function_name, target, handle, done](const Status& s) {
1271         VLOG(1) << "ProcessFLR Instantiate [success]: " << function_name
1272                 << " on: " << target << " with handle: " << *handle
1273                 << " (this: " << this << ")";
1274         done(s);
1275       });
1276 }
1277 
RemoveHandle(FunctionLibraryRuntime::Handle handle)1278 Status ProcessFunctionLibraryRuntime::RemoveHandle(
1279     FunctionLibraryRuntime::Handle handle) {
1280   mutex_lock l(mu_);
1281   table_.erase(function_data_[handle]->function_key());
1282   function_data_.erase(handle);
1283   return Status::OK();
1284 }
1285 
ReleaseMultiDeviceHandle(FunctionLibraryRuntime::Handle handle)1286 Status ProcessFunctionLibraryRuntime::ReleaseMultiDeviceHandle(
1287     FunctionLibraryRuntime::Handle handle) {
1288   std::unique_ptr<MultiDeviceFunctionData> mdata;
1289   {
1290     mutex_lock l(mu_);
1291     auto it = mdevice_data_.find(handle);
1292     --it->second->instantiation_counter_;
1293     if (it->second->instantiation_counter_ != 0) {
1294       return Status::OK();
1295     }
1296     mdata = std::move(it->second);
1297     table_.erase(mdata->function_key_);
1298     mdevice_data_.erase(it);
1299   }
1300 
1301   // If we are here we are releasing the last instantiation of `handle`.
1302   // Release all component function handles.
1303   Status overall_status;
1304   for (const auto& it : mdata->glue_) {
1305     const string& device = it.first;
1306     FunctionLibraryRuntime::Handle flr_handle = it.second.handle;
1307     FunctionLibraryRuntime* flr = GetFLR(device);
1308     if (flr == nullptr) {
1309       // TODO(nareshmodi): Implement DeregisterGraph call to remote device if
1310       // parent is not null.
1311       if (parent_ != nullptr) {
1312         return errors::Unimplemented(
1313             "Releasing a multi-device component handle on a remote device is "
1314             "not yet implemented.");
1315       }
1316       return errors::InvalidArgument(
1317           "Failed to find FunctionLibraryRuntime for device ", device,
1318           " when releasing multi-device function handle ", handle);
1319     }
1320     Status status = flr->ReleaseHandle(flr_handle);
1321     if (!status.ok()) {
1322       overall_status = status;
1323     }
1324   }
1325 
1326   return overall_status;
1327 }
1328 
ReleaseHandle(FunctionLibraryRuntime::Handle handle)1329 Status ProcessFunctionLibraryRuntime::ReleaseHandle(
1330     FunctionLibraryRuntime::Handle handle) {
1331   // Return directly if all function handles has already been released.
1332   if (flr_map_ == nullptr) return Status::OK();
1333 
1334   if (IsMultiDevice(handle)) {
1335     return ReleaseMultiDeviceHandle(handle);
1336   }
1337 
1338   FunctionLibraryRuntime* flr = nullptr;
1339   string target_device;
1340   {
1341     mutex_lock l(mu_);
1342     CHECK_EQ(1, function_data_.count(handle)) << " handle: " << handle;
1343     target_device = function_data_[handle]->target_device();
1344   }
1345   flr = GetFLR(target_device);
1346   if (flr != nullptr) {
1347     return flr->ReleaseHandle(handle);
1348   }
1349   return errors::InvalidArgument("Handle not found: ", handle);
1350 }
1351 
1352 FunctionLibraryRuntime::DoneCallback
ApplyCleanUpToDoneCallback(std::vector<std::unique_ptr<CleanUpItem>> * items,FunctionLibraryRuntime::DoneCallback done,const int64 step_id,const Rendezvous * created_rendezvous) const1353 ProcessFunctionLibraryRuntime::ApplyCleanUpToDoneCallback(
1354     std::vector<std::unique_ptr<CleanUpItem>>* items,
1355     FunctionLibraryRuntime::DoneCallback done, const int64 step_id,
1356     const Rendezvous* created_rendezvous) const {
1357   return
1358       [this, items, done = std::move(done), step_id,
1359        created_rendezvous](const Status& status) {
1360         if (created_rendezvous) {
1361           DCHECK(rendezvous_factory_);
1362           created_rendezvous->Unref();
1363           Status s = rendezvous_factory_.CleanUp(step_id);
1364           if (!s.ok()) {
1365             LOG(ERROR) << s;
1366           }
1367         }
1368         auto* local_status = new Status(status);
1369         CleanUp(items, [local_status, done](const Status& cleanup_status) {
1370           local_status->Update(cleanup_status);
1371           done(*local_status);
1372           delete local_status;
1373         });
1374         delete items;
1375       };
1376 }
1377 
CreateRendezvous(const FunctionLibraryRuntime::Options & opts,Rendezvous ** created_rendezvous) const1378 Status ProcessFunctionLibraryRuntime::CreateRendezvous(
1379     const FunctionLibraryRuntime::Options& opts,
1380     Rendezvous** created_rendezvous) const {
1381   if (rendezvous_factory_) {
1382     return rendezvous_factory_(opts.step_id, device_mgr_, created_rendezvous);
1383   } else {
1384     return errors::FailedPrecondition(
1385         "The caller does not provide a rendezvous and "
1386         "ProcessFunctionLibraryRuntime was created without a rendezvous "
1387         "factory.");
1388   }
1389 }
1390 
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,FunctionLibraryRuntime::DoneCallback done) const1391 void ProcessFunctionLibraryRuntime::Run(
1392     const FunctionLibraryRuntime::Options& opts,
1393     FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
1394     std::vector<Tensor>* rets,
1395     FunctionLibraryRuntime::DoneCallback done) const {
1396   FunctionLibraryRuntime::Options new_opts = opts;
1397   Rendezvous* created_rendezvous = nullptr;
1398   if (!opts.rendezvous) {
1399     Status s = CreateRendezvous(opts, &created_rendezvous);
1400     if (!s.ok()) {
1401       done(s);
1402       return;
1403     }
1404     new_opts.rendezvous = created_rendezvous;
1405     new_opts.create_rendezvous = false;
1406   }
1407 
1408   auto* cleanup_items = new std::vector<std::unique_ptr<CleanUpItem>>;
1409   done = ApplyCleanUpToDoneCallback(cleanup_items, std::move(done),
1410                                     new_opts.step_id, created_rendezvous);
1411   std::vector<FunctionRet>* function_rets = new std::vector<FunctionRet>;
1412   done = [rets, function_rets, done = std::move(done)](const Status& s) {
1413     Status status = s;
1414     if (status.ok()) {
1415       for (const auto& ret : *function_rets) {
1416         if (ret.index() == 0) {
1417           rets->push_back(absl::get<Tensor>(ret));
1418         } else {
1419           status.Update(errors::Internal(
1420               "Expect a Tensor as a function output but got a TensorShape."));
1421           break;
1422         }
1423       }
1424     }
1425     delete function_rets;
1426     done(status);
1427   };
1428   bool multi_device;
1429   {
1430     tf_shared_lock l(mu_);
1431     multi_device = mdevice_data_.find(handle) != mdevice_data_.end();
1432   }
1433   if (multi_device) {
1434     auto get_component_args = [&args](const ComponentFunctionData& comp_data,
1435                                       InternalArgs* comp_args) -> Status {
1436       // "Index"s of _Arg nodes are unique when all arguments are local Tensors.
1437       for (const auto& it : comp_data.arg_indices) {
1438         if (it.index >= args.size()) {
1439           return errors::InvalidArgument(
1440               "index ", it.index, " is out of range [0, ", args.size(), ")");
1441         }
1442         if (it.sub_index >= 0) {
1443           const Tensor& t = args[it.index];
1444           if (t.dtype() != DT_RESOURCE) {
1445             return errors::InvalidArgument("Got unexpected sub_index ",
1446                                            it.sub_index, " for argument ",
1447                                            it.index);
1448           }
1449           const auto& handles = t.flat<ResourceHandle>();
1450           if (it.sub_index >= handles.size()) {
1451             return errors::InvalidArgument(
1452                 "Sub_index ", it.sub_index, "is out of range [0,",
1453                 handles.size(), ") for argument ", it.index);
1454           }
1455           comp_args->args.push_back(Tensor(handles(it.sub_index)));
1456         } else {
1457           comp_args->args.push_back(args[it.index]);
1458         }
1459       }
1460       return Status::OK();
1461     };
1462     return RunMultiDevice(new_opts, handle, function_rets, cleanup_items,
1463                           std::move(done), std::move(get_component_args));
1464   }
1465   std::vector<FunctionArg> local_args;
1466   for (const auto& tensor : args) {
1467     local_args.push_back(tensor);
1468   }
1469   RunInternal(new_opts, handle, local_args, function_rets, cleanup_items,
1470               std::move(done));
1471 }
1472 
RunInternal(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,gtl::ArraySlice<FunctionArg> args,std::vector<FunctionRet> * rets,std::vector<std::unique_ptr<CleanUpItem>> * cleanup_items,FunctionLibraryRuntime::DoneCallback done) const1473 void ProcessFunctionLibraryRuntime::RunInternal(
1474     const FunctionLibraryRuntime::Options& opts,
1475     FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<FunctionArg> args,
1476     std::vector<FunctionRet>* rets,
1477     std::vector<std::unique_ptr<CleanUpItem>>* cleanup_items,
1478     FunctionLibraryRuntime::DoneCallback done) const {
1479   FunctionLibraryRuntime* flr = nullptr;
1480   string target_device;
1481   FunctionLibraryRuntime::LocalHandle local_handle;
1482   {
1483     tf_shared_lock l(mu_);
1484     auto iter = function_data_.find(handle);
1485     if (iter == function_data_.end()) {
1486       done(errors::NotFound("Handle: ", handle, " not found."));
1487       return;
1488     }
1489     FunctionData* function_data = iter->second.get();
1490     target_device = function_data->target_device();
1491     local_handle = function_data->local_handle();
1492   }
1493 
1494   if (!opts.remote_execution) {
1495     done(
1496         errors::InvalidArgument("ProcessFunctionLibraryRuntime::Run should "
1497                                 "only be called for multi-device functions or "
1498                                 "for remote execution."));
1499     return;
1500   }
1501 
1502   flr = GetFLR(target_device);
1503   if (flr != nullptr) {
1504     auto rendezvous = opts.rendezvous;
1505     string source_device = opts.source_device;
1506     DeviceContext* device_context;
1507     Status s = GetDeviceContext(source_device, &device_context);
1508     if (!s.ok()) {
1509       done(s);
1510       return;
1511     }
1512     int64 src_incarnation, target_incarnation;
1513     s = GetDeviceIncarnation(source_device, &src_incarnation);
1514     s.Update(GetDeviceIncarnation(target_device, &target_incarnation));
1515     if (!s.ok()) {
1516       done(s);
1517       return;
1518     }
1519 
1520     std::vector<Tensor> local_args = GetLocalArgs(args);
1521 
1522     // Send the args over to the target device.
1523     s = SendTensors(source_device, target_device, "arg_", src_incarnation,
1524                     local_args, device_context, opts.args_alloc_attrs,
1525                     rendezvous);
1526     if (!s.ok()) {
1527       done(s);
1528       return;
1529     }
1530     const std::vector<AllocatorAttributes>& rets_alloc_attrs =
1531         opts.rets_alloc_attrs;
1532     std::vector<Tensor>* remote_rets = new std::vector<Tensor>;
1533     flr->Run(opts, handle, local_args, remote_rets,
1534              [source_device, target_device, target_incarnation, rendezvous,
1535               device_context, rets_alloc_attrs, remote_rets, rets,
1536               done = std::move(done)](const Status& status) mutable {
1537                if (!status.ok()) {
1538                  delete remote_rets;
1539                  done(status);
1540                  return;
1541                }
1542                int64 num_returns = remote_rets->size();
1543                delete remote_rets;
1544                // Now receive the return values from the target.
1545                std::vector<Tensor>* recv_tensors = new std::vector<Tensor>;
1546                ReceiveTensorsAsync(target_device, source_device, "ret_",
1547                                    target_incarnation, num_returns,
1548                                    device_context, rets_alloc_attrs, rendezvous,
1549                                    recv_tensors,
1550                                    TensorsToFunctionRetsDoneCallback(
1551                                        rets, recv_tensors, std::move(done)));
1552              });
1553     return;
1554   }
1555   if (parent_ != nullptr) {
1556     auto cleanup_item = absl::make_unique<CleanUpItem>();
1557     cleanup_item->device = target_device;
1558     cleanup_item->step_id = opts.step_id;
1559     cleanup_item->local_handle = local_handle;
1560     cleanup_items->emplace_back(std::move(cleanup_item));
1561     parent_->Run(opts, local_handle, args, rets, std::move(done));
1562     return;
1563   }
1564   done(errors::Internal("Could not find device"));
1565 }
1566 
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,CallFrameInterface * frame,FunctionLibraryRuntime::DoneCallback done) const1567 void ProcessFunctionLibraryRuntime::Run(
1568     const FunctionLibraryRuntime::Options& opts,
1569     FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame,
1570     FunctionLibraryRuntime::DoneCallback done) const {
1571   std::vector<Tensor> args;
1572   args.reserve(frame->num_args());
1573   for (size_t i = 0; i < frame->num_args(); ++i) {
1574     const Tensor* arg;
1575     Status s = frame->GetArg(i, &arg);
1576     args.emplace_back(*arg);
1577     if (!s.ok()) {
1578       done(s);
1579     }
1580   }
1581   std::vector<Tensor>* rets = new std::vector<Tensor>;
1582   rets->reserve(frame->num_retvals());
1583 
1584   Run(opts, handle, args, rets,
1585 
1586       [frame, rets, done = std::move(done)](const Status& status) {
1587         std::unique_ptr<std::vector<Tensor>> rets_releaser(rets);
1588 
1589         if (!status.ok()) {
1590           done(status);
1591           return;
1592         }
1593 
1594         if (rets->size() != frame->num_retvals()) {
1595           done(errors::Internal(
1596               "Number of return values from function (", rets->size(),
1597               ") did not match expected number of return values (",
1598               frame->num_retvals(), ")."));
1599           return;
1600         }
1601 
1602         for (size_t i = 0; i < frame->num_retvals(); ++i) {
1603           Status s = frame->SetRetval(i, (*rets)[i]);
1604           if (!s.ok()) {
1605             done(s);
1606             return;
1607           }
1608         }
1609         done(Status::OK());
1610       });
1611 }
1612 
RunSync(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets) const1613 Status ProcessFunctionLibraryRuntime::RunSync(
1614     const FunctionLibraryRuntime::Options& opts,
1615     FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
1616     std::vector<Tensor>* rets) const {
1617   Notification n;
1618   Status s;
1619   Run(opts, handle, args, rets, [&n, &s](const Status& status) {
1620     s.Update(status);
1621     n.Notify();
1622   });
1623   n.WaitForNotification();
1624   return s;
1625 }
1626 
RunSync(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,CallFrameInterface * frame) const1627 Status ProcessFunctionLibraryRuntime::RunSync(
1628     const FunctionLibraryRuntime::Options& opts,
1629     FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame) const {
1630   Notification n;
1631   Status s;
1632   Run(opts, handle, frame, [&n, &s](const Status& status) {
1633     s.Update(status);
1634     n.Notify();
1635   });
1636   n.WaitForNotification();
1637   return s;
1638 }
1639 
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,const FunctionArgsInterface & args,std::vector<FunctionRet> * rets,FunctionLibraryRuntime::DoneCallback done) const1640 void ProcessFunctionLibraryRuntime::Run(
1641     const FunctionLibraryRuntime::Options& opts,
1642     FunctionLibraryRuntime::Handle handle, const FunctionArgsInterface& args,
1643     std::vector<FunctionRet>* rets,
1644     FunctionLibraryRuntime::DoneCallback done) const {
1645   bool has_remote_outputs = false;
1646   const MultiDeviceFunctionData* data = IsMultiDevice(handle);
1647   if (data != nullptr) {
1648     has_remote_outputs = data->has_remote_outputs;
1649   }
1650   if (!args.HasRemoteOrPackedInputs() && !has_remote_outputs) {
1651     const std::vector<Tensor> local_inputs = args.GetLocalTensors();
1652     std::vector<Tensor>* tensor_rets = new std::vector<Tensor>;
1653     return Run(
1654         opts, handle, local_inputs, tensor_rets,
1655         TensorsToFunctionRetsDoneCallback(rets, tensor_rets, std::move(done)));
1656   }
1657 
1658   FunctionLibraryRuntime::Options new_opts = opts;
1659   Rendezvous* created_rendezvous = nullptr;
1660   if (!opts.rendezvous) {
1661     Status s = CreateRendezvous(opts, &created_rendezvous);
1662     if (!s.ok()) {
1663       done(s);
1664       return;
1665     }
1666     new_opts.rendezvous = created_rendezvous;
1667     new_opts.create_rendezvous = false;
1668   }
1669 
1670 #if defined(IS_MOBILE_PLATFORM)
1671   done(errors::Unimplemented(
1672       "Remote inputs are not available on mobile devices."));
1673   return;
1674 #else   // !IS_MOBILE_PLATFORM
1675   auto* cleanup_items = new std::vector<std::unique_ptr<CleanUpItem>>;
1676   done = ApplyCleanUpToDoneCallback(cleanup_items, done, opts.step_id,
1677                                     created_rendezvous);
1678 
1679   auto get_component_args = [&args](const ComponentFunctionData& comp_data,
1680                                     InternalArgs* comp_args) -> Status {
1681     for (int i = 0; i < comp_data.arg_indices.size(); ++i) {
1682       const FunctionArgIndex index = comp_data.arg_indices.at(i);
1683       Tensor tensor;
1684       if (args.GetLocalArg(index, &tensor).ok()) {
1685         comp_args->args.push_back(std::move(tensor));
1686       } else {
1687         eager::RemoteTensorHandle remote_handle;
1688         TF_RETURN_IF_ERROR(args.GetRemoteArg(index, &remote_handle));
1689         comp_args->remote_args.emplace_back(
1690             absl::make_unique<eager::RemoteTensorHandle>(
1691                 std::move(remote_handle)));
1692         comp_args->args.push_back(comp_args->remote_args.back().get());
1693       }
1694     }
1695     return Status::OK();
1696   };
1697   return RunMultiDevice(new_opts, handle, rets, cleanup_items, std::move(done),
1698                         std::move(get_component_args));
1699 #endif  // !IS_MOBILE_PLATFORM
1700 }
1701 
CleanUp(std::vector<std::unique_ptr<CleanUpItem>> * items,FunctionLibraryRuntime::DoneCallback done) const1702 void ProcessFunctionLibraryRuntime::CleanUp(
1703     std::vector<std::unique_ptr<CleanUpItem>>* items,
1704     FunctionLibraryRuntime::DoneCallback done) const {
1705   auto* refcounted_done = new ReffedStatusCallback(std::move(done));
1706   for (auto& item : *items) {
1707     refcounted_done->Ref();
1708     auto* flr = GetFLR(item->device);
1709     if (flr != nullptr) {
1710       // TODO(fishx): cleanup state for local execution.
1711       refcounted_done->UpdateStatus(
1712           errors::Internal("Cleanup items shouldn't contain local item."));
1713       refcounted_done->Unref();
1714     } else if (parent_ != nullptr) {
1715       parent_->CleanUp(item->step_id, item->local_handle,
1716                        [refcounted_done](const Status& status) {
1717                          if (!status.ok()) {
1718                            refcounted_done->UpdateStatus(status);
1719                          }
1720                          // refcounted_done is thread-safe
1721                          refcounted_done->Unref();
1722                        });
1723     } else {
1724       refcounted_done->UpdateStatus(
1725           errors::Internal("Could not find device in cleanup."));
1726       refcounted_done->Unref();
1727     }
1728   }
1729   refcounted_done->Unref();
1730 }
1731 
Clone(Env * env,int graph_def_version,const OptimizerOptions & optimizer_options,std::unique_ptr<FunctionLibraryDefinition> * out_lib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> * out_pflr,bool skip_flib_def) const1732 Status ProcessFunctionLibraryRuntime::Clone(
1733     Env* env, int graph_def_version, const OptimizerOptions& optimizer_options,
1734     std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
1735     std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
1736     bool skip_flib_def) const {
1737   if (skip_flib_def) {
1738     *out_lib_def = absl::make_unique<FunctionLibraryDefinition>(
1739         lib_def_->default_registry(), FunctionDefLibrary{});
1740   } else {
1741     *out_lib_def = absl::make_unique<FunctionLibraryDefinition>(*lib_def_);
1742   }
1743   *out_pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
1744       device_mgr_, env, config_ ? &(*config_) : nullptr, graph_def_version,
1745       out_lib_def->get(), optimizer_options, default_thread_pool_, parent_,
1746       session_metadata_, rendezvous_factory_);
1747   {
1748     tf_shared_lock l(mu_);
1749     for (auto* d : composite_devices_) (*out_pflr)->AddCompositeDevice(d);
1750   }
1751   return Status::OK();
1752 }
1753 
1754 }  // namespace tensorflow
1755