1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Rewrites TPUReplicate nodes into replicated computations on TPU.
17 //
18 // To represent a distributed TPU computation, we use the
19 // TPUReplicate operator, that describes a subgraph (represented as a
20 // Tensorflow function) to replicate across a TPU pod.
21 //
22 // Model parallelism and data parallelism:
23 // ---------------------------------------
24 // We support two different kinds of parallelism on TPU:
25 // * data parallelism (replication), or parallelization across batches, and
26 // * model parallelism, or parallelization within a batch.
27 //
28 // The function passed to a TPUReplicate operator is replicated many
29 // times across a TPU pod (data parallelism). The `num_replicas` attribute
30 // controls how many replicas of the computation to create. Replicas are mostly
31 // independent; replicas can only communicate using the CrossReplicaSum
32 // operator, which is typically used to communicate gradients during training.
33 //
34 // Each replica may optionally use more than one TPU core (model
35 // parallelism). The `num_cores_per_replica` attribute controls how many cores
36 // there are per replica. For each core, there is a virtual TPU_REPLICATED_CORE
37 // device that is only valid within replicated TPU computations (e.g.,
38 // TPU_REPLICATED_CORE:0, TPU_REPLICATED_CORE:1, etc.); each TPU_REPLICATED_CORE
39 // device corresponds to one TPU core in every replica.
40 // Each replica has runs its own copy of the computation assigned to each
41 // TPU_REPLICATED_CORE device.
42 //
43 // The Python code is responsible for providing a device_assignment that
44 // describes how the replicated logical cores map to physical cores on the TPU
45 // topology.
46 //
47 // Inputs to TPUReplicate:
48 // ------------------------------
49 // The TPUReplicate operator takes three kinds of inputs, in the
50 // following order:
51 // * per-replica inputs. If there are three per-replica inputs (A, B, C) and two
52 //   replicas, the first six arguments to TPUReplicate will be:
53 //   A0 B0 C0 A1 B1 C1
54 //   where Ai is the A input to the i-th replica.
55 // * distributed inputs. These inputs follow the per-replica inputs.
56 //   If there are two distributed inputs (E, F) and two replicas, the following
57 //   arguments to TPUReplicate will be: E F.
58 //   But there is local E and F on each replica.
59 // * broadcast inputs. These inputs follow the distributed inputs. All
60 //   replicas receive a copy of each of these inputs.
61 // * variables. Resource variables accessed by the computation follow the
62 //   broadcast inputs.
63 //
64 // For example, for a computation with two replicas, three per-replica inputs
65 // (A, B, C), two distributed inputs(E, F), two broadcast inputs (X, Y), and two
66 // variables (V, W), the arguments to TPUReplicate will be:
67 // A0 B0 C0 A1 B1 C1 E F X Y V W
68 // and each replica will receive the following arguments:
69 // A B C E F X Y V W
70 //
71 // Distributed TPU compilation requires that the shapes of all operators
72 // be known statically at compilation time, before any nodes have executed.
73 // Shapes are determined using shape information emitted by InferShapes. It
74 // is not possible to replicate Tensorflow operators with unknown or dynamic
75 // shapes for TPU at present.
76 //
77 // Graph rewrite:
78 // --------------
79 // Compilation replaces TPUReplicate operators with:
80 // * a single TPUCompile node that compiles the computations,
81 // * one TPUExecute node for each TPU device in the system that
82 //   executes the relevant computation,
83 // * one ReadVariableOp for each variable accessed by the replicated
84 //   computation,
85 // * one AssignVariableOp for each variable accessed by the replicated
86 //   computation. An assignment is built even if a variable is only read by the
87 //   computation. We do not know which variables are written until we apply the
88 //   XlaCompiler to the computation, but that does not happen until after the
89 //   rewrite. Conservatively, we write back the values of all variables after
90 //   the computation completes.
91 //   TODO(phawkins): only write back variables that the computation may write.
92 // * one Shape node for each Tensor or Variable input to the computation whose
93 //   shape is not statically known at rewrite time. The input shapes are fed
94 //   to the TPUCompile node.
95 //
96 // To ensure that the reads and writes seem to happen at the right time in the
97 // graph execution, we add control edges from all predecessors of the original
98 // TPUReplicate operator to each of the ReadVariableOp operators.
99 // Similarly, we add control edges from all of the AssignVariableOp operators to
100 // all of the successors of the TPUReplicate operator.
101 //
102 // The TPUReplicate rewrite must run before placement, since resource
103 // variable inputs will have DT_RESOURCE, which cannot be sent across devices,
104 // leading to objections from the placer. The rewrite rewrites the resource
105 // accesses into explicit ReadVariableOp and AssignVariableOp operators that the
106 // placer is free to colocate with the variables.
107 
108 #ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_H_
109 #define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_H_
110 
111 #include <string>
112 #include <vector>
113 
114 #include "absl/container/node_hash_map.h"
115 #include "absl/types/span.h"
116 #include "tensorflow/compiler/jit/shape_inference.h"
117 #include "tensorflow/compiler/xla/service/computation_placer.h"
118 #include "tensorflow/core/common_runtime/optimization_registry.h"
119 #include "tensorflow/core/framework/function.h"
120 #include "tensorflow/core/graph/graph.h"
121 #include "tensorflow/core/platform/env.h"
122 #include "tensorflow/stream_executor/tpu/tpu_topology.h"
123 
124 namespace tensorflow {
125 
126 // Replaces clusters assigned to TPU_SYSTEM devices with
127 // TPUCompile and TPUExecute nodes assigned to the corresponding
128 // TPU devices.
129 class DistributedTPURewritePass : public GraphOptimizationPass {
130  public:
131   static void SetDistributedTpuRewritePassOptions(
132       bool distribute_vars, bool allow_xla_spmd_partition,
133       bool replicate_inputs_outputs_by_default_for_xla_spmd,
134       bool enable_cross_replica_sharding_mirrored_variables,
135       bool enable_automatic_model_parallelism, bool enable_xla_param_broadcast);
136 
137   Status Run(const GraphOptimizationPassOptions& options) override;
138 
139   // The following methods are public only for the use of unit tests.
140 
141   // See comment at the top of the file for how the inputs are ordered.
142   // Encapsulates the different TPU replicated node input and output
143   // information, and provide common APIs over them.
144   class ParameterInfo {
145    public:
ParameterInfo()146     ParameterInfo() {}
ParameterInfo(int64 num_replicas,int64 num_per_replica_args,int64 num_distributed_args,int64 num_broadcast_args,int64 num_variables,int64 num_guaranteed_constants,int64 num_retvals_per_replica)147     ParameterInfo(int64 num_replicas, int64 num_per_replica_args,
148                   int64 num_distributed_args, int64 num_broadcast_args,
149                   int64 num_variables, int64 num_guaranteed_constants,
150                   int64 num_retvals_per_replica)
151         : num_replicas_(num_replicas),
152           num_per_replica_args_(num_per_replica_args),
153           num_distributed_args_(num_distributed_args),
154           num_broadcast_args_(num_broadcast_args),
155           num_variables_(num_variables),
156           num_guaranteed_constants_(num_guaranteed_constants),
157           num_retvals_per_replica_(num_retvals_per_replica) {}
158 
NumReplicas()159     int64 NumReplicas() const { return num_replicas_; }
160 
NumPerReplicaArgs()161     int64 NumPerReplicaArgs() const { return num_per_replica_args_; }
162 
NumDistributedArgs()163     int64 NumDistributedArgs() const { return num_distributed_args_; }
164 
NumBroadcastArgs()165     int64 NumBroadcastArgs() const { return num_broadcast_args_; }
166 
NumVariables()167     int64 NumVariables() const { return num_variables_; }
168 
NumGuaranteedConstants()169     int64 NumGuaranteedConstants() const { return num_guaranteed_constants_; }
170 
NumRetvalsPerReplica()171     int64 NumRetvalsPerReplica() const { return num_retvals_per_replica_; }
172 
IsPerReplicaArg(int64 index)173     bool IsPerReplicaArg(int64 index) const {
174       return index < num_per_replica_args_;
175     }
176 
IsDistributedArg(int64 index)177     bool IsDistributedArg(int64 index) const {
178       return index >= num_per_replica_args_ &&
179              index < (num_per_replica_args_ + num_distributed_args_);
180     }
181 
IsBroadcastArg(int64 index)182     bool IsBroadcastArg(int64 index) const {
183       return index >= num_per_replica_args_ &&
184              index < (num_per_replica_args_ + num_distributed_args_ +
185                       num_broadcast_args_);
186     }
187 
IsVariableArg(int64 index)188     bool IsVariableArg(int64 index) const {
189       return index >= (num_per_replica_args_ + num_broadcast_args_) &&
190              index < (num_per_replica_args_ + num_distributed_args_ +
191                       num_broadcast_args_ + num_variables_);
192     }
193 
IsConstantArg(int64 index)194     bool IsConstantArg(int64 index) const {
195       return index >= (num_per_replica_args_ + num_distributed_args_ +
196                        num_broadcast_args_ + num_variables_) &&
197              index < (num_per_replica_args_ + num_distributed_args_ +
198                       num_broadcast_args_ + num_variables_ +
199                       num_guaranteed_constants_);
200     }
201 
202     // Returns the number of inputs which has been received by the host.
NumInputsFromHost()203     int64 NumInputsFromHost() const {
204       return num_replicas_ * num_per_replica_args_ + num_distributed_args_ +
205              num_broadcast_args_ + num_variables_ + num_guaranteed_constants_;
206     }
207 
208     // Returns the number of inputs which will be sent to each replica.
NumInputsToEachReplica()209     int64 NumInputsToEachReplica() const {
210       return num_per_replica_args_ + num_distributed_args_ +
211              num_broadcast_args_ + num_variables_ + num_guaranteed_constants_;
212     }
213 
214     // Returns the total number of output values returned to the host (for all
215     // replicas).
NumOutputsToHost()216     int64 NumOutputsToHost() const {
217       return num_replicas_ * num_retvals_per_replica_;
218     }
219 
220     // Returns the position of the first per-replica argument, within the set
221     // of all hosts arguments.
222     // Broadcast arguments follow the distributed arguments.
FirstBroadcastArgFromHost()223     int64 FirstBroadcastArgFromHost() const {
224       return num_replicas_ * num_per_replica_args_ + num_distributed_args_;
225     }
226 
227     // Indices of mirrored variables across replicas, which should be
228     // categorized as per_replica_args.
mirrored_variable_indices()229     const std::set<int64>& mirrored_variable_indices() const {
230       return mirrored_variable_indices_;
231     }
mutable_mirrored_variable_indices()232     std::set<int64>* mutable_mirrored_variable_indices() {
233       return &mirrored_variable_indices_;
234     }
235 
236    private:
237     int64 num_replicas_ = 1;
238     int64 num_per_replica_args_ = 0;
239     int64 num_distributed_args_ = 0;
240     int64 num_broadcast_args_ = 0;
241     int64 num_variables_ = 0;
242     int64 num_guaranteed_constants_ = 0;
243     int64 num_retvals_per_replica_ = 0;
244     std::set<int64> mirrored_variable_indices_;
245   };
246 
247   // Mapping from TPUReplicate cluster name to tpu device names. Value is a
248   // mapping from [replica][core] to a TF device name.
249   typedef absl::flat_hash_map<string, std::vector<std::vector<string>>>
250       TPUReplicateDeviceNamesMapping;
251 
252   // Determines which devices to use to run the computation.
253   // Inputs:
254   // * num_tpus_per_task: the number of TPU devices attached to each task
255   // * tpu_devices: a [task][device] collection of TPU devices
256   // * num_replicas: the number of replicas requested
257   // * num_cores_per_replica: the number of cores in each computation instance
258   // * topology_attr: the topology TPUReplicate attribute
259   // * device_assignment_attr: the device_assignment TPUReplicate attribute
260   // Outputs:
261   // * tf_device_assignment: a mapping from [replica][core] to a TF device name
262   // * xla_device_assignment: a mapping from [replica][core] to a linearized TPU
263   //   coordinate.
264   // TODO(phawkins): change tf_device_assignment to an xla::Array2D.
265   static Status BuildDeviceAssignment(
266       const tpu::TpuTopologyExternal& topology, int num_tpus_per_task,
267       const std::vector<std::vector<Device*>>& tpu_devices, int num_replicas,
268       int num_cores_per_replica, const string& topology_attr,
269       absl::Span<const int> device_assignment_attr,
270       std::vector<std::vector<string>>* tf_device_assignment,
271       std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment);
272 
273   // Returns the `computation` graph attached to TPUReplicate operator
274   // `node`. `flr` is a FunctionLibraryRuntime to use when
275   // instantiating the function body. Sets `*arg_types` and
276   // `*retval_types` to the argument/return types of the function.
277   static Status GetComputationForTPUReplicateOp(const NameAttrList& function,
278                                                 FunctionLibraryRuntime* flr,
279                                                 Graph* computation,
280                                                 DataTypeVector* arg_types,
281                                                 DataTypeVector* retval_types);
282 
283   // Returns the shapes of the argument tensors and return values of the
284   // TPUReplicate operator `node` using the _output_shapes,
285   // _output_handle_shapes, and _output_handle_types annotations on the input
286   // nodes. Expects inputs in the following order (see comment at top of file):
287   // * num_replicas * num_per_replica_args per-replica inputs,
288   // * num_broadcast_args broadcast inputs,
289   // * num_variables variable inputs.
290   // Returns an error if the input shapes to `node` are not statically known.
291   // Also verifies that all replicas have identical input shapes for their
292   // per-replica inputs.
293   static Status GetArgAndRetvalShapes(
294       const GraphShapeInfo& shape_info, const Node& node,
295       const ParameterInfo& params_info, std::vector<InferredShape>* arg_shapes,
296       std::vector<InferredShape>* retval_shapes);
297 
298   // Assigns arguments and return values to cores. The assignment is represented
299   // as an XLA op sharding, so that an argument can be replicated across cores.
300   // `arg_sharding` and `retval_sharding` are vectors of shardings indexed by
301   // argument/retval number.
302   // `arg_fast_mem` is vector of fast_mem indication which is indexed by
303   // argument number.
304   static Status AssignArgsAndRetvalsToCores(
305       int num_cores_per_replica, const ParameterInfo& params_info,
306       const DataTypeVector& arg_types,
307       const std::vector<InferredShape>& arg_shapes,
308       const DataTypeVector& retval_types,
309       const std::vector<InferredShape>& retval_shapes, const Graph& graph,
310       const Node* replicate_node, FunctionLibraryRuntime* flr,
311       bool allow_parameter_replication_for_spmd,
312       std::vector<::xla::OpSharding>* arg_sharding,
313       std::vector<bool>* arg_fast_mem,
314       std::vector<::xla::OpSharding>* retval_sharding,
315       std::vector<std::string>* arg_names);
316 
317   // Populates `*variables` with the "variables" inputs to `index`-th output of
318   // `node`.
319   struct VariableInput {
320     Node* node;
321     int index;
322 
323     // Type of the variable's value. Note that this is different to the type of
324     // the output of 'variable', which is always DT_RESOURCE.
325     DataType dtype;
326   };
327   static Status FindVariableInputs(const Node& node,
328                                    const NameRangeMap& input_range_map,
329                                    std::vector<VariableInput>* variables);
330 
331   // Populates '*guaranteed_constants' with the "guaranteed_constants" inputs
332   // to 'node'.
333   static Status FindGuaranteedConstantInputs(
334       const Node& node, const NameRangeMap& input_range_map,
335       std::vector<Node*>* guaranteed_constants);
336 
337   // Builds Shape nodes that compute the shapes of arguments whose shapes are
338   // not statically known.
339   static Status BuildDynamicShapeNodes(
340       const Node& replicate_node, const std::vector<InferredShape>& arg_shapes,
341       const ParameterInfo& params_info,
342       const std::vector<Node*>& variable_reads, Graph* graph,
343       std::vector<Node*>* dynamic_shape_nodes);
344 
345   // Builds a TPUCompile node that compiles the computation in
346   // `function_names`. calls `nodes`.
347   // TODO(b/33943292): at present, for model parallelism with Send/Recv to work
348   // the `nodes` must correspond to the computations assigned to TPU:0,
349   // TPU:1, ... in order since XLA hard-codes the chip IDs in the generated
350   // executables.
351   static Status BuildCompileNode(
352       const Node* replicate_node, const NameAttrList& function,
353       uint64 library_fingerprint, const ParameterInfo& params_info,
354       const std::vector<InferredShape>& arg_shapes,
355       const DataTypeVector& arg_types,
356       const std::vector<Node*>& guaranteed_constant_nodes,
357       const string& session_handle,
358       const std::vector<::xla::OpSharding>& arg_sharding,
359       const std::vector<bool>& arg_fast_mem,
360       const std::vector<std::string>& arg_names,
361       const std::vector<::xla::OpSharding>& retval_sharding,
362       int num_cores_per_replica, const string& compile_device,
363       const xla::DeviceAssignment* xla_device_assignment,
364       const std::vector<Node*>& dynamic_shape_nodes, Graph* graph,
365       Node** compile_node, int64 autotuner_thresh);
366 
367   // Builds a TPUCompileSucceededAssert node that verifies that compilation
368   // succeeded and replaces the TPUCompilationStatus node in the graph.
369   static Status BuildCompilationStatusReturnNodes(
370       Node* replicate_node, Node* compile_node,
371       Node** control_after_compilation, Graph* graph);
372 
373   // Builds ReadVariableOp nodes that read `variables`, with a control
374   // edges that ensure they happen after `control_predecessor`.
375   static Status BuildVariableReads(absl::Span<const VariableInput> variables,
376                                    Node* control_predecessor, Graph* graph,
377                                    std::vector<Node*>* variable_reads);
378 
379   // Returns true if graph or functions contain resource write op, otherwise
380   // return false.
381   // TODO(b/137048563): Recognize unused resource rewrite op.
382   static bool ContainsResourceWriteOp(const Graph& graph,
383                                       const FunctionLibraryDefinition& fld);
384   // Struct that describes a variable value to be written back from TPUExecute.
385   struct VariableWrite {
386     // A node:output pair containing a boolean tensor that determines whether
387     // the value should be written back.
388     Node* predicate;
389     int predicate_output;
390 
391     // A node:output pair containing the value to be written back.
392     Node* value;
393     int value_output;
394   };
395 
396   // Builds AssignVariableOp nodes that write `variables` with the values from
397   // `variable_writes`, with control edges that ensure the writes happen before
398   // `control_successor`.
399   static Status BuildVariableWrites(
400       absl::Span<const VariableInput> variables, Node* control_successor,
401       absl::Span<const VariableWrite> variable_writes, Graph* graph);
402 
403   // Builds TPUExecute operators assigned to each TPU device
404   // involved in the computation.
405   // Arguments:
406   // * `params_info` is the structure containing the information about the
407   //    TPUReplicate node inputs and outputs.
408   // * `num_tasks` is the number of TensorFlow tasks in the slice.
409   // * `num_cores_per_replica` is the number of cores which are dedicated to
410   //    each replica.
411   // * `replicate_node` is the original TPUReplicate node.
412   // * `arg_names` are the names of the arguments to the computation function
413   //    passed as argument to TPUReplicate, including per-replica,
414   //    broadcast, and variable arguments.
415   // * `arg_types` are the corresponding types of the arguments.
416   // * `arg_shapes` are the corresponding shapes (and handle types/shapes, if
417   //    applicable).
418   // * `arg_shardings` and `retval_shardings` are mappings from
419   //    arguments/return indices to shardings, as returned by
420   //    `AssignArgsAndRetvalsToCores`.
421   // * `pod_devices` lists the devices to assign to each core of each replica.
422   // * `variable_reads` is a vectors of ReadVariableOp operators, one for each
423   //    variable argument to the computation.
424   // * The execute operators will have a control edge from
425   //   `control_predecessor` and another control edge to `control_successor`.
426   // Populates '*variable_writes' with information about variable values to
427   // write back.
428   static Status BuildExecuteNodes(
429       const ParameterInfo& params_info, int num_tasks,
430       int num_cores_per_replica, const Node& replicate_node,
431       const std::vector<std::string>& arg_names,
432       const DataTypeVector& arg_types,
433       const std::vector<InferredShape>& arg_shapes,
434       const DataTypeVector& retval_types,
435       const std::vector<::xla::OpSharding>& arg_shardings,
436       const std::vector<::xla::OpSharding>& retval_shardings,
437       const std::vector<std::vector<string>>& tpu_device_names,
438       Node* compile_node, const std::vector<Node*>& variable_reads,
439       Node* control_predecessor, Node* control_successor,
440       std::vector<VariableWrite>* variable_writes, Graph* graph);
441 
442   // Connects the compile node to all the host transfer nodes, and removes the
443   // key placeholder node that was previously standing in for it.
444   // Arguments:
445   // * `compile_node` is the TPUCompile node that has been added to the graph.
446   // * `key_placeholder_node` is the placeholder node to send the key to all the
447   // host
448   // * transfer nodes in the original graph.
449   // * `graph` is the graph being rewritten.
450   static Status ConnectHostComputeNodes(Node* compile_node,
451                                         Node* key_placeholder_node,
452                                         Graph* graph);
453 
454   // Map from a Node in an outside_compilation cluster in the original graph to
455   // the list of Nodes, one for each replica, that it is expanded into during
456   // replication.
457   typedef absl::node_hash_map<Node*, std::vector<Node*>> NodeToNodeReplicasMap;
458 
459   // Map from the name of an outside_compilation cluster to the model-parallel
460   // core index that the HostCompute Op should be placed on in that cluster.
461   typedef std::map<string, int> HostComputeCoreMap;
462 
463   // Map from the name of an outside_compilation cluster to the list of Nodes
464   // that should run on the host for that cluster.
465   typedef std::map<string, std::vector<Node*>> OutsideCompilationNodeMap;
466 
467   // Copies the outside_compilation nodes in a cluster to create replica
468   // replica_index.
469   static Status CopyOutsideCompilationNodes(
470       int replica_index, const std::vector<Node*>& outside_compilation_nodes,
471       const DeviceNameUtils::ParsedName& tpu_device,
472       const DeviceNameUtils::ParsedName& partial_device,
473       NodeToNodeReplicasMap* node_images, Graph* graph);
474 
475   // Replicates all the nodes in outside_compilation clusters in a compiled
476   // computation.
477   static Status ReplicateOutsideCompilationNodes(
478       const std::vector<std::vector<string>>& tf_device_assignment,
479       const HostComputeCoreMap& host_compute_core,
480       const OutsideCompilationNodeMap& outside_compilation_nodes,
481       NodeToNodeReplicasMap* node_images, Graph* graph);
482 
483   // Lifts the edges between original outside_compilation nodes in a cluster
484   // onto their replicas.
485   static Status CopyOutsideCompilationEdges(
486       const std::vector<Node*>& outside_compilation_nodes,
487       const NodeToNodeReplicasMap& node_images,
488       const std::unordered_map<string, Node*> outside_compilation_inputs,
489       Graph* graph);
490 
491   // Lifts all the edges in outside_compilation clusters in a compiled
492   // computation to their replicas.
493   static Status ReplicateOutsideCompilationEdges(
494       const OutsideCompilationNodeMap& outside_compilation_nodes,
495       const NodeToNodeReplicasMap& node_images,
496       const std::unordered_map<string, Node*> outside_compilation_inputs,
497       Graph* graph);
498 
499   // Removes all the original outside_compilation nodes from the graph,
500   // following replication.
501   static Status RemoveOutsideCompilationNodes(
502       const NodeToNodeReplicasMap& node_images, Graph* graph);
503 
504   // Lowers outside compilation functional nodes (If/While/function call).
505   // Otherwise, when we have multiple workers, device placer will not be able to
506   // place nodes if outside compilation has DT_RESOURCE inputs (e.g. a
507   // DT_RESOURCE input fed into multiple While nodes on different devices).
508   static Status LowerOutsideCompilationFunctionalNodes(
509       Graph* g, const FunctionLibraryDefinition& flib_def,
510       const TPUReplicateDeviceNamesMapping& tpu_replicate_device_names_mapping);
511 
512   // Parses the 'host_compute_core' attribute on replicate_node to get the
513   // replicated core id of each outside_compilation cluster.
514   static Status ParseHostComputeCores(
515       const Node& replicate_node,
516       const OutsideCompilationNodeMap& outside_compilation_nodes,
517       HostComputeCoreMap* host_compute_core);
518 
519   // Gets the physical topology information about the TPU system.
520   static Status GetDeviceTopology(
521       const DeviceSet& device_set, const Node& replicate_node,
522       int* num_replicas, int* num_cores_per_replica, int* num_tasks,
523       std::vector<std::vector<string>>* tf_device_assignment,
524       std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment,
525       string* tpu_compilation_device);
526 
527   // Gets the types of args, retvals, and parameters.
528   static Status GetIOTypes(
529       int num_replicas, const Node& replicate_node, FunctionLibraryRuntime* flr,
530       Graph* graph, NameRangeMap* input_name_map, const NameAttrList** function,
531       std::unique_ptr<Graph>* computation, DataTypeVector* arg_types,
532       DataTypeVector* retval_types, ParameterInfo* params_info);
533 
534   // Find known constants and deals with variable reads.
535   static Status DealWithConstantsAndVariables(
536       const Node& replicate_node, const NameRangeMap& input_name_map,
537       Graph* graph, Node* host_transfer_sequencer, Node* control_before,
538       Node* control_after, absl::Span<const VariableInput> variable_nodes,
539       std::vector<Node*>* guaranteed_constant_nodes,
540       std::vector<Node*>* variable_reads);
541 
542   // Adds NoOp nodes for sequencing computation and variable reads/writes.
543   static Status BuildSequencingNodes(const string& tpu_compilation_device,
544                                      const Node& replicate_node, Graph* graph,
545                                      Node** host_transfer_sequencer,
546                                      Node** control_before,
547                                      Node** control_after);
548 
549   // Performs the pass's rewrite on a TPUReplicate node `node`.
550   static Status RewriteTPUReplicateNode(
551       const string& session_handle, const DeviceSet& device_set,
552       Node* replicate_node, FunctionLibraryDefinition* flib_def,
553       FunctionLibraryRuntime* flr, Node* host_compute_key_placeholder_node,
554       const OutsideCompilationNodeMap& outside_compilation_nodes,
555       const std::vector<Node*>& head_tail_outside_compilation_nodes,
556       NodeToNodeReplicasMap* outside_compilation_node_images, Graph* graph,
557       const GraphShapeInfo& shape_info,
558       TPUReplicateDeviceNamesMapping* tpu_replicate_device_names_mapping,
559       int64 autotuner_thresh);
560 
561   // Performs host training loop optimization. For example, when TPUExecute
562   // node is inside a while loop, then model weight variables can be sharded
563   // in XLA preferred layout and then unsharded only at the very last iteration
564   // to reduce the number of all_gather.
565   static Status PerformHostTrainingLoopOptimization(
566       Graph* graph, FunctionLibraryDefinition* flib_def,
567       FunctionLibraryRuntime* flr);
568 
569   // Heuristically place some nodes with unassigned devices on TPUs for
570   // performance reasons.
571   static Status PlaceUnassignedDeviceNodesOnTPUIfPossible(Graph* graph);
572 
573   // Updates the head and tail outside compiled nodes so that nodes have the
574   // correct device and removes the replication and outside compilation
575   // attributes so that these nodes do not trigger further graph optimization
576   // passes.
577   static Status UpdateHeadTailOutsideCompilation(
578       const std::vector<std::vector<string>>& tf_device_assignment,
579       const std::vector<Node*>& head_tail_outside_compilation_nodes);
580 
581  private:
582   static bool distribute_vars_;
583   static bool allow_xla_spmd_partition_;
584   static bool replicate_inputs_outputs_by_default_for_xla_spmd_;
585   static bool enable_cross_replica_sharding_mirrored_variables_;
586   static bool enable_automatic_model_parallelism_;
587   static bool enable_xla_param_broadcast_;
588 };
589 
590 }  // namespace tensorflow
591 
592 #endif  // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_H_
593