1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/local_master.h"
17 
18 #include <unordered_map>
19 
20 #include "tensorflow/core/distributed_runtime/master.h"
21 #include "tensorflow/core/platform/mutex.h"
22 
23 namespace tensorflow {
24 
25 namespace {
WaitForNotification(CallOptions * call_options,const int64 default_timeout_in_ms,Notification * n)26 Status WaitForNotification(CallOptions* call_options,
27                            const int64 default_timeout_in_ms, Notification* n) {
28   int64 timeout_in_ms = call_options->GetTimeout();
29   if (timeout_in_ms == 0) {
30     timeout_in_ms = default_timeout_in_ms;
31   }
32   if (timeout_in_ms > 0) {
33     int64 timeout_in_us = timeout_in_ms * 1000;
34     bool notified = WaitForNotificationWithTimeout(n, timeout_in_us);
35     if (!notified) {
36       call_options->StartCancel();
37       // The call has borrowed pointers to the request and response
38       // messages, so we must still wait for the call to complete.
39       n->WaitForNotification();
40       return errors::DeadlineExceeded("Operation timed out.");
41     }
42   } else {
43     n->WaitForNotification();
44   }
45   return Status::OK();
46 }
47 }  // namespace
48 
LocalMaster(Master * master_impl,const int64 default_timeout_in_ms)49 LocalMaster::LocalMaster(Master* master_impl, const int64 default_timeout_in_ms)
50     : master_impl_(master_impl),
51       default_timeout_in_ms_(default_timeout_in_ms) {}
52 
CreateSession(CallOptions * call_options,const CreateSessionRequest * request,CreateSessionResponse * response)53 Status LocalMaster::CreateSession(CallOptions* call_options,
54                                   const CreateSessionRequest* request,
55                                   CreateSessionResponse* response) {
56   Notification n;
57   Status ret;
58   master_impl_->CreateSession(request, response, [&n, &ret](const Status& s) {
59     ret.Update(s);
60     n.Notify();
61   });
62   TF_RETURN_IF_ERROR(
63       WaitForNotification(call_options, default_timeout_in_ms_, &n));
64   return ret;
65 }
66 
ExtendSession(CallOptions * call_options,const ExtendSessionRequest * request,ExtendSessionResponse * response)67 Status LocalMaster::ExtendSession(CallOptions* call_options,
68                                   const ExtendSessionRequest* request,
69                                   ExtendSessionResponse* response) {
70   Notification n;
71   Status ret;
72   master_impl_->ExtendSession(request, response, [&n, &ret](const Status& s) {
73     ret.Update(s);
74     n.Notify();
75   });
76   TF_RETURN_IF_ERROR(
77       WaitForNotification(call_options, default_timeout_in_ms_, &n));
78   return ret;
79 }
80 
PartialRunSetup(CallOptions * call_options,const PartialRunSetupRequest * request,PartialRunSetupResponse * response)81 Status LocalMaster::PartialRunSetup(CallOptions* call_options,
82                                     const PartialRunSetupRequest* request,
83                                     PartialRunSetupResponse* response) {
84   Notification n;
85   Status ret;
86   master_impl_->PartialRunSetup(request, response, [&n, &ret](const Status& s) {
87     ret.Update(s);
88     n.Notify();
89   });
90   TF_RETURN_IF_ERROR(
91       WaitForNotification(call_options, default_timeout_in_ms_, &n));
92   return ret;
93 }
94 
RunStep(CallOptions * call_options,RunStepRequestWrapper * request,MutableRunStepResponseWrapper * response)95 Status LocalMaster::RunStep(CallOptions* call_options,
96                             RunStepRequestWrapper* request,
97                             MutableRunStepResponseWrapper* response) {
98   Notification n;
99   Status ret;
100   master_impl_->RunStep(call_options, request, response,
101                         [&n, &ret](const Status& s) {
102                           ret.Update(s);
103                           n.Notify();
104                         });
105   TF_RETURN_IF_ERROR(
106       WaitForNotification(call_options, default_timeout_in_ms_, &n));
107   return ret;
108 }
109 
CreateRunStepRequest()110 MutableRunStepRequestWrapper* LocalMaster::CreateRunStepRequest() {
111   return new InMemoryRunStepRequest;
112 }
113 
CreateRunStepResponse()114 MutableRunStepResponseWrapper* LocalMaster::CreateRunStepResponse() {
115   return new InMemoryRunStepResponse;
116 }
117 
CloseSession(CallOptions * call_options,const CloseSessionRequest * request,CloseSessionResponse * response)118 Status LocalMaster::CloseSession(CallOptions* call_options,
119                                  const CloseSessionRequest* request,
120                                  CloseSessionResponse* response) {
121   Notification n;
122   Status ret;
123   master_impl_->CloseSession(request, response, [&n, &ret](const Status& s) {
124     ret.Update(s);
125     n.Notify();
126   });
127   TF_RETURN_IF_ERROR(
128       WaitForNotification(call_options, default_timeout_in_ms_, &n));
129   return ret;
130 }
131 
ListDevices(CallOptions * call_options,const ListDevicesRequest * request,ListDevicesResponse * response)132 Status LocalMaster::ListDevices(CallOptions* call_options,
133                                 const ListDevicesRequest* request,
134                                 ListDevicesResponse* response) {
135   Notification n;
136   Status ret;
137   master_impl_->ListDevices(request, response, [&n, &ret](const Status& s) {
138     ret.Update(s);
139     n.Notify();
140   });
141   TF_RETURN_IF_ERROR(
142       WaitForNotification(call_options, default_timeout_in_ms_, &n));
143   return ret;
144 }
145 
Reset(CallOptions * call_options,const ResetRequest * request,ResetResponse * response)146 Status LocalMaster::Reset(CallOptions* call_options,
147                           const ResetRequest* request,
148                           ResetResponse* response) {
149   Notification n;
150   Status ret;
151   master_impl_->Reset(request, response, [&n, &ret](const Status& s) {
152     ret.Update(s);
153     n.Notify();
154   });
155   TF_RETURN_IF_ERROR(
156       WaitForNotification(call_options, default_timeout_in_ms_, &n));
157   return ret;
158 }
159 
MakeCallable(CallOptions * call_options,const MakeCallableRequest * request,MakeCallableResponse * response)160 Status LocalMaster::MakeCallable(CallOptions* call_options,
161                                  const MakeCallableRequest* request,
162                                  MakeCallableResponse* response) {
163   Notification n;
164   Status ret;
165   master_impl_->MakeCallable(request, response, [&n, &ret](const Status& s) {
166     ret.Update(s);
167     n.Notify();
168   });
169   TF_RETURN_IF_ERROR(
170       WaitForNotification(call_options, default_timeout_in_ms_, &n));
171   return ret;
172 }
RunCallable(CallOptions * call_options,const RunCallableRequest * request,RunCallableResponse * response)173 Status LocalMaster::RunCallable(CallOptions* call_options,
174                                 const RunCallableRequest* request,
175                                 RunCallableResponse* response) {
176   Notification n;
177   Status ret;
178   master_impl_->RunCallable(call_options, request, response,
179                             [&n, &ret](const Status& s) {
180                               ret.Update(s);
181                               n.Notify();
182                             });
183   TF_RETURN_IF_ERROR(
184       WaitForNotification(call_options, default_timeout_in_ms_, &n));
185   return ret;
186 }
ReleaseCallable(CallOptions * call_options,const ReleaseCallableRequest * request,ReleaseCallableResponse * response)187 Status LocalMaster::ReleaseCallable(CallOptions* call_options,
188                                     const ReleaseCallableRequest* request,
189                                     ReleaseCallableResponse* response) {
190   Notification n;
191   Status ret;
192   master_impl_->ReleaseCallable(request, response, [&n, &ret](const Status& s) {
193     ret.Update(s);
194     n.Notify();
195   });
196   TF_RETURN_IF_ERROR(
197       WaitForNotification(call_options, default_timeout_in_ms_, &n));
198   return ret;
199 }
200 
201 namespace {
get_local_master_registry_lock()202 mutex* get_local_master_registry_lock() {
203   static mutex local_master_registry_lock(LINKER_INITIALIZED);
204   return &local_master_registry_lock;
205 }
206 
207 struct MasterInfo {
208   Master* master;
209   const int64 default_timeout_in_ms;
210 
MasterInfotensorflow::__anon9c61ccac0c11::MasterInfo211   MasterInfo(Master* master, const int64 default_timeout_in_ms)
212       : master(master), default_timeout_in_ms(default_timeout_in_ms) {}
213 };
214 
215 typedef std::unordered_map<string, MasterInfo> LocalMasterRegistry;
local_master_registry()216 LocalMasterRegistry* local_master_registry() {
217   static LocalMasterRegistry* local_master_registry_ = new LocalMasterRegistry;
218   return local_master_registry_;
219 }
220 }  // namespace
221 
222 /* static */
Register(const string & target,Master * master,int64 default_timeout_in_ms)223 void LocalMaster::Register(const string& target, Master* master,
224                            int64 default_timeout_in_ms) {
225   mutex_lock l(*get_local_master_registry_lock());
226   local_master_registry()->insert(
227       {target, MasterInfo(master, default_timeout_in_ms)});
228 }
229 
230 /* static */
Lookup(const string & target)231 std::unique_ptr<LocalMaster> LocalMaster::Lookup(const string& target) {
232   std::unique_ptr<LocalMaster> ret;
233   mutex_lock l(*get_local_master_registry_lock());
234   auto iter = local_master_registry()->find(target);
235   if (iter != local_master_registry()->end()) {
236     ret.reset(new LocalMaster(iter->second.master,
237                               iter->second.default_timeout_in_ms));
238   }
239   return ret;
240 }
241 
242 }  // namespace tensorflow
243