1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// This proto file defines messages which represent the HLO module. This is a
17// full fidelity serialization of the c++ HLO constructs.
18//
19// Many of the protos below are simple 1-to-1 serializations of the
20// corresponding C++ classes, e.g., HloModule, HloComputation, and
21// HloInstruction.
22//
23// FIELD NAMES ARE IMPORTANT
24//
25// Unlike most protos, you can't safely change the names of fields, even if you
26// keep the numeric ids the same. This is because we sometimes serialize these
27// protos as JSON, which includes the field names in the serialization.
28
29syntax = "proto3";
30
31package xla;
32
33import "tensorflow/compiler/xla/xla_data.proto";
34
35option cc_enable_arenas = true;
36
37// Serialization of HloInstruction.
38// Next ID: 76
39message HloInstructionProto {
40  reserved 10;
41  reserved "parameter_name";
42  reserved 12;
43  reserved "fused_instructions_computation";
44  reserved 4;
45  reserved "operand_names";
46  reserved 5;
47  reserved "control_predecessor_names";
48  reserved 6;
49  reserved "called_computation_names";
50  reserved 44;
51  reserved "replica_group_ids";
52  // Use backend_config instead for custom_call_opaque.
53  reserved 53;
54  reserved "custom_call_opaque";
55  // Use backend_config instead for all_reduce_barrier.
56  reserved 46;
57  reserved "all_reduce_barrier";
58
59  string name = 1;
60  string opcode = 2;
61  xla.ShapeProto shape = 3;
62
63  xla.OpMetadata metadata = 7;
64
65  // Literal, only present for kConstant.
66  xla.LiteralProto literal = 8;
67
68  // Parameter number is only present for kParameter.
69  int64 parameter_number = 9;
70
71  // Fusion state, only present for kFusion.
72  string fusion_kind = 11;
73
74  // Index for kGetTupleElement.
75  int64 tuple_index = 13;
76
77  // Dimensions present for some operations that require reshaping or
78  // broadcasting, including Reshape, Reduce, ReduceWindow, and Reverse.
79  repeated int64 dimensions = 14;
80
81  // Describes the window in a windowed operation such as convolution.
82  xla.Window window = 15;
83
84  // Describes the dimension numbers used for a convolution.
85  xla.ConvolutionDimensionNumbers convolution_dimension_numbers = 16;
86
87  // The number of feature groups. Used for a convolution. Must be a divisor of
88  // the input feature dimension and output feature dimension. If not specified,
89  // it will use a default value of 1.
90  int64 feature_group_count = 50;
91
92  int64 batch_group_count = 58;
93
94  // Describes the [begin, end) index range and stride for slices.
95  message SliceDimensions {
96    int64 start = 1;
97    int64 limit = 2;
98    int64 stride = 3;
99  }
100  repeated SliceDimensions slice_dimensions = 17;
101
102  // The bit sizes for a reduce-precision operation.
103  int32 exponent_bits = 18;
104  int32 mantissa_bits = 19;
105
106  // Describes the [start, start + size) range size for a dynamic slice
107  // ('start' is specified dynamically in the second operand of the operation).
108  repeated int64 dynamic_slice_sizes = 20;
109
110  // The padding configuration that describes the edge padding and interior
111  // padding of this pad instruction. Only set for pad instructions.
112  xla.PaddingConfig padding_config = 21;
113
114  // Outfeed configuration information, only present for kOutfeed.
115  bytes outfeed_config = 22;
116
117  // The distribution requested for random number generation.
118  // Only present for kRng.
119  xla.RandomDistribution distribution = 23;
120
121  // A small float number added to the variance to avoid divide-by-zero error.
122  // Only present for kBatchNormTraining.
123  float epsilon = 24;
124
125  // An integer value representing the index of the feature dimension.
126  // Only present for kBatchNormTraining.
127  int64 feature_index = 25;
128
129  // Represents a unique identifier for each Send/Recv instruction pair or
130  // optionally for collective instructions (AllReduce, CollectivePermute,
131  // AllToAll). Non-positive channel_id is equivalent to no channel id.
132  int64 channel_id = 26;
133
134  // The string representation of the infeed configuration.
135  bytes infeed_config = 27;
136
137  // Name of a external target (eg, global symbol) to call, only present for
138  // kCustomCall.
139  string custom_call_target = 28;
140
141  // Shape of outfeed request.
142  xla.ShapeProto outfeed_shape = 29;
143
144  // Describes the dimension numbers used for a dot operation
145  xla.DotDimensionNumbers dot_dimension_numbers = 30;
146
147  // FFT type (FFT, IFFT, etc).
148  xla.FftType fft_type = 31;
149
150  // FFT length.
151  repeated int64 fft_length = 32;
152
153  // Comparison direction only used for kCompare.
154  string comparison_direction = 63;
155
156  // Gather dimension numbers.
157  xla.GatherDimensionNumbers gather_dimension_numbers = 33;
158  repeated int64 gather_slice_sizes = 34;
159
160  // Compute Host.
161  string channel_name = 41;
162  int64 cost_estimate_ns = 42;
163
164  // The id of this instruction.
165  int64 id = 35;
166
167  repeated int64 operand_ids = 36;
168  repeated int64 control_predecessor_ids = 37;
169  repeated int64 called_computation_ids = 38;
170
171  xla.OpSharding sharding = 40;
172
173  // Backend configuration for the instruction. Has backend-specific meaning.
174  bytes backend_config = 43;
175
176  // Cross replica op fields.
177  repeated ReplicaGroup replica_groups = 49;
178  // Deprecated, but keeping it for backward compatibility. Use channel_id.
179  // Non-positive all_reduce_id is equivalent to no all_reduce_id.
180  int64 all_reduce_id = 45 [deprecated = true];
181
182  // If true, interprets ids in ReplicaGroup as global device ids, which is
183  // a linearized id of `replica_id * partition_count + partition_id`.
184  bool use_global_device_ids = 71;
185
186  // Whether this Send/Recv instruction transfers data to/from the host. Only
187  // present for Send and Recv instructions and their SendDone and RecvDone
188  // partners.
189  bool is_host_transfer = 47;
190
191  // Whether this Sort instruction should be stable.
192  bool is_stable = 60;
193
194  xla.ScatterDimensionNumbers scatter_dimension_numbers = 48;
195
196  // Precision configuration for the instruction. Has backend-specific meaning.
197  xla.PrecisionConfig precision_config = 51;
198
199  // Collective permute field.
200  repeated SourceTarget source_target_pairs = 52;
201
202  // Sharding for kDomain instructions.
203  xla.OpSharding domain_entry_sharding = 54;
204  xla.OpSharding domain_exit_sharding = 55;
205
206  // For custom call this indicates that the layouts are constrained. If
207  // constrain_layout is true then the 'shape' field must contain a layout, and
208  // 'operand_shapes_with_layout' must contain a shape with layout for each
209  // operand.
210  bool constrain_layout = 56;
211  repeated xla.ShapeProto operand_shapes_with_layout = 57;
212
213  // Options for TriangularSolve
214  xla.TriangularSolveOptions triangular_solve_options = 59;
215
216  // Options for Cholesky
217  xla.CholeskyOptions cholesky_options = 62;
218
219  // Describes how parameters behave with regards to replicas.
220  xla.ParameterReplication parameter_replication = 61;
221
222  // If set, the given instruction is run in parallel on e.g. multiple CPU
223  // cores.  The outermost dimension gets split up into
224  // outer_dimension_partitions[0] pieces, the next-outermost dim gets split
225  // into outer_dimension_partitions[1] pieces, etc.
226  //
227  // It's illegal to partition a dimension into more shards than there are
228  // elements in that dimension.
229  repeated int64 outer_dimension_partitions = 64;
230
231  // Whether the kCustomCall instruction has side-effects, only present for
232  // kCustomCall.
233  bool custom_call_has_side_effect = 65;
234
235  // A list of CustomCallOutputOperandAliasing pairs that specifies aliasing
236  // buffers between output and operands for kCustomCall.
237  repeated xla.CustomCallOutputOperandAliasing
238      custom_call_output_operand_aliasing = 74;
239
240  // The delta value for kRngGetAndUpdateState.
241  int64 delta = 66;
242
243  // Specifies if the gather/scatter indices are guaranteed to be sorted by the
244  // caller.
245  bool indices_are_sorted = 67;
246
247  // Frontend attributes to pass to the XLA backend.
248  xla.FrontendAttributes frontend_attributes = 68;
249
250  // Specifies if all elements updated are guaranteed to be unique by
251  // the caller.
252  bool unique_indices = 69;
253
254  // RNG algorithm used by kRngBitGenerator.
255  xla.RandomAlgorithm rng_algorithm = 70;
256
257  // The comparison type used for kCompare.
258  string comparison_type = 72;
259
260  // Specifies if this is a cross-program-prefetch, used by kCopyStart.
261  bool is_cross_program_prefetch = 73;
262
263  // If a convolution is dynamic, a dynamic padding type will be specified.
264  xla.PaddingType padding_type = 75;
265}
266
267// Serialization of HloComputation.
268message HloComputationProto {
269  reserved 3;
270  reserved "root_name";
271
272  string name = 1;
273
274  // The array of instructions is always in a valid dependency order, where
275  // operands appear before their users.
276  repeated HloInstructionProto instructions = 2;
277
278  // The program shape (with layout) of this computation.
279
280  xla.ProgramShapeProto program_shape = 4;
281
282  // The id of this computation.
283  int64 id = 5;
284
285  // The id of the root of the computation.
286  int64 root_id = 6;
287}
288
289// Serialization of an HLO schedule. An HLO schedule contains a total order of
290// instructions for each non-fusion computation in the module.
291message HloScheduleProto {
292  message InstructionSequence {
293    repeated int64 instruction_ids = 1;
294  }
295
296  // Map from computation id to sequence.
297  map<int64, InstructionSequence> sequences = 1;
298}
299
300enum Kind {
301  // Define a UNDEFINED_ALIAS equal to zero to get around the default-0 proto3
302  // behavior and missing has_*() APIs.
303  UNDEFINED_ALIAS = 0;
304  // The buffers may or may not alias at runtime.
305  MAY_ALIAS = 1;
306  // The buffers must alias at runtime.
307  MUST_ALIAS = 2;
308}
309
310message HloInputOutputAliasProto {
311  // The following proto describes a pair of aliased an input
312  // (described by parameter number and a ShapeIndex of the parameter)
313  // and an output (described by a ShapeIndex of the root
314  // instruction). For example:
315  //
316  // entry = {
317  //  output_shape_index={1},
318  //  parameter_number=0,
319  //  parameter_shape_index={1, 2},
320  // }
321  //
322  // This entry indicates that the first paremter's {1, 2} element is
323  // aliased with the {1} element of the root instruction.
324  message AliasEntryProto {
325    // ShapeIndex of the root hlo.
326    repeated int64 output_shape_index = 1;
327    // Number of the parameter in entry computation.
328    int64 parameter_number = 2;
329    // ShapeIndex of the parameter instruction.
330    repeated int64 parameter_shape_index = 3;
331    // The kind of alias to be setup.
332    Kind kind = 4;
333  }
334
335  repeated AliasEntryProto entries = 1;
336}
337
338message DynamicParameterBindingProto {
339  // A list of bindings which indicates that the `target_dim_num` in
340  // the subshape `target_param_index` of parameter `target_param_num`
341  // is a dynamic dimension and its real dynamic size is represented
342  // by `dynamic_param_index` in parameter `dynamic_param_num`.
343  //
344  // As an example, imagine we have a program:
345  //
346  // ENTRY main {
347  //   a = f32[] parameter(0)
348  //   b = f32[10] parameter(1)
349  //   ROOT root = (f32[], f32[10]) tuple(%a, %b)
350  // }
351  //
352  // Let's say 'b' (param index 1) is a dynamic shape whose input has
353  // an upperbound of 10 and real size is determined at runtime.'a'
354  // represents the real size of b's first dimension.
355  //
356  // In this case, the fields are set in the following way:
357  // dynamic_param_num = 1
358  // dynamic_param_index = {}
359  // target_param_num = 0
360  // target_param_index = {}
361  // target_param_dim = 0
362  message Binding {
363    int64 dynamic_param_num = 1;
364    repeated int64 dynamic_param_index = 2;
365    int64 target_param_num = 3;
366    repeated int64 target_param_index = 4;
367    int64 target_param_dim_num = 5;
368  }
369
370  repeated Binding entries = 1;
371}
372
373message CrossProgramPrefetch {
374  int64 parameter = 1;
375  repeated int64 index = 2;
376}
377
378// Serialization of HloModule.
379message HloModuleProto {
380  string name = 1;
381  string entry_computation_name = 2;
382  int64 entry_computation_id = 6;
383
384  // The array of computations is always in a valid dependency order, where
385  // callees appear before their callers.
386  repeated HloComputationProto computations = 3;
387
388  // The host program shape (with layout) of the entry computation.
389  xla.ProgramShapeProto host_program_shape = 4;
390
391  // The id of this module.
392  int64 id = 5;
393
394  // The schedule for this module.
395  HloScheduleProto schedule = 7;
396
397  // Describes alias information between inputs and outputs.
398  HloInputOutputAliasProto input_output_alias = 8;
399
400  DynamicParameterBindingProto dynamic_parameter_binding = 9;
401
402  repeated CrossProgramPrefetch cross_program_prefetches = 10;
403}
404
405// Serialization of LogicalBuffer.
406message LogicalBufferProto {
407  // Location represents an instruction and its shape index, which uniquely
408  // identifies a point where a buffer is needed.
409  message Location {
410    // NOTE: module_name isn't necessary, since all LogicalBuffers are
411    // associated with a single HloModule.
412    string computation_name = 1;
413    string instruction_name = 2;
414    repeated int64 shape_index = 3;
415  }
416
417  int64 id = 1;
418  int64 size = 2;
419
420  // The location where the buffer is defined.
421  Location defined_at = 3;
422
423  int64 color = 4;
424}
425
426// Serialization of BufferAllocation.
427message BufferAllocationProto {
428  // Assigned represents a single LogicalBuffer that is assigned to this
429  // BufferAllocation.
430  message Assigned {
431    int64 logical_buffer_id = 1;
432    int64 offset = 2;
433    int64 size = 3;
434  }
435
436  int64 index = 1;
437  int64 size = 2;
438  bool is_thread_local = 3;
439  bool is_tuple = 11;
440  bool is_entry_computation_parameter = 5;
441  bool is_constant = 12;
442  int64 parameter_number = 6;
443  repeated int64 parameter_shape_index = 10;
444  bool maybe_live_out = 7;
445  int64 color = 8;
446  repeated Assigned assigned = 9;
447}
448
449// A trace of a HeapSimulator run.
450message HeapSimulatorTrace {
451  // The trace includes a list of events, where each event describes one action
452  // performed by the heap simulator.
453  message Event {
454    enum Kind {
455      ALLOC = 0;  // A memory region was allocated for the buffer.
456      FREE = 1;   // A memory region was freed for the buffer.
457
458      // A buffer was shared with another (canonical) buffer. This is similar to
459      // ALLOC, except that instead of allocating a new region of memory, the
460      // memory region of the canonical buffer is directly re-used. Multiple
461      // buffers may share with the same canonical buffer. The lifetime of the
462      // canonical buffer is extended to the union of all lifetimes.
463      SHARE_WITH = 2;
464    }
465    Kind kind = 1;
466
467    // The id of the LogicalBuffer that the event applies to.
468    int64 buffer_id = 2;
469
470    // The HloInstruction that the simulation was processing that caused this
471    // event to occur, identified by its computation and instruction name. E.g.
472    // buffers defined by instruction A are allocated when processing A.
473    string computation_name = 3;
474    string instruction_name = 4;
475
476    // The id of the canonical LogicalBuffer that the buffer shares with. Only
477    // set for SHARE_WITH events.
478    int64 share_with_canonical_id = 5;
479  }
480  repeated Event events = 1;
481  bool whole_module_simulation = 2;
482}
483
484// An abstraction representing a set of HLO module built to run concurrently
485// across different devices.
486message HloModuleGroupProto {
487  string name = 1;
488  repeated HloModuleProto hlo_modules = 2;
489}
490
491// Serialization of BufferAssignment.
492message BufferAssignmentProto {
493  // Alias represents a source LogicalBuffer, and the buffer location that
494  // aliases it.
495  message BufferAlias {
496    int64 source_buffer_id = 1;
497    LogicalBufferProto.Location location = 2;
498  }
499
500  repeated LogicalBufferProto logical_buffers = 1;
501  repeated BufferAlias buffer_aliases = 2;
502  repeated BufferAllocationProto buffer_allocations = 3;
503  repeated HeapSimulatorTrace heap_simulator_traces = 4;
504}
505
506// Grouping message that contains all of the information above.
507message HloProto {
508  reserved 2;
509  reserved "hlo_ordering";
510
511  HloModuleProto hlo_module = 1;
512  BufferAssignmentProto buffer_assignment = 3;
513}
514
515// Encapsulates HloProto together with the arguments, result, and
516// execution_platform. This message is used for purposes such as
517// analysis/replay/file-storage.
518message HloSnapshot {
519  // The hlo graph.
520  HloProto hlo = 1;
521
522  // The arguments passed to the graph.
523  repeated LiteralProto arguments = 2;
524
525  // The result of the graph.
526  LiteralProto result = 3;
527
528  // The name of the platform used to run the graph.
529  string execution_platform = 4;
530}
531
532// Metadata for an HLO module. Dumped after HLO passes and before LLO lowering
533// with filename module_####.metadata.textproto, where #### is
534// canonical_module_id.
535message HloModuleMetadataProto {
536  // Uniquely identifies an HloModuleMetadata. Equal to the first unique_id
537  // of the module (a module may go through multiple unique_ids). If a module
538  // is partitioned into multiple modules, those modules will each have a new
539  // HloModuleMetadata with a different canonical_module_id.
540  int64 canonical_module_id = 1;
541
542  // Name of the module group that the module is part of.
543  string module_group_name = 2;
544
545  // The canonical module id of the module that this one is partitioned from,
546  // if applicable.
547  int64 original_module_id = 3;
548
549  // The canonical module ids of the modules that this one is partitioned into,
550  // if applicable.
551  repeated int64 partitioned_module_ids = 4;
552
553  // Metadata for the HLO passes that are run on the module.
554  repeated HloPassMetadata pass_metadata = 5;
555}
556
557// Metadata for one run of an HLO pass on a module. Provides more information
558// when processing debug dumps of HloProtos about the order of HLO passes and
559// various other stats like duration. `pass_id` may also be used to identify a
560// particular run of a pass in debug info that propagates through stages of
561// compilation.
562message HloPassMetadata {
563  // For a given module, pass_id uniquely identifies a run of an HLO pass on
564  // that module. Note that a pass_id may not always refer to the same pass
565  // because the order of passes during compilation may change. For finding
566  // metadata for a particular pass, pass_name and pipeline_name would be more
567  // reliable, although note that they may not be unique.
568  int64 pass_id = 1;
569  string pass_name = 2;
570  string pipeline_name = 3;
571
572  // Filenames of the dumps of the module after this pass ran. Module may be
573  // dumped in multiple formats, and the order of formats in this field will
574  // stay consistent across passes.
575  repeated string dump_filenames = 4;
576
577  // Return value of pass.Run(). True if this pass changed the module, or, in
578  // the case where the module was run through this pass as part of a module
579  // group, true if this pass changed any module in the same module group.
580  bool module_changed = 5;
581
582  // The unique_id of the module that this pass is run on. May be different from
583  // the canonical_module_id of the HloModuleMetadata that this HloPassMetadata
584  // is inside.
585  int64 module_id = 6;
586
587  // If the module went through this pass as part of a module group, this is
588  // set as the ids of all the modules in the module group. Empty otherwise.
589  repeated int64 module_group_module_ids = 7;
590
591  // Timestamp before and after the pass is run. Note they may be equal.
592  int64 start_timestamp_usec = 8;
593  int64 end_timestamp_usec = 9;
594}
595