1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/worker_session.h"
16 
17 namespace tensorflow {
18 
19 namespace {
20 
21 // A private cache that wraps worker_cache and allows reuse of
22 // WorkerInterface objects.
23 class WorkerFreeListCache : public WorkerCacheInterface {
24  public:
WorkerFreeListCache(std::unique_ptr<WorkerCacheInterface> w)25   explicit WorkerFreeListCache(std::unique_ptr<WorkerCacheInterface> w)
26       : wrapped_(std::move(w)) {}
27 
~WorkerFreeListCache()28   ~WorkerFreeListCache() final {
29     for (auto& p : workers_) {
30       wrapped_->ReleaseWorker(p.first, p.second.worker);
31     }
32   }
33 
ListWorkers(std::vector<string> * workers) const34   void ListWorkers(std::vector<string>* workers) const override {
35     wrapped_->ListWorkers(workers);
36   }
37 
ListWorkersInJob(const string & job_name,std::vector<string> * workers) const38   void ListWorkersInJob(const string& job_name,
39                         std::vector<string>* workers) const override {
40     wrapped_->ListWorkersInJob(job_name, workers);
41   }
42 
CreateWorker(const string & target)43   WorkerInterface* CreateWorker(const string& target) override {
44     mutex_lock l(mu_);
45     auto p = workers_.find(target);
46     if (p != workers_.end()) {
47       return p->second.worker;
48     }
49     WorkerState state;
50     state.worker = wrapped_->CreateWorker(target);
51     if (state.worker != nullptr) {
52       workers_.insert(std::make_pair(target, state));
53     }
54     return state.worker;
55   }
56 
ReleaseWorker(const string & target,WorkerInterface * worker)57   void ReleaseWorker(const string& target, WorkerInterface* worker) override {
58     // TODO(jeff,sanjay): Should decrement ref-count when we implement eviction.
59   }
60 
GetDeviceLocalityNonBlocking(const string & device,DeviceLocality * locality)61   bool GetDeviceLocalityNonBlocking(const string& device,
62                                     DeviceLocality* locality) override {
63     return wrapped_->GetDeviceLocalityNonBlocking(device, locality);
64   }
65 
GetDeviceLocalityAsync(const string & device,DeviceLocality * locality,StatusCallback done)66   void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
67                               StatusCallback done) override {
68     wrapped_->GetDeviceLocalityAsync(device, locality, done);
69   }
70 
SetLogging(bool active)71   void SetLogging(bool active) override { wrapped_->SetLogging(active); }
72 
ClearLogs()73   void ClearLogs() override { wrapped_->ClearLogs(); }
74 
RetrieveLogs(int64 step_id,StepStats * ss)75   bool RetrieveLogs(int64 step_id, StepStats* ss) override {
76     return wrapped_->RetrieveLogs(step_id, ss);
77   }
78 
79  private:
80   std::unique_ptr<WorkerCacheInterface> wrapped_;
81 
82   // Information kept per created WorkerInterface.
83   struct WorkerState {
84     WorkerInterface* worker;
85     // TODO(jeff,sanjay): Add reference count if we support eviction.
86   };
87 
88   // TODO(jeff,sanjay): Eviction when the map becomes too big.
89   mutex mu_;
90   std::unordered_map<string, WorkerState> workers_ GUARDED_BY(mu_);
91 };
92 
93 }  // namespace
94 
WorkerSession(const string & session_name,const string & worker_name,std::unique_ptr<WorkerCacheInterface> worker_cache,std::unique_ptr<DeviceMgr> device_mgr,std::unique_ptr<GraphMgr> graph_mgr)95 WorkerSession::WorkerSession(const string& session_name,
96                              const string& worker_name,
97                              std::unique_ptr<WorkerCacheInterface> worker_cache,
98                              std::unique_ptr<DeviceMgr> device_mgr,
99                              std::unique_ptr<GraphMgr> graph_mgr)
100     : session_name(session_name),
101       worker_name(worker_name),
102       worker_cache(new WorkerFreeListCache(std::move(worker_cache))),
103       graph_mgr(std::move(graph_mgr)),
104       cluster_flr(
105           new ClusterFunctionLibraryRuntime(this, !session_name.empty())),
106       device_mgr_(std::move(device_mgr)),
107       borrowed_device_mgr_(nullptr) {}
108 
109 /* static */
CreateWithBorrowedDeviceMgr(const string & session_name,const string & worker_name,std::unique_ptr<WorkerCacheInterface> worker_cache,DeviceMgr * borrowed_device_mgr,std::unique_ptr<GraphMgr> graph_mgr)110 std::shared_ptr<WorkerSession> WorkerSession::CreateWithBorrowedDeviceMgr(
111     const string& session_name, const string& worker_name,
112     std::unique_ptr<WorkerCacheInterface> worker_cache,
113     DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr) {
114   return std::shared_ptr<WorkerSession>(
115       new WorkerSession(session_name, worker_name, std::move(worker_cache),
116                         borrowed_device_mgr, std::move(graph_mgr)));
117 }
118 
WorkerSession(const string & session_name,const string & worker_name,std::unique_ptr<WorkerCacheInterface> worker_cache,DeviceMgr * borrowed_device_mgr,std::unique_ptr<GraphMgr> graph_mgr)119 WorkerSession::WorkerSession(const string& session_name,
120                              const string& worker_name,
121                              std::unique_ptr<WorkerCacheInterface> worker_cache,
122                              DeviceMgr* borrowed_device_mgr,
123                              std::unique_ptr<GraphMgr> graph_mgr)
124     : session_name(session_name),
125       worker_name(worker_name),
126       worker_cache(new WorkerFreeListCache(std::move(worker_cache))),
127       graph_mgr(std::move(graph_mgr)),
128       cluster_flr(
129           new ClusterFunctionLibraryRuntime(this, !session_name.empty())),
130       device_mgr_(nullptr),
131       borrowed_device_mgr_(borrowed_device_mgr) {}
132 
~WorkerSession()133 WorkerSession::~WorkerSession() {
134   if (graph_mgr) {
135     Status s = graph_mgr->DeregisterAll();
136     if (!s.ok()) {
137       LOG(WARNING) << "Error during worker session deletion: " << s;
138     }
139   }
140 }
141 
142 }  // namespace tensorflow
143