1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/pjrt/distributed/service.h"
17 
18 #include "absl/time/time.h"
19 #include "tensorflow/compiler/xla/pjrt/distributed/protocol.h"
20 #include "tensorflow/compiler/xla/pjrt/distributed/util.h"
21 #include "tensorflow/compiler/xla/status.h"
22 #include "tensorflow/compiler/xla/util.h"
23 #include "tensorflow/core/platform/errors.h"
24 #include "tensorflow/core/platform/random.h"
25 
26 namespace xla {
27 
DistributedRuntimeServiceImpl(const Options & options)28 DistributedRuntimeServiceImpl::DistributedRuntimeServiceImpl(
29     const Options& options)
30     : options_(options), session_id_(tensorflow::random::New64()) {
31   nodes_.resize(options.num_nodes);
32   local_topologies_.resize(options.num_nodes);
33 }
34 
~DistributedRuntimeServiceImpl()35 DistributedRuntimeServiceImpl::~DistributedRuntimeServiceImpl() {
36   {
37     absl::MutexLock lock(&mu_);
38     state_ = State::kClosed;
39     service_status_ =
40         tensorflow::errors::FailedPrecondition("Service shutting down.");
41     if (!stop_heartbeat_thread_.HasBeenNotified()) {
42       stop_heartbeat_thread_.Notify();
43     }
44   }
45 }
46 
47 // Steals the contents of `local_topologies`.
BuildGlobalTopology(absl::Span<LocalTopologyProto> local_topologies,GlobalTopologyProto * global_topology)48 void BuildGlobalTopology(absl::Span<LocalTopologyProto> local_topologies,
49                          GlobalTopologyProto* global_topology) {
50   int next_global_device_id = 0;
51   for (LocalTopologyProto& local : local_topologies) {
52     for (DeviceProto& device : *local.mutable_devices()) {
53       device.set_global_device_id(next_global_device_id++);
54     }
55     global_topology->add_nodes()->Swap(&local);
56   }
57 }
58 
ValidateNodeId(int node_id)59 xla::Status DistributedRuntimeServiceImpl::ValidateNodeId(int node_id) {
60   if (node_id < 0) {
61     return xla::InvalidArgument("Invalid node ID %d, must be non-negative",
62                                 node_id);
63   }
64   if (node_id >= options_.num_nodes) {
65     return xla::FailedPrecondition(
66         "Invalid node ID %d, must be in the range [0, %d)", node_id,
67         options_.num_nodes);
68   }
69   return xla::Status::OK();
70 }
71 
ValidateSessionId(uint64 session_id)72 xla::Status DistributedRuntimeServiceImpl::ValidateSessionId(
73     uint64 session_id) {
74   if (session_id != session_id_) {
75     return xla::FailedPrecondition(
76         "Session ID of request %llu does not match active session ID %llu",
77         session_id, session_id_);
78   }
79   return xla::Status::OK();
80 }
81 
Connect(::grpc::ServerContext * context,const ConnectRequest * request,ConnectResponse * response)82 ::grpc::Status DistributedRuntimeServiceImpl::Connect(
83     ::grpc::ServerContext* context, const ConnectRequest* request,
84     ConnectResponse* response) {
85   VLOG(10) << "Connect " << request->DebugString();
86   if (request->protocol_version() != kDistributedRuntimeProtocolVersion) {
87     return ToGrpcStatus(xla::InvalidArgument("Invalid protocol version %d",
88                                              request->protocol_version()));
89   }
90   absl::MutexLock lock(&mu_);
91   if (state_ != State::kInitializing) {
92     // This most likely indicates that a client task was restarted but the
93     // old master is still up. Clients should retry on failure.
94     return ToGrpcStatus(tensorflow::errors::Aborted(
95         "Connect() called when system is not initializing."));
96   }
97   int node_id = request->node_id();
98   xla::Status status = ValidateNodeId(node_id);
99   if (!status.ok()) {
100     return ToGrpcStatus(status);
101   }
102   if (!nodes_[node_id].present) {
103     nodes_[node_id].present = true;
104     ++num_nodes_present_;
105   }
106   nodes_[node_id].client_id = request->client_id();
107 
108   auto all_nodes_present_or_duplicate_request = [&]() {
109     mu_.AssertHeld();
110     return num_nodes_present_ == nodes_.size() ||
111            nodes_[node_id].client_id != request->client_id();
112   };
113   auto connect_timeout = absl::Milliseconds(request->timeout_milliseconds());
114   if (!mu_.AwaitWithTimeout(
115           absl::Condition(&all_nodes_present_or_duplicate_request),
116           connect_timeout)) {
117     nodes_[node_id].present = false;
118     --num_nodes_present_;
119     return ToGrpcStatus(tensorflow::errors::DeadlineExceeded(
120         "Timed out after ", absl::FormatDuration(connect_timeout),
121         " waiting for all nodes to call Connect()"));
122   }
123 
124   if (nodes_[node_id].client_id != request->client_id()) {
125     // This might happen either if two nodes are erroneously configured with the
126     // same ID number, or it might happen if a task fails and is restarted
127     // while we are waiting for nodes to connect. To elaborate on the second
128     // scenario, it would look like this:
129     // * a task calls Connect() with a particular node_id and client_id.
130     // * the task is killed and restarted, or alternatively the client's RPC
131     //   times out and it decides to retry.
132     // * the task calls Connect() again with the same node_id and a different
133     //   client_id.
134     // In this scenario we take whichever client showed up most recently and
135     // evict the client with an out-of-date client ID.
136     return ToGrpcStatus(
137         tensorflow::errors::Aborted("Duplicate node ID ", node_id));
138   }
139 
140   if (node_id == 0) {
141     state_ = State::kRunning;
142     heartbeat_thread_.reset(options_.env->StartThread(
143         tensorflow::ThreadOptions(), "pjrt_service_heartbeat",
144         [this]() { HeartbeatLoop(); }));
145   } else {
146     auto running = [&]() {
147       mu_.AssertHeld();
148       return state_ == State::kRunning;
149     };
150     mu_.Await(absl::Condition(&running));
151   }
152   nodes_[node_id].last_heartbeat = absl::Now();
153   response->set_session_id(session_id_);
154   return ::grpc::Status::OK;
155 }
156 
Shutdown(::grpc::ServerContext * context,const ShutdownRequest * request,ShutdownResponse * response)157 ::grpc::Status DistributedRuntimeServiceImpl::Shutdown(
158     ::grpc::ServerContext* context, const ShutdownRequest* request,
159     ShutdownResponse* response) {
160   VLOG(10) << "Shutdown " << request->DebugString();
161   xla::Status status = ValidateSessionId(request->session_id());
162   if (!status.ok()) {
163     return ToGrpcStatus(status);
164   }
165   absl::MutexLock lock(&mu_);
166   if (state_ != State::kRunning) {
167     if (!service_status_.ok()) {
168       return ToGrpcStatus(service_status_);
169     }
170     return ToGrpcStatus(xla::FailedPrecondition(
171         "Shutdown() called when system is not running."));
172   }
173   int node_id = request->node_id();
174   status = ValidateNodeId(node_id);
175   if (!status.ok()) {
176     return ToGrpcStatus(status);
177   }
178   ++num_nodes_shutting_down_;
179 
180   auto all_nodes_shutting_down = [&]() {
181     mu_.AssertHeld();
182     return num_nodes_shutting_down_ == nodes_.size() || !service_status_.ok();
183   };
184   if (!mu_.AwaitWithTimeout(absl::Condition(&all_nodes_shutting_down),
185                             options_.shutdown_timeout)) {
186     state_ = State::kClosed;
187     return ToGrpcStatus(tensorflow::errors::DeadlineExceeded(
188         "Timed out after ", absl::FormatDuration(options_.shutdown_timeout),
189         " waiting for all nodes to call Shutdown()"));
190   }
191   state_ = State::kClosed;
192   if (!stop_heartbeat_thread_.HasBeenNotified()) {
193     stop_heartbeat_thread_.Notify();
194   }
195   if (!service_status_.ok()) {
196     return ToGrpcStatus(service_status_);
197   }
198   return ::grpc::Status::OK;
199 }
200 
EnumerateDevices(::grpc::ServerContext * context,const EnumerateDevicesRequest * request,EnumerateDevicesResponse * response)201 ::grpc::Status DistributedRuntimeServiceImpl::EnumerateDevices(
202     ::grpc::ServerContext* context, const EnumerateDevicesRequest* request,
203     EnumerateDevicesResponse* response) {
204   VLOG(10) << "EnumerateDevices " << request->DebugString();
205   xla::Status status = ValidateSessionId(request->session_id());
206   if (!status.ok()) {
207     return ToGrpcStatus(status);
208   }
209   absl::MutexLock lock(&mu_);
210   if (state_ != State::kRunning) {
211     if (!service_status_.ok()) {
212       return ToGrpcStatus(service_status_);
213     }
214     return ToGrpcStatus(xla::FailedPrecondition(
215         "EnumerateDevices() called when system is not running."));
216   }
217   int node_id = request->local_topology().node_id();
218   status = ValidateNodeId(node_id);
219   if (!status.ok()) {
220     return ToGrpcStatus(status);
221   }
222   local_topologies_[node_id] = request->local_topology();
223   ++num_topologies_present_;
224 
225   auto all_topologies_present = [&]() {
226     mu_.AssertHeld();
227     return num_topologies_present_ == nodes_.size() || !service_status_.ok();
228   };
229   if (!mu_.AwaitWithTimeout(absl::Condition(&all_topologies_present),
230                             options_.enumerate_devices_timeout)) {
231     return ToGrpcStatus(tensorflow::errors::DeadlineExceeded(
232         "Timed out after ",
233         absl::FormatDuration(options_.enumerate_devices_timeout),
234         " waiting for all nodes to call EnumerateDevices()"));
235   }
236   if (!service_status_.ok()) {
237     return ToGrpcStatus(service_status_);
238   }
239 
240   if (node_id == 0) {
241     topology_.emplace();
242     BuildGlobalTopology(absl::Span<LocalTopologyProto>(local_topologies_),
243                         &*topology_);
244     local_topologies_.clear();
245   } else {
246     auto topology_ready = [&]() -> bool {
247       mu_.AssertHeld();
248       return topology_.has_value();
249     };
250     mu_.Await(absl::Condition(&topology_ready));
251   }
252   *response->mutable_global_topology() = *topology_;
253   return ::grpc::Status::OK;
254 }
255 
Heartbeat(::grpc::ServerContext * context,const HeartbeatRequest * request,HeartbeatResponse * response)256 ::grpc::Status DistributedRuntimeServiceImpl::Heartbeat(
257     ::grpc::ServerContext* context, const HeartbeatRequest* request,
258     HeartbeatResponse* response) {
259   VLOG(10) << "Heartbeat " << request->DebugString();
260   xla::Status status = ValidateSessionId(request->session_id());
261   if (!status.ok()) {
262     return ToGrpcStatus(status);
263   }
264   absl::MutexLock lock(&mu_);
265   if (state_ != State::kRunning) {
266     if (!service_status_.ok()) {
267       return ToGrpcStatus(service_status_);
268     }
269     return ToGrpcStatus(xla::FailedPrecondition(
270         "Heartbeat() called when system is not running."));
271   }
272   int node_id = request->node_id();
273   status = ValidateNodeId(node_id);
274   if (!status.ok()) {
275     return ToGrpcStatus(status);
276   }
277   nodes_[node_id].last_heartbeat = absl::Now();
278   return ::grpc::Status::OK;
279 }
280 
HeartbeatLoop()281 void DistributedRuntimeServiceImpl::HeartbeatLoop() {
282   while (true) {
283     stop_heartbeat_thread_.WaitForNotificationWithTimeout(
284         options_.heartbeat_interval);
285     VLOG(10) << "Checking heartbeats";
286     if (stop_heartbeat_thread_.HasBeenNotified()) {
287       VLOG(10) << "Heartbeat checking stopped.";
288       return;
289     }
290     absl::Time now = absl::Now();
291     absl::MutexLock lock(&mu_);
292     for (size_t i = 0; i < nodes_.size(); ++i) {
293       // If we haven't heard from the node for a number of heartbeat intervals,
294       // declare that we are unhealthy.
295       VLOG(10) << "Node " << i
296                << " last heartbeat: " << nodes_[i].last_heartbeat;
297       if (nodes_[i].last_heartbeat +
298               options_.max_missing_heartbeats * options_.heartbeat_interval <
299           now) {
300         LOG(INFO) << "Missed heartbeats from node " << i << ". Shutting down.";
301         state_ = State::kClosed;
302         service_status_ = tensorflow::errors::Aborted(
303             "Shutting down due to missed heartbeat from task ", i);
304         return;
305       }
306     }
307   }
308 }
309 
KeyValueGet(::grpc::ServerContext * context,const KeyValueGetRequest * request,KeyValueGetResponse * response)310 ::grpc::Status DistributedRuntimeServiceImpl::KeyValueGet(
311     ::grpc::ServerContext* context, const KeyValueGetRequest* request,
312     KeyValueGetResponse* response) {
313   VLOG(10) << "KeyValueGet " << request->DebugString();
314   xla::Status status = ValidateSessionId(request->session_id());
315   if (!status.ok()) {
316     return ToGrpcStatus(status);
317   }
318   {
319     absl::MutexLock lock(&mu_);
320     if (state_ != State::kRunning) {
321       if (!service_status_.ok()) {
322         return ToGrpcStatus(service_status_);
323       }
324       return ToGrpcStatus(xla::FailedPrecondition(
325           "KeyValueGet() called when system is not running."));
326     }
327   }
328   return key_value_store_.Get(
329       request->key(), absl::Milliseconds(request->timeout_milliseconds()),
330       response->mutable_value());
331 }
332 
KeyValueSet(::grpc::ServerContext * context,const KeyValueSetRequest * request,KeyValueSetResponse * response)333 ::grpc::Status DistributedRuntimeServiceImpl::KeyValueSet(
334     ::grpc::ServerContext* context, const KeyValueSetRequest* request,
335     KeyValueSetResponse* response) {
336   VLOG(10) << "KeyValueSet " << request->DebugString();
337   xla::Status status = ValidateSessionId(request->session_id());
338   if (!status.ok()) {
339     return ToGrpcStatus(status);
340   }
341   {
342     absl::MutexLock lock(&mu_);
343     if (state_ != State::kRunning) {
344       if (!service_status_.ok()) {
345         return ToGrpcStatus(service_status_);
346       }
347       return ToGrpcStatus(xla::FailedPrecondition(
348           "KeyValueSet() called when system is not running; clients must call "
349           "Connect() first"));
350     }
351   }
352   return key_value_store_.Set(request->key(), request->value());
353 }
354 
355 xla::StatusOr<std::unique_ptr<DistributedRuntimeService>>
Get(const std::string & address,std::shared_ptr<::grpc::ServerCredentials> credentials,const DistributedRuntimeServiceImpl::Options & options)356 DistributedRuntimeService::Get(
357     const std::string& address,
358     std::shared_ptr<::grpc::ServerCredentials> credentials,
359     const DistributedRuntimeServiceImpl::Options& options) {
360   auto service = absl::make_unique<DistributedRuntimeService>(options);
361   ::grpc::ServerBuilder builder;
362   builder.AddListeningPort(address, credentials);
363   VLOG(1) << "Distributed runtime service address " << address;
364   builder.RegisterService(&service->impl_);
365   service->server_ = builder.BuildAndStart();
366   if (!service->server_) {
367     return xla::Unknown("Failed to start RPC server");
368   }
369   LOG(INFO) << "Jax service listening on " << address;
370   return service;
371 }
372 
DistributedRuntimeService(const DistributedRuntimeServiceImpl::Options & options)373 DistributedRuntimeService::DistributedRuntimeService(
374     const DistributedRuntimeServiceImpl::Options& options)
375     : impl_(options) {}
376 
~DistributedRuntimeService()377 DistributedRuntimeService::~DistributedRuntimeService() {
378   if (server_) {
379     LOG(INFO) << "Jax service shutting down";
380     server_->Shutdown();
381     server_->Wait();
382   }
383 }
384 
385 }  // namespace xla
386