1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// This proto file defines messages which represent the HLO module. This is a
17// full fidelity serialization of the c++ HLO constructs.
18//
19// Many of the protos below are simple 1-to-1 serializations of the
20// corresponding C++ classes, e.g., HloModule, HloComputation, and
21// HloInstruction.
22//
23// FIELD NAMES ARE IMPORTANT
24//
25// Unlike most protos, you can't safely change the names of fields, even if you
26// keep the numeric ids the same. This is because we sometimes serialize these
27// protos as JSON, which includes the field names in the serialization.
28
29syntax = "proto3";
30
31package xla;
32
33import "tensorflow/compiler/xla/xla_data.proto";
34
35option cc_enable_arenas = true;
36
37// Serialization of HloInstruction.
38// Next ID: 64
39message HloInstructionProto {
40  reserved 10;
41  reserved "parameter_name";
42  reserved 12;
43  reserved "fused_instructions_computation";
44  reserved 4;
45  reserved "operand_names";
46  reserved 5;
47  reserved "control_predecessor_names";
48  reserved 6;
49  reserved "called_computation_names";
50  reserved 44;
51  reserved "replica_group_ids";
52
53  string name = 1;
54  string opcode = 2;
55  xla.ShapeProto shape = 3;
56
57  xla.OpMetadata metadata = 7;
58
59  // Literal, only present for kConstant.
60  xla.LiteralProto literal = 8;
61
62  // Parameter number is only present for kParameter.
63  int64 parameter_number = 9;
64
65  // Fusion state, only present for kFusion.
66  string fusion_kind = 11;
67
68  // Index for kGetTupleElement.
69  int64 tuple_index = 13;
70
71  // Dimensions present for some operations that require reshaping or
72  // broadcasting, including Reshape, Reduce, ReduceWindow, and Reverse.
73  repeated int64 dimensions = 14;
74
75  // Describes the window in a windowed operation such as convolution.
76  xla.Window window = 15;
77
78  // Describes the dimension numbers used for a convolution.
79  xla.ConvolutionDimensionNumbers convolution_dimension_numbers = 16;
80
81  // The number of feature groups. Used for a convolution. Must be a divisor of
82  // the input feature dimension and output feature dimension. If not specified,
83  // it will use a default value of 1.
84  int64 feature_group_count = 50;
85
86  int64 batch_group_count = 58;
87
88  // Describes the [begin, end) index range and stride for slices.
89  message SliceDimensions {
90    int64 start = 1;
91    int64 limit = 2;
92    int64 stride = 3;
93  }
94  repeated SliceDimensions slice_dimensions = 17;
95
96  // The bit sizes for a reduce-precision operation.
97  int32 exponent_bits = 18;
98  int32 mantissa_bits = 19;
99
100  // Describes the [start, start + size) range size for a dynamic slice
101  // ('start' is specified dynamically in the second operand of the operation).
102  repeated int64 dynamic_slice_sizes = 20;
103
104  // The padding configuration that describes the edge padding and interior
105  // padding of this pad instruction. Only set for pad instructions.
106  xla.PaddingConfig padding_config = 21;
107
108  // Outfeed configuration information, only present for kOutfeed.
109  bytes outfeed_config = 22;
110
111  // The distribution requested for random number generation.
112  // Only present for kRng.
113  xla.RandomDistribution distribution = 23;
114
115  // A small float number added to the variance to avoid divide-by-zero error.
116  // Only present for kBatchNormTraining.
117  float epsilon = 24;
118
119  // An integer value representing the index of the feature dimension.
120  // Only present for kBatchNormTraining.
121  int64 feature_index = 25;
122
123  // Represents a unique identifier for each Send/Recv instruction pair.
124  // Only present for kSend or kRecv.
125  int64 channel_id = 26;
126
127  // The string representation of the infeed configuration.
128  bytes infeed_config = 27;
129
130  // Name of a external target (eg, global symbol) to call, only present for
131  // kCustomCall.
132  string custom_call_target = 28;
133
134  // Opaque string, only present for kCustomCall.
135  string custom_call_opaque = 53;
136
137  // Shape of outfeed request.
138  xla.ShapeProto outfeed_shape = 29;
139
140  // Describes the dimension numbers used for a dot operation
141  xla.DotDimensionNumbers dot_dimension_numbers = 30;
142
143  // FFT type (FFT, IFFT, etc).
144  xla.FftType fft_type = 31;
145
146  // FFT length.
147  repeated int64 fft_length = 32;
148
149  // Comparison direction only used for kCompare.
150  string comparison_direction = 63;
151
152  // Gather dimension numbers.
153  xla.GatherDimensionNumbers gather_dimension_numbers = 33;
154  repeated int64 gather_slice_sizes = 34;
155
156  // Compute Host.
157  string channel_name = 41;
158  int64 cost_estimate_ns = 42;
159
160  // The id of this instruction.
161  int64 id = 35;
162
163  repeated int64 operand_ids = 36;
164  repeated int64 control_predecessor_ids = 37;
165  repeated int64 called_computation_ids = 38;
166
167  xla.OpSharding sharding = 40;
168
169  // Backend configuration for the instruction. Has backend-specific meaning.
170  string backend_config = 43;
171
172  // Cross replica op fields.
173  repeated ReplicaGroup replica_groups = 49;
174  int64 all_reduce_id = 45;
175  string all_reduce_barrier = 46;
176
177  // Whether this Send/Recv instruction transfers data to/from the host. Only
178  // present for Send and Recv instructions and their SendDone and RecvDone
179  // partners.
180  bool is_host_transfer = 47;
181
182  // Whether this Sort instruction should be stable.
183  bool is_stable = 60;
184
185  xla.ScatterDimensionNumbers scatter_dimension_numbers = 48;
186
187  // Precision configuration for the instruction. Has backend-specific meaning.
188  xla.PrecisionConfig precision_config = 51;
189
190  // Collective permute field.
191  repeated SourceTarget source_target_pairs = 52;
192
193  // Sharding for kDomain instructions.
194  xla.OpSharding domain_entry_sharding = 54;
195  xla.OpSharding domain_exit_sharding = 55;
196
197  // For custom call this indicates that the layouts are constrained. If
198  // constrain_layout is true then the 'shape' field must contain a layout, and
199  // 'operand_shapes_with_layout' must contain a shape with layout for each
200  // operand.
201  bool constrain_layout = 56;
202  repeated xla.ShapeProto operand_shapes_with_layout = 57;
203
204  // Options for TriangularSolve
205  xla.TriangularSolveOptions triangular_solve_options = 59;
206
207  // Options for Cholesky
208  xla.CholeskyOptions cholesky_options = 62;
209
210  // Describes how parameters behave with regards to replicas.
211  xla.ParameterReplication parameter_replication = 61;
212}
213
214// Serialization of HloComputation.
215message HloComputationProto {
216  reserved 3;
217  reserved "root_name";
218
219  string name = 1;
220
221  // The array of instructions is always in a valid dependency order, where
222  // operands appear before their users.
223  repeated HloInstructionProto instructions = 2;
224
225  // The program shape (with layout) of this computation.
226
227  xla.ProgramShapeProto program_shape = 4;
228
229  // The id of this computation.
230  int64 id = 5;
231
232  // The id of the root of the computation.
233  int64 root_id = 6;
234}
235
236// Serialization of an HLO schedule. An HLO schedule contains a total order of
237// instructions for each non-fusion computation in the module.
238message HloScheduleProto {
239  message InstructionSequence {
240    repeated int64 instruction_ids = 1;
241  }
242
243  // Map from computation id to sequence.
244  map<int64, InstructionSequence> sequences = 1;
245}
246
247message HloInputOutputAliasProto {
248  enum Kind {
249    // Define a UNDEFINED_ALIAS equal to zero to get around the default-0 proto3
250    // behavior and missing has_*() APIs.
251    UNDEFINED_ALIAS = 0;
252    // An alias setup by the user as must alias. A use setting USER_ALIAS is
253    // expecting the designed output to be dropped over the given input
254    // parameter number+index.
255    USER_ALIAS = 1;
256    // An alias setup by the compiler as part of its optimizations.
257    SYSTEM_ALIAS = 2;
258  }
259
260  // The following proto describes a pair of aliased an input
261  // (described by parameter number and a ShapeIndex of the parameter)
262  // and an output (described by a ShapeIndex of the root
263  // instruction). For example:
264  //
265  // entry = {
266  //  output_shape_index={1},
267  //  parameter_number=0,
268  //  parameter_shape_index={1, 2},
269  // }
270  //
271  // This entry indicates that the first paremter's {1, 2} element is
272  // aliased with the {1} element of the root instruction.
273  message AliasEntryProto {
274    // ShapeIndex of the root hlo.
275    repeated int64 output_shape_index = 1;
276    // Number of the parameter in entry computation.
277    int64 parameter_number = 2;
278    // ShapeIndex of the parameter instruction.
279    repeated int64 parameter_shape_index = 3;
280    // The kind of alias to be setup.
281    Kind kind = 4;
282  }
283
284  repeated AliasEntryProto entries = 1;
285}
286
287message DynamicParameterBindingProto {
288  // A list of bindings which indicates that the `target_dim_num` in
289  // the subshape `target_param_index` of parameter `target_param_num`
290  // is a dynamic dimension and its real dynamic size is represented
291  // by `dynamic_param_index` in parameter `dynamic_param_num`.
292  //
293  // As an example, imagine we have a program:
294  //
295  // ENTRY main {
296  //   a = f32[] parameter(0)
297  //   b = f32[10] parameter(1)
298  //   ROOT root = (f32[], f32[10]) tuple(%a, %b)
299  // }
300  //
301  // Let's say 'b' (param index 1) is a dynamic shape whose input has
302  // an upperbound of 10 and real size is determined at runtime.'a'
303  // represents the real size of b's first dimension.
304  //
305  // In this case, the fields are set in the following way:
306  // dynamic_param_num = 1
307  // dynamic_param_index = {}
308  // target_param_num = 0
309  // target_param_index = {}
310  // target_param_dim = 0
311  message Binding {
312    int64 dynamic_param_num = 1;
313    repeated int64 dynamic_param_index = 2;
314    int64 target_param_num = 3;
315    repeated int64 target_param_index = 4;
316    int64 target_param_dim_num = 5;
317  }
318
319  repeated Binding entries = 1;
320}
321
322// Serialization of HloModule.
323message HloModuleProto {
324  string name = 1;
325  string entry_computation_name = 2;
326  int64 entry_computation_id = 6;
327
328  // The array of computations is always in a valid dependency order, where
329  // callees appear before their callers.
330  repeated HloComputationProto computations = 3;
331
332  // The host program shape (with layout) of the entry computation.
333  xla.ProgramShapeProto host_program_shape = 4;
334
335  // The id of this module.
336  int64 id = 5;
337
338  // The schedule for this module.
339  HloScheduleProto schedule = 7;
340
341  // Describes alias information between inputs and outputs.
342  HloInputOutputAliasProto input_output_alias = 8;
343
344  DynamicParameterBindingProto dynamic_parameter_binding = 9;
345}
346
347// Serialization of LogicalBuffer.
348message LogicalBufferProto {
349  // Location represents an instruction and its shape index, which uniquely
350  // identifies a point where a buffer is needed.
351  message Location {
352    // NOTE: module_name isn't necessary, since all LogicalBuffers are
353    // associated with a single HloModule.
354    string computation_name = 1;
355    string instruction_name = 2;
356    repeated int64 shape_index = 3;
357  }
358
359  int64 id = 1;
360  int64 size = 2;
361
362  // The location where the buffer is defined.
363  Location defined_at = 3;
364
365  int64 color = 4;
366}
367
368// Serialization of BufferAllocation.
369message BufferAllocationProto {
370  // Assigned represents a single LogicalBuffer that is assigned to this
371  // BufferAllocation.
372  message Assigned {
373    int64 logical_buffer_id = 1;
374    int64 offset = 2;
375    int64 size = 3;
376  }
377
378  int64 index = 1;
379  int64 size = 2;
380  bool is_thread_local = 3;
381  bool is_tuple = 11;
382  bool is_entry_computation_parameter = 5;
383  bool is_constant = 12;
384  int64 parameter_number = 6;
385  repeated int64 parameter_shape_index = 10;
386  bool maybe_live_out = 7;
387  int64 color = 8;
388  repeated Assigned assigned = 9;
389}
390
391// A trace of a HeapSimulator run.
392message HeapSimulatorTrace {
393  // The trace includes a list of events, where each event describes one action
394  // performed by the heap simulator.
395  message Event {
396    enum Kind {
397      ALLOC = 0;  // A memory region was allocated for the buffer.
398      FREE = 1;   // A memory region was freed for the buffer.
399
400      // A buffer was shared with another (canonical) buffer. This is similar to
401      // ALLOC, except that instead of allocating a new region of memory, the
402      // memory region of the canonical buffer is directly re-used. Multiple
403      // buffers may share with the same canonical buffer. The lifetime of the
404      // canonical buffer is extended to the union of all lifetimes.
405      SHARE_WITH = 2;
406    }
407    Kind kind = 1;
408
409    // The id of the LogicalBuffer that the event applies to.
410    int64 buffer_id = 2;
411
412    // The HloInstruction that the simulation was processing that caused this
413    // event to occur, identified by its computation and instruction name. E.g.
414    // buffers defined by instruction A are allocated when processing A.
415    string computation_name = 3;
416    string instruction_name = 4;
417
418    // The id of the canonical LogicalBuffer that the buffer shares with. Only
419    // set for SHARE_WITH events.
420    int64 share_with_canonical_id = 5;
421  }
422  repeated Event events = 1;
423  bool whole_module_simulation = 2;
424}
425
426// An abstraction representing a set of HLO module built to run concurrently
427// across different devices.
428message HloModuleGroupProto {
429  string name = 1;
430  repeated HloModuleProto hlo_modules = 2;
431}
432
433// Serialization of BufferAssignment.
434message BufferAssignmentProto {
435  // Alias represents a source LogicalBuffer, and the buffer location that
436  // aliases it.
437  message BufferAlias {
438    int64 source_buffer_id = 1;
439    LogicalBufferProto.Location location = 2;
440  }
441
442  repeated LogicalBufferProto logical_buffers = 1;
443  repeated BufferAlias buffer_aliases = 2;
444  repeated BufferAllocationProto buffer_allocations = 3;
445  repeated HeapSimulatorTrace heap_simulator_traces = 4;
446}
447
448// Grouping message that contains all of the information above.
449message HloProto {
450  reserved 2;
451  reserved "hlo_ordering";
452
453  HloModuleProto hlo_module = 1;
454  BufferAssignmentProto buffer_assignment = 3;
455}
456
457// Encapsulates HloProto together with the arguments, result, and
458// execution_platform. This message is used for purposes such as
459// analysis/replay/file-storage.
460message HloSnapshot {
461  // The hlo graph.
462  HloProto hlo = 1;
463
464  // The arguments passed to the graph.
465  repeated LiteralProto arguments = 2;
466
467  // The result of the graph.
468  LiteralProto result = 3;
469
470  // The name of the platform used to run the graph.
471  string execution_platform = 4;
472}
473