1 /* Copyright 2020 Google LLC
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/pjrt/distributed/client.h"
17 
18 #include <chrono>  // NOLINT
19 #include <random>
20 
21 #include "absl/time/time.h"
22 #include "tensorflow/compiler/xla/pjrt/distributed/protocol.h"
23 #include "tensorflow/compiler/xla/pjrt/distributed/util.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/core/platform/errors.h"
26 #include "tensorflow/core/platform/random.h"
27 
28 namespace xla {
29 
DistributedRuntimeClient(std::shared_ptr<::grpc::Channel> channel,const Options & options)30 DistributedRuntimeClient::DistributedRuntimeClient(
31     std::shared_ptr<::grpc::Channel> channel, const Options& options)
32     : stub_(grpc::DistributedRuntimeService::NewStub(std::move(channel))),
33       options_(options) {}
34 
~DistributedRuntimeClient()35 DistributedRuntimeClient::~DistributedRuntimeClient() {
36   bool connected;
37   {
38     absl::MutexLock lock(&mu_);
39     connected = (state_ == State::kConnected);
40   }
41   if (connected) {
42     if (options_.shutdown_on_destruction) {
43       Status status = Shutdown();
44       if (!status.ok()) {
45         LOG(WARNING) << "PJRT shutdown failed: " << status;
46       }
47     } else {
48       if (!stop_heartbeats_.HasBeenNotified()) {
49         stop_heartbeats_.Notify();
50       }
51     }
52   }
53 }
54 
StateToString(State state)55 /*static*/ absl::string_view DistributedRuntimeClient::StateToString(
56     State state) {
57   switch (state) {
58     case State::kNotConnected:
59       return "kNotConnected";
60     case State::kConnected:
61       return "kConnected";
62     case State::kShuttingDown:
63       return "kShuttingDown";
64     case State::kClosed:
65       return "kClosed";
66   }
67 }
68 
Connect()69 xla::Status DistributedRuntimeClient::Connect() {
70   {
71     absl::MutexLock lock(&mu_);
72     if (state_ != State::kNotConnected) {
73       return xla::FailedPrecondition("Connect() called when client in state %s",
74                                      StateToString(state_));
75     }
76   }
77   ConnectRequest request;
78   request.set_protocol_version(kDistributedRuntimeProtocolVersion);
79   request.set_timeout_milliseconds(
80       absl::ToInt64Milliseconds(options_.rpc_timeout) / 2);
81   request.set_node_id(options_.node_id);
82   VLOG(10) << "Connect: " << request.DebugString();
83   ConnectResponse response;
84   ::grpc::Status status;
85   absl::Time deadline = absl::Now() + options_.init_timeout;
86   int attempt = 0;
87   std::default_random_engine generator;
88   std::uniform_real_distribution<double> distribution(0.0, 1.0);
89   do {
90     ::grpc::ClientContext ctx;
91     ctx.set_fail_fast(false);
92     ctx.set_deadline(absl::ToChronoTime(absl::Now() + options_.rpc_timeout));
93     request.set_client_id(tensorflow::random::New64());
94     response.Clear();
95     status = stub_->Connect(&ctx, request, &response);
96     if (!status.ok()) {
97       VLOG(1) << "Connect failed() with status: " << FromGrpcStatus(status);
98       if (attempt % 10 == 0) {
99         LOG(INFO) << "Connect failed() with status: " << FromGrpcStatus(status);
100       }
101       // Exponential backoff with jitter. Note we will retry for `init_timeout`
102       // time in total; the `14` here corresponds to an ~16s maximum interval
103       // between connection attempts.
104       int backoff = 1 << std::min(14, attempt);
105       absl::SleepFor(absl::Milliseconds(backoff * distribution(generator)));
106     }
107     ++attempt;
108   } while (!status.ok() && absl::Now() < deadline);
109   if (!status.ok()) {
110     LOG(ERROR) << "Connect() failed after " << attempt << " retries in "
111                << options_.init_timeout
112                << "; most recent failure status: " << FromGrpcStatus(status);
113     return tensorflow::errors::DeadlineExceeded(
114         absl::StrFormat("Connect() timed out after %s with %d attempts. Most "
115                         "recent failure was: %s",
116                         absl::FormatDuration(options_.init_timeout), attempt,
117                         FromGrpcStatus(status).ToString()));
118   }
119   VLOG(10) << "Connect() response: " << response.DebugString();
120   {
121     absl::MutexLock lock(&mu_);
122     state_ = State::kConnected;
123   }
124   session_id_ = response.session_id();
125 
126   heartbeat_thread_.reset(options_.env->StartThread(
127       tensorflow::ThreadOptions(), "pjrt_distributed_heartbeat",
128       [this]() { HeartbeatLoop(); }));
129   LOG(INFO) << "Connected to distributed JAX controller";
130   return xla::Status::OK();
131 }
132 
EnumerateDevices(const LocalTopologyProto & local_topology,GlobalTopologyProto * global_topology)133 xla::Status DistributedRuntimeClient::EnumerateDevices(
134     const LocalTopologyProto& local_topology,
135     GlobalTopologyProto* global_topology) {
136   {
137     absl::MutexLock lock(&mu_);
138     if (state_ != State::kConnected) {
139       return xla::FailedPrecondition(
140           "EnumerateDevices() called when client not connected.");
141     }
142   }
143   ::grpc::ClientContext ctx;
144   ctx.set_fail_fast(false);
145   ctx.set_deadline(absl::ToChronoTime(absl::Now() + options_.rpc_timeout));
146   EnumerateDevicesRequest request;
147   request.set_session_id(session_id_);
148   *request.mutable_local_topology() = local_topology;
149   request.mutable_local_topology()->set_node_id(options_.node_id);
150 
151   VLOG(10) << "EnumerateDevices: " << request.DebugString();
152   EnumerateDevicesResponse response;
153   ::grpc::Status status = stub_->EnumerateDevices(&ctx, request, &response);
154   if (!status.ok()) {
155     return FromGrpcStatus(status);
156   }
157   VLOG(10) << "EnumerateDevices() response: " << response.DebugString();
158   response.mutable_global_topology()->Swap(global_topology);
159   return xla::Status::OK();
160 }
161 
Shutdown()162 xla::Status DistributedRuntimeClient::Shutdown() {
163   LOG(INFO) << "Waiting for all distributed JAX tasks to shut down.";
164   ::grpc::ClientContext ctx;
165   {
166     absl::MutexLock lock(&mu_);
167     if (state_ != State::kConnected) {
168       return xla::FailedPrecondition(
169           "Shutdown() called when client not connected.");
170     }
171     state_ = State::kShuttingDown;
172   }
173   ctx.set_fail_fast(false);
174   ctx.set_deadline(absl::ToChronoTime(absl::Now() + options_.shutdown_timeout));
175   ShutdownRequest request;
176   request.set_session_id(session_id_);
177   VLOG(10) << "Shutdown: " << request.DebugString();
178   ShutdownResponse response;
179   ::grpc::Status status = stub_->Shutdown(&ctx, request, &response);
180   LOG(INFO) << "Distributed task shutdown result: " << FromGrpcStatus(status);
181   if (!status.ok()) {
182     return FromGrpcStatus(status);
183   }
184   if (!stop_heartbeats_.HasBeenNotified()) {
185     stop_heartbeats_.Notify();
186   }
187   VLOG(10) << "Shutdown() response: " << response.DebugString();
188   absl::MutexLock lock(&mu_);
189   state_ = State::kClosed;
190   return xla::Status::OK();
191 }
192 
BlockingKeyValueGet(std::string key,absl::Duration timeout)193 xla::StatusOr<std::string> DistributedRuntimeClient::BlockingKeyValueGet(
194     std::string key, absl::Duration timeout) {
195   {
196     absl::MutexLock lock(&mu_);
197     if (state_ != State::kConnected) {
198       return xla::FailedPrecondition(
199           "BlockingKeyValueGet() called when client not connected.");
200     }
201   }
202   ::grpc::ClientContext ctx;
203   ctx.set_fail_fast(false);
204   ctx.set_deadline(absl::ToChronoTime(absl::Now() + timeout));
205   KeyValueGetRequest request;
206   request.set_session_id(session_id_);
207   request.set_key(std::move(key));
208   timeout = std::min(timeout, absl::Minutes(10));  // Avoid overflow
209   request.set_timeout_milliseconds(timeout / absl::Milliseconds(1));
210   VLOG(10) << "BlockingKeyValueGet: " << request.DebugString();
211   KeyValueGetResponse response;
212   ::grpc::Status status = stub_->KeyValueGet(&ctx, request, &response);
213   if (!status.ok()) {
214     return FromGrpcStatus(status);
215   }
216   return response.value();
217 }
218 
KeyValueSet(std::string key,std::string value)219 xla::Status DistributedRuntimeClient::KeyValueSet(std::string key,
220                                                   std::string value) {
221   {
222     absl::MutexLock lock(&mu_);
223     if (state_ != State::kConnected) {
224       return xla::FailedPrecondition(
225           "KeyValueSet() called when client not connected.");
226     }
227   }
228   ::grpc::ClientContext ctx;
229   ctx.set_fail_fast(false);
230   ctx.set_deadline(absl::ToChronoTime(absl::Now() + options_.rpc_timeout));
231   KeyValueSetRequest request;
232   request.set_session_id(session_id_);
233   request.set_key(std::move(key));
234   request.set_value(std::move(value));
235   VLOG(10) << "KeyValueSet: " << request.DebugString();
236   KeyValueSetResponse response;
237   ::grpc::Status status = stub_->KeyValueSet(&ctx, request, &response);
238   return FromGrpcStatus(status);
239 }
240 
HeartbeatLoop()241 void DistributedRuntimeClient::HeartbeatLoop() {
242   int num_missing_heartbeats = 0;
243   while (true) {
244     stop_heartbeats_.WaitForNotificationWithTimeout(
245         options_.heartbeat_interval);
246     if (stop_heartbeats_.HasBeenNotified()) {
247       return;
248     }
249 
250     ::grpc::ClientContext ctx;
251     ctx.set_fail_fast(false);
252     ctx.set_deadline(absl::ToChronoTime(absl::Now() + options_.rpc_timeout));
253     HeartbeatRequest request;
254     request.set_session_id(session_id_);
255     request.set_node_id(options_.node_id);
256     VLOG(10) << "Heartbeat: " << request.DebugString();
257     HeartbeatResponse response;
258     ::grpc::Status status = stub_->Heartbeat(&ctx, request, &response);
259     if (status.ok()) {
260       num_missing_heartbeats = 0;
261     } else {
262       ++num_missing_heartbeats;
263       bool is_transient_error =
264           (status.error_code() == ::grpc::StatusCode::DEADLINE_EXCEEDED ||
265            status.error_code() == ::grpc::StatusCode::UNAVAILABLE);
266       if (!stop_heartbeats_.HasBeenNotified() &&
267           (!is_transient_error ||
268            num_missing_heartbeats > options_.max_missing_heartbeats)) {
269         // If we are shutting down, missed heartbeats are benign: they may
270         // simply mean that the server has shut down already before it saw
271         // the heartbeat request.
272         absl::MutexLock lock(&mu_);
273         if (state_ != State::kShuttingDown) {
274           options_.missed_heartbeat_callback(FromGrpcStatus(status),
275                                              !is_transient_error);
276         }
277         return;
278       }
279     }
280   }
281 }
282 
283 }  // namespace xla
284