1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_
17 
18 #include <unordered_map>
19 
20 // clang-format off
21 // Required for IS_MOBILE_PLATFORM
22 #include "tensorflow/core/platform/platform.h"
23 // clang-format on
24 
25 #include "absl/types/optional.h"
26 #include "absl/types/variant.h"
27 #include "tensorflow/core/common_runtime/composite_device.h"
28 #include "tensorflow/core/common_runtime/device_mgr.h"
29 #include "tensorflow/core/common_runtime/device_set.h"
30 #include "tensorflow/core/framework/function.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/protobuf/config.pb.h"
34 #if !defined(IS_MOBILE_PLATFORM)
35 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
36 #endif  // IS_MOBILE_PLATFORM
37 
38 namespace tensorflow {
39 
40 class FunctionArgsInterface {
41  public:
~FunctionArgsInterface()42   virtual ~FunctionArgsInterface() {}
43 
44   virtual bool HasRemoteOrPackedInputs() const = 0;
45 
46   virtual Status GetLocalArg(const FunctionArgIndex& index,
47                              Tensor* val) const = 0;
48 
49   virtual std::vector<Tensor> GetLocalTensors() const = 0;
50 
51 #if !defined(IS_MOBILE_PLATFORM)
GetRemoteArg(const FunctionArgIndex & index,eager::RemoteTensorHandle * val)52   virtual Status GetRemoteArg(const FunctionArgIndex& index,
53                               eager::RemoteTensorHandle* val) const {
54     return errors::Unimplemented(
55         "Serializing a remote argument is not implemented.");
56   }
57 #endif  // IS_MOBILE_PLATFORM
58 };
59 
60 // A class that stores all the FunctionLibraryRuntime objects, one per device.
61 class ProcessFunctionLibraryRuntime {
62  public:
63   // Creates FunctionLibraryRuntime objects for each device in the provided
64   // DeviceMgr. Caller needs to make sure that device_mgr, lib_def and parent
65   // (if provided) outlive this object.
66   ProcessFunctionLibraryRuntime(
67       const DeviceMgr* device_mgr, Env* env, const ConfigProto* config,
68       int graph_def_version, const FunctionLibraryDefinition* lib_def,
69       const OptimizerOptions& optimizer_options,
70       thread::ThreadPool* thread_pool = nullptr,
71       DistributedFunctionLibraryRuntime* parent = nullptr,
72       const SessionMetadata* session_metadata = nullptr,
73       Rendezvous::Factory rendezvous_factory = Rendezvous::Factory());
74 
~ProcessFunctionLibraryRuntime()75   ~ProcessFunctionLibraryRuntime() {
76     // Deleting the FunctionLibraryRuntime map will delete the function handles
77     // registered in it, which may call ReleaseHandle in this class again to
78     // release their sub-function. These circular calls may cause segfault
79     // since the flr_map_ may have already been deleted. Explicitly releasing
80     // flr_map_ here and checking flr_map_ in ReleaseHandle to avoid this.
81     flr_map_.reset();
82   }
83 
84   // Sends `tensors_to_send` from `source_device` to `target_device` using
85   // `rendezvous`. `key_prefix` is used as a prefix for the keys sent to the
86   // Rendezvous. `device_context` should be the DeviceContext of the device
87   // doing the sending. `alloc_attrs` should either be empty or be the size of
88   // `tensors_to_send` and indicates how the input tensors are allocated. Method
89   // takes references on each of the `tensors_to_send`. Method doesn't block.
90   static Status SendTensors(const string& source_device,
91                             const string& target_device,
92                             const string& key_prefix, int64 src_incarnation,
93                             gtl::ArraySlice<Tensor> tensors_to_send,
94                             DeviceContext* device_context,
95                             const std::vector<AllocatorAttributes>& alloc_attrs,
96                             RendezvousInterface* rendezvous);
97 
98   // Receives `received_tensors` from `target_device` (originally sent from
99   // `source_device`) using `rendezvous`. Uses `key_prefix` to construct the
100   // keys to be retrieved. `device_context` should be for the device receiving
101   // the tensors. `alloc_attrs` indicates how to allocate the received
102   // tensors and should either be empty or `num_tensors` in size. Method doesn't
103   // block and calls `done` when `num_tensors` are fetched.
104   static void ReceiveTensorsAsync(
105       const string& source_device, const string& target_device,
106       const string& key_prefix, int64 src_incarnation, int64 num_tensors,
107       DeviceContext* device_context,
108       const std::vector<AllocatorAttributes>& alloc_attrs,
109       RendezvousInterface* rendezvous, std::vector<Tensor>* received_tensors,
110       StatusCallback done);
111 
112   static const char kDefaultFLRDevice[];
113   // Returns the FunctionLibraryRuntime for the corresponding device_name.
114   FunctionLibraryRuntime* GetFLR(const string& device_name) const;
115 
116   // Returns the return types for the function identified by handle `h`.
117   Status GetRetTypes(FunctionLibraryRuntime::Handle h,
118                      DataTypeVector* ret_types);
119 
120   // Returns the device incarnation for the given device_name.
121   Status GetDeviceIncarnation(const string& device_name,
122                               int64* incarnation) const;
123 
124   // For a given canonicalized key signature of the function instantiated
125   // on device `device_name` and a `local_handle`, creates a handle and returns
126   // that value. Uses core/common_runtime/framework/function.h::Canonicalize
127   // to canonicalize the function signature.
128   FunctionLibraryRuntime::Handle AddHandle(
129       const string& function_key, const string& device_name,
130       FunctionLibraryRuntime::LocalHandle local_handle);
131 
132   // Returns a handle if found for the given key, else returns kInvalidHandle.
133   FunctionLibraryRuntime::Handle GetHandle(const string& function_key) const;
134 
135   // For the given handle instantiated on device `device_name` returns the local
136   // index of instantiation of that function. If the function was not
137   // instantiated on `device_name` or the function is multi-device,
138   // returns kInvalidLocalHandle.
139   //
140   // If `include_multi_device` is true and `handle` is a multi-device function
141   // with a single component that is placed on `device_name`, then this method
142   // will return the local handle for that component.
143   FunctionLibraryRuntime::LocalHandle GetHandleOnDevice(
144       const string& device_name, FunctionLibraryRuntime::Handle handle,
145       bool include_multi_device = false) const;
146 
147   // Fills `output_devices` with the devices on which the results will
148   // be produced. If some output is produced on CPU, the corresponding Device*
149   // is set to nullptr. If some output is DT_RESOURCE, the corresponding Device*
150   // is set to the device backing the resource.
151   // REQUIRES: `handle` identifies a multi-device function.
152   Status GetOutputDevices(FunctionLibraryRuntime::Handle handle,
153                           std::vector<Device*>* output_devices) const;
154 
155   // Returns true if function with handle `handle` was instantiated on device
156   // `device_name`. Returns false for multi-device functions.
157   bool IsInstantiatedOnDevice(const string& device_name,
158                               FunctionLibraryRuntime::Handle handle) const;
159 
160   // Instantiates the function. See framework/function.h for more details.
161   // Allows for function_name to be instantiated on different devices
162   // as specified in attrs.
163   Status Instantiate(const string& function_name, AttrSlice attrs,
164                      const FunctionLibraryRuntime::InstantiateOptions& options,
165                      FunctionLibraryRuntime::Handle* handle);
166 
167   // Returns whether the function represented by the given handle needs to
168   // execute cross process.
169   Status IsCrossProcess(FunctionLibraryRuntime::Handle handle,
170                         bool* is_cross_process) const;
171 
172   // Delegates to the local FLR that owns state corresponding to `handle` and
173   // tells it to release it. If the `handle` isn't needed at all, the local FLR
174   // might call RemoveHandle on this to get rid of the state owned by the Proc
175   // FLR.
176   // For multi-device functions, calls ReleaseHandle on local FLRs for each
177   // component function that is part of this multi-device function.
178   // Each local FLR might call RemoveHandle on this.
179   Status ReleaseHandle(FunctionLibraryRuntime::Handle handle);
180 
181   // Runs the function with given `handle`. Function could have been
182   // instantiated on any device. More details in framework/function.h
183   void Run(const FunctionLibraryRuntime::Options& opts,
184            FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
185            std::vector<Tensor>* rets,
186            FunctionLibraryRuntime::DoneCallback done) const;
187   void Run(const FunctionLibraryRuntime::Options& opts,
188            FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame,
189            FunctionLibraryRuntime::DoneCallback done) const;
190 
191   void Run(const FunctionLibraryRuntime::Options& opts,
192            FunctionLibraryRuntime::Handle handle,
193            const FunctionArgsInterface& args, std::vector<FunctionRet>* rets,
194            FunctionLibraryRuntime::DoneCallback done) const;
195 
196   Status RunSync(const FunctionLibraryRuntime::Options& opts,
197                  FunctionLibraryRuntime::Handle handle,
198                  gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets) const;
199   Status RunSync(const FunctionLibraryRuntime::Options& opts,
200                  FunctionLibraryRuntime::Handle handle,
201                  CallFrameInterface* frame) const;
202 
device_mgr()203   const DeviceMgr* device_mgr() { return device_mgr_; }
204 
device_set()205   const std::shared_ptr<DeviceSet> device_set() const {
206     tf_shared_lock l(mu_);
207     return device_set_;
208   }
209 
210   // Initialize the set of local and remote devices and corresponding flr for op
211   // device selection.
212   void InitializeDeviceAndFlr();
213 
config()214   const ConfigProto* config() const { return config_ ? &(*config_) : nullptr; }
215 
GetFunctionLibraryDefinition()216   const FunctionLibraryDefinition* GetFunctionLibraryDefinition() const {
217     return lib_def_;
218   }
219 
220   // Add a CompositeDevice to `device_set_`
AddCompositeDevice(CompositeDevice * d)221   void AddCompositeDevice(CompositeDevice* d) TF_LOCKS_EXCLUDED(mu_) {
222     mutex_lock l(mu_);
223     device_set_->AddDevice(d);
224     composite_devices_.push_back(d);
225   }
226 
227  protected:
228   friend class FunctionLibraryRuntimeImpl;
229 
230   struct InternalArgs {
231     std::vector<FunctionArg> args;
232 #if !defined(IS_MOBILE_PLATFORM)
233     // Holds the RemoteTensorHandles referred by args.
234     std::vector<std::unique_ptr<eager::RemoteTensorHandle>> remote_args;
235 #endif  // IS_MOBILE_PLATFORM
236   };
237 
238   // Structure to keep track of how a component function (a single-device
239   // piece of a multi-device function) fits into the multi-device function.
240   struct ComponentFunctionData {
241     // The handle for the instantiated component function.
242     FunctionLibraryRuntime::Handle handle;
243     // arg_indices.size() is the number of arguments to the component function.
244     // The i-th argument of the component function comes from the
245     // `arg_indices[i]`-th argument of the multi-device function.
246     std::vector<FunctionArgIndex> arg_indices;
247     // ret_indices.size() is the number of return values of the component
248     // function.  The i-th return value of the component function goes to the
249     // `ret_indices[i]`-th return value of the multi-device function.
250     std::vector<int> ret_indices;
251     // arg_alloc_attrs[i] are the allocator attributes of the i-th argument to
252     // the component function.
253     std::vector<AllocatorAttributes> arg_alloc_attrs;
254     // ret_alloc_attrs[i] are the allocator attributes of the i-th return value
255     // of the component function.
256     std::vector<AllocatorAttributes> ret_alloc_attrs;
257   };
258 
259   // Data structure holding information for a single instantiated multi-device
260   // function.
261   // The fields are filled in during instantiation. Once the object is
262   // added to mdevice_data_, all fields are constant.
263   struct MultiDeviceFunctionData {
MultiDeviceFunctionDataMultiDeviceFunctionData264     MultiDeviceFunctionData(const string& function_name,
265                             const string& function_key, int num_outputs,
266                             FunctionLibraryDefinition&& lib_def,
267                             DataTypeVector ret_types)
268         : function_name_(function_name),
269           function_key_(function_key),
270           instantiation_counter_(1),
271           lib_def_(std::move(lib_def)),
272           num_outputs_(num_outputs),
273           ret_types_(std::move(ret_types)),
274           is_cross_process_(false),
275           has_remote_outputs(false) {}
276 
277     const string function_name_;
278     const string function_key_;
279     uint64 instantiation_counter_;
280     // A library that contains definitions of component functions and their
281     // transitive dependencies.
282     FunctionLibraryDefinition lib_def_;
283     // Stored here to resize the output tensor vector when function is run.
284     const int num_outputs_;
285     DataTypeVector ret_types_;
286 
287     // Indicates whether this function needs to execute cross process.
288     bool is_cross_process_;
289     // Indicates whether this function has remote outputs.
290     bool has_remote_outputs;
291 
292     // Maps the device name to the information about the component function
293     // be run on this device.
294     std::unordered_map<string, ComponentFunctionData> glue_;
295   };
296 
297   struct CleanUpItem {
298     string device;
299     uint64 step_id;
300     FunctionLibraryRuntime::Handle local_handle;
301   };
302 
303   // If `handle` represents a multi-device function, returns the multi-device
304   // data associated with `handle`. Else, nullptr.
305   MultiDeviceFunctionData* IsMultiDevice(
306       FunctionLibraryRuntime::Handle handle) const;
307 
308   void RunMultiDevice(
309       const FunctionLibraryRuntime::Options& opts,
310       FunctionLibraryRuntime::Handle handle, std::vector<FunctionRet>* rets,
311       std::vector<std::unique_ptr<CleanUpItem>>* cleanup_items,
312       FunctionLibraryRuntime::DoneCallback done,
313       std::function<Status(const ComponentFunctionData& comp_data,
314                            InternalArgs* args)>
315           get_component_args) const;
316 
317   Status CreateRendezvous(const FunctionLibraryRuntime::Options& opts,
318                           Rendezvous** created_rendezvous) const;
319 
320   FunctionLibraryRuntime::DoneCallback ApplyCleanUpToDoneCallback(
321       std::vector<std::unique_ptr<CleanUpItem>>* items,
322       FunctionLibraryRuntime::DoneCallback done, const int64 step_id,
323       const Rendezvous* rendezvous) const;
324 
325   DistributedFunctionLibraryRuntime* const parent_;
326 
327  private:
328   FunctionLibraryRuntime::Handle AddHandleLocked(
329       const string& function_key, const string& device_name,
330       FunctionLibraryRuntime::LocalHandle local_handle)
331       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
332 
333   // For a given device_name, returns a DeviceContext for copying
334   // tensors to/from the device.
335   Status GetDeviceContext(const string& device_name,
336                           DeviceContext** device_context) const;
337 
338   // Looks up the information for the given `handle` and returns the name
339   // of the device where the function is registered.
340   string GetDeviceName(FunctionLibraryRuntime::Handle handle) const;
341 
342   // Removes handle from the state owned by this object.
343   Status RemoveHandle(FunctionLibraryRuntime::Handle handle);
344 
345   // Clones ProcessFunctionLibraryRuntime and FunctionLibraryDefinition
346   // (transferring ownership of both to the caller). Note that the
347   // ProcessFunctionLibraryRuntime borrows a pointer to the
348   // FunctionLibraryDefinition and so the FunctionLibraryDefinition should
349   // outlive the ProcessFunctionLibraryRuntime.
350   //
351   // The `skip_flib_def` argument controls whether the method should clone the
352   // FunctionLibraryDefinition (default behavior) or return an empty function
353   // library. The latter is used by tf.data, which manages
354   // FunctionLibraryDefinitions for its functions independently (and passes
355   // these into the FunctionLibraryRuntime through an overlay), to avoid linear
356   // runtime w.r.t. to number of functions in the current function library.
357   Status Clone(Env* env, int graph_def_version,
358                const OptimizerOptions& optimizer_options,
359                std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
360                std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
361                bool skip_flib_def = false) const;
362 
363   Status ReleaseMultiDeviceHandle(FunctionLibraryRuntime::Handle handle);
364 
365   Status InstantiateMultiDevice(
366       const string& function_name, AttrSlice attrs,
367       const FunctionLibraryRuntime::InstantiateOptions& options,
368       FunctionLibraryRuntime::Handle* handle);
369 
370   void InstantiateRemote(
371       const string& function_name, AttrSlice attrs,
372       const FunctionLibraryRuntime::InstantiateOptions& options,
373       FunctionLibraryRuntime::Handle* handle,
374       FunctionLibraryRuntime::DoneCallback done);
375 
376   FunctionLibraryRuntime::Handle AddMultiDeviceHandle(
377       const std::unique_ptr<MultiDeviceFunctionData> data,
378       const string& function_key);
379 
380   // TODO(iga): Reword
381   // Pins each arg that emits a `DT_RESOURCE` tensor to the device on which the
382   // corresponding resource lives. This ensures that the Placer assigns ops that
383   // access these resources to the appropriate devices.
384   Status PinArgsAndRets(const std::vector<string>& input_devices,
385                         const std::vector<string>& output_devices,
386                         const DeviceSet& device_set,
387                         const std::vector<Node*>& arg_nodes,
388                         const std::vector<Node*>& ret_nodes,
389                         Device* default_device) const;
390 
391   void RunInternal(const FunctionLibraryRuntime::Options& opts,
392                    FunctionLibraryRuntime::Handle handle,
393                    gtl::ArraySlice<FunctionArg> args,
394                    std::vector<FunctionRet>* rets,
395                    std::vector<std::unique_ptr<CleanUpItem>>* cleanup_items,
396                    FunctionLibraryRuntime::DoneCallback done) const;
397 
398   void CleanUp(std::vector<std::unique_ptr<CleanUpItem>>* items,
399                FunctionLibraryRuntime::DoneCallback done) const;
400 
401   // Data structure holding information for a single instantiated remote
402   // (to be executed on `target_device`) function.
403   class FunctionData {
404    public:
FunctionData(const string & target_device,FunctionLibraryRuntime::LocalHandle local_handle,const string & function_key)405     FunctionData(const string& target_device,
406                  FunctionLibraryRuntime::LocalHandle local_handle,
407                  const string& function_key)
408         : target_device_(target_device),
409           local_handle_(local_handle),
410           function_key_(function_key) {}
411 
target_device()412     const string& target_device() { return target_device_; }
function_key()413     const string& function_key() { return function_key_; }
414 
local_handle()415     FunctionLibraryRuntime::LocalHandle local_handle() {
416       mutex_lock l(mu_);
417       return local_handle_;
418     }
419 
420     // Initializes the FunctionData object by potentially making an Initialize
421     // call to the DistributedFunctionLibraryRuntime.
422     void DistributedInit(
423         DistributedFunctionLibraryRuntime* parent, const string& function_name,
424         const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
425         const FunctionLibraryRuntime::InstantiateOptions& options,
426         FunctionLibraryRuntime::DoneCallback done);
427 
is_cross_process()428     bool is_cross_process() {
429       mutex_lock l(mu_);
430       return is_cross_process_;
431     }
432 
433    private:
434     mutex mu_;
435 
436     const string target_device_;
437     FunctionLibraryRuntime::LocalHandle local_handle_ TF_GUARDED_BY(mu_);
438     const string function_key_;
439     bool is_cross_process_ TF_GUARDED_BY(mu_) = false;
440     bool init_started_ TF_GUARDED_BY(mu_) = false;
441     Status init_result_ TF_GUARDED_BY(mu_);
442     Notification init_done_;
443   };
444 
445   mutable mutex mu_;
446 
447   Env* const env_;
448   const absl::optional<const ConfigProto> config_;
449   const DeviceMgr* const device_mgr_;
450   const FunctionLibraryDefinition* lib_def_;
451   thread::ThreadPool* default_thread_pool_;
452 
453   // Cluster update can reinitialize the device_set_ due to remote device
454   // changes. At the same time, InstantiateMultiDevice can use the cached
455   // devices to instantiate multi-worker functions. Function instantiation would
456   // fail if it spans the changed remote devices.
457   std::shared_ptr<DeviceSet> device_set_ TF_GUARDED_BY(mu_);
458 
459   // Composite devices owned by a EagerContext.
460   std::vector<CompositeDevice*> composite_devices_ TF_GUARDED_BY(mu_);
461 
462   // Holds all the function instantiations. Maps function_keys to handles.
463   std::unordered_map<string, FunctionLibraryRuntime::Handle> table_
464       TF_GUARDED_BY(mu_);
465 
466   // Function data for instantiated remote functions.
467   std::unordered_map<FunctionLibraryRuntime::Handle,
468                      std::unique_ptr<FunctionData>>
469       function_data_ TF_GUARDED_BY(mu_);
470 
471   // Function data for instantiated multi-device functions.
472   std::unordered_map<FunctionLibraryRuntime::Handle,
473                      std::unique_ptr<MultiDeviceFunctionData>>
474       mdevice_data_ TF_GUARDED_BY(mu_);
475 
476   std::unique_ptr<
477       std::unordered_map<Device*, std::unique_ptr<FunctionLibraryRuntime>>>
478       flr_map_;
479   int next_handle_ TF_GUARDED_BY(mu_);
480   const SessionMetadata* const session_metadata_;
481   const Rendezvous::Factory rendezvous_factory_;
482 
483   const OptimizerOptions optimizer_options_;
484   const int graph_def_version_;
485 };
486 
487 }  // namespace tensorflow
488 
489 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_
490