1// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14// ==============================================================================
15//
16// Distributed XLA service protocol.
17//
18// This is a minimal distributed protocol intended for a small set of purposes
19// * barriers to wait for all clients to start up or shut down
20// * health checking to detect when clients vanish
21// * for sharing GPU topology and NCCL communicator state between distributed
22//   hosts.
23//
24// The intention is that a service is started during cluster initialization and
25// persists for the lifetime of the cluster.
26
27syntax = "proto3";
28
29package xla;
30
31// Describes a device local to a host.
32message DeviceProto {
33  int32 local_device_ordinal = 1;
34  string name = 2;
35  string vendor = 3;
36
37  // The following fields are present in the GlobalTopologyProto message
38  // returned by Connect() but not in the LocalTopologyProto messages passed to
39  // Connect(). In other words, the coordinator node determines the global
40  // device IDs during Connect().
41  int32 global_device_id = 4;  // Globally unique ID number.
42}
43
44message LocalTopologyProto {
45  int32 node_id = 1;
46  repeated DeviceProto devices = 2;
47}
48
49message GlobalTopologyProto {
50  repeated LocalTopologyProto nodes = 1;
51}
52
53message ConnectRequest {
54  int32 protocol_version = 1;
55  int32 timeout_milliseconds = 2;
56
57  // We assume that each node knows its globally-unique node ID, provided by
58  // whatever mechanism launches the tasks. Node IDs should form a dense range
59  // of integers [0, num_nodes).
60  int32 node_id = 3;
61
62  // A unique ID number for the client.
63  uint64 client_id = 4;
64}
65
66message ConnectResponse {
67  uint64 session_id = 1;
68}
69
70message EnumerateDevicesRequest {
71  uint64 session_id = 1;
72  LocalTopologyProto local_topology = 3;
73}
74
75message EnumerateDevicesResponse {
76  GlobalTopologyProto global_topology = 1;
77}
78
79message KeyValueGetRequest {
80  uint64 session_id = 1;
81  bytes key = 2;
82  int32 timeout_milliseconds = 3;
83}
84
85message KeyValueGetResponse {
86  bool found = 1;
87  bytes value = 2;
88}
89
90message KeyValueSetRequest {
91  uint64 session_id = 1;
92  bytes key = 2;
93  bytes value = 3;
94}
95
96message KeyValueSetResponse {}
97
98message HeartbeatRequest {
99  uint64 session_id = 1;
100  int32 node_id = 2;
101}
102message HeartbeatResponse {}
103
104message ShutdownRequest {
105  uint64 session_id = 1;
106  int32 node_id = 2;
107}
108message ShutdownResponse {}
109
110service DistributedRuntimeService {
111  // Connects a node to the distributed coordinator node. Blocks until all tasks
112  // have connected. The service receives the number of nodes to expect as an
113  // option passed to its constructor.
114  rpc Connect(ConnectRequest) returns (ConnectResponse) {}
115
116  // Blocking enumeration of devices, used by the GPU backend only.
117  // In parallel, all clients call EnumerateDevices() with their local device
118  // topology, and receive back a global topology in response.
119  rpc EnumerateDevices(EnumerateDevicesRequest)
120      returns (EnumerateDevicesResponse) {}
121
122  // Health-checking RPC. Workers send heartbeats to the coordinator at regular
123  // intervals. If the worker does not hear from the coordinator or the
124  // coordinator does not hear from the tasks, the tasks abort.
125  rpc Heartbeat(HeartbeatRequest) returns (HeartbeatResponse) {}
126
127  // Shutdown RPC. Workers send this RPC when they are ready to shut down; the
128  // RPC blocks until all tasks have indicated they are ready to shut down,
129  // or a timeout is reached.
130  rpc Shutdown(ShutdownRequest) returns (ShutdownResponse) {}
131
132  // Simple key-value store used for sharing configuration data.
133  // For example, when using NCCL to communicate between multiple GPUs,
134  // the NCCL communicator IDs are stored here.
135
136  // Looks up a key in the key-value service. Blocks until the key is present
137  // or until `timeout` expires.
138  rpc KeyValueGet(KeyValueGetRequest) returns (KeyValueGetResponse) {}
139
140  // Updates the value associated with a key.
141  rpc KeyValueSet(KeyValueSetRequest) returns (KeyValueSetResponse) {}
142}
143