1 /* Copyright 2020 Google LLC
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_CLIENT_H_
17 #define TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_CLIENT_H_
18 
19 #include <memory>
20 
21 #include "grpcpp/channel.h"
22 #include "absl/synchronization/mutex.h"
23 #include "absl/synchronization/notification.h"
24 #include "absl/time/time.h"
25 #include "tensorflow/compiler/xla/pjrt/distributed/protocol.grpc.pb.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/compiler/xla/types.h"
28 #include "tensorflow/core/platform/env.h"
29 
30 namespace xla {
31 
32 class DistributedRuntimeClient {
33  public:
34   struct Options {
35     // This node's global ID. Required.
36     int32 node_id = -1;
37 
38     // Environment used for starting threads.
39     tensorflow::Env* env = tensorflow::Env::Default();
40 
41     // RPC timeout used for RPC that don't have their own timeouts.
42     absl::Duration rpc_timeout = absl::Seconds(120);
43 
44     // Time period for which Connect() should be retried. The client will keep
45     // trying to open the initial connection for this period, even if any
46     // individual Connect() RPC fails. May be zero, in which case Connect() will
47     // only be attempted once.
48     absl::Duration init_timeout = absl::ZeroDuration();
49 
50     // How long to wait for all nodes to call Shutdown(). If the timeout
51     // expires, then shutdown() reports an error and returns control.
52     absl::Duration shutdown_timeout = absl::Seconds(60);
53 
54     // Interval at which the client should send heartbeat RPCs to the
55     // coordinator.
56     absl::Duration heartbeat_interval = absl::Seconds(10);
57 
58     // How many failed heartbeat RPCs may fail due to a possibly-ephemeral
59     // reason before we decide the coordinator has vanished and that we should
60     // shut down.
61     int max_missing_heartbeats = 10;
62 
63     // Callback invoked by the client when notification of a missing heartbeat
64     // is reported by the coordinator, or we have not heard from the coordinator
65     // recently. `coordinator_reported_failure` is true in the former case.
66     // Exposed so tests can override this behavior to something non-fatal.
67     std::function<void(xla::Status, bool coordinator_reported_failure)>
68         missed_heartbeat_callback =
69             [](xla::Status status, bool coordinator_reported_failure) {
70               if (coordinator_reported_failure) {
71                 LOG(QFATAL)
72                     << "Terminating process because the coordinator detected "
73                        "missing heartbeats. This most likely indicates that "
74                        "another task died; see the other task logs for more "
75                        "details. Status: "
76                     << status;
77               } else {
78                 LOG(QFATAL)
79                     << "Terminating process because of missing heartbeat "
80                        "response from the coordinator. This most likely "
81                        "indicates that the coordinator task died; see the "
82                        "coordinator's task logs for more details. Status: "
83                     << status;
84               }
85             };
86 
87     // For testing. Should the client explicitly Shutdown() on destruction?
88     bool shutdown_on_destruction = true;
89   };
90   DistributedRuntimeClient(std::shared_ptr<::grpc::Channel> channel,
91                            const Options& options);
DistributedRuntimeClient(std::shared_ptr<::grpc::Channel> channel)92   explicit DistributedRuntimeClient(std::shared_ptr<::grpc::Channel> channel)
93       : DistributedRuntimeClient(channel, Options()) {}
94   ~DistributedRuntimeClient();
95 
96   // Connects to the master, and blocks until all clients have successfully
97   // connected.
98   // Not thread-safe, i.e., calls to Connect()/Shutdown()/EnumerateDevices()
99   // must be serialized by some other means.
100   xla::Status Connect();
101 
102   // Reports to the master that the client is ready to shutdown, and blocks
103   // until all clients are ready to shutdown or the shutdown timeout expires.
104   // Not thread-safe.
105   xla::Status Shutdown();
106 
107   // Blocking enumeration of global devices. Used by the GPU platform.
108   // Not thread-safe.
109   xla::Status EnumerateDevices(const LocalTopologyProto& local_topology,
110                                GlobalTopologyProto* global_topology);
111 
112   // The following APIs are thread-safe.
113   xla::StatusOr<std::string> BlockingKeyValueGet(std::string key,
114                                                  absl::Duration timeout);
115 
116   xla::Status KeyValueSet(std::string key, std::string value);
117 
118  private:
119   // Entry point for the heartbeat thread.
120   void HeartbeatLoop();
121 
122   const std::unique_ptr<grpc::DistributedRuntimeService::Stub> stub_;
123   const Options options_;
124 
125   // Possible states of the client.
126   // The only legal transitions are downwards in the order below. i.e., there is
127   // no way to reopen a closed client.
128   enum class State {
129     // The client has not yet connected to the server, i.e., had a Connect()
130     // RPC succeed.
131     kNotConnected,
132 
133     // The client is connected to the server and as far as we are aware the
134     // connection is healthy.
135     kConnected,
136 
137     // The client is in the process of shutting down, i.e., Shutdown() has been
138     // called.
139     kShuttingDown,
140 
141     // The client has shut down its server connection, either due to an error
142     // or due to an explicit shutdown.
143     kClosed,
144   };
145 
146   static absl::string_view StateToString(State state);
147 
148   // state_ is protected by a mutex because the heartbeat thread needs to look
149   // at it.
150   absl::Mutex mu_;
151   State state_ ABSL_GUARDED_BY(mu_) = State::kNotConnected;
152 
153   // A unique session ID, assigned by the server during Connect().
154   uint64 session_id_;
155 
156   // Notification that tells the heartbeat thread to stop running.
157   absl::Notification stop_heartbeats_;
158 
159   // Thread responsible for performing heartbeats.
160   std::unique_ptr<tensorflow::Thread> heartbeat_thread_;
161 };
162 
163 }  // namespace xla
164 
165 #endif  // TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_CLIENT_H_
166