// Copyright 2020 The TensorFlow Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================== // // Distributed XLA service protocol. // // This is a minimal distributed protocol intended for a small set of purposes // * barriers to wait for all clients to start up or shut down // * health checking to detect when clients vanish // * for sharing GPU topology and NCCL communicator state between distributed // hosts. // // The intention is that a service is started during cluster initialization and // persists for the lifetime of the cluster. syntax = "proto3"; package xla; // Describes a device local to a host. message DeviceProto { int32 local_device_ordinal = 1; string name = 2; string vendor = 3; // The following fields are present in the GlobalTopologyProto message // returned by Connect() but not in the LocalTopologyProto messages passed to // Connect(). In other words, the coordinator node determines the global // device IDs during Connect(). int32 global_device_id = 4; // Globally unique ID number. } message LocalTopologyProto { int32 node_id = 1; repeated DeviceProto devices = 2; } message GlobalTopologyProto { repeated LocalTopologyProto nodes = 1; } message ConnectRequest { int32 protocol_version = 1; int32 timeout_milliseconds = 2; // We assume that each node knows its globally-unique node ID, provided by // whatever mechanism launches the tasks. Node IDs should form a dense range // of integers [0, num_nodes). int32 node_id = 3; // A unique ID number for the client. uint64 client_id = 4; } message ConnectResponse { uint64 session_id = 1; } message EnumerateDevicesRequest { uint64 session_id = 1; LocalTopologyProto local_topology = 3; } message EnumerateDevicesResponse { GlobalTopologyProto global_topology = 1; } message KeyValueGetRequest { uint64 session_id = 1; bytes key = 2; int32 timeout_milliseconds = 3; } message KeyValueGetResponse { bool found = 1; bytes value = 2; } message KeyValueSetRequest { uint64 session_id = 1; bytes key = 2; bytes value = 3; } message KeyValueSetResponse {} message HeartbeatRequest { uint64 session_id = 1; int32 node_id = 2; } message HeartbeatResponse {} message ShutdownRequest { uint64 session_id = 1; int32 node_id = 2; } message ShutdownResponse {} service DistributedRuntimeService { // Connects a node to the distributed coordinator node. Blocks until all tasks // have connected. The service receives the number of nodes to expect as an // option passed to its constructor. rpc Connect(ConnectRequest) returns (ConnectResponse) {} // Blocking enumeration of devices, used by the GPU backend only. // In parallel, all clients call EnumerateDevices() with their local device // topology, and receive back a global topology in response. rpc EnumerateDevices(EnumerateDevicesRequest) returns (EnumerateDevicesResponse) {} // Health-checking RPC. Workers send heartbeats to the coordinator at regular // intervals. If the worker does not hear from the coordinator or the // coordinator does not hear from the tasks, the tasks abort. rpc Heartbeat(HeartbeatRequest) returns (HeartbeatResponse) {} // Shutdown RPC. Workers send this RPC when they are ready to shut down; the // RPC blocks until all tasks have indicated they are ready to shut down, // or a timeout is reached. rpc Shutdown(ShutdownRequest) returns (ShutdownResponse) {} // Simple key-value store used for sharing configuration data. // For example, when using NCCL to communicate between multiple GPUs, // the NCCL communicator IDs are stored here. // Looks up a key in the key-value service. Blocks until the key is present // or until `timeout` expires. rpc KeyValueGet(KeyValueGetRequest) returns (KeyValueGetResponse) {} // Updates the value associated with a key. rpc KeyValueSet(KeyValueSetRequest) returns (KeyValueSetResponse) {} }