1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <map>
21 #include <memory>
22 #include <queue>
23 #include <string>
24 #include <unordered_set>
25 #include <vector>
26 
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/types/optional.h"
29 #include "tensorflow/c/eager/immediate_execution_context.h"
30 #include "tensorflow/core/common_runtime/composite_device.h"
31 #include "tensorflow/core/common_runtime/device_factory.h"
32 #include "tensorflow/core/common_runtime/device_mgr.h"
33 #include "tensorflow/core/common_runtime/eager/custom_device.h"
34 #include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h"
35 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
36 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
37 #include "tensorflow/core/common_runtime/function.h"
38 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
39 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
40 #include "tensorflow/core/example/example.pb.h"
41 #include "tensorflow/core/framework/collective.h"
42 #include "tensorflow/core/framework/function.h"
43 #include "tensorflow/core/framework/log_memory.h"
44 #include "tensorflow/core/framework/rendezvous.h"
45 #include "tensorflow/core/framework/tensor.h"
46 #include "tensorflow/core/lib/core/status.h"
47 #include "tensorflow/core/lib/core/stringpiece.h"
48 #include "tensorflow/core/lib/core/threadpool.h"
49 #include "tensorflow/core/lib/gtl/flatmap.h"
50 #include "tensorflow/core/lib/gtl/flatset.h"
51 #include "tensorflow/core/lib/gtl/inlined_vector.h"
52 #include "tensorflow/core/lib/gtl/map_util.h"
53 #include "tensorflow/core/platform/casts.h"
54 #include "tensorflow/core/platform/env.h"
55 #include "tensorflow/core/platform/fingerprint.h"
56 #include "tensorflow/core/platform/mutex.h"
57 #include "tensorflow/core/platform/platform.h"
58 #include "tensorflow/core/platform/status.h"
59 #include "tensorflow/core/platform/thread_annotations.h"
60 #include "tensorflow/core/public/session_options.h"
61 #include "tensorflow/core/public/version.h"
62 #include "tensorflow/core/util/device_name_utils.h"
63 
64 // "tensorflow/core/platform/platform.h" must be included first before using
65 // IS_MOBILE_PLATFORM.
66 #if !defined(IS_MOBILE_PLATFORM)
67 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
68 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
69 #include "tensorflow/core/distributed_runtime/server_lib.h"
70 #include "tensorflow/core/distributed_runtime/worker_cache.h"
71 #include "tensorflow/core/distributed_runtime/worker_env.h"
72 #endif  // !IS_MOBILE_PLATFORM
73 
74 namespace tensorflow {
75 
76 namespace eager {
77 // We need this forward declaration because we have circular dependency:
78 // Context -> RemoteMgr -> TensorHandle -> Context.
79 // TODO(fishx): Remove this once we remove Context dependency in TensorHandle.
80 class RemoteMgr;
81 }  // namespace eager
82 
83 class TensorHandle;
84 class EagerOperation;
85 
86 class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
87  public:
88   static constexpr uint64 kInvalidContextId = 0;
89 
NewContextId()90   static uint64 NewContextId() {
91     uint64 context_id = random::New64();
92     while (context_id == kInvalidContextId) {
93       context_id = random::New64();
94     }
95     return context_id;
96   }
97 
98   EagerContext(const SessionOptions& opts,
99                ContextDevicePlacementPolicy default_device_placement_policy,
100                bool async, const DeviceMgr* device_mgr, bool device_mgr_owned,
101                Rendezvous* rendezvous,
102                DistributedFunctionLibraryRuntime* cluster_flr = nullptr);
103 
Release()104   void Release() override { Unref(); }
105 
106   AbstractTensorInterface* CreateInt64Scalar(int64 value) override;
107   AbstractTensorInterface* CreateUint64Scalar(uint64 value) override;
108   AbstractTensorInterface* CreateInt32Scalar(int32 value) override;
109   AbstractTensorInterface* CreateFloatScalar(float value) override;
110   AbstractTensorInterface* CreateDoubleScalar(double value) override;
111   AbstractTensorInterface* CreateHalfScalar(Eigen::half value) override;
112   AbstractTensorInterface* CreateStringScalar(
113       tensorflow::tstring value) override;
114   AbstractTensorInterface* CreateComplex128Scalar(
115       tensorflow::complex128 value) override;
116   AbstractTensorInterface* CreateBoolScalar(bool value) override;
117 
118   AbstractTensorInterface* CreateTensor(
119       DataType dtype, absl::Span<const int64> dim_sizes) override;
120   AbstractTensorInterface* CreateTensor(DataType dtype, const int64_t* dims,
121                                         int num_dims, void* data, size_t len,
122                                         MemoryReleaser memory_releaser,
123                                         void* memory_releaser_arg) override;
124 
125   ImmediateExecutionTensorHandle* CreateLocalHandle(
126       AbstractTensorInterface* t) override;
127   // Create an abstract tensor handle from tensorflow::Tensor.
128   ImmediateExecutionTensorHandle* CreateLocalHandleFromTFTensor(
129       tensorflow::Tensor& t, const char* d_name) override;
130   ImmediateExecutionTensorHandle* CopyTensorHandleToDevice(
131       ImmediateExecutionTensorHandle* handle, const char* device_name,
132       Status* status) override;
133   ImmediateExecutionOperation* CreateOperation() override;
134 
135   // This is a virtual helper function to convert TFRT TensorHandle to
136   // tensorflow::TensorHandle. In current runtime EagerContext, just forward
137   // the input since the input tensor handle is already a
138   // tensorflow::TensorHandle.
139   ImmediateExecutionTensorHandle* TFTensorHandleFromInterface(
140       ImmediateExecutionTensorHandle* handle) override;
141 
142   Status RegisterFunction(AbstractFunction* f) override;
143 
144   bool UsesTFRT() override;
145 
146   void ListDevices(std::vector<DeviceAttributes>* devices) override;
147 
148   // Returns the function library runtime for the given device.
func_lib(const Device * d)149   FunctionLibraryRuntime* func_lib(const Device* d) const {
150     return pflr_->GetFLR(d->name());
151   }
152 
pflr()153   ProcessFunctionLibraryRuntime* pflr() const { return pflr_.get(); }
154 
runner()155   std::function<void(std::function<void()>)>* runner() { return &runner_; }
156 
157   // Specify a executor for this thread.
158   void SetExecutorForThread(EagerExecutor* executor) override;
159 
prioritized_device_type_list()160   const std::shared_ptr<std::vector<DeviceType>> prioritized_device_type_list()
161       const {
162     mutex_lock l(device_type_list_mu_);
163     return prioritized_device_type_list_;
164   }
165 
166   // Clear pending nodes in thread executors and kernel caches.
167   void ClearCachesAndThreadExecutors() override;
168   // Clear pending nodes in default executor and kernel caches.
169   void ClearCachesAndDefaultExecutor();
170 
171   // Sets the device placement policy for the current thread.
172   void SetThreadLocalDevicePlacementPolicy(
173       ContextDevicePlacementPolicy policy) override;
174 
175   // Returns the device placement policy for the current thread.
176   ContextDevicePlacementPolicy GetDevicePlacementPolicy() const override;
177 
178   // Select an appropriate device for an operation.
179   //
180   // Given the preferred device for the operation, and the node_def, finds the
181   // best suitable device for the operation in this context.
182   //
183   // The preferred device is specified as a `ParsedName` containing the elements
184   // (details) that the resulting device should match. If there are no such
185   // devices, and the context currently allows soft device placement, a suitable
186   // device not matching `preferred` will be chosen.
187   //
188   // The chosen device is stored in the `device` argument. The argument is not
189   // modified unless this method returns `Status::OK()`.
190   Status SelectDevice(DeviceNameUtils::ParsedName preferred,
191                       const NodeDef& ndef, Device** out) const;
192 
193   bool FindFunctionByName(const string& name) const;
194 
195   Status FindFunctionOpData(const string& name,
196                             const tensorflow::OpRegistrationData** op_data);
197 
198   const FunctionDef* FindFunctionDef(const string& name) const override;
199 
HostCPU()200   Device* HostCPU() const { return host_cpu_device_; }
CanonicalDevice(Device * d)201   Device* CanonicalDevice(Device* d) const {
202     return HostCPU() == d ? nullptr : d;
203   }
HostCPUParsedName()204   const DeviceNameUtils::ParsedName& HostCPUParsedName() const override {
205     return HostCPU()->parsed_name();
206   }
207 
HostCPUName()208   const string& HostCPUName() const override { return HostCPU()->name(); }
209 
GetGraphCollector()210   GraphCollector* GetGraphCollector() { return &graph_collector_; }
211 
212   EagerExecutor& Executor() override;
213 
214   // Add the given `fdef` to the local FunctionLibraryDefinition. And add an
215   // entry to the KernelAndDevice cache for it if it's not exist.
216   Status AddFunctionDef(const FunctionDef& fdef) override;
217 
218   Status AddFunctionDefWithStackTraces(
219       const FunctionDef& fdef, const StackTracesMap& stack_traces) override;
220 
221   // `library` contains all FunctionDefs and GradientDefs to expand `fdef`. Add
222   // it to the local FunctionLibraryDefinition as well, but no need to add it
223   // to the KernelAndDevice cache since they won't be executed as
224   // KernelAndDevices.
225   Status AddFunctionDef(const FunctionDef& fdef,
226                         const FunctionDefLibrary& library,
227                         bool add_to_local_only = false,
228                         const StackTracesMap& stack_traces = {});
229 
230   const FunctionDef* GetFunctionDef(const string& function_name);
231 
232   std::vector<string> ListFunctionNames() override;
233 
234   Status RemoveFunction(const string& func) override;
235 
236   // Wait for pending nodes to be finished in local executors (including context
237   // default executor and thread executors) and executors on remote workers.
238   // Return combined status of remote executors. If there are multiple errors,
239   // the Status code will be the same as the first remote executor that has
240   // errors, and the error message will be combined from all executors.
241   Status SyncExecutors();
242 
AsyncWait()243   Status AsyncWait() override { return SyncExecutors(); }
244 
245   core::RefCountPtr<KernelAndDevice> GetCachedKernel(Fprint128 cache_key);
246 
247   void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel);
248 
LogDevicePlacement()249   bool LogDevicePlacement() const { return log_device_placement_; }
SetLogDevicePlacement(bool enable)250   void SetLogDevicePlacement(bool enable) override {
251     log_device_placement_ = enable;
252   }
253 
254   // When tensor transfer across functions/eager executions using send/recv ops
255   // are required, `reuse_rendezvous_for_functions_` can be set to true so that
256   // function executions and eager executions use the same rendezvous instance,
257   // instead of creating new instance per function calls.
SetReuseRendezvousForFunctions(bool reuse_rendezvous_for_functions)258   void SetReuseRendezvousForFunctions(bool reuse_rendezvous_for_functions) {
259     reuse_rendezvous_for_functions_ = reuse_rendezvous_for_functions;
260   }
GetReuseRendezvousForFunctions()261   bool GetReuseRendezvousForFunctions() const {
262     return reuse_rendezvous_for_functions_;
263   }
264 
AllowSoftPlacement()265   bool AllowSoftPlacement() const { return allow_soft_placement_; }
SetAllowSoftPlacement(bool enable)266   void SetAllowSoftPlacement(bool enable) override {
267     allow_soft_placement_ = enable;
268   }
LogMemory()269   bool LogMemory() const { return log_memory_; }
270 
GetRendezvous()271   Rendezvous* GetRendezvous() const { return rendezvous_; }
272 
273   // Returns a function which maps from step_id to rendezvous. This closure
274   // respects the value of `SetReuseRendezvousForFunctions` at the time the
275   // closure was created, which allows the setting to be toggled around async op
276   // launches.
277   //
278   // The caller of the returned function owns a reference to the resulting
279   // Rendezvous.
RendezvousCreator()280   std::function<Rendezvous*(int64)> RendezvousCreator() {
281     if (reuse_rendezvous_for_functions_) {
282       return [this](int64 step_id) {
283         // Increment reference count as `rendezvous_` will be unref'ed after
284         // function execution.
285         rendezvous_->Ref();
286         return rendezvous_;
287       };
288     } else {
289       return [this](int64 step_id) { return CreateRendezvous(step_id); };
290     }
291   }
292 
collective_executor_mgr()293   CollectiveExecutorMgrInterface* collective_executor_mgr() {
294     return collective_executor_mgr_.Get();
295   }
GetCollectiveExecutorHandle()296   std::unique_ptr<CollectiveExecutor::Handle> GetCollectiveExecutorHandle() {
297     return std::unique_ptr<CollectiveExecutor::Handle>(
298         new CollectiveExecutor::Handle(
299             collective_executor_mgr()->FindOrCreate(0), true /*inherit_ref*/));
300   }
301 
local_device_mgr()302   const tensorflow::DeviceMgr* local_device_mgr() const {
303     return local_device_manager_.Get();
304   }
remote_device_mgr()305   const tensorflow::DynamicDeviceMgr* remote_device_mgr() const {
306     return remote_device_manager_.Get();
307   }
308 
GetOwnedRemoteDeviceMgr()309   tensorflow::DynamicDeviceMgr* GetOwnedRemoteDeviceMgr() {
310     return remote_device_manager_.GetOwned();
311   }
312 
ListLocalTfDevices()313   std::vector<Device*> ListLocalTfDevices() override {
314     return local_device_mgr()->ListDevices();
315   }
316 
317   // TODO(apassos) clean up RunMetadata storage.
MetadataMu()318   mutex* MetadataMu() TF_LOCK_RETURNED(metadata_mu_) { return &metadata_mu_; }
319   bool ShouldStoreGraphs() TF_LOCKS_EXCLUDED(metadata_mu_);
320   void SetShouldStoreGraphs(bool value) override;
RunMetadataProto()321   RunMetadata* RunMetadataProto() TF_EXCLUSIVE_LOCKS_REQUIRED(metadata_mu_) {
322     return run_metadata_.get();
323   }
324   std::unique_ptr<RunMetadata> ExportRunMetadata() override
325       TF_LOCKS_EXCLUDED(metadata_mu_);
326 
327   void StartStep() override;
328   void EndStep() override;
329   ScopedStepContainer* StepContainer();
330 
FuncLibDef()331   FunctionLibraryDefinition* FuncLibDef() { return &func_lib_def_; }
332 
333 #if !defined(IS_MOBILE_PLATFORM)
334   // Assign the EagerClient pointer to `client` based on the given device / task
335   // name, and increment the refcount of the client. The reference ownership is
336   // transferred to the caller, and the unref should automatically happen when
337   // destructing the RefCountPtr object at the caller's side.
338   // `client` must not be initialized or holding a reference of another object
339   // before calling this method.
340   Status GetClient(Device* device,
341                    core::RefCountPtr<eager::EagerClient>* client);
342   Status GetClient(const DeviceNameUtils::ParsedName& device_name,
343                    core::RefCountPtr<eager::EagerClient>* client);
344   Status GetClient(const string& remote_task,
345                    core::RefCountPtr<eager::EagerClient>* client);
346 
347   uint64 GetContextId() const;
348   uint64 GetContextViewId() const;
349   void IncrementContextViewId();
350 
351   // TODO(nareshmodi): Encapsulate remote state into a separate
352   // class/struct.
353   //
354   // Enables the eager context to communicate with remote devices. When
355   // initializing with this method, this context will be the primary context,
356   // which will kill all its remote contexts in shutdown.
357   //
358   // - server: A ServerInterface that exports the tensorflow.WorkerService.
359   // Note that this class expects the server to already have been started.
360   // - remote_eager_workers: A cache from which we can get "EagerClient"s to
361   // communicate with remote eager services.
362   // - remote_device_mgr: A DeviceMgr* which contains all remote devices
363   // (should contain no local devices).
364   // - remote_contexts: A vector containing task names.
365   Status InitializeRemoteMaster(
366       std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
367       std::shared_ptr<WorkerSession> worker_session,
368       std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
369       std::unique_ptr<DynamicDeviceMgr> remote_device_manager,
370       const std::vector<string>& remote_contexts, uint64 context_id,
371       Rendezvous* r, const DeviceMgr* local_device_mgr, int keep_alive_secs,
372       DistributedFunctionLibraryRuntime* cluster_flr,
373       std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
374           remote_mgr);
375 
376   // Update an existing master context with a new set of remote workers (i.e., a
377   // new "view" of cluster membership. Similar to InitializeRemoteMaster but
378   // this will keep the current context_id and increment a context_view_id, will
379   // keep the current resource manager so that resources from the previous view
380   // can still be accessed, and will automatically register existing functions
381   // if there are newly added hosts.
382   Status UpdateRemoteMaster(
383       uint64 context_id,
384       std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
385       const std::vector<string>& add_remote_contexts,
386       const std::vector<string>& remove_remote_contexts);
387 
388   // Similar with InitializeRemoteMaster but this context will not kill remote
389   // contexts in shutdown.
390   Status InitializeRemoteWorker(
391       std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
392       DynamicDeviceMgr* remote_device_mgr,
393       const std::vector<string>& remote_contexts, uint64 context_id,
394       uint64 context_view_id,
395       std::function<Rendezvous*(const int64)> rendezvous_creator,
396       DistributedFunctionLibraryRuntime* cluster_flr,
397       std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
398           remote_mgr,
399       std::function<void()> resource_deallocator);
400 
401   // Similar with InitializeRemoteWorker but will reuse existing context and
402   // increment context_view_id.
403   Status UpdateRemoteWorker(
404       std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
405       const std::vector<string>& remote_contexts, uint64 context_id);
406 
407   Status StoreCollectiveOpsServer(
408       std::unique_ptr<ServerInterface> new_server, const DeviceMgr* device_mgr,
409       CollectiveExecutorMgrInterface* rpc_collective_executor_mgr);
410 
411   // For the specified remote worker, preprocess and set its device filters.
412   Status SetRemoteDeviceFilters(const string& remote_worker,
413                                 const std::vector<string>& device_filters);
414 
415   // For the specified remote worker, apply the stored device filters to the
416   // list of device attributes following these rules:
417   // (1) if the remote worker does not have device filters, all devices are
418   //     visible to the worker;
419   // (2) if the device is on the remote worker, then it is visible;
420   // (3) if the device matches at least one device filter, then it is visible.
421   // The result is saved as a boolean vector of the same length (i.e.,
422   // filtered_device_mask) indicating whether each of the devices is visible to
423   // the remote worker.
424   void FilterDevicesForRemoteWorkers(
425       const string& remote_worker,
426       const protobuf::RepeatedPtrField<DeviceAttributes>& device_attrs,
427       std::vector<bool>* filtered_device_mask);
428 
429   // TODO(fishx): Remove the custom deleter once we remove forward declaration.
430   const std::unique_ptr<eager::RemoteMgr,
431                         std::function<void(eager::RemoteMgr*)>>&
RemoteMgr()432   RemoteMgr() {
433     return remote_mgr_;
434   }
435 
436   // If true, then tensors should be shipped across processes via the
437   // EagerService.Enqueue(SendTensorOp). If false, _Send/_Recv ops should be
438   // used instead (which in-turn use WorkerService.RecvTensor RPCs).
UseSendTensorRPC()439   bool UseSendTensorRPC() { return use_send_tensor_rpc_; }
440 
GetServer()441   tensorflow::ServerInterface* GetServer() { return server_.get(); }
442 
443   // For LLVM style RTTI.
classof(const AbstractContext * ptr)444   static bool classof(const AbstractContext* ptr) {
445     return ptr->getKind() == kEager;
446   }
447 
448   // Function to support distributed C API.
SetDistributedManager(std::unique_ptr<ImmediateExecutionDistributedManager> distributed)449   void SetDistributedManager(
450       std::unique_ptr<ImmediateExecutionDistributedManager> distributed)
451       override {
452     distributed_manager_ = std::move(distributed);
453   }
GetDistributedManager()454   ImmediateExecutionDistributedManager* GetDistributedManager() override {
455     return distributed_manager_.get();
456   }
457 #endif  // IS_MOBILE_PLATFORM
458 
459   // Closes remote eager contexts, waits for all RPCs to finish, and
460   // destroys the EagerClientCache. No RPCs can be made through this context
461   // after this method has been called.
462   // This method exists to aid a clean shutdown. It causes all RPCs to finish
463   // and remote TensorHandles to release their references to this context.
464   // To avoid deadlocks, this method must not be called on the thread
465   // processing RPCs because it makes RPCs and waits for their completion.
466   //
467   // On mobile, it just cleans the caches.
468   void WaitForAndCloseRemoteContexts();
469 
PinSmallOpsToCPU()470   bool PinSmallOpsToCPU() const { return pin_small_ops_to_cpu_; }
471 
TFEnv()472   tensorflow::Env* TFEnv() const { return env_; }
473 
474   Status FindDeviceFromName(const char* device_name, Device** device) const;
475 
476   Status FindCompositeDeviceFromName(StringPiece device_name,
477                                      CompositeDevice** device) const;
478 
479   Status RegisterCustomDevice(const string& name,
480                               std::unique_ptr<CustomDevice> device) override;
481 
GetCustomDeviceOpHandler()482   CustomDeviceOpHandler& GetCustomDeviceOpHandler() override {
483     return custom_device_op_handler_;
484   };
485 
486   // Find or create a composite device with the given `underlying_devices` and
487   // `device_name` (if not empty).
488   Status FindOrCreateCompositeDevice(
489       const std::vector<string>& underlying_devices, const string& device_name,
490       CompositeDevice** composite_device);
491 
492   bool OnSameTask(const Device* first, const Device* second) const;
493   // Gets the CPU device on the task of device.
494   Status CPUDeviceOnTask(const Device* device, Device** cpu_device) const;
495 
session_options()496   const SessionOptions& session_options() const { return opts_; }
497   void InitPrioritizedDeviceTypeList();
498 
499  private:
CreateRendezvous(int64 step_id)500   Rendezvous* CreateRendezvous(int64 step_id) const {
501     if (rendezvous_creator_ != nullptr) {
502       return rendezvous_creator_(step_id);
503     }
504 
505 #if !defined(IS_MOBILE_PLATFORM)
506     if (worker_env_ != nullptr && worker_env_->rendezvous_mgr != nullptr) {
507       auto* remote_r = worker_env_->rendezvous_mgr->Find(step_id);
508       remote_r->Initialize(worker_session_.get()).IgnoreError();
509       return remote_r;
510     }
511 #endif
512 
513     if (remote_device_mgr() == nullptr) {
514       return new IntraProcessRendezvous(local_device_mgr());
515     }
516 
517     return nullptr;
518   }
519 
520   ~EagerContext() override;
521 
522   Status MaybeRegisterFunctionRemotely(const FunctionDef& fdef);
523   Status RegisterExistingFunctionsOnRemoteWorkers(
524       const std::vector<string>& remote_workers);
525 
526   void ResetPFLR(const DeviceMgr* device_mgr, Env* env,
527                  const ConfigProto* config, int graph_def_version,
528                  const FunctionLibraryDefinition* lib_def,
529                  const OptimizerOptions& optimizer_options,
530                  thread::ThreadPool* thread_pool = nullptr,
531                  DistributedFunctionLibraryRuntime* cluster_flr = nullptr);
532 
533   void ResetClusterFLR(DistributedFunctionLibraryRuntime* cluster_flr);
534 
535   void ClearResourceContainer(const string& name);
536 
537   template <typename T>
538   struct OwnedOrUnownedHelper {
539    public:
OwnedOrUnownedHelperOwnedOrUnownedHelper540     OwnedOrUnownedHelper() {}
541     explicit OwnedOrUnownedHelper(T* object, const bool owned = false) {
542       Reset(object, owned);
543     }
544 
ResetOwnedOrUnownedHelper545     void Reset(std::unique_ptr<T> object) {
546       owned_object = std::move(object);
547       unowned_object_ptr = nullptr;
548     }
549 
550     void Reset(T* object, const bool owned = false) {
551       if (owned) {
552         owned_object.reset(object);
553         unowned_object_ptr = nullptr;
554       } else {
555         owned_object.reset(nullptr);
556         unowned_object_ptr = object;
557       }
558     }
559 
OwnedOwnedOrUnownedHelper560     bool Owned() const { return owned_object != nullptr; }
561 
GetOwnedOwnedOrUnownedHelper562     T* GetOwned() const { return owned_object.get(); }
GetOwnedOrUnownedHelper563     T* Get() const {
564       return owned_object ? owned_object.get() : unowned_object_ptr;
565     }
566 
567     std::unique_ptr<T> owned_object = nullptr;
568     T* unowned_object_ptr = nullptr;
569   };
570 
571   SessionOptions opts_;
572   const ContextDevicePlacementPolicy default_device_placement_policy_;
573 
574   // Note: we cannot use C++11 thread_local here as there is no concept of a
575   // thread-local-object-local variable in C++11.
576   mutable mutex policy_map_mu_;
577   std::unordered_map<std::thread::id, ContextDevicePlacementPolicy>
578       device_placement_policy_ TF_GUARDED_BY(policy_map_mu_);
579 
580   OwnedOrUnownedHelper<const DeviceMgr> local_device_manager_;
581   // Maintain copy of all previously created local device managers.
582   std::vector<std::unique_ptr<const DeviceMgr>> old_local_device_managers_;
583 
584   // Unowned DynamicDeviceMgr is set on remote worker to allow running
585   // multi-device function on remote worker.
586   OwnedOrUnownedHelper<DynamicDeviceMgr> remote_device_manager_;
587 
588   Device* host_cpu_device_;  // Owned by device_manager
589   mutable mutex device_type_list_mu_;
590   std::shared_ptr<std::vector<DeviceType>> prioritized_device_type_list_
591       TF_GUARDED_BY(device_type_list_mu_);
592   Rendezvous* rendezvous_;
593   std::function<Rendezvous*(const int64)> rendezvous_creator_;
594   CustomDeviceOpHandler custom_device_op_handler_;
595 
596   mutable mutex composite_devices_mu_;
597   // Maps from the fingerprint of a set of device names to a virtual
598   // CompositeDevice.
599   // TODO(b/145922293): Consider taking device names as keys.
600   absl::flat_hash_map<uint64, std::unique_ptr<CompositeDevice>>
601       composite_devices_ ABSL_GUARDED_BY(composite_devices_mu_);
602 
Global()603   FunctionLibraryDefinition func_lib_def_{OpRegistry::Global(), {}};
604 
605   std::unique_ptr<thread::ThreadPool> thread_pool_;
606 
607   // EagerContext owns the DistributedFunctionLibraryRuntime(
608   // EagerClusterFunctionLibraryRuntime) if using EagerService for remote
609   // function execution (lazy_copy_function_remote_inputs_=true).
610   OwnedOrUnownedHelper<DistributedFunctionLibraryRuntime> cluster_flr_;
611   // One FunctionLibraryRuntime per device.
612   // func_libs[i] is the FunctionLibraryRuntime corresponding to
613   // session->devices[i].
614   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
615 
616   std::function<void(std::function<void()>)> runner_;
617 
618   mutex cache_mu_;
619   struct RegisteredFunction : public core::RefCounted {
~RegisteredFunctionRegisteredFunction620     ~RegisteredFunction() override {}
621 
622     std::unique_ptr<std::vector<Fprint128>> cached_kernel_keys;
623   };
624   std::unordered_map<Fprint128, core::RefCountPtr<KernelAndDevice>,
625                      Fprint128Hasher>
626       kernel_cache_ TF_GUARDED_BY(cache_mu_);
627   std::unordered_map<string, RegisteredFunction*> registered_functions_
628       TF_GUARDED_BY(cache_mu_);
629 
630   // Whether we should compute RunMetadata.
631   std::atomic<bool> should_store_graphs_{false};
632   mutex metadata_mu_;
633   std::unique_ptr<RunMetadata> run_metadata_ TF_GUARDED_BY(metadata_mu_);
634   GraphCollector graph_collector_;
635   std::atomic<bool> log_device_placement_;
636   std::atomic<bool> allow_soft_placement_;
637 
638   // Information related to step containers.
639   std::atomic<int> num_active_steps_;
640   std::unique_ptr<ScopedStepContainer> step_container_
641       TF_GUARDED_BY(metadata_mu_);
642 
643   EagerExecutor default_executor_;
644   mutable mutex executor_map_mu_;
645   // Not owned.
646   std::unordered_map<std::thread::id, EagerExecutor*> thread_local_executor_
647       TF_GUARDED_BY(executor_map_mu_);
648   std::unordered_map<std::thread::id, std::unordered_set<EagerExecutor*>>
649       has_cleanup_ TF_GUARDED_BY(executor_map_mu_);
650 
651   const bool log_memory_;
652 
653   // Whether to use same rendezvous instance across function/eager executions.
654   bool reuse_rendezvous_for_functions_ = false;
655 
656   Env* const env_;
657 
658   OwnedOrUnownedHelper<CollectiveExecutorMgrInterface> collective_executor_mgr_;
659 
660 #if !defined(IS_MOBILE_PLATFORM)
661   std::vector<string> GetRemoteContexts() TF_LOCKS_EXCLUDED(remote_state_mu_);
662   bool IsRemoteContextsEmpty() TF_LOCKS_EXCLUDED(remote_state_mu_);
663   void CloseAndClearAllRemoteContexts();
664   void CloseRemoteContexts(const std::vector<string>& remote_contexts,
665                            uint64 context_id, uint64 context_view_id);
666 
667   Status SetMasterContextState(
668       std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
669       std::shared_ptr<WorkerSession> worker_session,
670       std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
671       std::unique_ptr<DynamicDeviceMgr> remote_device_manager,
672       uint64 context_id, uint64 context_view_id, Rendezvous* r,
673       const DeviceMgr* local_device_mgr, int keep_alive_secs,
674       DistributedFunctionLibraryRuntime* cluster_flr,
675       std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
676           remote_mgr);
677 
678   // The server_ is not const since we release it when the context is destroyed.
679   // Therefore the server_ object is not marked as const (even though it should
680   // be).
681   std::unique_ptr<ServerInterface> server_;
682   WorkerEnv* worker_env_ = nullptr;
683   std::shared_ptr<WorkerSession> worker_session_;
684 
685   mutable mutex remote_state_mu_;
686 
687   uint64 context_id_ TF_GUARDED_BY(remote_state_mu_);
688   // The view id of an eager context should be set to 0 when context is created,
689   // and continuously incremented when context with the same context_id gets
690   // updated. The view id should be consistent between master and workers.
691   uint64 context_view_id_ TF_GUARDED_BY(remote_state_mu_);
692   std::vector<string> remote_contexts_ TF_GUARDED_BY(remote_state_mu_);
693   std::unique_ptr<eager::EagerClientCache> remote_eager_workers_
694       TF_GUARDED_BY(remote_state_mu_);
695 
696   int keep_alive_secs_ TF_GUARDED_BY(remote_state_mu_);
697   std::atomic<int> sleep_for_secs_;
698 
699   std::unique_ptr<Thread> keep_alive_thread_;
700   mutex keep_alive_thread_shutdown_mu_;
701   condition_variable keep_alive_thread_cv_;
702   bool shutting_down_ TF_GUARDED_BY(keep_alive_thread_shutdown_mu_) = false;
703 
704   std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
705       remote_mgr_;
706   bool is_master_ TF_GUARDED_BY(remote_state_mu_);
707 
708   // Maps from a remote worker to a list of parsed device filters.
709   std::unordered_map<string, std::vector<DeviceNameUtils::ParsedName>>
710       cluster_device_filters_ TF_GUARDED_BY(remote_state_mu_);
711 
712   // A distributed manager that helps setup, update, and check liveness of
713   // member tasks in the cluster.
714   std::unique_ptr<ImmediateExecutionDistributedManager> distributed_manager_;
715 
716 #endif  // IS_MOBILE_PLATFORM
717 
718   // For a multi device function, the target device of each input is unknown
719   // until the function is instantiated on the default function device.
720   // If false, eagerly copy all remote inputs to the default function device;
721   // if true, lazily copy remote inputs to their target devices to avoid
722   // redundant copies.
723   bool lazy_copy_function_remote_inputs_ = false;
724   bool use_send_tensor_rpc_;
725   const bool pin_small_ops_to_cpu_;
726 
727   // Function that will be invoked in destructor to deallocate resources related
728   // to this context.
729   std::function<void()> resource_deallocator_ = nullptr;
730 };
731 
ContextFromInterface(ImmediateExecutionContext * context)732 inline EagerContext* ContextFromInterface(ImmediateExecutionContext* context) {
733   return down_cast<EagerContext*>(context);
734 }
735 
736 namespace internal {
737 struct EagerContextDeleter {
operatorEagerContextDeleter738   void operator()(EagerContext* p) const {
739     if (p != nullptr) {
740       p->Release();
741     }
742   }
743 };
744 }  // namespace internal
745 
746 using EagerContextPtr =
747     std::unique_ptr<EagerContext, internal::EagerContextDeleter>;
748 
749 }  // namespace tensorflow
750 
751 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_
752