1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/eager/context.h"
17 
18 #include "tensorflow/core/common_runtime/collective_executor_mgr.h"
19 #include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
20 #include "tensorflow/core/common_runtime/device_resolver_local.h"
21 #include "tensorflow/core/common_runtime/device_set.h"
22 #include "tensorflow/core/common_runtime/process_util.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #ifndef __ANDROID__
25 #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
26 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
27 #include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
28 #endif
29 #include "tensorflow/core/framework/resource_mgr.h"
30 #include "tensorflow/core/lib/core/blocking_counter.h"
31 #include "tensorflow/core/util/env_var.h"
32 
33 namespace tensorflow {
34 namespace {
35 
ReadBoolFromEnvVar(StringPiece env_var_name,bool default_val)36 bool ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val) {
37   bool val;
38   if (tensorflow::ReadBoolFromEnvVar(env_var_name, default_val, &val).ok()) {
39     return val;
40   }
41   return default_val;
42 }
43 
44 }  // namespace
45 
EagerContext(const SessionOptions & opts,ContextDevicePlacementPolicy default_policy,bool async,std::unique_ptr<const DeviceMgr> device_mgr,Rendezvous * rendezvous)46 EagerContext::EagerContext(const SessionOptions& opts,
47                            ContextDevicePlacementPolicy default_policy,
48                            bool async,
49                            std::unique_ptr<const DeviceMgr> device_mgr,
50                            Rendezvous* rendezvous)
51     : EagerContext(opts, default_policy, async, device_mgr.release(),
52                    /*device_mgr_owned*/ true, rendezvous) {}
53 
EagerContext(const SessionOptions & opts,ContextDevicePlacementPolicy default_policy,bool async,const DeviceMgr * device_mgr,bool device_mgr_owned,Rendezvous * rendezvous)54 EagerContext::EagerContext(const SessionOptions& opts,
55                            ContextDevicePlacementPolicy default_policy,
56                            bool async, const DeviceMgr* device_mgr,
57                            bool device_mgr_owned, Rendezvous* rendezvous)
58     : policy_(default_policy),
59       devices_(device_mgr->ListDevices()),
60       rendezvous_(rendezvous),
61       thread_pool_(NewThreadPoolFromSessionOptions(opts)),
62       pflr_(new ProcessFunctionLibraryRuntime(
63           device_mgr, opts.env, TF_GRAPH_DEF_VERSION, &func_lib_def_,
64           opts.config.graph_options().optimizer_options(), thread_pool_.get())),
65       log_device_placement_(opts.config.log_device_placement()),
66       num_active_steps_(0),
67       async_default_(async),
68       log_memory_(LogMemory::IsEnabled()),
69       env_(opts.env),
70       use_send_tensor_rpc_(false),
71       pin_small_ops_to_cpu_(ReadBoolFromEnvVar(
72           "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING", false)) {
73   if (device_mgr_owned) {
74     local_device_manager_.reset(device_mgr);
75     local_unowned_device_manager_ = nullptr;
76   } else {
77     local_unowned_device_manager_ = device_mgr;
78   }
79   InitDeviceMapAndAsync();
80   runner_ = [this](std::function<void()> closure) {
81     this->thread_pool_->Schedule(std::move(closure));
82   };
83 
84   std::unique_ptr<DeviceResolverInterface> drl(
85       new DeviceResolverLocal(local_device_mgr()));
86   std::unique_ptr<ParamResolverInterface> cprl(new CollectiveParamResolverLocal(
87       opts.config, local_device_mgr(), drl.get(),
88       "/job:localhost/replica:0/task:0"));
89   collective_executor_mgr_.reset(new CollectiveExecutorMgr(
90       opts.config, local_device_mgr(), std::move(drl), std::move(cprl)));
91 }
92 
InitDeviceMapAndAsync()93 void EagerContext::InitDeviceMapAndAsync() {
94   if (async_default_) {
95     executor_.EnableAsync();
96   }
97 
98   for (auto* device : devices_) {
99     devices_map_[device->name()] = device;
100   }
101 
102   if (remote_device_manager_ != nullptr) {
103     for (auto* device : remote_device_manager_->ListDevices()) {
104       if (devices_map_.find(device->name()) == devices_map_.end()) {
105         devices_map_[device->name()] = device;
106         devices_.push_back(device);
107       }
108     }
109   }
110 
111   DeviceSet ds;
112   for (Device* d : devices_) {
113     ds.AddDevice(d);
114   }
115   prioritized_device_type_list_ = ds.PrioritizedDeviceTypeList();
116 }
117 
Async() const118 bool EagerContext::Async() const {
119   mutex_lock l(async_map_mu_);
120   return gtl::FindWithDefault(thread_local_async_, std::this_thread::get_id(),
121                               async_default_);
122 }
123 
SetAsyncForThread(bool async)124 Status EagerContext::SetAsyncForThread(bool async) {
125   {
126     tensorflow::mutex_lock l(async_map_mu_);
127     thread_local_async_[std::this_thread::get_id()] = async;
128   }
129   if (async) {
130     executor_.EnableAsync();
131   } else {
132     // TODO(agarwal): Currently we add a wait here to handle cases where a
133     // sync op has a control dependency on an async op, and the latter has not
134     // executed yet. This wait can be removed by storing all the control
135     // inputs and waiting for them when executing ops.
136     return executor_.WaitForAllPendingNodes();
137   }
138   return Status::OK();
139 }
140 
ClearCaches()141 Status EagerContext::ClearCaches() {
142   // The executor stores pointers to kernels, so we need to make sure that no
143   // async eager ops are still executing. We lock the cache during this time as
144   // well.
145   mutex_lock ml(cache_mu_);
146   TF_RETURN_IF_ERROR(executor_.WaitForAllPendingNodes());
147   gtl::STLDeleteValues(&kernel_cache_);
148 
149   return Status::OK();
150 }
151 
SetThreadLocalDevicePlacementPolicy(ContextDevicePlacementPolicy policy)152 void EagerContext::SetThreadLocalDevicePlacementPolicy(
153     ContextDevicePlacementPolicy policy) {
154   mutex_lock ml(policy_map_mu_);
155   thread_local_policies_[std::this_thread::get_id()] = policy;
156 }
157 
GetDevicePlacementPolicy()158 ContextDevicePlacementPolicy EagerContext::GetDevicePlacementPolicy() {
159   mutex_lock ml(policy_map_mu_);
160   auto policy_map_it = thread_local_policies_.find(std::this_thread::get_id());
161   if (policy_map_it != thread_local_policies_.end()) {
162     return policy_map_it->second;
163   }
164   return policy_;
165 }
166 
167 #ifndef __ANDROID__
CloseRemoteContexts()168 void EagerContext::CloseRemoteContexts() {
169   // Close all remote contexts.
170   std::vector<eager::CloseContextRequest> requests(remote_contexts_.size());
171   std::vector<eager::CloseContextResponse> responses(remote_contexts_.size());
172   BlockingCounter counter(static_cast<int>(remote_contexts_.size()));
173 
174   int i = 0;
175   for (const auto& worker_and_context_id : remote_contexts_) {
176     auto* client =
177         remote_eager_workers_->GetClient(worker_and_context_id.first);
178 
179     requests[i].set_context_id(worker_and_context_id.second);
180     client->CloseContextAsync(
181         &requests[i], &responses[i],
182         [&worker_and_context_id, &counter](const Status& s) {
183           if (!s.ok()) {
184             LOG(ERROR) << "Unable to close remote context with ID "
185                        << worker_and_context_id.second
186                        << " for worker: " << worker_and_context_id.first
187                        << " due to " << s.error_message();
188           }
189           counter.DecrementCount();
190         });
191     i++;
192   }
193 
194   counter.Wait();
195 }
196 #endif
197 
~EagerContext()198 EagerContext::~EagerContext() {
199 #ifndef __ANDROID__
200   if (server_) {
201     // TODO(nareshmodi): Fix this.
202     LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
203                     "Servers don't support clean shutdown.";
204     server_.release();
205   }
206 
207   {
208     mutex_lock l(keep_alive_thread_shutdown_mu_);
209     shutting_down_ = true;
210     keep_alive_thread_cv_.notify_all();
211   }
212   keep_alive_thread_.reset();
213 
214   CloseRemoteContexts();
215 #endif
216 
217   executor_.WaitForAllPendingNodes().IgnoreError();
218   ClearCaches().IgnoreError();
219   rendezvous_->Unref();
220 
221   for (auto& thread : child_threads_) {
222     thread.reset();
223   }
224 }
225 
AddChildThread(std::unique_ptr<Thread> thread)226 void EagerContext::AddChildThread(std::unique_ptr<Thread> thread) {
227   child_threads_.push_back(std::move(thread));
228 }
229 
FindFunctionByName(const string & name)230 bool EagerContext::FindFunctionByName(const string& name) {
231   mutex_lock l(functions_mu_);
232   return func_lib_def_.Find(name) != nullptr;
233 }
234 
FindFunctionOpData(const string & name,const tensorflow::OpRegistrationData ** op_data)235 Status EagerContext::FindFunctionOpData(
236     const string& name, const tensorflow::OpRegistrationData** op_data) {
237   mutex_lock l(functions_mu_);
238   return func_lib_def_.LookUp(name, op_data);
239 }
240 
FindFunctionDef(const string & name)241 const FunctionDef* EagerContext::FindFunctionDef(const string& name) {
242   mutex_lock l(functions_mu_);
243   return func_lib_def_.Find(name);
244 }
245 
FindDeviceByName(const string & name,Device ** result)246 Status EagerContext::FindDeviceByName(const string& name, Device** result) {
247   auto it = devices_map_.find(name);
248   if (it == devices_map_.end()) {
249     return errors::InvalidArgument(name, " unknown device.");
250   }
251   *result = it->second;
252   return Status::OK();
253 }
254 
ClearRunMetadata()255 void EagerContext::ClearRunMetadata() {
256   if (metadata_listener_ != nullptr) {
257     metadata_listener_->BeforeClearRunMetadata();
258   }
259   run_metadata_.Clear();
260 }
261 
RegisterRunMetadataListener(RunMetadataListener * listener)262 Status EagerContext::RegisterRunMetadataListener(
263     RunMetadataListener* listener) {
264   mutex_lock l(metadata_mu_);
265   if (metadata_listener_ != nullptr) {
266     return Status(error::Code::INVALID_ARGUMENT,
267                   "Cannot run two eager profiler at the same time");
268   }
269   metadata_listener_ = listener;
270   return Status::OK();
271 }
272 
ClearRunMetadataListener()273 void EagerContext::ClearRunMetadataListener() {
274   mutex_lock l(metadata_mu_);
275   metadata_listener_ = nullptr;
276 }
277 
StartStep()278 void EagerContext::StartStep() {
279   mutex_lock ml(metadata_mu_);
280   num_active_steps_++;
281   if (step_container_ == nullptr) {
282     step_container_.reset(
283         new ScopedStepContainer(0, [this](const string& name) {
284           for (Device* device : devices_) {
285             device->resource_manager()->Cleanup(name).IgnoreError();
286           }
287         }));
288   }
289 }
290 
EndStep()291 void EagerContext::EndStep() {
292   mutex_lock ml(metadata_mu_);
293   num_active_steps_--;
294   if (num_active_steps_ == 0) {
295     step_container_.reset();
296   }
297 }
298 
StepContainer()299 ScopedStepContainer* EagerContext::StepContainer() {
300   if (num_active_steps_.load() == 0) {
301     return nullptr;
302   }
303   mutex_lock ml(metadata_mu_);
304   return step_container_.get();
305 }
306 
MaybeRegisterFunctionRemotely(const FunctionDef & fdef)307 Status EagerContext::MaybeRegisterFunctionRemotely(const FunctionDef& fdef) {
308   if (remote_device_manager_ == nullptr) return Status::OK();
309 #ifndef __ANDROID__
310   BlockingCounter blocking_counter(static_cast<int>(remote_contexts_.size()));
311 
312   std::vector<eager::RegisterFunctionRequest> requests(remote_contexts_.size());
313   std::vector<eager::RegisterFunctionResponse> responses(
314       remote_contexts_.size());
315   std::vector<Status> statuses(remote_contexts_.size());
316 
317   int i = 0;
318   for (const auto& target_and_context_id : remote_contexts_) {
319     requests[i].set_context_id(target_and_context_id.second);
320     *requests[i].mutable_function_def() = fdef;
321 
322     auto* eager_client =
323         remote_eager_workers_->GetClient(target_and_context_id.first);
324 
325     eager_client->RegisterFunctionAsync(
326         &requests[i], &responses[i],
327         [i, &statuses, &blocking_counter](const Status& status) {
328           statuses[i] = status;
329           blocking_counter.DecrementCount();
330         });
331 
332     i++;
333   }
334   blocking_counter.Wait();
335 
336   for (int i = 0; i < remote_contexts_.size(); i++) {
337     TF_RETURN_IF_ERROR(statuses[i]);
338   }
339 #endif
340   return Status::OK();
341 }
342 
AddFunctionDef(const FunctionDef & fdef)343 Status EagerContext::AddFunctionDef(const FunctionDef& fdef) {
344   mutex_lock l(functions_mu_);
345   TF_RETURN_IF_ERROR(func_lib_def_.AddFunctionDef(fdef));
346 
347   return MaybeRegisterFunctionRemotely(fdef);
348 }
349 
GetCachedKernel(Fprint128 cache_key)350 KernelAndDevice* EagerContext::GetCachedKernel(Fprint128 cache_key) {
351   tf_shared_lock l(cache_mu_);
352   return gtl::FindPtrOrNull(kernel_cache_, cache_key);
353 }
354 
AddKernelToCache(Fprint128 cache_key,KernelAndDevice * kernel)355 void EagerContext::AddKernelToCache(Fprint128 cache_key,
356                                     KernelAndDevice* kernel) {
357   mutex_lock ml(cache_mu_);
358   gtl::InsertOrUpdate(&kernel_cache_, cache_key, kernel);
359 }
360 
ShouldStoreGraphs()361 bool EagerContext::ShouldStoreGraphs() {
362   mutex_lock ml(metadata_mu_);
363   return should_store_graphs_.load() || metadata_listener_ != nullptr;
364 }
365 
ShouldStoreStepStats()366 bool EagerContext::ShouldStoreStepStats() {
367   mutex_lock ml(metadata_mu_);
368   return should_store_step_stats_.load() || metadata_listener_ != nullptr;
369 }
370 
SetShouldStoreGraphs(bool value)371 void EagerContext::SetShouldStoreGraphs(bool value) {
372   mutex_lock ml(metadata_mu_);
373   should_store_graphs_.store(value);
374   if (!value || metadata_listener_ != nullptr) {
375     run_metadata_.Clear();
376   }
377 }
378 
SetShouldStoreStepStats(bool value)379 void EagerContext::SetShouldStoreStepStats(bool value) {
380   mutex_lock ml(metadata_mu_);
381   should_store_step_stats_.store(value);
382   if (!value || metadata_listener_ != nullptr) {
383     run_metadata_.Clear();
384   }
385 }
386 
387 namespace {
GetTaskName(Device * d,string * task_name)388 Status GetTaskName(Device* d, string* task_name) {
389   string ignored;
390   if (!DeviceNameUtils::SplitDeviceName(d->name(), task_name, &ignored)) {
391     return errors::InvalidArgument("Unable to parse device name: ", d->name());
392   }
393 
394   return Status::OK();
395 }
396 }  // namespace
397 
398 #ifndef __ANDROID__
GetClientAndContextID(Device * device,eager::EagerClient ** client,uint64 * context_id)399 Status EagerContext::GetClientAndContextID(Device* device,
400                                            eager::EagerClient** client,
401                                            uint64* context_id) {
402   auto it = device_to_client_cache_.find(device);
403   if (it != device_to_client_cache_.end()) {
404     *client = it->second.first;
405     *context_id = it->second.second;
406   }
407   string device_task_name;
408   TF_RETURN_IF_ERROR(GetTaskName(device, &device_task_name));
409 
410   *client = remote_eager_workers_->GetClient(device_task_name);
411 
412   if (*client == nullptr) {
413     return errors::InvalidArgument(
414         "Unable to find eager client corresponding to device ", device->name());
415   }
416 
417   auto context_iterator = remote_contexts_.find(device_task_name);
418   if (context_iterator == remote_contexts_.end()) {
419     return errors::Internal("Unable to find a context for handle on task: ",
420                             device_task_name, ". This should not be possible");
421   }
422   *context_id = context_iterator->second;
423 
424   device_to_client_cache_.insert({device, {*client, *context_id}});
425 
426   return Status::OK();
427 }
428 
StoreCollectiveOpsServer(std::unique_ptr<ServerInterface> server,DeviceMgr * device_mgr,CollectiveExecutorMgrInterface * rpc_collective_executor_mgr)429 Status EagerContext::StoreCollectiveOpsServer(
430     std::unique_ptr<ServerInterface> server, DeviceMgr* device_mgr,
431     CollectiveExecutorMgrInterface* rpc_collective_executor_mgr) {
432   collective_executor_mgr_.reset(nullptr);
433   unowned_collective_executor_mgr_ = rpc_collective_executor_mgr;
434 
435   local_device_manager_.reset(nullptr);
436   local_unowned_device_manager_ = device_mgr;
437 
438   devices_ = local_unowned_device_manager_->ListDevices();
439   devices_map_.clear();
440 
441   InitDeviceMapAndAsync();
442   TF_RETURN_IF_ERROR(ClearCaches());
443 
444   pflr_.reset(new ProcessFunctionLibraryRuntime(
445       local_unowned_device_manager_, env_, TF_GRAPH_DEF_VERSION, &func_lib_def_,
446       {}, thread_pool_.get()));
447 
448   // Memory leak!
449   if (server_ != nullptr) {
450     LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
451                     "Servers don't support clean shutdown.";
452     server_.release();
453   }
454   server_ = std::move(server);
455 
456   return Status::OK();
457 }
458 
InitializeRemote(std::unique_ptr<ServerInterface> server,std::unique_ptr<eager::EagerClientCache> remote_eager_workers,std::unique_ptr<DeviceMgr> remote_device_manager,const gtl::FlatMap<string,uint64> & remote_contexts,Rendezvous * r,DeviceMgr * local_device_mgr,int keep_alive_secs)459 Status EagerContext::InitializeRemote(
460     std::unique_ptr<ServerInterface> server,
461     std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
462     std::unique_ptr<DeviceMgr> remote_device_manager,
463     const gtl::FlatMap<string, uint64>& remote_contexts, Rendezvous* r,
464     DeviceMgr* local_device_mgr, int keep_alive_secs) {
465   mutex_lock l(remote_state_mu_);
466 
467   if (!remote_contexts_.empty()) {
468     CloseRemoteContexts();
469   }
470   remote_contexts_ = remote_contexts;
471 
472   use_send_tensor_rpc_ =
473       ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", false);
474 
475   local_unowned_device_manager_ = local_device_mgr;
476   local_device_manager_ = nullptr;
477   pflr_.reset(new ProcessFunctionLibraryRuntime(
478       local_unowned_device_manager_, env_, TF_GRAPH_DEF_VERSION, &func_lib_def_,
479       {}, thread_pool_.get()));
480 
481   devices_ = local_unowned_device_manager_->ListDevices();
482   devices_map_.clear();
483 
484   if (rendezvous_ != nullptr) rendezvous_->Unref();
485   rendezvous_ = r;
486 
487   // Memory leak!
488   if (server_ != nullptr) {
489     LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
490                     "Servers don't support clean shutdown.";
491     server_.release();
492   }
493 
494   server_ = std::move(server);
495   remote_eager_workers_ = std::move(remote_eager_workers);
496 
497   active_remote_contexts_.clear();
498   for (const auto& remote_context : remote_contexts_) {
499     active_remote_contexts_.insert(remote_context.second);
500   }
501 
502   device_to_client_cache_.clear();
503   remote_device_manager_ = std::move(remote_device_manager);
504 
505   InitDeviceMapAndAsync();
506 
507   TF_RETURN_IF_ERROR(ClearCaches());
508 
509   keep_alive_secs_ = keep_alive_secs;
510 
511   sleep_for_secs_ = std::max(1, keep_alive_secs_ / 2);
512 
513   // Only schedule a single closure.
514   if (keep_alive_thread_ == nullptr) {
515     keep_alive_thread_.reset(
516         env_->StartThread({}, "EagerKeepAliveThread", [this]() {
517           while (true) {
518             {
519               {
520                 mutex_lock l(keep_alive_thread_shutdown_mu_);
521                 keep_alive_thread_cv_.wait_for(
522                     l, std::chrono::seconds(sleep_for_secs_));
523 
524                 if (shutting_down_) {
525                   return;
526                 }
527               }
528               {
529                 mutex_lock l(remote_state_mu_);
530                 if (keep_alive_secs_ > 0) {
531                   {
532                     for (const auto& worker_and_context_id : remote_contexts_) {
533                       auto* client = remote_eager_workers_->GetClient(
534                           worker_and_context_id.first);
535 
536                       eager::KeepAliveRequest* request =
537                           new eager::KeepAliveRequest;
538                       eager::KeepAliveResponse* response =
539                           new eager::KeepAliveResponse;
540 
541                       request->set_context_id(worker_and_context_id.second);
542                       client->KeepAliveAsync(
543                           request, response,
544                           [request, response](const Status& s) {
545                             delete request;
546                             delete response;
547                           });
548                     }
549                   }
550                 }
551               }
552             }
553           }
554         }));
555   }
556   return Status::OK();
557 }
558 #endif
559 
560 }  // namespace tensorflow
561