1syntax = "proto3";
2
3package tensorflow.tpu;
4
5import "tensorflow/compiler/xla/xla.proto";
6import "tensorflow/compiler/xla/xla_data.proto";
7import "tensorflow/core/framework/tensor_shape.proto";
8import "tensorflow/core/framework/types.proto";
9import "tensorflow/core/protobuf/tpu/dynamic_padding.proto";
10
11option cc_enable_arenas = true;
12
13// This is an experimental proto used in the TF/XLA bridge to store metadata to
14// a compile op (e.g. _TPUCompileMlir).
15// TODO(lyandy): Deprecate proto once generic metadata proto is created.
16message TPUCompileMetadataProto {
17  // Description of the types and shapes of the arguments to a computation.
18  message Arg {
19    enum Kind {
20      INVALID = 0;
21      PARAMETER = 1;
22      VARIABLE = 2;
23      // These are args which have been guaranteed to be constants during the
24      // session lifetime by the use of the GuaranteeConstOp (or ConstantOp).
25      GUARANTEED_CONSTANT = 3;
26    }
27    DataType dtype = 1;
28    TensorShapeProto shape = 2;
29    Kind kind = 3;
30
31    // The cross-core sharding of this input within each replica, e.g.,
32    // assigning to one core, or replicate across all cores.
33    xla.OpSharding sharding = 4;
34
35    // Whether this argument will receive the same data across all replicas.
36    bool is_same_data_across_replicas = 5;
37
38    enum EnableXlaSharding {
39      DISALLOWED = 0;
40      // Sharding is allowed if host training loop exists.
41      TENTATIVE = 1;
42      ALLOWED = 2;
43    }
44    // Whether to allow XLA to produce separate programs to shard/unshard this
45    // argument. Requires this arg to be an on-device Kind::VARIABLE, or a
46    // Kind::PARAMETER. For Kind::PARAMETER, it represents the initial value of
47    // a variable, and retval_index_for_sharding must be specified for the
48    // corresponding updated value.
49    EnableXlaSharding enable_xla_sharding = 6;
50
51    // If XLA sharding is allowed on a Kind::PARAMETER, this field is used to
52    // specify the corresponding updated value in the return values. Use -1 for
53    // variables that are not updated.
54    int32 retval_index_for_sharding = 8;
55
56    // Whether this argument is placed on fast memory or not.
57    bool fast_mem = 7;
58
59    // Whether to let XLA to decide the layout during compilation, as opposed to
60    // using a fixed layout determined by the shape.
61    bool unrestricted_layout = 9;
62
63    // Name of the node that the arg comes from.
64    string name = 10;
65  }
66  repeated Arg args = 1;
67
68  // Description of the return values from a computation.
69  message Retval {
70    // The cross-core sharding of this return value within each replica, e.g.,
71    // assigning to one core, or replicate across all cores.
72    xla.OpSharding sharding = 1;
73  }
74  repeated Retval retvals = 2;
75
76  // Number of replicas of the computation and number of cores in each replica.
77  // TODO(b/140721404): it may not be necessary to state the number of cores per
78  // replica here. Reconsider when replicated model-parallelism is implemented
79  // in XLA.
80  int32 num_replicas = 3;
81  int32 num_cores_per_replica = 4;
82
83  reserved 5;  // was device_names
84  reserved 7;  // was replica_device_assignment
85
86  xla.DeviceAssignmentProto device_assignment = 8;
87
88  // A fingerprint of the function library. Ensures that any functions called
89  // by the computation have matching definitions.
90  uint64 function_library_fingerprint = 6;
91
92  // Unique session identifier. Can be empty.
93  string session_handle = 9;
94
95  // Fingerprint of guaranteed_const value. The fingerprint computation inside
96  // tpu_compile_op may be slow. The computation can be avoided by setting the
97  // fingerprint value here.
98  string guaranteed_const_fingerprint = 10;
99
100  repeated tpu.PaddingMap padding_maps = 11;
101
102  // The location of step markers that XLA compile will instrument.
103  xla.DebugOptions.StepMarkerLocation step_marker_location = 12;
104
105  // Minimum number of batches run through the XLA graph before XLA fusion
106  // autotuner is enabled. Default value of zero disables the autotuner.
107  // The XLA fusion autotuner can improve performance by executing a heuristic
108  // search on the compiler parameters.
109  int64 xla_fusion_autotuner_thresh = 13;
110
111  // Enables TPU compiler to add partitioning policies for inputs/outputs to
112  // the XLA computation for model parallelism.
113  bool enable_automatic_model_parallelism = 14;
114
115  // Whether to use XLA's SPMD or MPMD partitioner when compiler partitioning is
116  // requested.
117  bool use_spmd_for_xla_partitioning = 15;
118
119  // Enables use of XLA collectives for broadcast of replicated parameters to
120  // all replicas, instead of using TensorFlow Send/Recv.
121  bool broadcast_replicated_parameters_via_collectives = 16;
122}
123