1syntax = "proto3";
2
3package xrt;
4
5import "tensorflow/compiler/tf2xla/host_compute_metadata.proto";
6import "tensorflow/compiler/xla/service/hlo.proto";
7import "tensorflow/compiler/xla/xla.proto";
8import "tensorflow/compiler/xla/xla_data.proto";
9
10message DeviceAssignment {
11  message ComputationDevice {
12    message DeviceMeshCoordinates {
13      // The mesh coordinates for the device. Usually (X, Y, Z, Core), in the
14      // order in which they are returned in the TopologyProto.
15      //  X    = value(0)
16      //  Y    = value(1)
17      //  Z    = value(2)
18      //  Core = value(3)
19      repeated int32 value = 1;
20    }
21    // As many replicas as there are in the replicated computation.
22    repeated DeviceMeshCoordinates replica_devices = 1;
23  }
24  // As many ComputationDevice as many there are computations (number
25  // of cores per replica).
26  repeated ComputationDevice computation_devices = 1;
27}
28
29// Options for an XLA compilation.
30message XLAComputationConfig {
31  // The number of replicas the computation will be run on. If this is
32  // default (0) it is interpreted as 1.
33  int32 num_replicas = 1;
34  // The number of "model-parallel" cores per replica. If this is
35  // default (0) it is interpreted as 1.
36  int32 num_cores_per_replica = 2;
37  // Optional metadata about host sends and recvs.
38  tensorflow.tf2xla.HostComputeMetadata host_compute_metadata = 3;
39
40  // The arg/result shapes for the whole computation.
41  xla.ProgramShapeProto program_shape = 4;
42  // The arg/result shapes for each core of a model-parallel
43  // computation. per_core_args_and_result_shapes is optional for a
44  // single-core computation.
45  repeated xla.ProgramShapeProto per_core_program_shape = 5;
46  // Describes how replicated computation instances should be assigned to
47  // devices. There are num_cores_per_replica computations, and each one will be
48  // sent and executed to the set of replica device numbers described in the
49  // DeviceAssignment proto.
50  DeviceAssignment device_assignment = 6;
51  // The debugging options to be passed to the XLA compilation process.
52  xla.DebugOptions debug_options = 7;
53
54  // Everything inside Experimental is subject to change and is not subject
55  // to API stability guarantees in
56  // https://www.tensorflow.org/guide/version_compat.
57  message Experimental {
58    message UpdateIndexPair {
59      int32 index = 1;
60      bool updated = 2;
61    }
62
63    // stateful_input_indices is only useful when using XRT-compiled
64    // programs together with standard TensorFlow TPU execution ops, so should
65    // be ignored by most clients.
66    //
67    // Optionally the client can pass information about which inputs
68    // to the computation are updates to "stateful" quantities. Each
69    // element of stateful_input_indices includes an index indicating
70    // which input argument it corresponds to, and a bool indicating
71    // whether the value is updated or not. If the XRT computation is
72    // going to be used with a TensorFlow TPU execution op then an
73    // input index must be present for each input that will correspond
74    // to a resource variable in the execution op, and may not be
75    // present for any other input.
76    repeated UpdateIndexPair stateful_input_indices = 1;
77  }
78
79  Experimental experimental = 8;
80}
81
82// Options and XLA computation for a compilation.
83message XLAComputation {
84  XLAComputationConfig config = 1;
85  xla.HloSnapshot hlo_snapshot = 2;
86}
87
88// Literal to allocate space for, and transfer to, device memory.
89message XLAAllocation {
90  reserved 1;
91  xla.LiteralProto value = 2;
92}
93
94// Node in a tree describing a tuple constructed from input handles. A
95// node is an internal node if tuples is non-empty, in which case
96// input_index and release_input_handle are ignored. Otherwise a node
97// is a leaf node. Each leaf XLATupleNode is the index of an input
98// which corresponds to a handle that will be grafted onto the output
99// tuple at that location. If release_input_handle is true that input
100// handle will be released and become invalid.  Inputs may be repeated
101// in which case leaves of the output tuple will alias. If an input is
102// repeated, release_input_handle must be false for every leaf where
103// that input appears.
104//
105// For example, if input 0 has shape {} and input 1 has shape {2,3}
106// then the XLATupleNode with structure {1,{0,1}} corresponds to a
107// tuple with shape {{2,3},{{},{2,3}}}.
108message XLATupleNode {
109  int32 input_index = 1;
110  bool release_input_handle = 2;
111  repeated XLATupleNode tuples = 3;
112}
113
114message CommonExecutionConfig {
115  // The replica index this execute is driving.
116  int32 replica_id = 1;
117  // Mapping local device ordinals to global replica IDs.
118  // local_replica_mapping[LOCAL_DEVICE_ORDINAL] = GLOBAL_REPLICA_ID
119  repeated int32 local_replica_mapping = 2;
120  // The execution run ID used to correlate different XRT execute operations
121  // happeining in parallel from different threads.
122  int64 run_id = 3;
123}
124
125// Options for an XLA execution.
126message XRTExecutionConfig {
127  // Local device to run on. This is present because the execute Op
128  // may be placed on a device such as CPU or TPU_SYSTEM that
129  // logically manages multiple cores.
130  int32 device_ordinal = 1;
131  // Which model-parallel computation to run from the compiled bundle.
132  int32 core_index_in_replica = 2;
133  // Optional key to disambiguate between executions. This is only
134  // needed if multiple host send/recvs may be outstanding
135  // concurrently with executions.
136  string execution_instance_key = 3;
137  // If non-zero, rng_seed to reset the core with.
138  uint32 rng_seed = 4;
139  // If true, release allocation handles on the inputs after running.
140  bool release_input_handles = 5;
141  // If true, release the handle to the computation after running.
142  bool release_compilation_handle = 6;
143  // If set to true, and the result shape is a tuple, then instead of returning
144  // a single tuple allocation the execution will return a vector of
145  // allocations, one for each of the first-level elements of the result tuple.
146  bool return_exploded_tuple = 7;
147  reserved 8;
148  // The common configuration for XRT execute operations.
149  CommonExecutionConfig common_config = 9;
150}
151
152message XRTChainedExecuteConfig {
153  // If non-zero, rng_seed to reset the core with.
154  uint32 rng_seed = 1;
155  // Which model-parallel computation to run from the compiled bundle.
156  int32 core_index_in_replica = 2;
157  // Optional key to disambiguate between executions. This is only needed if
158  // multiple host send/recvs may be outstanding concurrently with executions.
159  string execution_instance_key = 3;
160  reserved 4;
161  // The common configuration for XRT execute operations.
162  CommonExecutionConfig common_config = 5;
163}
164
165// A single chained execute operation. An operation can either be a device data
166// load, or an existing (as in, previously compiled and accessible via its int64
167// handle) XLA computation execution.
168message XRTChainedExecuteOp {
169  // Represents an input for this operation.
170  message Input {
171    // The index within the XRTChainedExecutePlan.ops post-order of the source
172    // operation for this input.
173    int64 op_index = 1;
174    // The output index of the value generated by the operation at op_index.
175    // Zero (default value) means no index ({}) while if an indexing is
176    // required, output_index needs to be set to index+1.
177    // Thanks proto3!
178    int64 output_index = 2;
179  }
180  // Represents an output of the XRTChainedExecute operation, which should
181  // originate by the output of this operation.
182  message Output {
183    // The index in the value generated by this operation, which should be
184    // forwarded as XRTChainedExecute output. If output_index is zero (default
185    // value) the whole output will be used as result. This means that if the
186    // output shape is a tuple, the result will be the full tuple. Otherwise the
187    // real sub-tuple index will be output_index - 1.
188    int64 output_index = 1;
189    // The index in the vector of the results returned by the XRTChainedExecute
190    // operation, where this output should be forwarded.
191    int64 result_index = 2;
192  }
193
194  oneof op_oneof {
195    // The handle to an existing XRT device data.
196    int64 data_handle = 1;
197    // The handle to an existing XRT compiled computation.
198    int64 computation_handle = 2;
199  }
200  // The outputs of this XRTChainedExecuteOp operation.
201  repeated Output outputs = 3;
202  // The inputs of this XRTChainedExecuteOp operation. If data_handle is set,
203  // there are no inputs.
204  repeated Input inputs = 4;
205}
206
207// Execution plan for the XRTChainedExecute operation.
208message XRTChainedExecutePlan {
209  // The post order with the XRT computations to be executed.
210  repeated XRTChainedExecuteOp ops = 1;
211}
212
213// The message used to encode the options for the XRTMetricsCollect operation.
214message XRTMetricsCollect {
215  // A list of regular expressions to match the metric names. Empty means to
216  // return all the metrics reported by the collection registry.
217  repeated string metrics_regex = 1;
218}
219
220message Percentiles {
221  message Point {
222    // In the [0, 100] range.
223    double percentile = 1;
224    double value = 2;
225  }
226
227  // The time (in nanoseconds) of the first sample within the samples buffer.
228  uint64 start_nstime = 1;
229  // The time (in nanoseconds) of the last sample within the samples buffer.
230  uint64 end_nstime = 2;
231  // The minimum value of the samples within the samples buffer.
232  double min_value = 3;
233  // The maximum value of the samples within the samples buffer.
234  double max_value = 4;
235  // The mean value of the samples within the samples buffer.
236  double mean = 5;
237  // The stndard deviation of the samples within the samples buffer.
238  double stddev = 6;
239  // The number samples within the samples buffer.
240  uint64 num_samples = 7;
241  // The total number of times this metrics has been posted a value to.
242  uint64 total_samples = 8;
243  // The sum of all the posted values.
244  double accumulator = 9;
245  // The percentile points reported by the metric.
246  repeated Point points = 10;
247}
248
249message MetricValues {
250  enum UnitOfMeasure {
251    INVALID = 0;
252    NUMBER = 1;
253    TIME = 2;
254    BYTES = 3;
255  }
256
257  // The metric name.
258  string name = 1;
259
260  oneof values_oneof {
261    Percentiles percentiles_value = 2;
262    int64 int64_value = 3;
263  }
264
265  UnitOfMeasure unit_of_measure = 4;
266}
267
268message MetricsReport {
269  repeated MetricValues metrics = 1;
270}
271
272message MemoryInfo {
273  // The total memory on a device, in KB.
274  int64 kb_total = 1;
275  // The free memory on a device, in KB.
276  int64 kb_free = 2;
277}
278