1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/eager/context.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 // clang-format off
22 // Required for IS_MOBILE_PLATFORM
23 #include "tensorflow/c/eager/immediate_execution_context.h"
24 #include "tensorflow/core/common_runtime/function.h"
25 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
26 #include "tensorflow/core/lib/core/refcount.h"
27 #include "tensorflow/core/lib/gtl/map_util.h"
28 #include "tensorflow/core/nccl/collective_communicator.h"
29 #include "tensorflow/core/platform/errors.h"
30 #include "tensorflow/core/platform/mutex.h"
31 #include "tensorflow/core/platform/platform.h"
32 // clang-format on
33 
34 #include "tensorflow/c/tf_tensor.h"
35 #include "tensorflow/c/tf_tensor_internal.h"
36 #include "tensorflow/core/common_runtime/collective_executor_mgr.h"
37 #include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
38 #include "tensorflow/core/common_runtime/colocation_graph.h"
39 #include "tensorflow/core/common_runtime/device_resolver_local.h"
40 #include "tensorflow/core/common_runtime/device_set.h"
41 #include "tensorflow/core/common_runtime/process_util.h"
42 #include "tensorflow/core/framework/graph_def_util.h"
43 #include "tensorflow/core/framework/function.h"
44 #include "tensorflow/core/lib/core/errors.h"
45 #include "tensorflow/core/protobuf/config.pb.h"
46 #include "tensorflow/core/public/version.h"
47 #include "tensorflow/core/util/device_name_utils.h"
48 #if !defined(IS_MOBILE_PLATFORM)
49 #include "tensorflow/core/distributed_runtime/cluster_function_library_runtime.h"
50 #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
51 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
52 #include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
53 #endif  // !IS_MOBILE_PLATFORM
54 #include "tensorflow/core/framework/resource_mgr.h"
55 #include "tensorflow/core/lib/core/blocking_counter.h"
56 #include "tensorflow/core/lib/monitoring/gauge.h"
57 #include "tensorflow/core/util/env_var.h"
58 
59 namespace tensorflow {
60 namespace {
61 
ReadBoolFromEnvVar(StringPiece env_var_name,bool default_val)62 bool ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val) {
63   bool val;
64   if (tensorflow::ReadBoolFromEnvVar(env_var_name, default_val, &val).ok()) {
65     return val;
66   }
67   return default_val;
68 }
69 
70 auto* eager_context_created =
71     monitoring::Gauge<bool, 0>::New("/tensorflow/core/eager_context_created",
72                                     "True if an eager context was created.");
73 
74 }  // namespace
75 
EagerContext(const SessionOptions & opts,ContextDevicePlacementPolicy default_device_placement_policy,bool async,const DeviceMgr * device_mgr,bool device_mgr_owned,Rendezvous * rendezvous,DistributedFunctionLibraryRuntime * cluster_flr)76 EagerContext::EagerContext(
77     const SessionOptions& opts,
78     ContextDevicePlacementPolicy default_device_placement_policy, bool async,
79     const DeviceMgr* device_mgr, bool device_mgr_owned, Rendezvous* rendezvous,
80     DistributedFunctionLibraryRuntime* cluster_flr)
81     : ImmediateExecutionContext(kEager),
82       opts_(opts),
83       default_device_placement_policy_(default_device_placement_policy),
84       local_device_manager_(device_mgr, device_mgr_owned),
85       host_cpu_device_(device_mgr->HostCPU()),
86       rendezvous_(rendezvous),
87       thread_pool_(NewThreadPoolFromSessionOptions(opts)),
88       cluster_flr_(cluster_flr),
89       log_device_placement_(opts.config.log_device_placement()),
90       allow_soft_placement_(opts.config.allow_soft_placement()),
91       num_active_steps_(0),
92       step_container_(std::make_unique<ScopedStepContainer>(
93           0, [this](const string& name) { ClearResourceContainer(name); })),
94       default_executor_(async),
95       log_memory_(LogMemory::IsEnabled()),
96       env_(opts.env),
97       use_send_tensor_rpc_(false),
98       pin_small_ops_to_cpu_(ReadBoolFromEnvVar(
99           "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING", false)) {
100   ResetPFLR(device_mgr, opts.env, &opts.config, TF_GRAPH_DEF_VERSION,
101             &func_lib_def_, opts.config.graph_options().optimizer_options(),
102             thread_pool_.get(), cluster_flr);
103   // Starts exporting metrics through a platform-specific monitoring API (if
104   // provided). For builds using "tensorflow/core/platform/default", this is
105   // currently a no-op.
106   eager_context_created->GetCell()->Set(true);
107   InitPrioritizedDeviceTypeList();
__anone8ef76940302(std::function<void()> closure) 108   runner_ = [this](std::function<void()> closure) {
109     this->thread_pool_->Schedule(std::move(closure));
110   };
111 
112   run_metadata_ = std::make_unique<RunMetadata>();
113 
114 #if !defined(IS_MOBILE_PLATFORM)
115   context_id_ = kInvalidContextId;
116   context_view_id_ = 0;
117 #endif  // IS_MOBILE_PLATFORM
118 
119   std::unique_ptr<DeviceResolverInterface> drl(
120       new DeviceResolverLocal(local_device_mgr()));
121   std::unique_ptr<ParamResolverInterface> cprl(new CollectiveParamResolverLocal(
122       opts.config, local_device_mgr(), drl.get(),
123       "/job:localhost/replica:0/task:0"));
124   collective_executor_mgr_.Reset(
125       new CollectiveExecutorMgr(opts.config, local_device_mgr(), std::move(drl),
126                                 std::move(cprl), MaybeCreateNcclCommunicator()),
127       /*owned=*/true);
128 }
129 
CreateInt64Scalar(int64 value)130 AbstractTensorInterface* EagerContext::CreateInt64Scalar(int64 value) {
131   return new TensorInterface(Tensor(value));
132 }
133 
CreateUint64Scalar(uint64 value)134 AbstractTensorInterface* EagerContext::CreateUint64Scalar(uint64 value) {
135   return new TensorInterface(Tensor(value));
136 }
137 
CreateInt32Scalar(int32 value)138 AbstractTensorInterface* EagerContext::CreateInt32Scalar(int32 value) {
139   return new TensorInterface(Tensor(value));
140 }
141 
CreateFloatScalar(float value)142 AbstractTensorInterface* EagerContext::CreateFloatScalar(float value) {
143   return new TensorInterface(Tensor(value));
144 }
145 
CreateDoubleScalar(double value)146 AbstractTensorInterface* EagerContext::CreateDoubleScalar(double value) {
147   return new TensorInterface(Tensor(value));
148 }
149 
CreateHalfScalar(Eigen::half value)150 AbstractTensorInterface* EagerContext::CreateHalfScalar(Eigen::half value) {
151   return new TensorInterface(Tensor(value));
152 }
153 
CreateStringScalar(tstring value)154 AbstractTensorInterface* EagerContext::CreateStringScalar(tstring value) {
155   return new TensorInterface(Tensor(value));
156 }
157 
CreateComplex128Scalar(complex128 value)158 AbstractTensorInterface* EagerContext::CreateComplex128Scalar(
159     complex128 value) {
160   return new TensorInterface(Tensor(value));
161 }
162 
CreateBoolScalar(bool value)163 AbstractTensorInterface* EagerContext::CreateBoolScalar(bool value) {
164   return new TensorInterface(Tensor(value));
165 }
166 
CreateTensor(DataType dtype,absl::Span<const int64> dim_sizes)167 AbstractTensorInterface* EagerContext::CreateTensor(
168     DataType dtype, absl::Span<const int64> dim_sizes) {
169   return new TensorInterface(Tensor(dtype, TensorShape(dim_sizes)));
170 }
171 
CreateTensor(DataType dtype,const int64_t * dims,int num_dims,void * data,size_t len,MemoryReleaser memory_releaser,void * memory_releaser_arg)172 AbstractTensorInterface* EagerContext::CreateTensor(
173     DataType dtype, const int64_t* dims, int num_dims, void* data, size_t len,
174     MemoryReleaser memory_releaser, void* memory_releaser_arg) {
175   TF_Tensor* tensor_wrapper =
176       TF_NewTensor(static_cast<TF_DataType>(dtype), dims, num_dims, data, len,
177                    memory_releaser, memory_releaser_arg);
178 
179   AbstractTensorInterface* result = nullptr;
180   std::swap(result, tensor_wrapper->tensor);
181   TF_DeleteTensor(tensor_wrapper);
182   return result;
183 }
184 
ResetPFLR(const DeviceMgr * device_mgr,Env * env,const ConfigProto * config,int graph_def_version,const FunctionLibraryDefinition * lib_def,const OptimizerOptions & optimizer_options,thread::ThreadPool * thread_pool,DistributedFunctionLibraryRuntime * cluster_flr)185 void EagerContext::ResetPFLR(const DeviceMgr* device_mgr, Env* env,
186                              const ConfigProto* config, int graph_def_version,
187                              const FunctionLibraryDefinition* lib_def,
188                              const OptimizerOptions& optimizer_options,
189                              thread::ThreadPool* thread_pool,
190                              DistributedFunctionLibraryRuntime* cluster_flr) {
191   Rendezvous::Factory rendezvous_factory{
192       [this](const int64 step_id, const DeviceMgr*, Rendezvous** r) {
193         *r = CreateRendezvous(step_id);
194         return Status::OK();
195       }};
196   pflr_.reset(new ProcessFunctionLibraryRuntime(
197       device_mgr, env, config, graph_def_version, lib_def, optimizer_options,
198       thread_pool, cluster_flr,
199       /*session_metadata=*/nullptr, std::move(rendezvous_factory)));
200 }
201 
InitPrioritizedDeviceTypeList()202 void EagerContext::InitPrioritizedDeviceTypeList() {
203   DeviceSet ds;
204   for (Device* d : local_device_mgr()->ListDevices()) {
205     ds.AddDevice(d);
206   }
207   auto remote_device_manager = remote_device_mgr();
208   if (remote_device_manager != nullptr) {
209     for (Device* d : remote_device_manager->ListDevices()) {
210       ds.AddDevice(d);
211     }
212   }
213   mutex_lock l(device_type_list_mu_);
214   prioritized_device_type_list_ =
215       std::make_shared<std::vector<DeviceType>>(ds.PrioritizedDeviceTypeList());
216 }
217 
218 namespace {
219 // Using absl::StrJoin with lambda does not work in tf-lite builds.
220 // TODO(b/148160441): Replace with absl::StrJoin once DeviceBase has operator<<.
DevicesToString(const PrioritizedDeviceVector & devices)221 std::vector<string> DevicesToString(const PrioritizedDeviceVector& devices) {
222   std::vector<string> v;
223   v.reserve(devices.size());
224   for (const auto& p : devices) {
225     v.push_back(p.first->name());
226   }
227   return v;
228 }
229 
DeviceTypesToString(const PrioritizedDeviceTypeVector & types)230 std::vector<string> DeviceTypesToString(
231     const PrioritizedDeviceTypeVector& types) {
232   std::vector<string> v;
233   v.reserve(types.size());
234   for (const auto& p : types) {
235     v.push_back(p.first.type_string());
236   }
237   return v;
238 }
239 
240 // Selects the "best" device that both exists and is supported.
241 //
242 // The `existing` argument specifies the available devices in the system, in
243 // priority order. The `supported` argument specifies the supported device types
244 // and their priorities, lower index types having higher priority.
245 // Currently the type priority defined by the `supported` parameter takes
246 // precedence over system device priorities from `existing`.
247 //
248 // TODO(b/148213212): Allow setting default device in eager context.
SelectBestMatchingDevice(const DeviceNameUtils::ParsedName & pattern,const PrioritizedDeviceVector & existing,const PrioritizedDeviceTypeVector & supported)249 Device* SelectBestMatchingDevice(const DeviceNameUtils::ParsedName& pattern,
250                                  const PrioritizedDeviceVector& existing,
251                                  const PrioritizedDeviceTypeVector& supported) {
252   for (const std::pair<DeviceType, int32>& prioritized_type : supported) {
253     for (const std::pair<Device*, int32>& prioritized_device : existing) {
254       Device* dev = prioritized_device.first;
255       if (DeviceType(dev->attributes().device_type()) ==
256               prioritized_type.first &&
257           DeviceNameUtils::IsCompleteSpecification(pattern,
258                                                    dev->parsed_name())) {
259         return dev;
260       }
261     }
262   }
263   return nullptr;
264 }
265 
266 }  // namespace
267 
SelectDevice(DeviceNameUtils::ParsedName preferred,const NodeDef & ndef,Device ** out) const268 Status EagerContext::SelectDevice(DeviceNameUtils::ParsedName preferred,
269                                   const NodeDef& ndef, Device** out) const {
270   DCHECK(out != nullptr);
271 
272   PrioritizedDeviceTypeVector supported_devs;
273   auto device_type_list = prioritized_device_type_list();
274   TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode(
275       *device_type_list, ndef, &supported_devs, &HostCPU()->parsed_name()));
276   if (supported_devs.empty()) {
277     return errors::NotFound("Could not find device for node: ",
278                             errors::FormatNodeNameForError(ndef.name()), " = ",
279                             ndef.op(), "[", SummarizeAttrs(ndef), "]",
280                             "\nAll kernels registered for op ", ndef.op(),
281                             ":\n", KernelsRegisteredForOp(ndef.op()));
282   }
283 
284   // Select the first matching registered device from the supported device
285   // list. If nothing matches and soft placement is enabled, pick a suitable
286   // device from the available ones.
287   const auto pflr_device_set = pflr()->device_set();
288   const PrioritizedDeviceVector& existing =
289       pflr_device_set->prioritized_devices();
290   *out = SelectBestMatchingDevice(preferred, existing, supported_devs);
291   if (*out != nullptr) {
292     return Status::OK();
293   }
294 
295   if (AllowSoftPlacement()) {
296     DeviceNameUtils::ParsedName soft_device_name = preferred;
297     soft_device_name.type.clear();
298     soft_device_name.has_type = false;
299     soft_device_name.has_id = false;
300     // TODO(b/148213746): Soft placement logic picks up another task if the
301     // requested does not exist.
302     *out = SelectBestMatchingDevice(soft_device_name, existing, supported_devs);
303     if (*out != nullptr) {
304       return Status::OK();
305     }
306   }
307 
308   if (DeviceNameUtils::HasSomeDetails(preferred)) {
309     return errors::InvalidArgument(
310         "Could not satisfy device specification '", preferred,
311         "'. enable_soft_placement=", AllowSoftPlacement(),
312         ". Supported device types [",
313         absl::StrJoin(DeviceTypesToString(supported_devs), ", "),
314         "]. All available devices [",
315         absl::StrJoin(DevicesToString(existing), ", "), "].");
316   }
317   return errors::InvalidArgument(
318       "No supported device found in available devices [",
319       absl::StrJoin(DevicesToString(existing), ", "),
320       "]. enable_soft_placement=", AllowSoftPlacement(),
321       ". Supported devices types [",
322       absl::StrJoin(DeviceTypesToString(supported_devs), ", "), "].");
323 }
324 
ResetClusterFLR(DistributedFunctionLibraryRuntime * cluster_flr)325 void EagerContext::ResetClusterFLR(
326     DistributedFunctionLibraryRuntime* cluster_flr) {
327   cluster_flr_.Reset(cluster_flr, /*owned=*/true);
328 }
329 
Executor()330 EagerExecutor& EagerContext::Executor() {
331   tf_shared_lock l(executor_map_mu_);
332   return *gtl::FindWithDefault(thread_local_executor_,
333                                std::this_thread::get_id(), &default_executor_);
334 }
335 
SetExecutorForThread(EagerExecutor * executor)336 void EagerContext::SetExecutorForThread(EagerExecutor* executor) {
337   tensorflow::mutex_lock l(executor_map_mu_);
338   if (executor == &default_executor_) {
339     thread_local_executor_.erase(std::this_thread::get_id());
340   } else {
341     auto thread_id = std::this_thread::get_id();
342     thread_local_executor_[thread_id] = executor;
343     auto& executors_with_cleanups = has_cleanup_[thread_id];
344     if (executors_with_cleanups.find(executor) ==
345         executors_with_cleanups.end()) {
346       executors_with_cleanups.insert(executor);
347       // If the executor is deleted before this context, we need to remove it
348       // from the map to avoid attempting to sync it in our destructor.
349       std::function<void()> cleanup([this, thread_id, executor]() {
350         {
351           tensorflow::mutex_lock l(executor_map_mu_);
352           auto existing = thread_local_executor_.find(thread_id);
353           if (existing != thread_local_executor_.end() &&
354               existing->second == executor) {
355             thread_local_executor_.erase(thread_id);
356           }
357           has_cleanup_[thread_id].erase(executor);
358         }
359       });
360       executor->AddCleanup(reinterpret_cast<intptr_t>(this),
361                            std::move(cleanup));
362     }
363   }
364 }
365 
ClearCachesAndThreadExecutors()366 void EagerContext::ClearCachesAndThreadExecutors() {
367   std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
368   {
369     mutex_lock l(executor_map_mu_);
370     executors_copy = thread_local_executor_;
371   }
372   for (const auto& entry : executors_copy) {
373     entry.second->WaitForAllPendingNodes().IgnoreError();
374   }
375   ClearCachesAndDefaultExecutor();
376 }
377 
ClearCachesAndDefaultExecutor()378 void EagerContext::ClearCachesAndDefaultExecutor() {
379   // The executor stores pointers to kernels, so we need to make sure that no
380   // async eager ops are still executing. We lock the cache during this time
381   // as well.
382   mutex_lock ml(cache_mu_);
383   default_executor_.WaitForAllPendingNodes().IgnoreError();
384   kernel_cache_.clear();
385   for (auto& entry : registered_functions_) {
386     entry.second->cached_kernel_keys->clear();
387   }
388   {
389     mutex_lock ml(metadata_mu_);
390     step_container_.reset(new ScopedStepContainer(
391         0, [this](const string& name) { ClearResourceContainer(name); }));
392   }
393 }
394 
SetThreadLocalDevicePlacementPolicy(ContextDevicePlacementPolicy policy)395 void EagerContext::SetThreadLocalDevicePlacementPolicy(
396     ContextDevicePlacementPolicy policy) {
397   mutex_lock ml(policy_map_mu_);
398   device_placement_policy_[std::this_thread::get_id()] = policy;
399 }
400 
GetDevicePlacementPolicy() const401 ContextDevicePlacementPolicy EagerContext::GetDevicePlacementPolicy() const {
402   tf_shared_lock l(policy_map_mu_);
403   auto policy_map_it =
404       device_placement_policy_.find(std::this_thread::get_id());
405   if (policy_map_it != device_placement_policy_.end()) {
406     return policy_map_it->second;
407   }
408   return default_device_placement_policy_;
409 }
410 
411 #if !defined(IS_MOBILE_PLATFORM)
GetRemoteContexts()412 std::vector<string> EagerContext::GetRemoteContexts() {
413   tf_shared_lock l(remote_state_mu_);
414   return remote_contexts_;
415 }
416 
IsRemoteContextsEmpty()417 bool EagerContext::IsRemoteContextsEmpty() {
418   tf_shared_lock l(remote_state_mu_);
419   return remote_contexts_.empty();
420 }
421 
CloseAndClearAllRemoteContexts()422 void EagerContext::CloseAndClearAllRemoteContexts() {
423   uint64 context_id;
424   uint64 context_view_id;
425   std::vector<string> remote_contexts_copy;
426   {
427     mutex_lock l(remote_state_mu_);
428     if (!is_master_) return;
429     context_id = context_id_;
430     context_view_id = context_view_id_;
431     context_id_ = kInvalidContextId;
432     // Forget the current view id and reset to the starting value 0.
433     context_view_id_ = 0;
434 
435     // Make a copy of remote targets to avoid holding the lock when sending
436     // close context requests.
437     remote_contexts_copy = remote_contexts_;
438     remote_contexts_.clear();
439   }
440   CloseRemoteContexts(remote_contexts_copy, context_id, context_view_id);
441 }
442 
CloseRemoteContexts(const std::vector<string> & remote_contexts,uint64 context_id,uint64 context_view_id)443 void EagerContext::CloseRemoteContexts(
444     const std::vector<string>& remote_contexts, uint64 context_id,
445     uint64 context_view_id) {
446   // Close all remote contexts.
447   eager::CloseContextRequest request;
448   request.set_context_id(context_id);
449   request.set_context_view_id(context_view_id);
450   // Setting context_id to a new value can avoid us issuing DestroyTensorHandle
451   // request to closed remote workers.
452   std::vector<eager::CloseContextResponse> responses(remote_contexts.size());
453   BlockingCounter counter(static_cast<int>(remote_contexts.size()));
454 
455   int i = 0;
456   for (const auto& worker : remote_contexts) {
457     core::RefCountPtr<eager::EagerClient> client;
458     Status s = GetClient(worker, &client);
459 
460     client->CloseContextAsync(
461         &request, &responses[i],
462         [&worker, &counter, context_id](const Status& s) {
463           if (!s.ok()) {
464             LOG(ERROR) << "Unable to close remote context with ID "
465                        << context_id << " for worker: " << worker << " due to "
466                        << s.error_message();
467           }
468           counter.DecrementCount();
469         });
470     i++;
471   }
472 
473   counter.Wait();
474 }
475 
476 #endif  // !IS_MOBILE_PLATFORM
477 
WaitForAndCloseRemoteContexts()478 void EagerContext::WaitForAndCloseRemoteContexts() {
479   ClearCachesAndThreadExecutors();
480 
481 #if !defined(IS_MOBILE_PLATFORM)
482   {
483     mutex_lock l(keep_alive_thread_shutdown_mu_);
484     shutting_down_ = true;
485     keep_alive_thread_cv_.notify_all();
486   }
487   keep_alive_thread_.reset();
488 
489   if (!IsRemoteContextsEmpty()) {
490     CloseAndClearAllRemoteContexts();
491   }
492 
493   {
494     mutex_lock l(remote_state_mu_);
495 
496     default_executor_.ShutDown().IgnoreError();
497     std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
498     {
499       mutex_lock l(executor_map_mu_);
500       executors_copy = thread_local_executor_;
501     }
502     for (const auto& it : executors_copy) {
503       it.second->ShutDown().IgnoreError();
504     }
505 
506     // This shuts down the completion queue and joins the thread polling it.
507     // The thread exits only after the completion queue has been drained of all
508     // the events. These events' completion should invoke all remaining RPC
509     // callbacks.
510     // This also deletes all EagerClient instances. There should not be any
511     // references to EagerClients left after all RPCs and async ops have been
512     // finished.
513     remote_eager_workers_ = nullptr;
514   }
515 #endif  // !IS_MOBILE_PLATFORM
516 }
517 
~EagerContext()518 EagerContext::~EagerContext() {
519   // TODO(iga): Add a separate API method to shutdown EagerContext so that we
520   // don't send RPCs and block in destructor.
521   WaitForAndCloseRemoteContexts();
522 
523   // Custom devices may have obtained references to various context components
524   // (executors, thread pool). It's safer to run their destructors early.
525   custom_device_op_handler_.Clear();
526 
527   ClearCachesAndThreadExecutors();
528   std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
529   {
530     mutex_lock l(executor_map_mu_);
531     executors_copy = thread_local_executor_;
532   }
533   for (const auto& entry : executors_copy) {
534     // Let the executor know that its cleanup closure is no longer valid.
535     entry.second->RemoveCleanups(reinterpret_cast<intptr_t>(this));
536   }
537   for (auto& entry : registered_functions_) {
538     while (!entry.second->Unref()) {
539       // remove all references.
540     }
541   }
542   registered_functions_.clear();
543 
544 #if !defined(IS_MOBILE_PLATFORM)
545   if (server_) {
546     // TODO(b/136478427): Fix this.
547     LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
548                     "Servers don't support clean shutdown.";
549     server_.release();
550   }
551 
552   {
553     mutex_lock l(keep_alive_thread_shutdown_mu_);
554     shutting_down_ = true;
555     keep_alive_thread_cv_.notify_all();
556   }
557   keep_alive_thread_.reset();
558   if (!remote_contexts_.empty()) {
559     CloseAndClearAllRemoteContexts();
560   }
561 #endif  // !IS_MOBILE_PLATFORM
562 
563   if (rendezvous_) {
564     rendezvous_->Unref();
565   }
566   if (resource_deallocator_ != nullptr) {
567     resource_deallocator_();
568   }
569 }
570 
FindFunctionByName(const string & name) const571 bool EagerContext::FindFunctionByName(const string& name) const {
572   return func_lib_def_.Find(name) != nullptr;
573 }
574 
FindFunctionOpData(const string & name,const tensorflow::OpRegistrationData ** op_data)575 Status EagerContext::FindFunctionOpData(
576     const string& name, const tensorflow::OpRegistrationData** op_data) {
577   return func_lib_def_.LookUp(name, op_data);
578 }
579 
FindFunctionDef(const string & name) const580 const FunctionDef* EagerContext::FindFunctionDef(const string& name) const {
581   return func_lib_def_.Find(name);
582 }
583 
ExportRunMetadata()584 std::unique_ptr<RunMetadata> EagerContext::ExportRunMetadata() {
585   mutex_lock ml(metadata_mu_);
586   auto result = std::make_unique<RunMetadata>();
587   run_metadata_.swap(result);
588   return result;
589 }
590 
UsesTFRT()591 bool EagerContext::UsesTFRT() { return false; }
592 
ListDevices(std::vector<tensorflow::DeviceAttributes> * devices)593 void EagerContext::ListDevices(
594     std::vector<tensorflow::DeviceAttributes>* devices) {
595   local_device_mgr()->ListDeviceAttributes(devices);
596   if (remote_device_mgr()) {
597     remote_device_mgr()->ListDeviceAttributes(devices);
598   }
599 }
600 
StartStep()601 void EagerContext::StartStep() {
602   mutex_lock ml(metadata_mu_);
603   num_active_steps_++;
604 }
605 
EndStep()606 void EagerContext::EndStep() {
607   mutex_lock ml(metadata_mu_);
608   num_active_steps_--;
609   if (num_active_steps_ == 0) {
610     // TODO(b/139809335): This does not properly clean up remote resources
611     // Clean up the previous step container and create a new one.
612     step_container_.reset(new ScopedStepContainer(
613         0, [this](const string& name) { ClearResourceContainer(name); }));
614   }
615 }
616 
StepContainer()617 ScopedStepContainer* EagerContext::StepContainer() {
618   mutex_lock ml(metadata_mu_);
619   return step_container_.get();
620 }
621 
MaybeRegisterFunctionRemotely(const FunctionDef & fdef)622 Status EagerContext::MaybeRegisterFunctionRemotely(const FunctionDef& fdef) {
623   // Only client context can register function on remote worker context.
624   if (!remote_device_manager_.Owned()) return Status::OK();
625 #if !defined(IS_MOBILE_PLATFORM)
626   std::shared_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
627   request->set_context_id(GetContextId());
628 
629   eager::RegisterFunctionOp* register_function =
630       request->add_queue()->mutable_register_function();
631   *register_function->mutable_function_def() = fdef;
632   StripDefaultAttributes(
633       *OpRegistry::Global(),
634       register_function->mutable_function_def()->mutable_node_def());
635 
636   auto remote_contexts = GetRemoteContexts();
637   for (const auto& target : remote_contexts) {
638     core::RefCountPtr<eager::EagerClient> eager_client;
639     TF_RETURN_IF_ERROR(GetClient(target, &eager_client));
640 
641     eager::EnqueueResponse* response = new eager::EnqueueResponse();
642     eager_client->StreamingEnqueueAsync(
643         /*call_opts=*/nullptr, request.get(), response,
644         [request, response](const Status& status) {
645           if (!status.ok()) {
646             LOG(ERROR) << "Failed to register function remotely due to "
647                        << status.error_message()
648                        << "\nThis could happen if the remote target has been "
649                           "disconnected from the client.";
650           }
651           delete response;
652         });
653   }
654 #endif  // !IS_MOBILE_PLATFORM
655   return Status::OK();
656 }
657 
RegisterExistingFunctionsOnRemoteWorkers(const std::vector<string> & remote_workers)658 Status EagerContext::RegisterExistingFunctionsOnRemoteWorkers(
659     const std::vector<string>& remote_workers) {
660 #if !defined(IS_MOBILE_PLATFORM)
661   // Register multiple functions on selected remote workers.
662   uint64 context_id = GetContextId();
663   FunctionDefLibrary function_defs = func_lib_def_.ToProto();
664   std::vector<std::shared_ptr<eager::EnqueueRequest>> requests(
665       function_defs.function_size());
666   for (int i = 0; i < function_defs.function_size(); i++) {
667     requests[i] = std::make_shared<eager::EnqueueRequest>();
668     requests[i]->set_context_id(context_id);
669     eager::RegisterFunctionOp* register_function =
670         requests[i]->add_queue()->mutable_register_function();
671     *register_function->mutable_function_def() =
672         std::move(*function_defs.mutable_function(i));
673     StripDefaultAttributes(
674         *OpRegistry::Global(),
675         register_function->mutable_function_def()->mutable_node_def());
676   }
677 
678   for (auto& remote_worker : remote_workers) {
679     core::RefCountPtr<eager::EagerClient> eager_client;
680     Status s = GetClient(remote_worker, &eager_client);
681     if (!s.ok()) {
682       continue;
683     }
684     for (int i = 0; i < requests.size(); i++) {
685       auto response = std::make_shared<eager::EnqueueResponse>();
686       eager_client->StreamingEnqueueAsync(
687           /*call_opts=*/nullptr, requests[i].get(), response.get(),
688           [request = requests[i], response](const Status& s) {
689             if (!s.ok()) {
690               LOG(ERROR) << "Failed to register function remotely due to "
691                          << s.error_message()
692                          << "\nThis could happen if the remote target has been "
693                             "disconnected from the client.";
694             }
695           });
696     }
697   }
698 #endif  // !IS_MOBILE_PLATFORM
699   return Status::OK();
700 }
701 
AddFunctionDefWithStackTraces(const FunctionDef & fdef,const StackTracesMap & stack_traces)702 Status EagerContext::AddFunctionDefWithStackTraces(
703     const FunctionDef& fdef, const StackTracesMap& stack_traces) {
704   return AddFunctionDef(fdef, FunctionDefLibrary(),
705                         /* add_to_local_only=*/false, stack_traces);
706 }
707 
AddFunctionDef(const FunctionDef & fdef)708 Status EagerContext::AddFunctionDef(const FunctionDef& fdef) {
709   return AddFunctionDef(fdef, FunctionDefLibrary(),
710                         /* add_to_local_only=*/false);
711 }
712 
AddFunctionDef(const FunctionDef & fdef,const FunctionDefLibrary & library,const bool add_to_local_only,const StackTracesMap & stack_traces)713 Status EagerContext::AddFunctionDef(const FunctionDef& fdef,
714                                     const FunctionDefLibrary& library,
715                                     const bool add_to_local_only,
716                                     const StackTracesMap& stack_traces) {
717   bool is_first_ref = false;
718   {
719     mutex_lock l(cache_mu_);
720     auto* registered_function =
721         gtl::FindPtrOrNull(registered_functions_, fdef.signature().name());
722     if (registered_function == nullptr) {
723       registered_function = new RegisteredFunction;
724       registered_function->cached_kernel_keys =
725           absl::make_unique<std::vector<Fprint128>>();
726       gtl::InsertOrUpdate(&registered_functions_, fdef.signature().name(),
727                           registered_function);
728     } else {
729       // The function has been registered before. If the function is the same,
730       // then we take a Ref() otherwise we error out.
731       const FunctionDef* prev_fdef =
732           func_lib_def_.Find(fdef.signature().name());
733       if (prev_fdef == nullptr) {
734         return errors::Internal("Function: ", fdef.signature().name(),
735                                 " is in the cache but not in the library");
736       }
737       if (!FunctionDefsEqual(fdef, *prev_fdef)) {
738         return errors::InvalidArgument(
739             "Attempting to add a duplicate function with name: ",
740             fdef.signature().name(), " where the previous and current ",
741             "definitions differ. Previous definition: ",
742             prev_fdef->DebugString(),
743             " and current definition: ", fdef.DebugString());
744       }
745       registered_function->Ref();
746     }
747     is_first_ref = registered_function->RefCountIsOne();
748   }
749   if (is_first_ref) {
750     TF_RETURN_IF_ERROR(func_lib_def_.AddFunctionDef(fdef, stack_traces));
751     TF_RETURN_IF_ERROR(func_lib_def_.AddLibrary(library));
752     if (!add_to_local_only) {
753       return MaybeRegisterFunctionRemotely(fdef);
754     }
755   }
756   return Status::OK();
757 }
758 
GetFunctionDef(const string & function_name)759 const FunctionDef* EagerContext::GetFunctionDef(const string& function_name) {
760   return func_lib_def_.Find(function_name);
761 }
762 
ListFunctionNames()763 std::vector<string> EagerContext::ListFunctionNames() {
764   return func_lib_def_.ListFunctionNames();
765 }
766 
RemoveFunction(const string & func)767 Status EagerContext::RemoveFunction(const string& func) {
768   bool is_last_ref = false;
769   {
770     mutex_lock l(cache_mu_);
771     auto* registered_function = gtl::FindPtrOrNull(registered_functions_, func);
772     if (registered_function == nullptr) {
773       return errors::InvalidArgument("Tried to remove non-existent function '",
774                                      func, "'.");
775     }
776     is_last_ref = registered_function->RefCountIsOne();
777     if (is_last_ref) {
778       for (auto& key : *registered_function->cached_kernel_keys) {
779         kernel_cache_.erase(key);
780       }
781       registered_functions_.erase(func);
782     }
783     registered_function->Unref();
784   }
785   if (is_last_ref) {
786     // TODO(fishx): Remove remote function as well.
787     return func_lib_def_.RemoveFunction(func);
788   }
789   return Status::OK();
790 }
791 
SyncExecutors()792 Status EagerContext::SyncExecutors() {
793   StatusGroup sg;
794   // Synchronize on context default executor
795   sg.Update(default_executor_.WaitForAllPendingNodes());
796   default_executor_.ClearError();
797 
798   // Synchronize thread local executors on client
799   std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
800   {
801     mutex_lock l(executor_map_mu_);
802     executors_copy = thread_local_executor_;
803   }
804   for (const auto& entry : executors_copy) {
805     sg.Update(entry.second->WaitForAllPendingNodes());
806     entry.second->ClearError();
807   }
808 
809 #if !defined(IS_MOBILE_PLATFORM)
810   auto remote_contexts = GetRemoteContexts();
811   // Synchronize executors on remote workers
812   eager::EnqueueRequest request;
813   request.set_context_id(GetContextId());
814   request.add_queue()->mutable_sync_remote_executor_for_stream();
815   BlockingCounter counter(static_cast<int>(remote_contexts.size()));
816   std::vector<Status> statuses(remote_contexts.size());
817 
818   for (int i = 0; i < remote_contexts.size(); i++) {
819     const auto& target = remote_contexts[i];
820     core::RefCountPtr<eager::EagerClient> eager_client;
821     TF_RETURN_IF_ERROR(GetClient(target, &eager_client));
822 
823     eager::EnqueueResponse* response = new eager::EnqueueResponse();
824     eager_client->StreamingEnqueueAsync(
825         /*call_opts=*/nullptr, &request, response,
826         [response, target, &counter, &s = statuses[i]](const Status& status) {
827           s = status;
828           delete response;
829           counter.DecrementCount();
830         });
831   }
832   counter.Wait();
833   for (const Status& s : statuses) {
834     sg.Update(s);
835   }
836 #endif  // !IS_MOBILE_PLATFORM
837   return sg.as_summary_status();
838 }
839 
GetCachedKernel(Fprint128 cache_key)840 core::RefCountPtr<KernelAndDevice> EagerContext::GetCachedKernel(
841     Fprint128 cache_key) {
842   tf_shared_lock l(cache_mu_);
843   auto iter = kernel_cache_.find(cache_key);
844   if (iter == kernel_cache_.end()) {
845     return nullptr;
846   }
847   core::RefCountPtr<KernelAndDevice> new_ref(iter->second.get());
848   new_ref->Ref();
849   return new_ref;
850 }
851 
AddKernelToCache(Fprint128 cache_key,KernelAndDevice * kernel)852 void EagerContext::AddKernelToCache(Fprint128 cache_key,
853                                     KernelAndDevice* kernel) {
854   mutex_lock ml(cache_mu_);
855   core::RefCountPtr<KernelAndDevice> new_ref(kernel);
856   new_ref->Ref();
857   kernel_cache_[cache_key] = std::move(new_ref);
858   auto* registered_function =
859       gtl::FindPtrOrNull(registered_functions_, kernel->name());
860   // The kernel name can be either a primitive op or a function.
861   if (registered_function != nullptr) {
862     registered_function->cached_kernel_keys->emplace_back(cache_key);
863   }
864 }
865 
ShouldStoreGraphs()866 bool EagerContext::ShouldStoreGraphs() { return should_store_graphs_.load(); }
867 
SetShouldStoreGraphs(bool value)868 void EagerContext::SetShouldStoreGraphs(bool value) {
869   mutex_lock ml(metadata_mu_);
870   should_store_graphs_.store(value);
871   if (!value) {
872     run_metadata_.reset(new RunMetadata);
873   }
874 }
875 
FindDeviceFromName(const char * device_name,Device ** device) const876 Status EagerContext::FindDeviceFromName(const char* device_name,
877                                         Device** device) const {
878   *device = HostCPU();
879   if (device_name == nullptr || strlen(device_name) == 0) {
880     return Status::OK();
881   }
882 
883   auto status = local_device_mgr()->LookupDevice(device_name, device);
884   if (status.ok()) {
885     return status;
886   }
887 
888   if (remote_device_mgr() != nullptr) {
889     return remote_device_mgr()->LookupDevice(device_name, device);
890   }
891 
892   return status;
893 }
894 
FindCompositeDeviceFromName(StringPiece device_name,CompositeDevice ** device) const895 Status EagerContext::FindCompositeDeviceFromName(
896     StringPiece device_name, CompositeDevice** device) const {
897   tf_shared_lock l(composite_devices_mu_);
898   for (const auto& d : composite_devices_) {
899     if (d.second->name() == device_name) {
900       *device = d.second.get();
901       return Status::OK();
902     }
903   }
904   return errors::NotFound("Unknown composite device: ", device_name);
905 }
906 
RegisterCustomDevice(const string & device_name,std::unique_ptr<CustomDevice> device)907 Status EagerContext::RegisterCustomDevice(
908     const string& device_name, std::unique_ptr<CustomDevice> device) {
909   Device* existing_physical_device = nullptr;
910   if (FindDeviceFromName(device_name.c_str(), &existing_physical_device).ok()) {
911     return errors::AlreadyExists(device_name,
912                                  " already registered as a physical device.");
913   }
914   return custom_device_op_handler_.RegisterCustomDevice(device_name,
915                                                         std::move(device));
916 }
917 
FindOrCreateCompositeDevice(const std::vector<string> & underlying_devices,const string & device_name,CompositeDevice ** composite_device)918 Status EagerContext::FindOrCreateCompositeDevice(
919     const std::vector<string>& underlying_devices, const string& device_name,
920     CompositeDevice** composite_device) {
921   if (!device_name.empty() &&
922       FindCompositeDeviceFromName(device_name, composite_device).ok()) {
923     return Status::OK();
924   }
925 
926   const uint64 hash_key = Fingerprint64(absl::StrJoin(underlying_devices, ","));
927 
928   mutex_lock l(composite_devices_mu_);
929   auto iter = composite_devices_.find(hash_key);
930   if (iter != composite_devices_.end()) {
931     *composite_device = iter->second.get();
932     return Status::OK();
933   }
934 
935   Status s;
936   std::unique_ptr<CompositeDevice> device;
937   if (device_name.empty()) {
938     // Create a CompositeDevice on the same task as the host CPU, in order to
939     // trigger packed TensorHandle copy from a client to a remote worker.
940     device = CompositeDevice::MakeDevice(underlying_devices,
941                                          composite_devices_.size(),
942                                          HostCPU()->parsed_name(), &s);
943   } else {
944     device = CompositeDevice::MakeDevice(underlying_devices, device_name, &s);
945   }
946   TF_RETURN_IF_ERROR(s);
947   *composite_device = device.get();
948   pflr_->AddCompositeDevice(*composite_device);
949   composite_devices_.emplace(hash_key, std::move(device));
950   return Status::OK();
951 }
952 
OnSameTask(const Device * first,const Device * second) const953 bool EagerContext::OnSameTask(const Device* first, const Device* second) const {
954   if (first == nullptr) first = HostCPU();
955   if (second == nullptr) second = HostCPU();
956   return first->parsed_name().job == second->parsed_name().job &&
957          first->parsed_name().replica == second->parsed_name().replica &&
958          first->parsed_name().task == second->parsed_name().task;
959 }
960 
961 // Gets the CPU device on the task of device.
CPUDeviceOnTask(const Device * device,Device ** cpu_device) const962 Status EagerContext::CPUDeviceOnTask(const Device* device,
963                                      Device** cpu_device) const {
964   string cpu_device_name;
965   TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
966       device->name(), &cpu_device_name));
967 
968   return FindDeviceFromName(cpu_device_name.c_str(), cpu_device);
969 }
970 
ClearResourceContainer(const string & name)971 void EagerContext::ClearResourceContainer(const string& name) {
972   // TODO(b/139809335): This does not properly clean up remote resources
973   auto local_devices = local_device_mgr()->ListDevices();
974   for (Device* device : local_devices) {
975     // Only ignore container not found errors.
976     device->resource_manager()->Cleanup(name).IgnoreError();
977   }
978 }
979 
980 namespace {
GetTaskName(Device * d,string * task_name)981 Status GetTaskName(Device* d, string* task_name) {
982   string ignored;
983   if (!DeviceNameUtils::SplitDeviceName(d->name(), task_name, &ignored)) {
984     return errors::InvalidArgument("Unable to parse device name: ", d->name());
985   }
986 
987   return Status::OK();
988 }
989 }  // namespace
990 
991 #if !defined(IS_MOBILE_PLATFORM)
GetClient(Device * device,core::RefCountPtr<eager::EagerClient> * client)992 Status EagerContext::GetClient(Device* device,
993                                core::RefCountPtr<eager::EagerClient>* client) {
994   return GetClient(device->parsed_name(), client);
995 }
996 
GetClient(const DeviceNameUtils::ParsedName & device_name,core::RefCountPtr<eager::EagerClient> * client)997 Status EagerContext::GetClient(const DeviceNameUtils::ParsedName& device_name,
998                                core::RefCountPtr<eager::EagerClient>* client) {
999   string device_task_name;
1000   if (!DeviceNameUtils::GetTaskName(device_name, &device_task_name)) {
1001     return errors::InvalidArgument(
1002         "Task is not fully specified in device name: ",
1003         DeviceNameUtils::ParsedNameToString(device_name));
1004   }
1005 
1006   {
1007     tf_shared_lock l(remote_state_mu_);
1008     if (remote_eager_workers_ == nullptr) {
1009       return errors::Internal(
1010           "Haven't set up remote eager worker in this eager context yet.");
1011     }
1012     TF_RETURN_IF_ERROR(
1013         remote_eager_workers_->GetClient(device_task_name, client));
1014 
1015     if (*client == nullptr) {
1016       return errors::InvalidArgument(
1017           "Unable to find eager client corresponding to device ",
1018           DeviceNameUtils::ParsedNameToString(device_name));
1019     }
1020     if (std::find(remote_contexts_.begin(), remote_contexts_.end(),
1021                   device_task_name) == remote_contexts_.end()) {
1022       return errors::Internal("Unable to find a context for handle on task: ",
1023                               device_task_name, ". This should not happen.");
1024     }
1025   }
1026 
1027   return Status::OK();
1028 }
1029 
GetClient(const string & remote_task,core::RefCountPtr<eager::EagerClient> * client)1030 Status EagerContext::GetClient(const string& remote_task,
1031                                core::RefCountPtr<eager::EagerClient>* client) {
1032   {
1033     tf_shared_lock l(remote_state_mu_);
1034     if (remote_eager_workers_ == nullptr) {
1035       return errors::Internal(
1036           "Haven't set up remote eager worker in this eager context yet.");
1037     }
1038     TF_RETURN_IF_ERROR(remote_eager_workers_->GetClient(remote_task, client));
1039   }
1040 
1041   if (*client == nullptr) {
1042     return errors::InvalidArgument(
1043         "Unable to find eager client corresponding to target ", remote_task);
1044   }
1045   return Status::OK();
1046 }
1047 
GetContextId() const1048 uint64 EagerContext::GetContextId() const {
1049   tf_shared_lock l(remote_state_mu_);
1050   return context_id_;
1051 }
1052 
GetContextViewId() const1053 uint64 EagerContext::GetContextViewId() const {
1054   tf_shared_lock l(remote_state_mu_);
1055   return context_view_id_;
1056 }
1057 
IncrementContextViewId()1058 void EagerContext::IncrementContextViewId() {
1059   mutex_lock l(remote_state_mu_);
1060   context_view_id_ += 1;
1061 }
1062 
1063 // Set collective ops related state in the context. Passing nullptr to
1064 // `new_server` will reuse the existing GRPC server in context.
StoreCollectiveOpsServer(std::unique_ptr<ServerInterface> new_server,const DeviceMgr * device_mgr,CollectiveExecutorMgrInterface * rpc_collective_executor_mgr)1065 Status EagerContext::StoreCollectiveOpsServer(
1066     std::unique_ptr<ServerInterface> new_server, const DeviceMgr* device_mgr,
1067     CollectiveExecutorMgrInterface* rpc_collective_executor_mgr) {
1068   collective_executor_mgr_.Reset(rpc_collective_executor_mgr);
1069 
1070   if (device_mgr != local_device_manager_.Get()) {
1071     if (local_device_manager_.Owned()) {
1072       old_local_device_managers_.push_back(
1073           std::move(local_device_manager_.owned_object));
1074     }
1075     local_device_manager_.Reset(device_mgr);
1076   }
1077   host_cpu_device_ = local_device_manager_.Get()->HostCPU();
1078 
1079   if (reuse_rendezvous_for_functions_) {
1080     // If reuse_rendezvous_for_functions_ is true, CreateRendezvous is
1081     // idempotent and ignores its step_id argument. Create a rendezvous now to
1082     // replace the old one, preventing the old one from getting used.
1083     if (rendezvous_ != nullptr) rendezvous_->Unref();
1084     rendezvous_ = CreateRendezvous(/*step_id=*/-1);
1085     return errors::Aborted("Cannot create a valid rendezvous.");
1086   }
1087 
1088   InitPrioritizedDeviceTypeList();
1089   ClearCachesAndThreadExecutors();
1090   default_executor_.ClearError();
1091   {
1092     tensorflow::mutex_lock l(executor_map_mu_);
1093     for (auto& entry : thread_local_executor_) {
1094       entry.second->ClearError();
1095     }
1096   }
1097 
1098   const ConfigProto* config = pflr_ ? pflr_->config() : nullptr;
1099   ResetPFLR(
1100       local_device_manager_.Get(), env_, /*config=*/config,
1101       TF_GRAPH_DEF_VERSION, &func_lib_def_,
1102       /*optimizer_options=*/
1103       config ? config->graph_options().optimizer_options() : OptimizerOptions(),
1104       thread_pool_.get());
1105 
1106   if (new_server != nullptr) {
1107     // Memory leak!
1108     if (server_ != nullptr) {
1109       LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
1110                       "Servers don't support clean shutdown.";
1111       server_.release();
1112     }
1113     server_ = std::move(new_server);
1114   }
1115   DCHECK(server_ != nullptr);
1116 
1117   return Status::OK();
1118 }
1119 
SetRemoteDeviceFilters(const string & remote_worker,const std::vector<string> & device_filters)1120 Status EagerContext::SetRemoteDeviceFilters(
1121     const string& remote_worker, const std::vector<string>& device_filters) {
1122   // Get fully specified task name for remote worker
1123   string remote_worker_task_name;
1124   DeviceNameUtils::ParsedName pw;
1125   if (!DeviceNameUtils::ParseFullName(remote_worker, &pw)) {
1126     return tensorflow::errors::InvalidArgument(
1127         "Remote worker task name is invalid ", remote_worker);
1128   }
1129   // Force set a replica as the key in cluster device filters map. I.e., if the
1130   // remote worker is `/job:worker/task:0` it then becomes
1131   // `/job:worker/replica:0/task:0`.
1132   pw.has_replica = true;
1133   if (!DeviceNameUtils::GetTaskName(pw, &remote_worker_task_name)) {
1134     return tensorflow::errors::InvalidArgument(
1135         "Job name and task index must be specified for worker ", remote_worker);
1136   }
1137 
1138   std::vector<DeviceNameUtils::ParsedName> parsed_filters;
1139   for (auto& filter : device_filters) {
1140     DeviceNameUtils::ParsedName parsed_filter;
1141     if (DeviceNameUtils::ParseFullName(filter, &parsed_filter)) {
1142       parsed_filters.emplace_back(parsed_filter);
1143     } else {
1144       return tensorflow::errors::InvalidArgument("Invalid filter: ", filter);
1145     }
1146   }
1147 
1148   if (VLOG_IS_ON(1)) {
1149     VLOG(1) << "Setting device filters for " << remote_worker << ":";
1150     for (auto& filter : device_filters) {
1151       VLOG(1) << "  " << filter;
1152     }
1153   }
1154   mutex_lock l(remote_state_mu_);
1155   cluster_device_filters_.emplace(remote_worker_task_name, parsed_filters);
1156   return Status::OK();
1157 }
1158 
FilterDevicesForRemoteWorkers(const string & remote_worker,const protobuf::RepeatedPtrField<DeviceAttributes> & device_attrs,std::vector<bool> * filtered_device_mask)1159 void EagerContext::FilterDevicesForRemoteWorkers(
1160     const string& remote_worker,
1161     const protobuf::RepeatedPtrField<DeviceAttributes>& device_attrs,
1162     std::vector<bool>* filtered_device_mask) {
1163   filtered_device_mask->resize(device_attrs.size());
1164   std::fill(filtered_device_mask->begin(), filtered_device_mask->end(), false);
1165 
1166   tf_shared_lock l(remote_state_mu_);
1167   auto it = cluster_device_filters_.find(remote_worker);
1168   // If no filters were specified, all devices should be visible to the worker
1169   if (it == cluster_device_filters_.end() || it->second.empty()) {
1170     std::fill(filtered_device_mask->begin(), filtered_device_mask->end(), true);
1171     return;
1172   }
1173 
1174   const std::vector<DeviceNameUtils::ParsedName>& parsed_filters = it->second;
1175   DeviceNameUtils::ParsedName parsed_remote_worker;
1176   DeviceNameUtils::ParseFullName(remote_worker, &parsed_remote_worker);
1177   for (int i = 0; i < device_attrs.size(); i++) {
1178     DeviceNameUtils::ParsedName pn;
1179     DeviceNameUtils::ParseFullName(device_attrs[i].name(), &pn);
1180     if (DeviceNameUtils::IsSameAddressSpace(parsed_remote_worker, pn)) {
1181       // If this device is on the remote worker itself, it should be visible
1182       // regardless of device filters
1183       filtered_device_mask->at(i) = true;
1184       continue;
1185     }
1186     for (const auto& pf : parsed_filters) {
1187       if ((!pn.has_job || !pf.has_job || pn.job == pf.job) &&
1188           (!pn.has_replica || !pf.has_replica || pn.replica == pf.replica) &&
1189           (!pn.has_task || !pf.has_task || pn.task == pf.task) &&
1190           (!pn.has_type || !pf.has_type || pn.type == pf.type) &&
1191           (!pn.has_id || !pf.has_id || pn.id == pf.id)) {
1192         // Found a match, make it visible, stop processing more device filters
1193         filtered_device_mask->at(i) = true;
1194         break;
1195       }
1196     }
1197   }
1198 }
1199 
InitializeRemoteMaster(std::unique_ptr<ServerInterface> server,WorkerEnv * worker_env,std::shared_ptr<WorkerSession> worker_session,std::unique_ptr<eager::EagerClientCache> remote_eager_workers,std::unique_ptr<DynamicDeviceMgr> remote_device_manager,const std::vector<string> & remote_contexts,uint64 context_id,Rendezvous * r,const DeviceMgr * local_device_mgr,int keep_alive_secs,DistributedFunctionLibraryRuntime * cluster_flr,std::unique_ptr<eager::RemoteMgr,std::function<void (eager::RemoteMgr *)>> remote_mgr)1200 Status EagerContext::InitializeRemoteMaster(
1201     std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
1202     std::shared_ptr<WorkerSession> worker_session,
1203     std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
1204     std::unique_ptr<DynamicDeviceMgr> remote_device_manager,
1205     const std::vector<string>& remote_contexts, uint64 context_id,
1206     Rendezvous* r, const DeviceMgr* local_device_mgr, int keep_alive_secs,
1207     DistributedFunctionLibraryRuntime* cluster_flr,
1208     std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
1209         remote_mgr) {
1210   if (context_id == kInvalidContextId) {
1211     return errors::InvalidArgument(
1212         "Failed to initialize remote for master context due to invalid ",
1213         "context id");
1214   }
1215 
1216   if (!IsRemoteContextsEmpty()) {
1217     CloseAndClearAllRemoteContexts();
1218   }
1219   {
1220     mutex_lock l(remote_state_mu_);
1221     remote_contexts_ = remote_contexts;
1222   }
1223 
1224   return SetMasterContextState(
1225       std::move(server), worker_env, std::move(worker_session),
1226       std::move(remote_eager_workers), std::move(remote_device_manager),
1227       context_id, 0, r, local_device_mgr, keep_alive_secs, cluster_flr,
1228       std::move(remote_mgr));
1229 }
1230 
UpdateRemoteMaster(uint64 context_id,std::unique_ptr<eager::EagerClientCache> remote_eager_workers,const std::vector<string> & add_remote_contexts,const std::vector<string> & remove_remote_contexts)1231 Status EagerContext::UpdateRemoteMaster(
1232     uint64 context_id,
1233     std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
1234     const std::vector<string>& add_remote_contexts,
1235     const std::vector<string>& remove_remote_contexts) {
1236   {
1237     tf_shared_lock l(remote_state_mu_);
1238     if (context_id != context_id_) {
1239       return errors::InvalidArgument(
1240           "Failed to update remote master context due to invalid context id. ",
1241           "Request id = ", context_id, " but current id = ", context_id_);
1242     }
1243   }
1244 
1245   if (!remove_remote_contexts.empty()) {
1246     // N.B. remove_remote_contexts include both removed and replaced workers.
1247     // In the case where a worker is replaced by one that resolves to the same
1248     // `hostname:port`, it is safe to close context with the current view id,
1249     // since the newly created context on the remote worker will be holding
1250     // a larger view id and ignores this request.
1251     CloseRemoteContexts(remove_remote_contexts, context_id, GetContextViewId());
1252     mutex_lock l(remote_state_mu_);
1253     for (const string& remote_context : remove_remote_contexts) {
1254       remote_contexts_.erase(
1255           std::remove(remote_contexts_.begin(), remote_contexts_.end(),
1256                       remote_context),
1257           remote_contexts_.end());
1258     }
1259   }
1260   if (!add_remote_contexts.empty()) {
1261     mutex_lock l(remote_state_mu_);
1262     remote_contexts_.insert(std::end(remote_contexts_),
1263                             std::begin(add_remote_contexts),
1264                             std::end(add_remote_contexts));
1265   }
1266 
1267   {
1268     mutex_lock l(remote_state_mu_);
1269     context_view_id_++;
1270 
1271     remote_eager_workers_ = std::move(remote_eager_workers);
1272     pflr_->InitializeDeviceAndFlr();
1273     InitPrioritizedDeviceTypeList();
1274 
1275     default_executor_.ClearError();
1276     {
1277       tensorflow::mutex_lock l(executor_map_mu_);
1278       for (auto& entry : thread_local_executor_) {
1279         entry.second->ClearError();
1280       }
1281     }
1282   }
1283 
1284   // Register existing functions to the newly added remote workers. Note that
1285   // this should happen only after updating `remote_contexts_` because new
1286   // functions might be registered while we update the context. When that
1287   // happens, this ordering ensures that `MaybeRegisterFunctionRemotely` will
1288   // register the new functions on all remote workers (including the newly added
1289   // ones), and `RegisterExistingFunctionsOnRemoteWorkers` will take care of
1290   // registering existing functions, where duplicate registrations will be
1291   // ignored by the remote workers.
1292   TF_RETURN_IF_ERROR(
1293       RegisterExistingFunctionsOnRemoteWorkers(add_remote_contexts));
1294   return Status::OK();
1295 }
1296 
1297 // Set distributed execution related state in the master context.
SetMasterContextState(std::unique_ptr<ServerInterface> server,WorkerEnv * worker_env,std::shared_ptr<WorkerSession> worker_session,std::unique_ptr<eager::EagerClientCache> remote_eager_workers,std::unique_ptr<DynamicDeviceMgr> remote_device_manager,uint64 context_id,uint64 context_view_id,Rendezvous * r,const DeviceMgr * local_device_mgr,int keep_alive_secs,DistributedFunctionLibraryRuntime * cluster_flr,std::unique_ptr<eager::RemoteMgr,std::function<void (eager::RemoteMgr *)>> remote_mgr)1298 Status EagerContext::SetMasterContextState(
1299     std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
1300     std::shared_ptr<WorkerSession> worker_session,
1301     std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
1302     std::unique_ptr<DynamicDeviceMgr> remote_device_manager, uint64 context_id,
1303     uint64 context_view_id, Rendezvous* r, const DeviceMgr* local_device_mgr,
1304     int keep_alive_secs, DistributedFunctionLibraryRuntime* cluster_flr,
1305     std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
1306         remote_mgr) {
1307   mutex_lock l(remote_state_mu_);
1308   is_master_ = true;
1309   context_id_ = context_id;
1310   context_view_id_ = context_view_id;
1311 
1312   use_send_tensor_rpc_ =
1313       ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", true);
1314 
1315   if (local_device_mgr != local_device_manager_.Get()) {
1316     if (local_device_manager_.Owned()) {
1317       old_local_device_managers_.push_back(
1318           std::move(local_device_manager_.owned_object));
1319     }
1320     local_device_manager_.Reset(local_device_mgr);
1321   }
1322   host_cpu_device_ = local_device_manager_.Get()->HostCPU();
1323 
1324   if (rendezvous_ != nullptr) rendezvous_->Unref();
1325   rendezvous_ = r;
1326 
1327   // Memory leak!
1328   if (server_ != nullptr) {
1329     LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
1330                     "Servers don't support clean shutdown.";
1331     server_.release();
1332   }
1333   server_ = std::move(server);
1334 
1335   remote_mgr_ = std::move(remote_mgr);
1336   worker_env_ = worker_env;
1337   worker_session_ = std::move(worker_session);
1338   remote_eager_workers_ = std::move(remote_eager_workers);
1339 
1340   remote_device_manager_.Reset(std::move(remote_device_manager));
1341   ResetClusterFLR(cluster_flr);
1342 
1343   InitPrioritizedDeviceTypeList();
1344 
1345   ClearCachesAndThreadExecutors();
1346   default_executor_.ClearError();
1347   {
1348     tensorflow::mutex_lock l(executor_map_mu_);
1349     for (auto& entry : thread_local_executor_) {
1350       entry.second->ClearError();
1351     }
1352   }
1353   const auto* config = pflr_->config();
1354   ResetPFLR(local_device_manager_.Get(), env_, config, TF_GRAPH_DEF_VERSION,
1355             &func_lib_def_, config->graph_options().optimizer_options(),
1356             thread_pool_.get(), cluster_flr_.Get());
1357 
1358   keep_alive_secs_ = keep_alive_secs;
1359   sleep_for_secs_ = std::max(1, keep_alive_secs_ / 2);
1360   // Only schedule a single closure.
1361   if (keep_alive_thread_ == nullptr) {
1362     keep_alive_thread_.reset(
1363         env_->StartThread({}, "EagerKeepAliveThread", [this]() {
1364           while (true) {
1365             {
1366               {
1367                 mutex_lock l(keep_alive_thread_shutdown_mu_);
1368 
1369                 if (shutting_down_) {
1370                   return;
1371                 }
1372 
1373                 keep_alive_thread_cv_.wait_for(
1374                     l, std::chrono::seconds(sleep_for_secs_));
1375 
1376                 if (shutting_down_) {
1377                   return;
1378                 }
1379               }
1380               {
1381                 mutex_lock l(remote_state_mu_);
1382                 if (keep_alive_secs_ > 0) {
1383                   {
1384                     for (const auto& worker : remote_contexts_) {
1385                       core::RefCountPtr<eager::EagerClient> client;
1386                       Status s =
1387                           remote_eager_workers_->GetClient(worker, &client);
1388 
1389                       if (!s.ok()) {
1390                         LOG(WARNING) << "Keep-alive thread was unable to find "
1391                                         "a client for target "
1392                                      << worker << ". Got error: " << s;
1393                         continue;
1394                       }
1395 
1396                       eager::KeepAliveRequest* request =
1397                           new eager::KeepAliveRequest;
1398                       eager::KeepAliveResponse* response =
1399                           new eager::KeepAliveResponse;
1400 
1401                       request->set_context_id(context_id_);
1402                       client->KeepAliveAsync(
1403                           request, response,
1404                           [request, response](const Status& s) {
1405                             delete request;
1406                             delete response;
1407                           });
1408                     }
1409                   }
1410                 }
1411               }
1412             }
1413           }
1414         }));
1415   }
1416   return Status::OK();
1417 }
1418 
InitializeRemoteWorker(std::unique_ptr<eager::EagerClientCache> remote_eager_workers,DynamicDeviceMgr * remote_device_mgr,const std::vector<string> & remote_contexts,uint64 context_id,uint64 context_view_id,std::function<Rendezvous * (const int64)> rendezvous_creator,DistributedFunctionLibraryRuntime * cluster_flr,std::unique_ptr<eager::RemoteMgr,std::function<void (eager::RemoteMgr *)>> remote_mgr,std::function<void ()> resource_deallocator)1419 Status EagerContext::InitializeRemoteWorker(
1420     std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
1421     DynamicDeviceMgr* remote_device_mgr,
1422     const std::vector<string>& remote_contexts, uint64 context_id,
1423     uint64 context_view_id,
1424     std::function<Rendezvous*(const int64)> rendezvous_creator,
1425     DistributedFunctionLibraryRuntime* cluster_flr,
1426     std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
1427         remote_mgr,
1428     std::function<void()> resource_deallocator) {
1429   if (context_id == kInvalidContextId) {
1430     return errors::InvalidArgument(
1431         "Failed to initialize remote for worker context due to invalid ",
1432         "context id");
1433   }
1434   mutex_lock l(remote_state_mu_);
1435 
1436   if (remote_device_manager_.Owned() || server_ != nullptr ||
1437       keep_alive_thread_ != nullptr) {
1438     return errors::FailedPrecondition(
1439         "EagerContext::InitializeRemoteWorker Failed. ",
1440         "Already initialized remote as a master context.");
1441   }
1442   is_master_ = false;
1443 
1444   remote_contexts_ = remote_contexts;
1445   context_id_ = context_id;
1446   context_view_id_ = context_view_id;
1447 
1448   rendezvous_creator_ = std::move(rendezvous_creator);
1449   remote_eager_workers_ = std::move(remote_eager_workers);
1450   remote_mgr_ = std::move(remote_mgr);
1451   ResetClusterFLR(cluster_flr);
1452 
1453   remote_device_manager_.Reset(remote_device_mgr);
1454 
1455   const auto* config = pflr_->config();
1456   ResetPFLR(local_device_manager_.Get(), env_, config, TF_GRAPH_DEF_VERSION,
1457             &func_lib_def_, config->graph_options().optimizer_options(),
1458             thread_pool_.get(), cluster_flr_.Get());
1459   InitPrioritizedDeviceTypeList();
1460 
1461   ClearCachesAndThreadExecutors();
1462   default_executor_.ClearError();
1463   {
1464     tensorflow::mutex_lock l(executor_map_mu_);
1465     for (auto& entry : thread_local_executor_) {
1466       entry.second->ClearError();
1467     }
1468   }
1469 
1470   resource_deallocator_ = std::move(resource_deallocator);
1471 
1472   return Status::OK();
1473 }
1474 
UpdateRemoteWorker(std::unique_ptr<eager::EagerClientCache> remote_eager_workers,const std::vector<string> & remote_contexts,uint64 context_id)1475 Status EagerContext::UpdateRemoteWorker(
1476     std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
1477     const std::vector<string>& remote_contexts, uint64 context_id) {
1478   {
1479     mutex_lock l(remote_state_mu_);
1480     if (context_id != context_id_) {
1481       return errors::InvalidArgument(
1482           "Failed to update remote for worker context due to invalid ",
1483           "context id. Request id = ", context_id,
1484           " but current id = ", context_id_);
1485     }
1486     context_view_id_++;
1487 
1488     remote_contexts_ = remote_contexts;
1489     remote_eager_workers_ = std::move(remote_eager_workers);
1490     InitPrioritizedDeviceTypeList();
1491     pflr_->InitializeDeviceAndFlr();
1492   }
1493 
1494   // No need to update remote_device_manager_ since it's not owned for remote
1495   // worker context (owned by the corresponding worker session).
1496   if (remote_device_manager_.Owned()) {
1497     return errors::FailedPrecondition(
1498         "EagerContext::UpdateRemoteWorker failed because the context was "
1499         "initialized as a master context.");
1500   }
1501 
1502   ClearCachesAndThreadExecutors();
1503   default_executor_.ClearError();
1504   {
1505     tensorflow::mutex_lock l(executor_map_mu_);
1506     for (auto& entry : thread_local_executor_) {
1507       entry.second->ClearError();
1508     }
1509   }
1510   return Status::OK();
1511 }
1512 #endif  // !IS_MOBILE_PLATFORM
1513 
1514 }  // namespace tensorflow
1515