1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Compilation for distributed TPU (TPU_REPLICATED_CORE devices).
17 
18 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h"
19 
20 #include <queue>
21 #include <vector>
22 
23 #include "absl/algorithm/container.h"
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/strings/escaping.h"
26 #include "tensorflow/compiler/jit/encapsulate_util.h"
27 #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
28 #include "tensorflow/compiler/tf2xla/sharding_util.h"
29 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
30 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
31 #include "tensorflow/compiler/xla/array3d.h"
32 #include "tensorflow/compiler/xla/array4d.h"
33 #include "tensorflow/compiler/xla/client/sharding_builder.h"
34 #include "tensorflow/compiler/xla/service/computation_placer.h"
35 #include "tensorflow/compiler/xla/xla.pb.h"
36 #include "tensorflow/core/common_runtime/function.h"
37 #include "tensorflow/core/common_runtime/graph_constructor.h"
38 #include "tensorflow/core/common_runtime/lower_function_call_op.h"
39 #include "tensorflow/core/common_runtime/lower_functional_ops.h"
40 #include "tensorflow/core/common_runtime/lower_if_op.h"
41 #include "tensorflow/core/common_runtime/lower_while_op.h"
42 #include "tensorflow/core/common_runtime/optimization_registry.h"
43 #include "tensorflow/core/framework/function.h"
44 #include "tensorflow/core/framework/graph_to_functiondef.h"
45 #include "tensorflow/core/framework/node_def_builder.h"
46 #include "tensorflow/core/framework/node_def_util.h"
47 #include "tensorflow/core/framework/partial_tensor_shape.h"
48 #include "tensorflow/core/framework/tensor.pb.h"
49 #include "tensorflow/core/framework/types.pb.h"
50 #include "tensorflow/core/framework/versions.pb.h"
51 #include "tensorflow/core/graph/algorithm.h"
52 #include "tensorflow/core/graph/graph.h"
53 #include "tensorflow/core/lib/core/errors.h"
54 #include "tensorflow/core/lib/core/status.h"
55 #include "tensorflow/core/lib/gtl/cleanup.h"
56 #include "tensorflow/core/lib/strings/proto_serialization.h"
57 #include "tensorflow/core/lib/strings/str_util.h"
58 #include "tensorflow/core/platform/fingerprint.h"
59 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
60 #include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h"
61 #include "tensorflow/core/protobuf/tpu/topology.pb.h"
62 #include "tensorflow/core/public/session_options.h"
63 #include "tensorflow/core/tpu/graph_rewrite/cond_builder.h"
64 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h"
65 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h"
66 #include "tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.h"
67 #include "tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.h"
68 #include "tensorflow/core/tpu/tpu_compile_interface.h"
69 #include "tensorflow/core/tpu/tpu_defs.h"
70 #include "tensorflow/core/tpu/tpu_fingerprint_utils.h"
71 #include "tensorflow/core/tpu/tpu_ops_c_api.h"
72 #include "tensorflow/core/util/device_name_utils.h"
73 #include "tensorflow/core/util/dump_graph.h"
74 #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
75 
76 namespace tensorflow {
77 
78 namespace {
79 
80 // Device coordinates are defined as (x, y, z, core), thus resulting in a rank 4
81 // topology.
82 constexpr int kTPUTopologyRank = 4;
83 
84 // An upper bound on how many cores may be present in the topology.
85 static constexpr int kTPUMaxTopologySize = 4096;
86 
87 // Attribute containing the serialized xla::OpSharding to be passed to the
88 // corresponding XLA HLO operation, which represents how a shape is distributed
89 // across logical cores, e.g., replication, single-device, or partitioning.
90 const char kShardingAttribute[] = "_XlaSharding";
91 
92 const char kTPUPartitionedInput[] = "TPUPartitionedInput";
93 const char kTPUPartitionedOutput[] = "TPUPartitionedOutput";
94 
95 const char kVarHandleOp[] = "VarHandleOp";
96 
97 static const char* const kTPUCompilationResultAttr = "_tpu_compilation_status";
98 static const char* const kPostDeviceRewriteAttr = "_post_device_rewrite";
99 
100 class IntrusiveHeapLink {
101  public:
102   using size_type = size_t;
103   static constexpr size_type kNotMember = -1;
104 
105   IntrusiveHeapLink() = default;
106 
107   // Only IntrusiveHeap and LinkAccess objects should make these objects.
IntrusiveHeapLink(size_type pos)108   explicit IntrusiveHeapLink(size_type pos) : pos_{pos} {}
109 
110   // Only IntrusiveHeap and LinkAccess should get the value.
get() const111   size_type get() const { return pos_; }
112 
113  private:
114   size_type pos_{kNotMember};
115 };
116 
117 template <typename T, IntrusiveHeapLink T::*M>
118 struct IntrusiveHeapDataMemberLinkAccess {
Gettensorflow::__anonf0ad67bd0111::IntrusiveHeapDataMemberLinkAccess119   IntrusiveHeapLink Get(const T* elem) const { return elem->*M; }
Settensorflow::__anonf0ad67bd0111::IntrusiveHeapDataMemberLinkAccess120   void Set(T* elem, IntrusiveHeapLink link) const { elem->*M = link; }
121 };
122 
123 template <typename T>
124 struct DefaultIntrusiveHeapLinkAccess {
Gettensorflow::__anonf0ad67bd0111::DefaultIntrusiveHeapLinkAccess125   IntrusiveHeapLink Get(const T* elem) const { return elem->heap; }
Settensorflow::__anonf0ad67bd0111::DefaultIntrusiveHeapLinkAccess126   void Set(T* elem, IntrusiveHeapLink link) const { elem->heap = link; }
127 };
128 
129 template <typename T, typename PtrCompare,
130           typename LinkAccess = DefaultIntrusiveHeapLinkAccess<T>,
131           typename Alloc = std::allocator<T*>>
132 class IntrusiveHeap {
133  public:
134   typedef typename IntrusiveHeapLink::size_type size_type;
135   typedef T value_type;
136   typedef T* pointer;
137   typedef const T* const_pointer;
138   typedef PtrCompare pointer_compare_type;
139   typedef LinkAccess link_access_type;
140   typedef Alloc allocator_type;
141 
IntrusiveHeap(const pointer_compare_type & comp=pointer_compare_type (),const link_access_type & link_access=link_access_type (),const allocator_type & alloc=allocator_type ())142   explicit IntrusiveHeap(
143       const pointer_compare_type& comp = pointer_compare_type(),
144       const link_access_type& link_access = link_access_type(),
145       const allocator_type& alloc = allocator_type())
146       : rep_(comp, link_access, alloc) {}
147 
size() const148   size_type size() const { return heap().size(); }
149 
empty() const150   bool empty() const { return heap().empty(); }
151 
152   // Return the top element, but don't remove it.
top() const153   pointer top() const {
154     DCHECK(!empty());
155     return heap()[0];
156   }
157 
158   // Remove the top() pointer from the heap and return it.
Pop()159   pointer Pop() {
160     pointer t = top();
161     Remove(t);
162     return t;
163   }
164 
165   // Insert 't' into the heap.
Push(pointer t)166   void Push(pointer t) {
167     SetPositionOf(t, heap().size());
168     heap().push_back(t);
169     FixHeapUp(t);
170   }
171 
172   // Adjust the heap to accommodate changes in '*t'.
Adjust(pointer t)173   void Adjust(pointer t) {
174     DCHECK(Contains(t));
175     size_type h = GetPositionOf(t);
176     if (h != 0 && compare()(t, heap()[(h - 1) >> 1])) {
177       FixHeapUp(t);
178     } else {
179       FixHeapDown(t);
180     }
181   }
182 
183   // Remove the specified pointer from the heap.
Remove(pointer t)184   void Remove(pointer t) {
185     DCHECK(Contains(t));
186     size_type h = GetPositionOf(t);
187     SetPositionOf(t, IntrusiveHeapLink::kNotMember);
188     if (h == heap().size() - 1) {
189       // Fast path for removing from back of heap.
190       heap().pop_back();
191       return;
192     }
193     // Move the element from the back of the heap to overwrite 't'.
194     pointer& elem = heap()[h];
195     elem = heap().back();
196     SetPositionOf(elem, h);  // Element has moved, so update its link.
197     heap().pop_back();
198     Adjust(elem);  // Restore the heap invariant.
199   }
200 
Clear()201   void Clear() { heap().clear(); }
202 
Contains(const_pointer t) const203   bool Contains(const_pointer t) const {
204     size_type h = GetPositionOf(t);
205     return (h != IntrusiveHeapLink::kNotMember) && (h < size()) &&
206            heap()[h] == t;
207   }
208 
reserve(size_type n)209   void reserve(size_type n) { heap().reserve(n); }
210 
capacity() const211   size_type capacity() const { return heap().capacity(); }
212 
get_allocator() const213   allocator_type get_allocator() const { return rep_.heap_.get_allocator(); }
214 
215  private:
216   typedef std::vector<pointer, allocator_type> heap_type;
217 
218   // Empty base class optimization for pointer_compare and link_access.
219   // The heap_ data member retains a copy of the allocator, so it is not
220   // stored explicitly.
221   struct Rep : pointer_compare_type, link_access_type {
Reptensorflow::__anonf0ad67bd0111::IntrusiveHeap::Rep222     explicit Rep(const pointer_compare_type& cmp,
223                  const link_access_type& link_access,
224                  const allocator_type& alloc)
225         : pointer_compare_type(cmp),
226           link_access_type(link_access),
227           heap_(alloc) {}
228     heap_type heap_;  // NOLINT
229   };
230 
compare() const231   const pointer_compare_type& compare() const { return rep_; }
232 
link_access() const233   const link_access_type& link_access() const { return rep_; }
234 
heap() const235   const heap_type& heap() const { return rep_.heap_; }
heap()236   heap_type& heap() { return rep_.heap_; }
237 
GetPositionOf(const_pointer t) const238   size_type GetPositionOf(const_pointer t) const {
239     return link_access().Get(t).get();
240   }
241 
SetPositionOf(pointer t,size_type pos) const242   void SetPositionOf(pointer t, size_type pos) const {
243     return link_access().Set(t, IntrusiveHeapLink(pos));
244   }
245 
FixHeapUp(pointer t)246   void FixHeapUp(pointer t) {
247     size_type h = GetPositionOf(t);
248     while (h != 0) {
249       size_type parent = (h - 1) >> 1;
250       if (compare()(heap()[parent], t)) {
251         break;
252       }
253       heap()[h] = heap()[parent];
254       SetPositionOf(heap()[h], h);
255       h = parent;
256     }
257     heap()[h] = t;
258     SetPositionOf(t, h);
259   }
260 
FixHeapDown(pointer t)261   void FixHeapDown(pointer t) {
262     size_type h = GetPositionOf(t);
263     for (;;) {
264       size_type kid = (h << 1) + 1;
265       if (kid >= heap().size()) {
266         break;
267       }
268       if (kid + 1 < heap().size() && compare()(heap()[kid + 1], heap()[kid])) {
269         ++kid;
270       }
271       if (compare()(t, heap()[kid])) {
272         break;
273       }
274       heap()[h] = heap()[kid];
275       SetPositionOf(heap()[h], h);
276       h = kid;
277     }
278 
279     heap()[h] = t;
280     SetPositionOf(t, h);
281   }
282 
283   Rep rep_;
284 };
285 
CoreDeviceLabel(int core)286 string CoreDeviceLabel(int core) {
287   return strings::StrCat("/device:", DEVICE_TPU_REPLICATED_CORE, ":", core);
288 }
289 
290 // Creates a unique node name with a particular prefix.
UniqueNodeName(const StringPiece prefix,Graph * graph)291 string UniqueNodeName(const StringPiece prefix, Graph* graph) {
292   return graph->NewName(strings::StrCat(prefix, "/_", internal::GetNodeId()));
293 }
294 
SetNodeDeviceForTPUCommunication(DeviceNameUtils::ParsedName device,const string & target_device_type,Node * node)295 Status SetNodeDeviceForTPUCommunication(DeviceNameUtils::ParsedName device,
296                                         const string& target_device_type,
297                                         Node* node) {
298   TF_RET_CHECK(device.has_type && device.type == DEVICE_TPU_NODE);
299   TF_RET_CHECK(device.has_id);
300   TF_RET_CHECK(HasNodeAttr(node->def(), kXlaHasHostTransferAttrName));
301 
302   // Store the device instance as an attr on the Node.
303   TF_RETURN_IF_ERROR(SetDeviceOrdinalAttributeForNode(node, device.id));
304 
305   // Place the execute Op on the TPU_SYSTEM device so it can access the cache of
306   // compiled protos in the resource manager.
307   device.type = target_device_type;
308   device.id = 0;
309 
310   node->set_assigned_device_name(DeviceNameUtils::ParsedNameToString(device));
311   return Status::OK();
312 }
313 
314 // Iterate over the nodes in the original graph and find all the TPUReplicate
315 // nodes, and all the nodes that are part of outside_compilation clusters.
FindTaggedNodes(Graph * graph,std::vector<Node * > * replicate_nodes,std::map<string,DistributedTPURewritePass::OutsideCompilationNodeMap> * outside_compilation_nodes,std::map<string,std::vector<Node * >> * head_tail_outside_compilation_nodes)316 Status FindTaggedNodes(
317     Graph* graph, std::vector<Node*>* replicate_nodes,
318     std::map<string, DistributedTPURewritePass::OutsideCompilationNodeMap>*
319         outside_compilation_nodes,
320     std::map<string, std::vector<Node*>>* head_tail_outside_compilation_nodes) {
321   for (Node* node : graph->op_nodes()) {
322     if (node->type_string() == "_TPUReplicate") {
323       replicate_nodes->push_back(node);
324       const AttrValue* cluster_attr = node->attrs().Find(kTPUReplicateAttr);
325       if (cluster_attr == nullptr) {
326         return errors::Internal("TPUReplicate node ", node->name(), " has no ",
327                                 kTPUReplicateAttr, " attr.");
328       } else {
329         const string& cluster = cluster_attr->s();
330         if (cluster.empty()) {
331           return errors::Internal("Attr ", kTPUReplicateAttr, " on node ",
332                                   node->name(), " has no string value.");
333         }
334         if (outside_compilation_nodes->find(cluster) !=
335             outside_compilation_nodes->end()) {
336           return errors::Internal(
337               "TPUReplicate node ", node->name(), " has ", kTPUReplicateAttr,
338               " attr value '", cluster,
339               "' which is a duplicate of another TPUReplicate node in the "
340               "graph.");
341         }
342         (*outside_compilation_nodes)[cluster] =
343             DistributedTPURewritePass::OutsideCompilationNodeMap();
344         (*head_tail_outside_compilation_nodes)[cluster] = std::vector<Node*>();
345       }
346     }
347   }
348   for (Node* node : graph->op_nodes()) {
349     if (node->type_string() != "_TPUReplicate") {
350       const AttrValue* cluster_attr = node->attrs().Find(kTPUReplicateAttr);
351       const AttrValue* outside_compilation_attr =
352           node->attrs().Find(kOutsideCompilationAttr);
353       if (cluster_attr == nullptr) {
354         if (outside_compilation_attr != nullptr) {
355           return errors::Internal("Node ", node->name(), " has ",
356                                   kOutsideCompilationAttr, " attr but no ",
357                                   kTPUReplicateAttr, " attr.");
358         }
359       } else {
360         const string& cluster = cluster_attr->s();
361         if (cluster.empty()) {
362           return errors::Internal("Attr ", kTPUReplicateAttr, " on node ",
363                                   node->name(), " has no string value.");
364         }
365         const auto iter = outside_compilation_nodes->find(cluster);
366         if (iter == outside_compilation_nodes->end()) {
367           return errors::Internal(
368               "Attr ", kTPUReplicateAttr, " on node ", node->name(),
369               " does not correspond to a TPUReplicate node.");
370         }
371         if (outside_compilation_attr == nullptr) {
372           return errors::Internal("Node ", node->name(), " has ",
373                                   kTPUReplicateAttr, " attr but no ",
374                                   kOutsideCompilationAttr, " attr.");
375         }
376         const string& oc_cluster = outside_compilation_attr->s();
377         if (oc_cluster.empty()) {
378           return errors::Internal("Attr ", kOutsideCompilationAttr, " on node ",
379                                   node->name(), " has no string value.");
380         }
381 
382         // Outside compilation cluster at head and tail of TPU computation has
383         // already been moved to host and is already replicated. As so, do not
384         // replicate outside compilation nodes with replica id attribute.
385         int replica_id;
386         if (TryGetNodeAttr(node->def(), kXlaReplicaIdAttrName, &replica_id)) {
387           const AttrValue* head_attr =
388               node->attrs().Find("_xla_only_arg_or_oc_input");
389           const AttrValue* tail_attr =
390               node->attrs().Find("_xla_only_ret_or_oc_output");
391           if (((head_attr != nullptr) && (head_attr->b())) ||
392               ((tail_attr != nullptr) && (tail_attr->b()))) {
393             // This is safe as this has the same keys as
394             // outside_compilation_nodes which we already know has this key.
395             (*head_tail_outside_compilation_nodes)[cluster].push_back(node);
396           }
397           continue;
398         }
399         iter->second[oc_cluster].push_back(node);
400       }
401     }
402   }
403   return Status::OK();
404 }
405 
406 // Helper class to spread TPU computation arguments and return values
407 // across cores.
408 // If all shapes are fully defined, balance by their size.
409 // If some of them are not fully defined, the undefined shapes size will
410 // be estimated with the average size of the fully defined ones.
411 // If none are defined, fall back to round-robin.
412 class TensorDevicePlacer {
413  public:
414   // Creates a TensorDevicePlacer object to distribute arguments or
415   // return values to a set of num_devices devices, where the types and
416   // the inferred shapes of the inputs (arguments or return values) are
417   // passed in types and shapes.
TensorDevicePlacer(int64 num_devices,const DataTypeVector & types,const std::vector<InferredShape> & shapes)418   TensorDevicePlacer(int64 num_devices, const DataTypeVector& types,
419                      const std::vector<InferredShape>& shapes)
420       : index_nodes_(num_devices), sizes_(types.size()) {
421     int64 total_size = 0;
422     int64 num_defined = 0;
423     for (int64 i = 0; i < types.size(); ++i) {
424       sizes_[i] = GetInferredShapeSize(shapes[i], types[i]);
425       if (sizes_[i] >= 0) {
426         total_size += sizes_[i];
427         ++num_defined;
428       }
429     }
430     // If a shape is undefined, select a size for it which is the average
431     // of the defined shapes. If no shapes are defined, assign 1 so that we
432     // get round-robin behavior.
433     int64 undefined_shape_size =
434         (num_defined > 0) ? total_size / num_defined : 1;
435     for (int64 i = 0; i < sizes_.size(); ++i) {
436       if (sizes_[i] < 0) {
437         sizes_[i] = undefined_shape_size;
438       }
439     }
440 
441     for (int64 i = 0; i < num_devices; ++i) {
442       heap_.Push(&index_nodes_[i]);
443     }
444   }
445 
446   // Reports that the argument/return-value at index has been assigned
447   // by the user to a given device.
ReportDeviceAssigned(int64 device,int64 index)448   void ReportDeviceAssigned(int64 device, int64 index) {
449     DeviceNode* node = &index_nodes_.at(device);
450     node->size += sizes_.at(index);
451     heap_.Adjust(node);
452   }
453 
454   // Retrieves the device at which the argument/return-value at index
455   // should be assigned to.
RetrieveAssignment(int64 index)456   int64 RetrieveAssignment(int64 index) {
457     DeviceNode* node = heap_.top();
458     int64 device = node - index_nodes_.data();
459     node->size += sizes_.at(index);
460     heap_.Adjust(node);
461     return device;
462   }
463 
464  private:
465   struct DeviceNode {
466     struct Compare {
467       // Compare functor to implement a min heap using the ::gtl::IntrusiveHeap
468       // infrastructure.
operator ()tensorflow::__anonf0ad67bd0111::TensorDevicePlacer::DeviceNode::Compare469       bool operator()(const DeviceNode* lhs, const DeviceNode* rhs) const {
470         return lhs->size < rhs->size;
471       }
472     };
473 
474     IntrusiveHeapLink heap;
475     int64 size = 0;
476   };
477 
GetInferredShapeSize(const InferredShape & ishape,DataType dtype)478   static int64 GetInferredShapeSize(const InferredShape& ishape,
479                                     DataType dtype) {
480     return ishape.shape.IsFullyDefined()
481                ? ishape.shape.num_elements() * DataTypeSize(dtype)
482                : -1;
483   }
484 
485   std::vector<DeviceNode> index_nodes_;
486   IntrusiveHeap<DeviceNode, typename DeviceNode::Compare> heap_;
487   std::vector<int64> sizes_;
488 };
489 
ValidateCoreNumber(int64 core,int64 num_cores_per_replica)490 Status ValidateCoreNumber(int64 core, int64 num_cores_per_replica) {
491   if (core < 0 || core >= num_cores_per_replica) {
492     return tensorflow::errors::InvalidArgument("Invalid core ID: ", core,
493                                                ". The valid core IDs are [0..",
494                                                num_cores_per_replica, ")");
495   }
496   return Status::OK();
497 }
498 
FindHostComputeKeyPlaceholderNodes(const Graph * graph,const std::vector<Node * > & replicate_nodes,std::unordered_map<string,Node * > * host_compute_key_placeholder_map)499 Status FindHostComputeKeyPlaceholderNodes(
500     const Graph* graph, const std::vector<Node*>& replicate_nodes,
501     std::unordered_map<string, Node*>* host_compute_key_placeholder_map) {
502   host_compute_key_placeholder_map->clear();
503   for (const auto node : replicate_nodes) {
504     (*host_compute_key_placeholder_map)[node->name()] = nullptr;
505   }
506 
507   for (Node* node : graph->op_nodes()) {
508     if (node->type_string() == "Placeholder" &&
509         str_util::EndsWith(node->name(), "_key_placeholder")) {
510       const AttrValue* call_node_attr =
511           node->attrs().Find("_host_compute_call_node");
512       if (call_node_attr != nullptr) {
513         auto iter = host_compute_key_placeholder_map->find(call_node_attr->s());
514         if (iter == host_compute_key_placeholder_map->end()) {
515           return errors::InvalidArgument(
516               "Node ", node->name(), " has _host_compute_call_node attribute '",
517               call_node_attr->s(), "' that doesn't correspond to a call node");
518         }
519         if (iter->second != nullptr) {
520           return errors::InvalidArgument(
521               "Key placeholder node ", iter->second->name(), " for call node ",
522               call_node_attr->s(), " previously found as ",
523               iter->second->name());
524         }
525         iter->second = node;
526       }
527     }
528   }
529 
530   return Status::OK();
531 }
532 
ReplaceCompilationResultNodeWithIdentity(Graph * graph,Node ** node)533 Status ReplaceCompilationResultNodeWithIdentity(Graph* graph, Node** node) {
534   Node* old_node = *node;
535   // We want to replace the node with an identity node with the same name.
536   const string& node_name = old_node->name();
537 
538   // Create identity node.
539   TF_ASSIGN_OR_RETURN(
540       Node * id_node,
541       BuildIdentityNode(graph, node_name, DT_STRING,
542                         /*input=*/nullptr, /*requested_device=*/""));
543 
544   // No incoming edges are copied as a new one will be added from compile node
545   // to id_node.
546 
547   // Copy outgoing edges to the id node.
548   std::vector<const Edge*> out_edges(old_node->out_edges().begin(),
549                                      old_node->out_edges().end());
550   for (const Edge* edge : out_edges) {
551     Node* dst = edge->dst();
552     int src_output = edge->src_output();
553     int dst_input = edge->dst_input();
554 
555     if (src_output == Graph::kControlSlot) {
556       graph->AddControlEdge(id_node, dst);
557     } else {
558       graph->AddEdge(id_node, src_output, dst, dst_input);
559     }
560     graph->RemoveEdge(edge);
561   }
562   graph->RemoveNode(old_node);
563 
564   *node = id_node;
565   return Status::OK();
566 }
567 
FillPaddingMap(const Node & replicate_node,protobuf::RepeatedPtrField<tpu::PaddingMap> * padding_maps)568 Status FillPaddingMap(
569     const Node& replicate_node,
570     protobuf::RepeatedPtrField<tpu::PaddingMap>* padding_maps) {
571   std::vector<string> padding_map_strs;
572   TF_RETURN_IF_ERROR(
573       GetNodeAttr(replicate_node.attrs(), "padding_map", &padding_map_strs));
574   padding_maps->Reserve(padding_map_strs.size());
575   for (const string& padding_map_str : padding_map_strs) {
576     tpu::PaddingMap* padding_map = padding_maps->Add();
577     if (!padding_map->ParseFromString(padding_map_str)) {
578       return errors::InvalidArgument(
579           "Malformed padding_map serialized string: ", padding_map_str);
580     }
581   }
582   return Status::OK();
583 }
584 
GetStepMarkerLocation(const Node & replicate_node,xla::DebugOptions::StepMarkerLocation * location)585 Status GetStepMarkerLocation(const Node& replicate_node,
586                              xla::DebugOptions::StepMarkerLocation* location) {
587   string step_marker_location_attr;
588   TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), "step_marker_location",
589                                  &step_marker_location_attr));
590   if (step_marker_location_attr.empty()) {
591     *location = xla::DebugOptions::STEP_MARK_AT_ENTRY;
592   } else {
593     if (!xla::DebugOptions::StepMarkerLocation_Parse(step_marker_location_attr,
594                                                      location)) {
595       return errors::InvalidArgument("Malformed step_marker_location: ",
596                                      step_marker_location_attr);
597     }
598   }
599   return Status::OK();
600 }
601 
602 // Extracts a map of dimension and number of splits for tiled input from xla
603 // sharding attribute.
GetDimensionIndicesAndNumSplitsFromSharding(const xla::OpSharding & sharding,std::map<int,int> * split_dimension_map)604 Status GetDimensionIndicesAndNumSplitsFromSharding(
605     const xla::OpSharding& sharding, std::map<int, int>* split_dimension_map) {
606   int64 tensor_tile_rank = sharding.tile_assignment_dimensions_size();
607   if (sharding.replicate_on_last_tile_dim()) {
608     tensor_tile_rank--;
609   }
610   for (int dim_index = 0; dim_index < tensor_tile_rank; dim_index++) {
611     if (sharding.tile_assignment_dimensions(dim_index) > 1) {
612       split_dimension_map->emplace(
613           dim_index, sharding.tile_assignment_dimensions(dim_index));
614     }
615   }
616 
617   if (split_dimension_map->empty()) {
618     return errors::InvalidArgument("Arg has unnecessary tiled sharding: ",
619                                    sharding.DebugString());
620   }
621   return Status::OK();
622 }
623 
624 // Updates contents of the function with `function_name` in function library
625 // definition `flib_def` to `new_graph`. This is required when graph
626 // transformation happens inside a function call body.
UpdateFunctionLibDefinition(const Graph & new_graph,const std::string & function_name,FunctionLibraryDefinition * flib_def)627 Status UpdateFunctionLibDefinition(const Graph& new_graph,
628                                    const std::string& function_name,
629                                    FunctionLibraryDefinition* flib_def) {
630   FunctionDef graph_fdef;
631   TF_RETURN_IF_ERROR(GraphToFunctionDef(new_graph, function_name, &graph_fdef));
632   TF_RETURN_IF_ERROR(flib_def->ReplaceFunction(function_name, graph_fdef));
633   return Status::OK();
634 }
635 
636 struct NodeOut {
637   Node* node;
638   int index;
639 };
640 
641 struct ShardedInputIndex {
642   int replica_id;
643   int argument_index;
644 
operator <tensorflow::__anonf0ad67bd0111::ShardedInputIndex645   bool operator<(const ShardedInputIndex& rhs) const {
646     return std::tie(replica_id, argument_index) <
647            std::tie(rhs.replica_id, rhs.argument_index);
648   }
649 };
650 
651 struct ShardedInputInfo {
652   // Split node that would be connected to tiled input Node.
653   Node* split_node;
654   // List of splits nodes and output index of the split node from which sharded
655   // input will be connected to the TPUExecute node. The inputs are ordered by
656   // logical core ids.
657   std::vector<NodeOut> sharded_inputs;
658 };
659 
660 // Adds pad node after split node to graph for uneven sharding tiled inputs.
661 // |graph| owns the returned Node* instance.
CreatePadNode(const int padding,const int num_dims,const int split_dim,DataType dtype,Node * control_predecessor,Node * split_node,const int split_index,Graph * graph)662 xla::StatusOr<Node*> CreatePadNode(const int padding, const int num_dims,
663                                    const int split_dim, DataType dtype,
664                                    Node* control_predecessor, Node* split_node,
665                                    const int split_index, Graph* graph) {
666   // Add paddings node.
667   Status s;
668   NodeDef paddings_def;
669   paddings_def.set_name(
670       graph->NewName(absl::StrCat(split_node->name(), "/paddings")));
671   paddings_def.set_op("Const");
672   AddNodeAttr("dtype", DT_INT32, &paddings_def);
673   paddings_def.set_device(split_node->assigned_device_name());
674   TensorProto sizes_tensor_proto;
675   sizes_tensor_proto.set_dtype(DT_INT32);
676   for (int i = 0; i < num_dims; ++i) {
677     sizes_tensor_proto.add_int_val(0);
678     if (i == split_dim) {
679       sizes_tensor_proto.add_int_val(padding);
680     } else {
681       sizes_tensor_proto.add_int_val(0);
682     }
683   }
684   TensorShape sizes_shape({num_dims, 2});
685   sizes_shape.AsProto(sizes_tensor_proto.mutable_tensor_shape());
686   AddNodeAttr("value", sizes_tensor_proto, &paddings_def);
687   Node* paddings_node = graph->AddNode(paddings_def, &s);
688   TF_RETURN_IF_ERROR(s);
689 
690   // Add Pad node.
691   NodeDef pad_def;
692   pad_def.set_name(graph->NewName(
693       absl::StrCat(split_node->name(), "/pad_shard_", split_index)));
694   pad_def.set_op("Pad");
695   pad_def.set_device(split_node->assigned_device_name());
696   AddNodeAttr("T", dtype, &pad_def);
697   AddNodeAttr("Tpaddings", DT_INT32, &pad_def);
698   pad_def.add_input(absl::StrCat(split_node->name(), ":", split_index));
699   pad_def.add_input(absl::StrCat(paddings_node->name(), ":0"));
700   Node* pad_node = graph->AddNode(pad_def, &s);
701   pad_node->set_assigned_device_name(split_node->assigned_device_name());
702   TF_RETURN_IF_ERROR(s);
703   // Add edges for pad node.
704   graph->AddEdge(split_node, split_index, pad_node, 0);
705   graph->AddEdge(paddings_node, 0, pad_node, 1);
706   graph->AddControlEdge(control_predecessor, pad_node);
707   return pad_node;
708 }
709 
710 // Adds split node and split dimension node to graph for sharding tiled inputs.
711 // |graph| owns the returned Node* instance.
CreateSplitNode(const int num_splits,const int dim,const int num_dims,const int64 padding,const int orig_src_output,DataType dtype,absl::string_view name_prefix,Node * control_predecessor,Node * orig_src,Graph * graph)712 xla::StatusOr<Node*> CreateSplitNode(const int num_splits, const int dim,
713                                      const int num_dims, const int64 padding,
714                                      const int orig_src_output, DataType dtype,
715                                      absl::string_view name_prefix,
716                                      Node* control_predecessor, Node* orig_src,
717                                      Graph* graph) {
718   const std::string input_assigned_device = orig_src->assigned_device_name();
719   Node* to_split_node = orig_src;
720   int to_split_index = orig_src_output;
721   if (padding > 0) {
722     TF_ASSIGN_OR_RETURN(
723         Node * pad_node,
724         CreatePadNode(padding, num_dims, dim, dtype, control_predecessor,
725                       orig_src, orig_src_output, graph));
726     to_split_node = pad_node;
727     to_split_index = 0;
728   }
729 
730   // Add a split dimension node.
731   NodeDef split_dim_def;
732   split_dim_def.set_name(
733       graph->NewName(absl::StrCat(name_prefix, "/split_dim")));
734   split_dim_def.set_op("Const");
735   split_dim_def.set_device(input_assigned_device);
736   AddNodeAttr("dtype", DT_INT32, &split_dim_def);
737   TensorProto tensor_proto;
738   tensor_proto.set_dtype(DT_INT32);
739   tensor_proto.add_int_val(dim);
740   TensorShape shape({});
741   shape.AsProto(tensor_proto.mutable_tensor_shape());
742   AddNodeAttr("value", tensor_proto, &split_dim_def);
743   Status s;
744   Node* split_dim_node = graph->AddNode(split_dim_def, &s);
745   TF_RETURN_IF_ERROR(s);
746   // Add a split node.
747   NodeDef split_def;
748   split_def.set_name(graph->NewName(absl::StrCat(name_prefix, "/split")));
749   split_def.set_op("Split");
750   split_def.set_device(input_assigned_device);
751   AddNodeAttr("num_split", num_splits, &split_def);
752   AddNodeAttr("T", dtype, &split_def);
753   split_def.add_input(absl::StrCat(split_dim_node->name(), ":0"));
754   split_def.add_input(absl::StrCat(to_split_node->name(), ":", to_split_index));
755   Node* split_node = graph->AddNode(split_def, &s);
756   TF_RETURN_IF_ERROR(s);
757 
758   split_node->set_assigned_device_name(input_assigned_device);
759 
760   // If colocate the newly created split op to source node of input to TPU
761   // computation.
762   split_node->AddAttr(kColocationAttrName,
763                       std::vector<string>{absl::StrCat(kColocationGroupPrefix,
764                                                        orig_src->name())});
765 
766   graph->AddEdge(split_dim_node, 0, split_node, 0);
767   graph->AddEdge(to_split_node, to_split_index, split_node, 1);
768 
769   // Add a control dependency from `control_predecessor` to newly created
770   // constant node. This ensures that newly added split/split dim
771   // nodes are placed inside correct while loop frames when TPUExecute
772   // node is inside a host training loop.
773   graph->AddControlEdge(control_predecessor, split_dim_node);
774   return split_node;
775 }
776 
GetPadding(const int split_dim,const int num_splits,const PartialTensorShape & partial_tensor_shape)777 int64 GetPadding(const int split_dim, const int num_splits,
778                  const PartialTensorShape& partial_tensor_shape) {
779   // If dim dimension is not defined, no uneven sharding support.
780   if (partial_tensor_shape.dim_size(split_dim) <= 0) {
781     return 0;
782   }
783   int64 per_split_size = tensorflow::MathUtil::CeilOfRatio<int64>(
784       partial_tensor_shape.dim_size(split_dim), num_splits);
785   int64 total_padding =
786       per_split_size * num_splits - partial_tensor_shape.dim_size(split_dim);
787   return total_padding;
788 }
789 
790 // Creates a set of splits nodes that shards tiled input node in graph.
CreateOrGetSplitNodesForInputSharding(const xla::OpSharding & sharding,int orig_arg_num,DataType dtype,const PartialTensorShape & partial_tensor_shape,int replica_id,int orig_src_output,Node * orig_src,Node * control_predecessor,Graph * graph,std::map<ShardedInputIndex,ShardedInputInfo> * arg_index_to_sharded_input_map)791 xla::StatusOr<ShardedInputInfo> CreateOrGetSplitNodesForInputSharding(
792     const xla::OpSharding& sharding, int orig_arg_num, DataType dtype,
793     const PartialTensorShape& partial_tensor_shape, int replica_id,
794     int orig_src_output, Node* orig_src, Node* control_predecessor,
795     Graph* graph,
796     std::map<ShardedInputIndex, ShardedInputInfo>*
797         arg_index_to_sharded_input_map) {
798   ShardedInputIndex input_index{replica_id, orig_arg_num};
799   auto iter = arg_index_to_sharded_input_map->find(input_index);
800   if (iter != arg_index_to_sharded_input_map->end()) {
801     return iter->second;
802   }
803   // Maps input dimension and number of splits with which the
804   // dimension sharded.
805   std::map<int, int> split_dimension_map;
806   TF_RETURN_IF_ERROR(GetDimensionIndicesAndNumSplitsFromSharding(
807       sharding, &split_dimension_map));
808   TF_RET_CHECK(!split_dimension_map.empty())
809       << "Unnecessary sharding attribute found.";
810 
811   // For v1 while loop, nodes inside the loop body must either
812   //  1) Have data edges from while loop input node.
813   //  or
814   //  2) Have direct control dependency from while loop input control
815   //     node.
816   //
817   // As so, if we are adding Split node inside, while loop body,
818   // we must manually add a control dependency to a node inside
819   // a while loop (i.e. `control_predecessor`) to constant nodes
820   // without data in-edges to make sure that added split nodes
821   // have correct frame name. Else, placer will complain when
822   // `BuildControlFlow()` is invoked.
823 
824   auto sharding_it = split_dimension_map.begin();
825   std::queue<Node*> split_nodes_for_dimension;
826   absl::flat_hash_map<Node*, int> node_to_split_dim;
827   int split_dimension = sharding_it->first;
828   int num_split = sharding_it->second;
829 
830   // Creates a tree of split nodes for sharding tiled inputs. Splits nodes
831   // are created such that input data is sharded in row major order.
832   // Split nodes at ith depth from the original input node represent nodes
833   // that split the input data at ith dimension.
834   TF_ASSIGN_OR_RETURN(
835       Node * root_split_node,
836       CreateSplitNode(
837           num_split, split_dimension, partial_tensor_shape.dims(),
838           GetPadding(split_dimension, num_split, partial_tensor_shape),
839           orig_src_output, dtype,
840           absl::StrCat("sharded_input/replica_", replica_id, "_dim_",
841                        split_dimension),
842           control_predecessor, orig_src, graph));
843   sharding_it++;
844 
845   split_nodes_for_dimension.emplace(root_split_node);
846   node_to_split_dim[root_split_node] = split_dimension;
847 
848   while (sharding_it != split_dimension_map.end()) {
849     split_dimension = sharding_it->first;
850     num_split = sharding_it->second;
851     int num_split_nodes_in_dimension = split_nodes_for_dimension.size();
852     for (int i = 0; i < num_split_nodes_in_dimension; ++i) {
853       Node* input_split_node = split_nodes_for_dimension.front();
854       split_nodes_for_dimension.pop();
855       for (int src_output_index = 0;
856            src_output_index < input_split_node->num_outputs();
857            ++src_output_index) {
858         TF_ASSIGN_OR_RETURN(
859             Node * split_node,
860             CreateSplitNode(
861                 num_split, split_dimension, partial_tensor_shape.dims(),
862                 GetPadding(split_dimension, num_split, partial_tensor_shape),
863                 src_output_index, dtype,
864                 absl::StrCat("sharded_input/replica_", replica_id, "_dim_",
865                              split_dimension),
866                 control_predecessor, input_split_node, graph));
867         split_nodes_for_dimension.emplace(split_node);
868         node_to_split_dim[split_node] = split_dimension;
869       }
870     }
871     sharding_it++;
872   }
873 
874   // `split_nodes_for_dimension` now includes final split nodes
875   // from which sharded data will be fed into TPUExcute nodes -- sorted by
876   // row major order.
877   std::vector<NodeOut> sharded_inputs_list(
878       sharding.tile_assignment_devices_size());
879   int64 next_core_tile_index = 0;
880   while (!split_nodes_for_dimension.empty()) {
881     Node* split_node = split_nodes_for_dimension.front();
882     split_nodes_for_dimension.pop();
883     int num_splits;
884     TF_RETURN_IF_ERROR(
885         GetNodeAttr(split_node->def(), "num_split", &num_splits));
886     for (int out_index = 0; out_index < num_splits; ++out_index) {
887       int64 repeat_count = sharding.replicate_on_last_tile_dim()
888                                ? *sharding.tile_assignment_dimensions().rbegin()
889                                : 1;
890       for (int64 i = 0; i < repeat_count; ++i) {
891         int64 next_core =
892             sharding.tile_assignment_devices(next_core_tile_index++);
893         sharded_inputs_list[next_core] = NodeOut{split_node, out_index};
894       }
895     }
896   }
897 
898   ShardedInputInfo sharded_input_info{root_split_node,
899                                       std::move(sharded_inputs_list)};
900   (*arg_index_to_sharded_input_map)[input_index] = sharded_input_info;
901   return sharded_input_info;
902 }
903 
904 // Creates a concat node to be used for aggregating sharded retvals across
905 // logical cores.
CreateConcatNode(int dim,int num_splits,DataType dtype,absl::string_view name_prefix,const std::vector<NodeOut> & inputs,Graph * graph,absl::string_view device)906 xla::StatusOr<Node*> CreateConcatNode(int dim, int num_splits, DataType dtype,
907                                       absl::string_view name_prefix,
908                                       const std::vector<NodeOut>& inputs,
909                                       Graph* graph, absl::string_view device) {
910   // Add a Concat dim node.
911   NodeDef concat_dim_def;
912   concat_dim_def.set_name(
913       graph->NewName(absl::StrCat(name_prefix, "/concat_dim")));
914   concat_dim_def.set_op("Const");
915   AddNodeAttr("dtype", DT_INT32, &concat_dim_def);
916   concat_dim_def.set_device(std::string(device));
917   TensorProto tensor_proto;
918   tensor_proto.set_dtype(DT_INT32);
919   tensor_proto.add_int_val(dim);
920   TensorShape shape({});
921   shape.AsProto(tensor_proto.mutable_tensor_shape());
922   AddNodeAttr("value", tensor_proto, &concat_dim_def);
923   Status s;
924   Node* concat_dim_node = graph->AddNode(concat_dim_def, &s);
925   TF_RETURN_IF_ERROR(s);
926 
927   // Add a Concat node.
928   NodeDef concat_def;
929   concat_def.set_name(graph->NewName(absl::StrCat(name_prefix, "/concat")));
930   concat_def.set_op("Concat");
931   AddNodeAttr("N", num_splits, &concat_def);
932   AddNodeAttr("T", dtype, &concat_def);
933   concat_def.add_input(absl::StrCat(concat_dim_node->name(), ":0"));
934   concat_def.set_device(std::string(device));
935   for (const auto& i : inputs) {
936     concat_def.add_input(absl::StrCat(i.node->name(), ":", i.index));
937   }
938   Node* concat_node = graph->AddNode(concat_def, &s);
939   TF_RETURN_IF_ERROR(s);
940 
941   graph->AddEdge(concat_dim_node, 0, concat_node, 0);
942 
943   // 0th input to concat node is a concat dim node. So we start from 1st input
944   // and add all input edges.
945   int dst_input = 1;
946   for (const auto& i : inputs) {
947     graph->AddEdge(i.node, i.index, concat_node, dst_input);
948     ++dst_input;
949   }
950   return concat_node;
951 }
952 
953 // Adds slice node after concat node to graph for uneven sharding tiled inputs.
CreateSliceNode(DataType dtype,const PartialTensorShape & shape,Node * concat_node,const int concat_out_index,Graph * graph,absl::string_view device)954 xla::StatusOr<Node*> CreateSliceNode(DataType dtype,
955                                      const PartialTensorShape& shape,
956                                      Node* concat_node,
957                                      const int concat_out_index, Graph* graph,
958                                      absl::string_view device) {
959   Status s;
960   // Add begin node for concat.
961   NodeDef begin_def;
962   begin_def.set_name(
963       graph->NewName(absl::StrCat(concat_node->name(), "/slice_begin")));
964   begin_def.set_op("Const");
965   AddNodeAttr("dtype", DT_INT32, &begin_def);
966   begin_def.set_device(std::string(device));
967   TensorProto begin_tensor_proto;
968   begin_tensor_proto.set_dtype(DT_INT32);
969   for (int i = 0; i < shape.dims(); ++i) {
970     begin_tensor_proto.add_int_val(0);
971   }
972   TensorShape begin_shape({shape.dims()});
973   begin_shape.AsProto(begin_tensor_proto.mutable_tensor_shape());
974   AddNodeAttr("value", begin_tensor_proto, &begin_def);
975   Node* begin_node = graph->AddNode(begin_def, &s);
976   TF_RETURN_IF_ERROR(s);
977 
978   // Add size node.
979   NodeDef size_def;
980   size_def.set_name(
981       graph->NewName(absl::StrCat(concat_node->name(), "/slice_size")));
982   size_def.set_op("Const");
983   AddNodeAttr("dtype", DT_INT32, &size_def);
984   size_def.set_device(std::string(device));
985   TensorProto sizes_tensor_proto;
986   sizes_tensor_proto.set_dtype(DT_INT32);
987   for (int i = 0; i < shape.dims(); ++i) {
988     sizes_tensor_proto.add_int_val(shape.dim_size(i));
989   }
990   TensorShape sizes_shape({shape.dims()});
991   sizes_shape.AsProto(sizes_tensor_proto.mutable_tensor_shape());
992   AddNodeAttr("value", sizes_tensor_proto, &size_def);
993   Node* size_node = graph->AddNode(size_def, &s);
994   TF_RETURN_IF_ERROR(s);
995 
996   // Add Slice node.
997   NodeDef slice_def;
998   slice_def.set_name(
999       graph->NewName(absl::StrCat(concat_node->name(), "/slice")));
1000   slice_def.set_op("Slice");
1001   slice_def.set_device(std::string(device));
1002   AddNodeAttr("T", dtype, &slice_def);
1003   AddNodeAttr("Index", DT_INT32, &slice_def);
1004   slice_def.add_input(absl::StrCat(concat_node->name(), ":", concat_out_index));
1005   slice_def.add_input(absl::StrCat(begin_node->name(), ":0"));
1006   slice_def.add_input(absl::StrCat(size_node->name(), ":0"));
1007   Node* slice_node = graph->AddNode(slice_def, &s);
1008   TF_RETURN_IF_ERROR(s);
1009   // Add edges for slice node.
1010   graph->AddEdge(concat_node, concat_out_index, slice_node, 0);
1011   graph->AddEdge(begin_node, 0, slice_node, 1);
1012   graph->AddEdge(size_node, 0, slice_node, 2);
1013   return slice_node;
1014 }
1015 
1016 // Creates a set of Concat nodes that aggregates sharded outputs from TPUExecute
1017 // nodes into a single output. Sharded outputs are concatenated along row major
1018 // order. That is, tiled output along 0th dimension will be concatenated last.
CreateConcatNodesForRetval(const xla::OpSharding & sharding,DataType dtype,const PartialTensorShape & inferred_shape,int replica_id,const std::vector<NodeOut> & orig_inputs,Graph * graph,absl::string_view device)1019 xla::StatusOr<Node*> CreateConcatNodesForRetval(
1020     const xla::OpSharding& sharding, DataType dtype,
1021     const PartialTensorShape& inferred_shape, int replica_id,
1022     const std::vector<NodeOut>& orig_inputs, Graph* graph,
1023     absl::string_view device) {
1024   std::map<int, int> split_dimension_map;
1025   TF_RETURN_IF_ERROR(GetDimensionIndicesAndNumSplitsFromSharding(
1026       sharding, &split_dimension_map));
1027   std::vector<NodeOut> inputs_to_sharded_retval = orig_inputs;
1028   bool has_paddings = false;
1029 
1030   for (auto it = split_dimension_map.rbegin(); it != split_dimension_map.rend();
1031        it++) {
1032     auto dim = it->first;
1033     auto num_splits = it->second;
1034 
1035     int num_concat_nodes = inputs_to_sharded_retval.size() / num_splits;
1036     int input_index_to_concat_node = 0;
1037 
1038     std::vector<NodeOut> new_concat_nodes;
1039     for (int i = 0; i < num_concat_nodes; ++i) {
1040       auto concat_input_it =
1041           inputs_to_sharded_retval.begin() + input_index_to_concat_node;
1042       std::vector<NodeOut> inputs(concat_input_it,
1043                                   concat_input_it + num_splits);
1044       input_index_to_concat_node += num_splits;
1045 
1046       TF_ASSIGN_OR_RETURN(
1047           Node * concat_node,
1048           CreateConcatNode(
1049               dim, num_splits, dtype,
1050               absl::StrCat("sharded_output/replica_", replica_id, "_dim_", dim),
1051               inputs, graph, device));
1052       int64 paddings = GetPadding(dim, num_splits, inferred_shape);
1053       has_paddings |= paddings > 0;
1054       new_concat_nodes.emplace_back(NodeOut{concat_node, 0});
1055     }
1056     inputs_to_sharded_retval = new_concat_nodes;
1057   }
1058 
1059   TF_RET_CHECK(inputs_to_sharded_retval.size() == 1);
1060   if (has_paddings) {
1061     TF_ASSIGN_OR_RETURN(Node * slice_node,
1062                         CreateSliceNode(dtype, inferred_shape,
1063                                         inputs_to_sharded_retval.at(0).node,
1064                                         /*concat_out_index*/ 0, graph, device));
1065     return slice_node;
1066   }
1067   return inputs_to_sharded_retval.at(0).node;
1068 }
1069 
1070 // Set the padding ops the same devices as the original inputs. If the original
1071 // inputs are on TPUs, the padding ops will be placed on TPUs and XLA on demand
1072 // mode will be triggered, so we don't need to copy the data back to the host
1073 // to do the padding.
SetPaddingNodesDevices(Graph * graph)1074 Status SetPaddingNodesDevices(Graph* graph) {
1075   for (Node* n : graph->op_nodes()) {
1076     bool tpu_padding_attr;
1077     if (n->type_string() == "Pad" &&
1078         GetNodeAttr(n->attrs(), kPostDeviceRewriteAttr, &tpu_padding_attr)
1079             .ok()) {
1080       Node* unpadded_input;
1081       TF_RETURN_IF_ERROR(n->input_node(0, &unpadded_input));
1082 
1083       const string& requested_device = unpadded_input->requested_device();
1084       const string& assigned_device = unpadded_input->assigned_device_name();
1085       if (!requested_device.empty() || !assigned_device.empty()) {
1086         // The output nodes of the original unpadded inputs include the padded
1087         // inputs and real shapes of inputs, we assign those to the same device
1088         // as the original inputs.
1089         for (Node* out : unpadded_input->out_nodes()) {
1090           if (GetNodeAttr(out->attrs(), kPostDeviceRewriteAttr,
1091                           &tpu_padding_attr)
1092                   .ok()) {
1093             out->set_requested_device(requested_device);
1094             out->set_assigned_device_name(assigned_device);
1095           }
1096         }
1097         // There might be a tf.shape node added before TPUCompileOp, we need to
1098         // set its device as well.
1099         for (Node* out : n->out_nodes()) {
1100           if (n->type_string() == "Shape") {
1101             out->set_requested_device(requested_device);
1102             out->set_assigned_device_name(assigned_device);
1103           }
1104         }
1105       }
1106     }
1107   }
1108   return Status::OK();
1109 }
1110 
AssignedOrRequestedDevice(const Node * node)1111 const string& AssignedOrRequestedDevice(const Node* node) {
1112   if (!node->assigned_device_name().empty()) {
1113     return node->assigned_device_name();
1114   }
1115   return node->requested_device();
1116 }
1117 
IsTpuDevice(const string & device_string)1118 bool IsTpuDevice(const string& device_string) {
1119   DeviceNameUtils::ParsedName device;
1120   return DeviceNameUtils::ParseFullName(device_string, &device) &&
1121          device.type == DEVICE_TPU_NODE;
1122 }
1123 
1124 // Returns a set of device ops can be placed on TPU. There is no strict rule of
1125 // thumb to decide which ops should be in the list, but empirically they are
1126 // mostly dummy ops like Identity-like ops or control flow related ops. However
1127 // people can add also add other ops like Pad to allow data stay on TPU.
PlaceOnTPUOpList()1128 const absl::flat_hash_set<std::string>& PlaceOnTPUOpList() {
1129   static const auto place_on_tpu_ops = new absl::flat_hash_set<std::string>(
1130       {"Identity", "IdentityN", "Enter", "Exit", "Switch", "Merge",
1131        "NextIteration", "Shape", "_Retval"});
1132   return *place_on_tpu_ops;
1133 }
1134 
1135 // If an op satisfies the following conditions, it will be placed on the same
1136 // TPU device as its inputs:
1137 //   (1) The op can be placed on TPU (in the PlaceOnTPUOpList)
1138 //   (2) The op itself has no requested or assigned devices.
1139 //   (3) All the data inputs of this op are placed on the same device on TPUs.
1140 //       There are exceptions like the NextIterations input of Switch node can
1141 //       be placed on CPU as it is just a boolean.
1142 //
1143 // Returns true if the node device has been changed, otherwise returns false.
PlaceOpsOnTPU(Node * node)1144 bool PlaceOpsOnTPU(Node* node) {
1145   if (!AssignedOrRequestedDevice(node).empty() ||
1146       !PlaceOnTPUOpList().contains(node->type_string())) {
1147     return false;
1148   }
1149   string src_tpu_device = "";
1150   Node* src_node;
1151   for (const Edge* e : node->in_edges()) {
1152     if (e->IsControlEdge()) {
1153       continue;
1154     }
1155     Node* src = e->src();
1156     const string& src_device = AssignedOrRequestedDevice(src);
1157 
1158     // Make exceptions that we don't force the some inputs to place on TPUs.
1159     if (node->IsSwitch() && src->IsLoopCond()) {
1160       continue;
1161     }
1162 
1163     if (!IsTpuDevice(src_device) ||
1164         (!src_tpu_device.empty() && src_device != src_tpu_device)) {
1165       return false;
1166     }
1167     if (src_tpu_device.empty()) {
1168       src_tpu_device = src_device;
1169       src_node = src;
1170     }
1171   }
1172   node->set_assigned_device_name(src_node->assigned_device_name());
1173   node->set_requested_device(src_node->requested_device());
1174   return true;
1175 }
1176 
CreateOpMetadataFromNode(const Node & node)1177 xla::OpMetadata CreateOpMetadataFromNode(const Node& node) {
1178   xla::OpMetadata metadata;
1179   metadata.set_op_type(node.type_string());
1180   metadata.set_op_name(node.name());
1181   return metadata;
1182 }
1183 
1184 // Helper struct holding node (nullable) and associated sharding.
1185 struct NodeAndSharding {
NodeAndShardingtensorflow::__anonf0ad67bd0111::NodeAndSharding1186   explicit NodeAndSharding(const Node* node, const xla::OpSharding& sharding)
1187       : node(node), sharding(sharding) {}
1188 
1189   const Node* node;
1190   xla::OpSharding sharding;
1191 };
1192 
1193 // Validate sharding configuration derived from XlaSharding attribute.
1194 // Infer the core id from the OpSharding, if necessary.
ParseAndValidateSharding(const NodeAndSharding & node_and_sharding,const int num_cores_per_replica,int64 * inferred_core_id,absl::optional<NodeAndSharding> * result)1195 Status ParseAndValidateSharding(const NodeAndSharding& node_and_sharding,
1196                                 const int num_cores_per_replica,
1197                                 int64* inferred_core_id,
1198                                 absl::optional<NodeAndSharding>* result) {
1199   if (node_and_sharding.sharding.type() == xla::OpSharding::MAXIMAL) {
1200     int64 core_annotation =
1201         node_and_sharding.sharding.tile_assignment_devices(0);
1202     TF_RETURN_IF_ERROR(
1203         ValidateCoreNumber(core_annotation, num_cores_per_replica));
1204     if (*inferred_core_id == -1 || *inferred_core_id > core_annotation) {
1205       *inferred_core_id = core_annotation;
1206       result->emplace(node_and_sharding);
1207     }
1208   } else {
1209     if (node_and_sharding.sharding.type() == xla::OpSharding::OTHER) {
1210       for (int64 core : node_and_sharding.sharding.tile_assignment_devices()) {
1211         TF_RETURN_IF_ERROR(ValidateCoreNumber(core, num_cores_per_replica));
1212       }
1213     }
1214 
1215     if (!result->has_value()) {
1216       *result = node_and_sharding;
1217     } else {
1218       std::string result_value_serialized;
1219       xla::OpSharding result_value = result->value().sharding;
1220       result_value.clear_metadata();
1221       SerializeToStringDeterministic(result_value, &result_value_serialized);
1222 
1223       std::string sharding_serialized;
1224       xla::OpSharding sharding = node_and_sharding.sharding;
1225       sharding.clear_metadata();
1226       SerializeToStringDeterministic(sharding, &sharding_serialized);
1227 
1228       // TODO(lyandy): Choose the more granular sharding instead of always
1229       // assigning to core 0 (maximal).
1230       if (result_value_serialized != sharding_serialized) {
1231         // We see different shardings, assign to core 0.
1232         auto core_zero_sharding = xla::sharding_builder::AssignDevice(0);
1233         DCHECK_NE(node_and_sharding.node, nullptr);
1234         *core_zero_sharding.add_metadata() =
1235             CreateOpMetadataFromNode(*node_and_sharding.node);
1236         result->emplace(
1237             NodeAndSharding(node_and_sharding.node, core_zero_sharding));
1238       }
1239     }
1240   }
1241   return Status::OK();
1242 }
1243 
1244 // As XlaSharding node may be followed by Cast op or an Identity op,
1245 // recursively walk the graph and aggregate nodes connectd to
1246 // |input_node| or Cast/Identity op following the |input_node|.
FindNodesMaybeContainingShardingInfo(const Node & input_node,std::vector<const Node * > * nodes)1247 void FindNodesMaybeContainingShardingInfo(const Node& input_node,
1248                                           std::vector<const Node*>* nodes) {
1249   if (input_node.IsIdentity() || input_node.type_string() == "Cast") {
1250     for (const Node* connected_node : input_node.out_nodes())
1251       FindNodesMaybeContainingShardingInfo(*connected_node, nodes);
1252   }
1253   nodes->emplace_back(&input_node);
1254 }
1255 
1256 // Parse sharding configuration from |node| or it's adjacent nodes.
1257 // XlaSharding configuration may be derived from
1258 //   a) Connected Identity op node.
1259 //   b) Connected Cast op node.
1260 xla::StatusOr<absl::optional<NodeAndSharding>>
ParseInputShardingFromAdjacentNode(const int num_cores_per_replica,const Node & node)1261 ParseInputShardingFromAdjacentNode(const int num_cores_per_replica,
1262                                    const Node& node) {
1263   // If |node| has `device` attribute or is a XlaSharding op,
1264   // return the parsed OpSharding.
1265   TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
1266                       ParseShardingFromDevice(node, num_cores_per_replica,
1267                                               /*add_metadata=*/true));
1268   if (sharding.has_value()) {
1269     return absl::optional<NodeAndSharding>(NodeAndSharding(&node, *sharding));
1270   }
1271 
1272   // XlaShardingOp may be followed by an identity or followed by identity
1273   // and a Cast op.
1274   std::vector<const Node*> potential_nodes_with_input_sharding;
1275   FindNodesMaybeContainingShardingInfo(node,
1276                                        &potential_nodes_with_input_sharding);
1277   for (const Node* maybe_node_with_sharding_info :
1278        potential_nodes_with_input_sharding) {
1279     if (maybe_node_with_sharding_info->type_string() != "XlaSharding") continue;
1280 
1281     TF_ASSIGN_OR_RETURN(
1282         absl::optional<xla::OpSharding> sharding_config,
1283         ParseShardingFromDevice(*maybe_node_with_sharding_info,
1284                                 num_cores_per_replica, /*add_metadata=*/true));
1285     if (sharding_config.has_value()) {
1286       return absl::optional<NodeAndSharding>(
1287           NodeAndSharding(maybe_node_with_sharding_info, *sharding_config));
1288     }
1289   }
1290   return absl::optional<NodeAndSharding>();
1291 }
1292 
1293 // Walk the graph from an argument node to find OpSharding configuration
1294 // from its neighbor nodes. Sharding configuration may be inferred from
1295 //  1) Parsing XlaSharding attribute from neighboring node.
1296 //  2) If argument node is a resource, then by parsing adjacent nodes
1297 //     of the connected ReadVariable op.
ParseAndValidateShardingFromNeighbors(const int num_cores_per_replica,const std::string & arg_node_name,const Node & neighbor_node,int64 * inferred_core_id,bool * is_fast_mem,absl::optional<NodeAndSharding> * result)1298 Status ParseAndValidateShardingFromNeighbors(
1299     const int num_cores_per_replica, const std::string& arg_node_name,
1300     const Node& neighbor_node, int64* inferred_core_id, bool* is_fast_mem,
1301     absl::optional<NodeAndSharding>* result) {
1302   if (neighbor_node.attrs().Find(TPU_FAST_MEM_ATTR) != nullptr) {
1303     *is_fast_mem = true;
1304     VLOG(2) << "place " << neighbor_node.name() << " on fast memory because "
1305             << arg_node_name << " has " << TPU_FAST_MEM_ATTR << " attribute";
1306   }
1307 
1308   // XlaSharding information may be encoded on node directly connected to the
1309   // argument node.
1310   TF_ASSIGN_OR_RETURN(
1311       absl::optional<NodeAndSharding> node_and_sharding,
1312       ParseInputShardingFromAdjacentNode(num_cores_per_replica, neighbor_node));
1313   if (node_and_sharding.has_value()) {
1314     TF_RETURN_IF_ERROR(ParseAndValidateSharding(
1315         *node_and_sharding, num_cores_per_replica, inferred_core_id, result));
1316     return Status::OK();
1317   }
1318 
1319   // When we use variable in TPU computation, we always have a
1320   // XlaSharding op followed by a ReadVariableOp. As so, correctly parse
1321   // the users of ReadVariableOp for potential sharding configuration.
1322   if (neighbor_node.type_string() == "ReadVariableOp") {
1323     for (const Edge* e : neighbor_node.out_edges()) {
1324       if (e->IsControlEdge()) continue;
1325 
1326       if (e->dst()->attrs().Find(TPU_FAST_MEM_ATTR) != nullptr) {
1327         *is_fast_mem = true;
1328         VLOG(2) << "place " << arg_node_name << " on fast memory because "
1329                 << e->dst()->name() << TPU_FAST_MEM_ATTR << " attribute";
1330       }
1331 
1332       TF_ASSIGN_OR_RETURN(
1333           absl::optional<NodeAndSharding> node_and_sharding,
1334           ParseInputShardingFromAdjacentNode(num_cores_per_replica, *e->dst()));
1335       if (node_and_sharding.has_value()) {
1336         TF_RETURN_IF_ERROR(ParseAndValidateSharding(*node_and_sharding,
1337                                                     num_cores_per_replica,
1338                                                     inferred_core_id, result));
1339         return Status::OK();
1340       }
1341     }
1342   }
1343   return Status::OK();
1344 }
1345 
1346 }  // namespace
1347 
1348 // Inputs:
1349 //   replication_spec_string: the device to which the TPUReplicate node was
1350 //     assigned.
1351 //   device_set: the set of TF devices.
1352 // Outputs:
1353 //   tpu_compilation_device: the name of the TPU compilation device.
1354 //   num_tpus_per_task: the number of TPUs in each task. Verifies that all tasks
1355 //     have the same number of TPU devices.
1356 //   tpu_devices: the TPU devices, indexed by [task][device].
GetTPUDeviceNames(const string & replication_spec_string,const DeviceSet & device_set,string * tpu_compilation_device,int * num_tpus_per_task,std::vector<std::vector<Device * >> * tpu_devices)1357 static Status GetTPUDeviceNames(
1358     const string& replication_spec_string, const DeviceSet& device_set,
1359     string* tpu_compilation_device, int* num_tpus_per_task,
1360     std::vector<std::vector<Device*>>* tpu_devices) {
1361   // TODO(b/110910013) GetSystemDevice parses the spec and returns the name of
1362   // the tpu_system device, which we replace by the cpu device. We do this
1363   // replacement because we want to place the TPUCompileOp (and the compile
1364   // assert op) explicitly on cpu devices on the same job as the tpu_system
1365   // device.
1366   DeviceNameUtils::ParsedName replication_spec;
1367   Device* replication_device;
1368   TF_RETURN_IF_ERROR(DistributedTPURewriteHelpers::GetSystemDevice(
1369       replication_spec_string, device_set, &replication_spec,
1370       &replication_device));
1371   *tpu_compilation_device =
1372       str_util::StringReplace(replication_device->name(), DEVICE_TPU_SYSTEM,
1373                               DEVICE_CPU, /*replace_all=*/true);
1374 
1375   // Finds the set of TPU devices attached to the tasks in the job.
1376   TF_RETURN_IF_ERROR(DistributedTPURewriteHelpers::GetTPUDevices(
1377       replication_spec, device_set, num_tpus_per_task, tpu_devices));
1378 
1379   return Status::OK();
1380 }
1381 
1382 // Parses the topology attribute of TPUReplicate, and populates *topology with
1383 // a physical mesh coordinate to (task, device) mapping.
ParseTopologyAttr(const string & topology_attr,const tpu::TpuTopologyExternal & tpu_topology,int num_tasks,int num_tpus_per_task,xla::Array4D<std::pair<int,int>> * topology)1384 static Status ParseTopologyAttr(const string& topology_attr,
1385                                 const tpu::TpuTopologyExternal& tpu_topology,
1386                                 int num_tasks, int num_tpus_per_task,
1387                                 xla::Array4D<std::pair<int, int>>* topology) {
1388   static_assert(4 == kTPUTopologyRank, "Assumes the topology rank is 4");
1389   tpu::TopologyProto proto;
1390   proto.ParseFromString(topology_attr);
1391   if (proto.mesh_shape_size() != kTPUTopologyRank) {
1392     return errors::InvalidArgument("TPU topology must be rank ",
1393                                    kTPUTopologyRank);
1394   }
1395   if (proto.num_tasks() != num_tasks) {
1396     return errors::InvalidArgument("Mismatched number of TPU tasks");
1397   }
1398   if (proto.num_tpu_devices_per_task() != num_tpus_per_task) {
1399     return errors::InvalidArgument("Mismatched number of TPUs per task (",
1400                                    proto.num_tpu_devices_per_task(),
1401                                    " != ", num_tpus_per_task, ").");
1402   }
1403   if (proto.device_coordinates_size() !=
1404       num_tasks * num_tpus_per_task * kTPUTopologyRank) {
1405     return errors::InvalidArgument(
1406         "device coordinates should be ", num_tasks, "x", num_tpus_per_task, "x",
1407         kTPUTopologyRank, "; got ", proto.device_coordinates_size());
1408   }
1409 
1410   int devices_per_chip = tpu_topology.LogicalDevicesPerChip(kTensorCore);
1411   *topology = xla::Array4D<std::pair<int, int>>(
1412       tpu_topology.chip_bounds().x, tpu_topology.chip_bounds().y,
1413       tpu_topology.chip_bounds().z, devices_per_chip, {-1, -1});
1414   int pos = 0;
1415   for (int task = 0; task < num_tasks; ++task) {
1416     for (int device = 0; device < num_tpus_per_task; ++device) {
1417       int32 x = proto.device_coordinates(pos++);
1418       int32 y = proto.device_coordinates(pos++);
1419       int32 z = proto.device_coordinates(pos++);
1420       int32 core = proto.device_coordinates(pos++);
1421 
1422       if (!tpu_topology.HasChip(x, y, z) || core < 0 ||
1423           core >= devices_per_chip) {
1424         return errors::InvalidArgument(
1425             "Mesh coordinates (", x, ",", y, ",", z, ",", core,
1426             ") are not valid for the current TPU topology");
1427       }
1428       if ((*topology)(x, y, z, core).first != -1) {
1429         return errors::InvalidArgument("Duplicate coordinates (", x, ",", y,
1430                                        ",", z, ",", core, ") in TPU topology");
1431       }
1432       (*topology)(x, y, z, core) = {task, device};
1433     }
1434   }
1435   return Status::OK();
1436 }
1437 
1438 // Parses the value of the device_assignment attribute to TPUReplicate.
1439 // Populates *device_assignment; *device_assignment must be a 2D array with
1440 // shape (num_replicas, num_cores_per_replica).
ParseDeviceAssignmentAttr(absl::Span<const int> device_assignment_attr,const tpu::TpuTopologyExternal & tpu_topology,int num_replicas,int num_cores_per_replica,xla::Array2D<tpu::TpuCoreLocationExternal> * device_assignment)1441 static Status ParseDeviceAssignmentAttr(
1442     absl::Span<const int> device_assignment_attr,
1443     const tpu::TpuTopologyExternal& tpu_topology, int num_replicas,
1444     int num_cores_per_replica,
1445     xla::Array2D<tpu::TpuCoreLocationExternal>* device_assignment) {
1446   static_assert(4 == kTPUTopologyRank, "Assumes the topology rank is 4");
1447 
1448   const int64 device_assignment_attr_size =
1449       num_replicas * num_cores_per_replica * kTPUTopologyRank;
1450   if (device_assignment_attr.size() != device_assignment_attr_size) {
1451     return errors::InvalidArgument(
1452         "Length of device_assignment attribute must be equal to num_replicas (",
1453         num_replicas, ") * num_cores_per_replica (", num_cores_per_replica,
1454         ") * ", kTPUTopologyRank, " got ", device_assignment_attr.size());
1455   }
1456   for (int core : device_assignment_attr) {
1457     if (core < 0 || core >= kTPUMaxTopologySize) {
1458       return errors::InvalidArgument(
1459           "Invalid core number in device assignment: ", core);
1460     }
1461   }
1462 
1463   *device_assignment = xla::Array2D<tpu::TpuCoreLocationExternal>(
1464       num_replicas, num_cores_per_replica);
1465   int devices_per_chip = tpu_topology.LogicalDevicesPerChip(kTensorCore);
1466   xla::Array4D<int> replica_assignment(
1467       tpu_topology.chip_bounds().x, tpu_topology.chip_bounds().y,
1468       tpu_topology.chip_bounds().z, devices_per_chip, -1);
1469   int pos = 0;
1470   for (int replica = 0; replica < num_replicas; ++replica) {
1471     for (int logical_core = 0; logical_core < num_cores_per_replica;
1472          ++logical_core) {
1473       int32 x = device_assignment_attr[pos++];
1474       int32 y = device_assignment_attr[pos++];
1475       int32 z = device_assignment_attr[pos++];
1476       int32 core = device_assignment_attr[pos++];
1477 
1478       if (!tpu_topology.HasChip(x, y, z) || core < 0 ||
1479           core >= devices_per_chip) {
1480         return errors::InvalidArgument(
1481             "Mesh coordinates (", x, ",", y, ",", core,
1482             ") are not valid for the current TPU topology");
1483       }
1484       tpu::TpuCoreLocationExternal core_location =
1485           tpu_topology.Core(kTensorCore, x, y, z, core);
1486 
1487       if (replica_assignment(x, y, z, core) != -1) {
1488         return errors::InvalidArgument("Duplicate coordinates (", x, ",", y,
1489                                        ",", z, ",", core,
1490                                        ") in TPU device assignment");
1491       }
1492       replica_assignment(x, y, z, core) = replica;
1493       (*device_assignment)(replica, logical_core) = core_location;
1494     }
1495   }
1496   return Status::OK();
1497 }
1498 
1499 // Builds TensorFlow device assignments for the special case of a single core
1500 // computation that is replicated to every core in the mesh.
1501 // LINT.IfChange
BuildFullMeshDeviceAssignment(int num_replicas,const std::vector<std::vector<Device * >> & tpu_devices,int num_tasks,int num_tpus_per_task,std::vector<std::vector<string>> * tf_device_assignment)1502 static Status BuildFullMeshDeviceAssignment(
1503     int num_replicas, const std::vector<std::vector<Device*>>& tpu_devices,
1504     int num_tasks, int num_tpus_per_task,
1505     std::vector<std::vector<string>>* tf_device_assignment) {
1506   // Assign TensorFlow devices to replicas arbitrarily.
1507   for (int i = 0; i < num_replicas; ++i) {
1508     int task = i / num_tpus_per_task;
1509     int device = i % num_tpus_per_task;
1510     TF_RET_CHECK(task >= 0 && task < num_tasks);
1511     TF_RET_CHECK(device >= 0 && device < num_tpus_per_task);
1512 
1513     // We don't actually know which TF device corresponds to which physical
1514     // device, but it doesn't matter—they're all identical.
1515     (*tf_device_assignment)[i] = {tpu_devices[task][device]->name()};
1516   }
1517   return Status::OK();
1518 }
1519 // LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc)
1520 
1521 // Builds TensorFlow device assignments for a replicated computation and convert
1522 // device_assignment into xla_device_assignment.
BuildGeneralDeviceAssignment(int num_replicas,int num_cores_per_replica,const std::vector<std::vector<Device * >> & tpu_devices,const xla::Array2D<tpu::TpuCoreLocationExternal> & device_assignment,const xla::Array4D<std::pair<int,int>> & topology,std::vector<std::vector<string>> * tf_device_assignment,std::unique_ptr<xla::DeviceAssignment> * xla_device_assignment)1523 static Status BuildGeneralDeviceAssignment(
1524     int num_replicas, int num_cores_per_replica,
1525     const std::vector<std::vector<Device*>>& tpu_devices,
1526     const xla::Array2D<tpu::TpuCoreLocationExternal>& device_assignment,
1527     const xla::Array4D<std::pair<int, int>>& topology,
1528     std::vector<std::vector<string>>* tf_device_assignment,
1529     std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment) {
1530   // Assign TensorFlow devices to each computation's replicas according to
1531   // device_assignment and 'topology'.
1532   *xla_device_assignment = absl::make_unique<xla::DeviceAssignment>(
1533       num_replicas, num_cores_per_replica);
1534   for (int replica = 0; replica < num_replicas; ++replica) {
1535     for (int computation = 0; computation < num_cores_per_replica;
1536          ++computation) {
1537       const tpu::TpuCoreLocationExternal& core_location =
1538           device_assignment(replica, computation);
1539 
1540       int task;
1541       int device;
1542       std::tie(task, device) =
1543           topology(core_location.chip_coordinates().x,
1544                    core_location.chip_coordinates().y,
1545                    core_location.chip_coordinates().z, core_location.index());
1546 
1547       CHECK_LT(computation, num_cores_per_replica);
1548       (**xla_device_assignment)(replica, computation) = core_location.Id();
1549 
1550       // The communication pattern between replicas will be determined later by
1551       // BuildAllReduceRing.
1552       TF_RET_CHECK(task >= 0 && task < tpu_devices.size());
1553       TF_RET_CHECK(device >= 0 && device < tpu_devices[task].size());
1554       (*tf_device_assignment)[replica].push_back(
1555           tpu_devices[task][device]->name());
1556     }
1557   }
1558   return Status::OK();
1559 }
1560 
BuildDeviceAssignment(const tpu::TpuTopologyExternal & tpu_topology,int num_tpus_per_task,const std::vector<std::vector<Device * >> & tpu_devices,int num_replicas,int num_cores_per_replica,const string & topology_attr,absl::Span<const int> device_assignment_attr,std::vector<std::vector<string>> * tf_device_assignment,std::unique_ptr<xla::DeviceAssignment> * xla_device_assignment)1561 /*static*/ Status DistributedTPURewritePass::BuildDeviceAssignment(
1562     const tpu::TpuTopologyExternal& tpu_topology, int num_tpus_per_task,
1563     const std::vector<std::vector<Device*>>& tpu_devices, int num_replicas,
1564     int num_cores_per_replica, const string& topology_attr,
1565     absl::Span<const int> device_assignment_attr,
1566     std::vector<std::vector<string>>* tf_device_assignment,
1567     std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment) {
1568   const int num_tasks = tpu_devices.size();
1569   const int num_tpu_devices = num_tasks * num_tpus_per_task;
1570   VLOG(2) << "num_tasks=" << num_tasks
1571           << " num_tpus_per_task=" << num_tpus_per_task;
1572 
1573   // Checks num_replicas is sane first to avoid integer overflow.
1574   if (num_replicas > num_tpu_devices) {
1575 #ifdef PLATFORM_CLOUD_TPU
1576     return errors::InvalidArgument("Requested num_replicas=", num_replicas,
1577                                    " but there are only ", num_tpu_devices,
1578                                    " cores in the TPU topology.");
1579 #else
1580     return errors::InvalidArgument("Requested num_replicas=", num_replicas,
1581                                    " but there are only ", num_tpu_devices,
1582                                    " cores in the TPU topology.");
1583 #endif
1584   }
1585   if (num_replicas * num_cores_per_replica > num_tpu_devices) {
1586     return errors::InvalidArgument(
1587         "Requested num_replicas=", num_replicas, " with ",
1588         num_cores_per_replica, " cores per replica, but there are only ",
1589         num_tpu_devices, " cores in the TPU topology");
1590   }
1591 
1592   tf_device_assignment->clear();
1593   tf_device_assignment->resize(num_replicas);
1594 
1595   // Special case: we allow the user to omit the topology and device assignment
1596   // information in two cases:
1597   // * there is only one replica and one core per replica. In this case, we
1598   //   don't need to know topology information because we don't communicate with
1599   //   other cores.
1600   // * the number of replicas is equal to the number of cores in the slice. In
1601   //   this case, all cores are running the same program so we don't need to
1602   //   know which is which.
1603   if (topology_attr.empty()) {
1604     // LINT.IfChange
1605     if (num_replicas != 1 && num_replicas != num_tpu_devices) {
1606       return errors::InvalidArgument(
1607           "TPUReplicate asked to create ", num_replicas,
1608           " replicas, but the number of cores in the TPU topology is ",
1609           num_tpu_devices,
1610           " and no TPU device assignment was supplied. "
1611           "A TPU device assignment is required if the number of replicas is "
1612           "not 1 or the number of cores in the topology (",
1613           num_tpu_devices, ")");
1614     }
1615 
1616     if (num_cores_per_replica != 1) {
1617       return errors::InvalidArgument(
1618           "A TPU topology must be provided if num_cores_per_replica != 1");
1619     }
1620 
1621     if (!device_assignment_attr.empty()) {
1622       return errors::InvalidArgument(
1623           "A TPU topology must be provided if device_assignment_attr is "
1624           "non-empty");
1625     }
1626 
1627     // If there is only one replica, assign the Tensorflow computation to task 0
1628     // device 0, and leave the XLA device assignment empty. We don't know which
1629     // core this is in the TPU topology, but it doesn't matter—we don't need to
1630     // communicate with any other cores.
1631     if (num_replicas == 1) {
1632       (*tf_device_assignment)[0] = {tpu_devices[0][0]->name()};
1633       return Status::OK();
1634     }
1635 
1636     // Otherwise, num_replicas is equal to the number of cores, and we build a
1637     // device assignment that covers the entire mesh. We do not need to know
1638     // the topology to do so because all cores are identical.
1639     return BuildFullMeshDeviceAssignment(num_replicas, tpu_devices, num_tasks,
1640                                          num_tpus_per_task,
1641                                          tf_device_assignment);
1642     // LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc)
1643   }
1644 
1645   // Array that maps mesh coordinates to {TF task, TF TPU device #} pairs.
1646   xla::Array4D<std::pair<int, int>> topology;
1647   TF_RETURN_IF_ERROR(ParseTopologyAttr(topology_attr, tpu_topology, num_tasks,
1648                                        num_tpus_per_task, &topology));
1649 
1650   // Array that maps logical (replica, core) pairs to physical mesh coordinates.
1651   xla::Array2D<tpu::TpuCoreLocationExternal> device_assignment;
1652   TF_RETURN_IF_ERROR(ParseDeviceAssignmentAttr(
1653       device_assignment_attr, tpu_topology, num_replicas, num_cores_per_replica,
1654       &device_assignment));
1655 
1656   return BuildGeneralDeviceAssignment(
1657       num_replicas, num_cores_per_replica, tpu_devices, device_assignment,
1658       topology, tf_device_assignment, xla_device_assignment);
1659 }
1660 
GetComputationForTPUReplicateOp(const NameAttrList & function,FunctionLibraryRuntime * flr,Graph * computation,DataTypeVector * arg_types,DataTypeVector * retval_types)1661 Status DistributedTPURewritePass::GetComputationForTPUReplicateOp(
1662     const NameAttrList& function, FunctionLibraryRuntime* flr,
1663     Graph* computation, DataTypeVector* arg_types,
1664     DataTypeVector* retval_types) {
1665   FunctionLibraryRuntime::Handle handle;
1666 
1667   TF_RETURN_IF_ERROR(
1668       flr->Instantiate(function.name(), AttrSlice(&function.attr()), &handle));
1669 
1670   const FunctionBody* fbody = flr->GetFunctionBody(handle);
1671 
1672   CopyGraph(*fbody->graph, computation);
1673   *arg_types = fbody->arg_types;
1674   *retval_types = fbody->ret_types;
1675   return Status::OK();
1676 }
1677 
1678 // Grab the InferredShape corresponding to an edge input.
GetEdgeShape(const GraphShapeInfo & shape_info,const Edge & edge,const InferredShape ** info)1679 static Status GetEdgeShape(const GraphShapeInfo& shape_info, const Edge& edge,
1680                            const InferredShape** info) {
1681   auto it = shape_info.find(edge.src()->name());
1682   if (it == shape_info.end()) {
1683     return errors::InvalidArgument(
1684         "Input to replicated TPU computation is missing InferredShape: ",
1685         edge.src()->name());
1686   }
1687   TF_RET_CHECK(it->second.size() > edge.src_output());
1688   *info = &it->second[edge.src_output()];
1689   return Status::OK();
1690 }
1691 
GetArgAndRetvalShapes(const GraphShapeInfo & shape_info,const Node & node,const ParameterInfo & params_info,std::vector<InferredShape> * arg_shapes,std::vector<InferredShape> * retval_shapes)1692 Status DistributedTPURewritePass::GetArgAndRetvalShapes(
1693     const GraphShapeInfo& shape_info, const Node& node,
1694     const ParameterInfo& params_info, std::vector<InferredShape>* arg_shapes,
1695     std::vector<InferredShape>* retval_shapes) {
1696   std::vector<const Edge*> input_edges;
1697   TF_RETURN_IF_ERROR(node.input_edges(&input_edges));
1698 
1699   // If any replica's arg shape is unknown, we will mark the computation's arg
1700   // shape as being unknown. If the shapes differ the TpuExecute Op will raise a
1701   // runtime error.
1702   std::vector<bool> any_replica_shape_unknown(
1703       params_info.NumInputsToEachReplica());
1704   arg_shapes->clear();
1705   arg_shapes->resize(params_info.NumInputsToEachReplica());
1706   TF_RET_CHECK(input_edges.size() == params_info.NumInputsFromHost());
1707   // Determines the shapes of the per-replica arguments and checks that all
1708   // replicas have identical shapes.
1709   int64 edge_pos = 0;
1710   auto check_shape = [&](int input_index) -> Status {
1711     const InferredShape* info;
1712     TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info));
1713     ++edge_pos;
1714 
1715     if ((info->handle_type == DT_INVALID && !info->shape.IsFullyDefined()) ||
1716         (info->handle_type != DT_INVALID &&
1717          !info->handle_shape.IsFullyDefined())) {
1718       any_replica_shape_unknown[input_index] = true;
1719     }
1720     xla::StatusOr<InferredShape> status =
1721         MergeInferredShapes((*arg_shapes)[input_index], *info);
1722     if (!status.ok()) {
1723       return errors::InvalidArgument(
1724           "Mismatched shapes for input ", input_index, ": ",
1725           (*arg_shapes)[input_index].shape.DebugString(), " vs. ",
1726           info->shape.DebugString());
1727     }
1728     (*arg_shapes)[input_index] = status.ValueOrDie();
1729     return Status::OK();
1730   };
1731 
1732   for (int64 i = 0; i < params_info.NumReplicas(); ++i) {
1733     for (int64 j = 0; j < params_info.NumPerReplicaArgs(); ++j) {
1734       TF_RETURN_IF_ERROR(check_shape(j));
1735     }
1736   }
1737 
1738   for (int64 i = 0; i < params_info.NumDistributedArgs(); ++i) {
1739     TF_RETURN_IF_ERROR(check_shape(params_info.NumPerReplicaArgs() + i));
1740   }
1741 
1742   for (int64 i = 0;
1743        i < params_info.NumPerReplicaArgs() + params_info.NumDistributedArgs();
1744        ++i) {
1745     if (any_replica_shape_unknown[i]) {
1746       (*arg_shapes)[i].shape = PartialTensorShape();
1747       (*arg_shapes)[i].handle_shape = PartialTensorShape();
1748     }
1749   }
1750 
1751   // Determines the shape of the broadcast arguments.
1752   for (int64 i = 0; i < params_info.NumBroadcastArgs(); ++i) {
1753     TF_RET_CHECK(node.input_type(edge_pos) != DT_RESOURCE);
1754     const InferredShape* info;
1755     TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info));
1756     (*arg_shapes)[i + params_info.NumPerReplicaArgs() +
1757                   params_info.NumDistributedArgs()]
1758         .shape = info->shape;
1759     ++edge_pos;
1760   }
1761 
1762   // Determines the handle shape and handle type of the resource variable
1763   // arguments.
1764   for (int64 i = 0; i < params_info.NumVariables(); ++i) {
1765     TF_RET_CHECK(node.input_type(edge_pos) == DT_RESOURCE);
1766     const InferredShape* info;
1767     TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info));
1768     InferredShape& arg_shape =
1769         (*arg_shapes)[i + params_info.NumPerReplicaArgs() +
1770                       params_info.NumDistributedArgs() +
1771                       params_info.NumBroadcastArgs()];
1772     arg_shape.shape = TensorShape();  // Variables are always scalars.
1773     arg_shape.handle_shape = info->handle_shape;
1774     arg_shape.handle_type = info->handle_type;
1775     TF_RET_CHECK(arg_shape.handle_type != DT_INVALID)
1776         << " input edge: " << input_edges[edge_pos]->DebugString();
1777     ++edge_pos;
1778   }
1779 
1780   // Determines the shape of the guaranteed constants.
1781   // TODO(vinuraja): Can be removed because they are not required for any
1782   // calculations. Leaving them here for symmetry with other structures like
1783   // arg_types, arg_sharding, etc.
1784   for (int64 i = 0; i < params_info.NumGuaranteedConstants(); ++i) {
1785     TF_RET_CHECK(node.input_type(edge_pos) != DT_RESOURCE);
1786     const InferredShape* info;
1787     TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info));
1788     (*arg_shapes)[i + params_info.NumPerReplicaArgs() +
1789                   params_info.NumDistributedArgs() +
1790                   params_info.NumBroadcastArgs() + params_info.NumVariables()]
1791         .shape = info->shape;
1792     ++edge_pos;
1793   }
1794 
1795   // Extract the return value shapes.
1796   auto it = shape_info.find(node.name());
1797   retval_shapes->clear();
1798   if (it != shape_info.end()) {
1799     TF_RET_CHECK(it->second.size() >= node.num_outputs());
1800     retval_shapes->resize(node.num_outputs());
1801     for (int i = 0; i < node.num_outputs(); ++i) {
1802       (*retval_shapes)[i].shape = it->second[i].shape;
1803     }
1804   } else if (node.num_outputs() > 0) {
1805     return errors::InvalidArgument(
1806         "Replicated TPU computation is missing InferredShape: ",
1807         FormatNodeForError(node));
1808   }
1809   return Status::OK();
1810 }
1811 
1812 // Verifies that all nodes have legal sharding.
ValidateCoreNumbers(const Graph & graph,int num_cores_per_replica)1813 static Status ValidateCoreNumbers(const Graph& graph,
1814                                   int num_cores_per_replica) {
1815   for (Node* n : graph.nodes()) {
1816     TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
1817                         ParseShardingFromDevice(*n, num_cores_per_replica,
1818                                                 /*add_metadata=*/true));
1819   }
1820   return Status::OK();
1821 }
1822 
InferXlaShardingFromNeighbors(const Node & n,int num_cores_per_replica,FunctionLibraryRuntime * flr,CachedFunctionHandles * cached_function_handles,absl::optional<NodeAndSharding> * output_node_and_sharding,bool * is_fast_mem)1823 static Status InferXlaShardingFromNeighbors(
1824     const Node& n, int num_cores_per_replica, FunctionLibraryRuntime* flr,
1825     CachedFunctionHandles* cached_function_handles,
1826     absl::optional<NodeAndSharding>* output_node_and_sharding,
1827     bool* is_fast_mem) {
1828   int64 core = -1;
1829   absl::optional<NodeAndSharding> result;
1830   // We assume the variable has been allocated on fast memory if any consuming
1831   // op has TPU_FAST_MEM_ATTR attribute. This is a protocol between runtime and
1832   // compiler.
1833   *is_fast_mem = false;
1834   for (const Edge* edge : n.out_edges()) {
1835     if (edge->IsControlEdge()) continue;
1836 
1837     TF_RETURN_IF_ERROR(ParseAndValidateShardingFromNeighbors(
1838         num_cores_per_replica, n.name(), *edge->dst(), &core, is_fast_mem,
1839         &result));
1840 
1841     if (!flr) continue;
1842 
1843     // The nodes deciding this arg's device assignment might be in
1844     // FunctionDef. Instantiate FunctionDefs associated with this node
1845     // and check nodes using this arg.
1846     std::function<Status(const Edge* call_edge)> parse_sharding_from_function =
1847         [&](const Edge* call_edge) {
1848           auto associated_functions = GetAssociatedFunctions(
1849               *call_edge->dst(), flr->GetFunctionLibraryDefinition());
1850           for (auto& associated_function : associated_functions) {
1851             FunctionLibraryRuntime::Handle handle;
1852             TF_RETURN_IF_ERROR(cached_function_handles->GetOrInstantiate(
1853                 associated_function.func_name(),
1854                 AttrSlice(&associated_function.attrs()), &handle));
1855             const FunctionBody* body = flr->GetFunctionBody(handle);
1856             Graph* g = body->graph;
1857 
1858             for (Node* body_node : g->nodes()) {
1859               if (!body_node->IsArg()) continue;
1860 
1861               int index;
1862               TF_RETURN_IF_ERROR(
1863                   GetNodeAttr(body_node->attrs(), "index", &index));
1864               if (index != call_edge->dst_input()) continue;
1865 
1866               for (const Edge* out_edge : body_node->out_edges()) {
1867                 if (out_edge->IsControlEdge()) continue;
1868 
1869                 TF_RETURN_IF_ERROR(ParseAndValidateShardingFromNeighbors(
1870                     num_cores_per_replica, n.name(), *out_edge->dst(), &core,
1871                     is_fast_mem, &result));
1872 
1873                 TF_RETURN_IF_ERROR(parse_sharding_from_function(out_edge));
1874               }
1875             }
1876           }
1877           return Status::OK();
1878         };
1879     TF_RETURN_IF_ERROR(parse_sharding_from_function(edge));
1880   }
1881   *output_node_and_sharding = result;
1882   return Status::OK();
1883 }
1884 
UseSpmdForXlaPartitioning(const Node * replicate_node)1885 bool UseSpmdForXlaPartitioning(const Node* replicate_node) {
1886   bool spmd_attr;
1887   if (!replicate_node ||
1888       !TryGetNodeAttr(replicate_node->attrs(), "use_spmd_for_xla_partitioning",
1889                       &spmd_attr)) {
1890     spmd_attr = false;
1891   }
1892   return spmd_attr;
1893 }
1894 
FormatNodeAndShardingMsg(const absl::optional<NodeAndSharding> & node_and_sharding)1895 std::string FormatNodeAndShardingMsg(
1896     const absl::optional<NodeAndSharding>& node_and_sharding) {
1897   DCHECK(node_and_sharding.has_value());
1898 
1899   xla::OpSharding sharding_no_metadata = node_and_sharding->sharding;
1900   sharding_no_metadata.clear_metadata();
1901   std::string escaped_sharding_str =
1902       absl::CEscape(sharding_no_metadata.SerializeAsString());
1903   if (node_and_sharding->node == nullptr) {
1904     return absl::StrCat(" via default sharding '", escaped_sharding_str, "'");
1905   }
1906 
1907   return absl::StrCat(" via node ", node_and_sharding->node->DebugString(),
1908                       " sharding '", escaped_sharding_str, "'");
1909 }
1910 
AssignArgsAndRetvalsToCores(int num_cores_per_replica,const ParameterInfo & params_info,const DataTypeVector & arg_types,const std::vector<InferredShape> & arg_shapes,const DataTypeVector & retval_types,const std::vector<InferredShape> & retval_shapes,const Graph & graph,const Node * replicate_node,FunctionLibraryRuntime * flr,bool allow_parameter_replication_for_spmd,std::vector<xla::OpSharding> * arg_sharding,std::vector<bool> * arg_fast_mem,std::vector<xla::OpSharding> * retval_sharding,std::vector<std::string> * arg_names)1911 Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores(
1912     int num_cores_per_replica, const ParameterInfo& params_info,
1913     const DataTypeVector& arg_types,
1914     const std::vector<InferredShape>& arg_shapes,
1915     const DataTypeVector& retval_types,
1916     const std::vector<InferredShape>& retval_shapes, const Graph& graph,
1917     const Node* replicate_node, FunctionLibraryRuntime* flr,
1918     bool allow_parameter_replication_for_spmd,
1919     std::vector<xla::OpSharding>* arg_sharding, std::vector<bool>* arg_fast_mem,
1920     std::vector<xla::OpSharding>* retval_sharding,
1921     std::vector<std::string>* arg_names) {
1922   // Builds vectors of the argument and return nodes.
1923   std::vector<Node*> args(arg_types.size());
1924   std::vector<Node*> retvals(retval_types.size());
1925   absl::flat_hash_map<int, Node*> partitioned_output_nodes;
1926   for (Node* node : graph.op_nodes()) {
1927     if (node->IsArg()) {
1928       int index;
1929       TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index));
1930       TF_RET_CHECK(index >= 0 && index < args.size());
1931       args[index] = node;
1932     } else if (node->IsRetval()) {
1933       int index;
1934       TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index));
1935       TF_RET_CHECK(index >= 0 && index < retvals.size());
1936       retvals[index] = node;
1937     }
1938   }
1939   for (const Edge* edge : replicate_node->out_edges()) {
1940     int num_partitioned_outputs = 0;
1941     for (const Edge* out_edge : edge->dst()->out_edges()) {
1942       if (out_edge->dst()->type_string() == kTPUPartitionedOutput) {
1943         partitioned_output_nodes[edge->src_output()] = out_edge->dst();
1944         num_partitioned_outputs++;
1945       }
1946     }
1947     if (num_partitioned_outputs > 1) {
1948       return errors::InvalidArgument(
1949           "More than one TPUPartitionedOutput per replciated output.");
1950     }
1951   }
1952 
1953   // Verifies there are no missing arguments/return values.
1954   for (int i = 0; i < args.size(); ++i) {
1955     if (args[i] == nullptr) {
1956       return errors::Internal("Missing function argument: ", i);
1957     }
1958   }
1959   for (int i = 0; i < retvals.size(); ++i) {
1960     if (retvals[i] == nullptr) {
1961       return errors::Internal("Missing function return value: ", i);
1962     }
1963   }
1964 
1965   // Assigns a core to each _Arg. Chooses the lowest-numbered core that
1966   // consumes the argument. We choose the lowest-numbered core so the
1967   // assignment is deterministic.
1968   TensorDevicePlacer args_device_selector(num_cores_per_replica, arg_types,
1969                                           arg_shapes);
1970   arg_sharding->resize(args.size());
1971   arg_names->resize(args.size());
1972   arg_fast_mem->resize(args.size());
1973   CachedFunctionHandles cached_function_handles(flr);
1974   const bool use_spmd = (UseSpmdForXlaPartitioning(replicate_node) ||
1975                          replicate_inputs_outputs_by_default_for_xla_spmd_) &&
1976                         allow_parameter_replication_for_spmd;
1977 
1978   // Offset _TPUReplicate non per replica argument indices by
1979   // (num_replicas - 1) * num_per_replica_args as _TPUReplicate nodes are
1980   // constructed with all per replica args across all replicas while the
1981   // encapsulated function only has 1 replica's per replica args. Per replica
1982   // args are ordered by replica first, so the index here does not require an
1983   // offset and the first replica's input nodes is sufficient for determining
1984   // argument sharding.
1985   const int index_offset =
1986       (params_info.NumReplicas() - 1) * params_info.NumPerReplicaArgs();
1987   for (int i = 0; i < args.size(); ++i) {
1988     const Node* n = args[i];
1989     absl::optional<int64> assigned_core;
1990     absl::optional<NodeAndSharding> node_and_sharding;
1991     bool is_fast_mem;
1992     TF_RETURN_IF_ERROR(InferXlaShardingFromNeighbors(
1993         *n, num_cores_per_replica, flr, &cached_function_handles,
1994         &node_and_sharding, &is_fast_mem));
1995 
1996     const bool is_per_replica_arg = params_info.IsPerReplicaArg(i);
1997     if (is_per_replica_arg || params_info.IsDistributedArg(i)) {
1998       Node* input_node;
1999       TF_RETURN_IF_ERROR(replicate_node->input_node(
2000           i + (is_per_replica_arg ? 0 : index_offset), &input_node));
2001       if (input_node->type_string() == kTPUPartitionedInput) {
2002         TF_ASSIGN_OR_RETURN(
2003             absl::optional<xla::OpSharding> parsed_sharding,
2004             GetShardingFromNodeDef(input_node->def(), /*add_metadata=*/true));
2005         if (!parsed_sharding.has_value())
2006           return errors::InvalidArgument("Missing _XlaSharding attr from: ",
2007                                          input_node->DebugString());
2008         node_and_sharding = NodeAndSharding(input_node, *parsed_sharding);
2009         VLOG(1) << "Arg " << i << " parsed sharding information from "
2010                 << input_node->DebugString() << " : "
2011                 << parsed_sharding->DebugString();
2012       }
2013     }
2014 
2015     if (params_info.IsVariableArg(i)) {
2016       Node* input_node;
2017       TF_RETURN_IF_ERROR(
2018           replicate_node->input_node(i + index_offset, &input_node));
2019       if (input_node->type_string() == kVarHandleOp) {
2020         TF_ASSIGN_OR_RETURN(
2021             absl::optional<xla::OpSharding> parsed_sharding,
2022             GetShardingFromNodeDef(input_node->def(), /*add_metadata=*/true));
2023         if (parsed_sharding.has_value()) {
2024           node_and_sharding = NodeAndSharding(input_node, *parsed_sharding);
2025           VLOG(1) << "Arg " << i << " parsed sharding information from "
2026                   << input_node->DebugString() << " : "
2027                   << parsed_sharding->DebugString();
2028         }
2029       }
2030     }
2031 
2032     if (node_and_sharding.has_value() && enable_automatic_model_parallelism_) {
2033       return tensorflow::errors::InvalidArgument(
2034           "Specifying manual sharding is not allowed when automatic "
2035           "model parallelism is enabled.",
2036           node_and_sharding->sharding.DebugString());
2037     }
2038 
2039     if (!node_and_sharding.has_value()) {
2040       if (use_spmd &&
2041           (params_info.IsVariableArg(i) || params_info.IsBroadcastArg(i) ||
2042            ((params_info.IsPerReplicaArg(i) ||
2043              params_info.IsDistributedArg(i)) &&
2044             arg_types[i] != DT_RESOURCE))) {
2045         // Use replication for host variables or non-variable per-replica
2046         // inputs.
2047         node_and_sharding = NodeAndSharding(/*node=*/nullptr,
2048                                             xla::sharding_builder::Replicate());
2049       } else {
2050         // TODO(dlibenzi): Distributing variables to cores other than 0 makes
2051         // learning/brain/research/babelfish/trainer:trainer_tpu_test fail.
2052         // For now distribute only per replica arguments, unless
2053         // tf_jf_distribute_vars is set, to allow debugging the issue.
2054         if (((params_info.IsPerReplicaArg(i) ||
2055               params_info.IsDistributedArg(i)) &&
2056              arg_types[i] != DT_RESOURCE) ||
2057             (distribute_vars_ && params_info.IsVariableArg(i))) {
2058           assigned_core = args_device_selector.RetrieveAssignment(i);
2059         } else {
2060           assigned_core = 0;
2061         }
2062         node_and_sharding = NodeAndSharding(
2063             /*node=*/nullptr,
2064             xla::sharding_builder::AssignDevice(*assigned_core));
2065       }
2066       *node_and_sharding->sharding.add_metadata() =
2067           CreateOpMetadataFromNode(*replicate_node);
2068     } else if (node_and_sharding->sharding.type() == xla::OpSharding::MAXIMAL) {
2069       assigned_core = node_and_sharding->sharding.tile_assignment_devices(0);
2070     } else if (node_and_sharding->sharding.type() !=
2071                    xla::OpSharding::REPLICATED &&
2072                node_and_sharding->sharding.type() != xla::OpSharding::OTHER) {
2073       return tensorflow::errors::InvalidArgument(
2074           "Unsupported argument sharding (for arg ", n->DebugString(),
2075           "): ", node_and_sharding->sharding.DebugString());
2076     }
2077     if (assigned_core.has_value()) {
2078       args_device_selector.ReportDeviceAssigned(*assigned_core, i);
2079       VLOG(3) << "Assigning argument " << i << " (" << n->DebugString()
2080               << ") to core " << *assigned_core
2081               << FormatNodeAndShardingMsg(node_and_sharding);
2082       args[i]->set_assigned_device_name(CoreDeviceLabel(*assigned_core));
2083     } else if (node_and_sharding->sharding.type() == xla::OpSharding::OTHER) {
2084       for (int64 core : node_and_sharding->sharding.tile_assignment_devices()) {
2085         args_device_selector.ReportDeviceAssigned(core, i);
2086       }
2087       VLOG(3) << "Assigning argument " << i << " (" << n->DebugString()
2088               << ") with tiled sharding to cores "
2089               << absl::StrJoin(
2090                      node_and_sharding->sharding.tile_assignment_devices(), ",")
2091               << " " << FormatNodeAndShardingMsg(node_and_sharding);
2092     } else {
2093       DCHECK_EQ(node_and_sharding->sharding.type(),
2094                 xla::OpSharding::REPLICATED);
2095       for (int64 core = 0; core < num_cores_per_replica; ++core) {
2096         args_device_selector.ReportDeviceAssigned(core, i);
2097       }
2098       VLOG(3) << "Assigning argument " << i << " (" << n->DebugString()
2099               << ") to all cores"
2100               << FormatNodeAndShardingMsg(node_and_sharding);
2101     }
2102     (*arg_sharding)[i] = node_and_sharding->sharding;
2103     (*arg_fast_mem)[i] = is_fast_mem;
2104     (*arg_names)[i] = n->name();
2105     if (is_fast_mem) {
2106       VLOG(3) << "Add " << TPU_FAST_MEM_ATTR << " attribute to "
2107               << args[i]->name();
2108     }
2109     args[i]->AddAttr(kShardingAttribute,
2110                      node_and_sharding->sharding.SerializeAsString());
2111   }
2112   TF_RETURN_IF_ERROR(cached_function_handles.ReleaseAllHandles());
2113 
2114   // Assigns each _Retval node to the core that produces its value.
2115   TensorDevicePlacer retvals_device_selector(num_cores_per_replica,
2116                                              retval_types, retval_shapes);
2117   retval_sharding->resize(retvals.size());
2118   for (int i = 0; i < retvals.size(); ++i) {
2119     const Edge* edge;
2120     TF_RETURN_IF_ERROR(retvals[i]->input_edge(0, &edge));
2121 
2122     TF_ASSIGN_OR_RETURN(
2123         absl::optional<xla::OpSharding> edge_sharding,
2124         ParseShardingFromEdgeSource(*edge, num_cores_per_replica,
2125                                     /*add_metadata=*/true));
2126 
2127     absl::optional<NodeAndSharding> node_and_sharding;
2128     if (edge_sharding.has_value()) {
2129       node_and_sharding.emplace(NodeAndSharding(edge->src(), *edge_sharding));
2130     }
2131 
2132     if (partitioned_output_nodes.contains(i)) {
2133       Node* output_node = partitioned_output_nodes[i];
2134       TF_ASSIGN_OR_RETURN(
2135           absl::optional<xla::OpSharding> parsed_sharding,
2136           GetShardingFromNodeDef(output_node->def(), /*add_metadata=*/true));
2137       if (parsed_sharding.has_value()) {
2138         node_and_sharding = NodeAndSharding(output_node, *parsed_sharding);
2139         VLOG(1) << "Retval " << i << " parsed sharding information from "
2140                 << output_node->DebugString() << " : "
2141                 << parsed_sharding->DebugString();
2142       }
2143     }
2144     absl::optional<int64> assigned_core;
2145     if (node_and_sharding.has_value()) {
2146       if (enable_automatic_model_parallelism_) {
2147         return tensorflow::errors::InvalidArgument(
2148             "Specifying manual sharding is not allowed when automatic "
2149             "model parallelism is enabled.",
2150             node_and_sharding->sharding.DebugString());
2151       }
2152 
2153       if (node_and_sharding->sharding.type() == xla::OpSharding::MAXIMAL) {
2154         assigned_core = node_and_sharding->sharding.tile_assignment_devices(0);
2155         TF_RETURN_IF_ERROR(
2156             ValidateCoreNumber(*assigned_core, num_cores_per_replica));
2157       } else if (node_and_sharding->sharding.type() !=
2158                      xla::OpSharding::REPLICATED &&
2159                  node_and_sharding->sharding.type() != xla::OpSharding::OTHER) {
2160         return tensorflow::errors::InvalidArgument(
2161             "Unsupported argument sharding for retval ",
2162             retvals[i]->DebugString(), " edge=", edge->DebugString(), ": ",
2163             node_and_sharding->sharding.DebugString());
2164       }
2165     } else {
2166       if (use_spmd) {
2167         node_and_sharding = NodeAndSharding(/*node=*/nullptr,
2168                                             xla::sharding_builder::Replicate());
2169       } else {
2170         if (distribute_vars_) {
2171           assigned_core = retvals_device_selector.RetrieveAssignment(i);
2172         } else {
2173           assigned_core = 0;
2174         }
2175         node_and_sharding = NodeAndSharding(
2176             /*node=*/nullptr,
2177             xla::sharding_builder::AssignDevice(*assigned_core));
2178       }
2179       *node_and_sharding->sharding.add_metadata() =
2180           CreateOpMetadataFromNode(*replicate_node);
2181     }
2182     if (assigned_core.has_value()) {
2183       retvals[i]->set_assigned_device_name(CoreDeviceLabel(*assigned_core));
2184       retvals_device_selector.ReportDeviceAssigned(*assigned_core, i);
2185       VLOG(3) << "Assigning return value " << i << " ("
2186               << retvals[i]->DebugString() << ") to core " << *assigned_core
2187               << FormatNodeAndShardingMsg(node_and_sharding);
2188     } else if (node_and_sharding->sharding.type() == xla::OpSharding::OTHER) {
2189       for (int64 core : node_and_sharding->sharding.tile_assignment_devices()) {
2190         retvals_device_selector.ReportDeviceAssigned(core, i);
2191       }
2192       VLOG(3) << "Assigning return value " << i << " ("
2193               << retvals[i]->DebugString() << ") with tiled sharding to cores "
2194               << absl::StrJoin(
2195                      node_and_sharding->sharding.tile_assignment_devices(), ",")
2196               << " " << FormatNodeAndShardingMsg(node_and_sharding);
2197     } else {
2198       DCHECK_EQ(node_and_sharding->sharding.type(),
2199                 xla::OpSharding::REPLICATED);
2200       for (int64 core = 0; core < num_cores_per_replica; ++core) {
2201         retvals_device_selector.ReportDeviceAssigned(core, i);
2202       }
2203       VLOG(3) << "Assigning return value " << i << " ("
2204               << retvals[i]->DebugString() << ") to all cores"
2205               << FormatNodeAndShardingMsg(node_and_sharding);
2206     }
2207     retvals[i]->AddAttr(kShardingAttribute,
2208                         node_and_sharding->sharding.SerializeAsString());
2209     (*retval_sharding)[i] = node_and_sharding->sharding;
2210   }
2211   if (use_spmd &&
2212       (absl::c_any_of(*arg_sharding,
2213                       [](const xla::OpSharding& s) {
2214                         return s.type() == xla::OpSharding::MAXIMAL;
2215                       }) ||
2216        absl::c_any_of(*retval_sharding, [](const xla::OpSharding& s) {
2217          return s.type() == xla::OpSharding::MAXIMAL;
2218        }))) {
2219     LOG(WARNING) << "XLA SPMD only supports cases where all inputs/outputs "
2220                     "exist on every partition (sharded or replicated). Fall "
2221                     "back to MPMD.";
2222     return AssignArgsAndRetvalsToCores(
2223         num_cores_per_replica, params_info, arg_types, arg_shapes, retval_types,
2224         retval_shapes, graph, replicate_node, flr,
2225         /*allow_parameter_replication_for_spmd=*/false, arg_sharding,
2226         arg_fast_mem, retval_sharding, arg_names);
2227   }
2228   return Status::OK();
2229 }
2230 
2231 // Builds Shape nodes that compute the shapes of arguments whose shapes are not
2232 // statically known.
BuildDynamicShapeNodes(const Node & replicate_node,const std::vector<InferredShape> & arg_shapes,const ParameterInfo & params_info,const std::vector<Node * > & variable_reads,Graph * graph,std::vector<Node * > * dynamic_shape_nodes)2233 /* static */ Status DistributedTPURewritePass::BuildDynamicShapeNodes(
2234     const Node& replicate_node, const std::vector<InferredShape>& arg_shapes,
2235     const ParameterInfo& params_info, const std::vector<Node*>& variable_reads,
2236     Graph* graph, std::vector<Node*>* dynamic_shape_nodes) {
2237   dynamic_shape_nodes->clear();
2238 
2239   std::vector<const Edge*> replicate_input_edges;
2240   TF_RETURN_IF_ERROR(replicate_node.input_edges(&replicate_input_edges));
2241 
2242   // The compiler determines the shape of each constant by inspecting the value
2243   // of its corresponding host-memory tensor; this happens when a step is run.
2244   // As a result, the shapes of constants are not needed at graph rewrite time.
2245   const int num_args = arg_shapes.size() - params_info.NumGuaranteedConstants();
2246   TF_RET_CHECK(num_args == params_info.NumPerReplicaArgs() +
2247                                params_info.NumDistributedArgs() +
2248                                params_info.NumBroadcastArgs() +
2249                                params_info.NumVariables());
2250 
2251   for (int i = 0; i < num_args; ++i) {
2252     const PartialTensorShape* shape = arg_shapes[i].handle_type == DT_INVALID
2253                                           ? &arg_shapes[i].shape
2254                                           : &arg_shapes[i].handle_shape;
2255     if (!shape->IsFullyDefined()) {
2256       Node* src;
2257       int src_output;
2258       if (params_info.IsPerReplicaArg(i)) {
2259         TF_RET_CHECK(i < replicate_input_edges.size());
2260         // All replicas must have the same input shapes. Uses the shape of the
2261         // inputs from the first replica.
2262         src = replicate_input_edges[i]->src();
2263         src_output = replicate_input_edges[i]->src_output();
2264       } else if (params_info.IsDistributedArg(i) ||
2265                  params_info.IsBroadcastArg(i)) {
2266         int64 input_num =
2267             params_info.NumPerReplicaArgs() * params_info.NumReplicas() + i -
2268             params_info.NumPerReplicaArgs();
2269         TF_RET_CHECK(0 <= input_num &&
2270                      input_num < replicate_input_edges.size());
2271         src = replicate_input_edges[input_num]->src();
2272         src_output = replicate_input_edges[input_num]->src_output();
2273       } else {
2274         int64 var_num = i - params_info.NumPerReplicaArgs() -
2275                         params_info.NumDistributedArgs() -
2276                         params_info.NumBroadcastArgs();
2277         TF_RET_CHECK(0 <= var_num && var_num < variable_reads.size());
2278         src = variable_reads[var_num];
2279         src_output = 0;
2280       }
2281 
2282       NodeDef def;
2283       def.set_name(graph->NewName(strings::StrCat(src->name(), "/shape")));
2284       def.set_op("Shape");
2285       def.set_device(src->assigned_device_name());
2286       AddNodeAttr("T", src->output_type(src_output), &def);
2287       AddNodeAttr("out_type", DT_INT64, &def);
2288       MergeDebugInfo(NodeDebugInfo(replicate_node.def()), &def);
2289 
2290       Status status;
2291       Node* shape_node = graph->AddNode(def, &status);
2292       if (!status.ok()) return status;
2293       dynamic_shape_nodes->push_back(shape_node);
2294 
2295       shape_node->set_assigned_device_name(src->assigned_device_name());
2296       graph->AddEdge(src, src_output, shape_node, 0);
2297     }
2298   }
2299   return Status::OK();
2300 }
2301 
2302 // Builds a TPUCompile node that compiles the bodies of the function call
2303 // `nodes`.
BuildCompileNode(const Node * replicate_node,const NameAttrList & function,uint64 library_fingerprint,const ParameterInfo & params_info,const std::vector<InferredShape> & arg_shapes,const DataTypeVector & arg_types,const std::vector<Node * > & guaranteed_constant_nodes,const string & session_handle,const std::vector<xla::OpSharding> & arg_sharding,const std::vector<bool> & arg_fast_mem,const std::vector<std::string> & arg_names,const std::vector<xla::OpSharding> & retval_sharding,int num_cores_per_replica,const string & compile_device,const xla::DeviceAssignment * xla_device_assignment,const std::vector<Node * > & dynamic_shape_nodes,Graph * graph,Node ** compile_node,int64 autotuner_thresh)2304 Status DistributedTPURewritePass::BuildCompileNode(
2305     const Node* replicate_node, const NameAttrList& function,
2306     uint64 library_fingerprint, const ParameterInfo& params_info,
2307     const std::vector<InferredShape>& arg_shapes,
2308     const DataTypeVector& arg_types,
2309     const std::vector<Node*>& guaranteed_constant_nodes,
2310     const string& session_handle,
2311     const std::vector<xla::OpSharding>& arg_sharding,
2312     const std::vector<bool>& arg_fast_mem,
2313     const std::vector<std::string>& arg_names,
2314     const std::vector<xla::OpSharding>& retval_sharding,
2315     int num_cores_per_replica, const string& compile_device,
2316     const xla::DeviceAssignment* xla_device_assignment,
2317     const std::vector<Node*>& dynamic_shape_nodes, Graph* graph,
2318     Node** compile_node, int64 autotuner_thresh) {
2319   VLOG(1) << "BuildCompileNode";
2320 
2321   tpu::TPUCompileMetadataProto proto;
2322   proto.set_num_replicas(params_info.NumReplicas());
2323   proto.set_num_cores_per_replica(num_cores_per_replica);
2324   proto.set_function_library_fingerprint(library_fingerprint);
2325   proto.set_enable_automatic_model_parallelism(
2326       enable_cross_replica_sharding_mirrored_variables_);
2327   const bool use_spmd =
2328       UseSpmdForXlaPartitioning(replicate_node) && allow_xla_spmd_partition_ &&
2329       !absl::c_any_of(arg_sharding,
2330                       [](const xla::OpSharding& s) {
2331                         return s.type() == xla::OpSharding::MAXIMAL;
2332                       }) &&
2333       !absl::c_any_of(retval_sharding, [](const xla::OpSharding& s) {
2334         return s.type() == xla::OpSharding::MAXIMAL;
2335       });
2336   proto.set_use_spmd_for_xla_partitioning(use_spmd);
2337   proto.set_broadcast_replicated_parameters_via_collectives(
2338       enable_xla_param_broadcast_);
2339 
2340   // Get and fill padding map.
2341   if (replicate_node != nullptr) {
2342     TF_RETURN_IF_ERROR(
2343         FillPaddingMap(*replicate_node, proto.mutable_padding_maps()));
2344     xla::DebugOptions::StepMarkerLocation location;
2345     TF_RETURN_IF_ERROR(GetStepMarkerLocation(*replicate_node, &location));
2346     proto.set_step_marker_location(location);
2347   }
2348 
2349   if (xla_device_assignment != nullptr) {
2350     TF_RETURN_IF_ERROR(
2351         xla_device_assignment->Serialize(proto.mutable_device_assignment()));
2352   }
2353 
2354   const int num_args = arg_types.size();
2355   const int num_guaranteed_constants = guaranteed_constant_nodes.size();
2356   const int guaranteed_const_start_index = num_args - num_guaranteed_constants;
2357   TF_RET_CHECK(num_args == arg_shapes.size());
2358   TF_RET_CHECK(num_args == arg_sharding.size())
2359       << num_args << " != " << arg_sharding.size();
2360 
2361   for (int i = 0; i < num_args; ++i) {
2362     tpu::TPUCompileMetadataProto::Arg* arg = proto.add_args();
2363     DataType type = arg_types[i];
2364     const InferredShape& arg_shape = arg_shapes[i];
2365     arg->set_name(arg_names[i]);
2366     if (type == DT_RESOURCE) {
2367       TF_RET_CHECK(arg_shape.handle_type != DT_INVALID) << i;
2368       arg->set_dtype(arg_shape.handle_type);
2369       arg_shape.handle_shape.AsProto(arg->mutable_shape());
2370       arg->set_kind(tpu::TPUCompileMetadataProto::Arg::VARIABLE);
2371       arg->set_fast_mem(arg_fast_mem[i]);
2372     } else {
2373       arg->set_dtype(type);
2374       arg_shape.shape.AsProto(arg->mutable_shape());
2375       if (i >= guaranteed_const_start_index) {
2376         const DataType edge_type =
2377             guaranteed_constant_nodes[i - guaranteed_const_start_index]
2378                 ->output_type(0);
2379         TF_RET_CHECK(type == edge_type)
2380             << "Arg type: " << type << " but edge type: " << edge_type;
2381         arg->set_kind(tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT);
2382       } else {
2383         arg->set_kind(tpu::TPUCompileMetadataProto::Arg::PARAMETER);
2384       }
2385     }
2386     // As long as the argument is not a per-replica one, it should have the same
2387     // value for all replicas. For clarity, we keep the (redundant) checks for
2388     // variable, broadcast and constant types, to prevent bugs in case new types
2389     // with different semantics are introduced in the future.
2390     arg->set_is_same_data_across_replicas(
2391         !params_info.IsPerReplicaArg(i) && !params_info.IsDistributedArg(i) &&
2392         (params_info.IsVariableArg(i) || params_info.IsBroadcastArg(i) ||
2393          params_info.IsConstantArg(i)));
2394     if (params_info.mirrored_variable_indices().count(i) > 0) {
2395       CHECK_EQ(type, DT_RESOURCE);
2396       arg->set_is_same_data_across_replicas(true);
2397       // 64-bit type is not shardable by XLA:TPU yet.
2398       bool sharding_enabled = (arg_shape.handle_type != DT_COMPLEX64 &&
2399                                arg_shape.handle_type != DT_INT64 &&
2400                                arg_shape.handle_type != DT_UINT64 &&
2401                                arg_shape.handle_type != DT_DOUBLE);
2402       arg->set_enable_xla_sharding(
2403           sharding_enabled ? tpu::TPUCompileMetadataProto::Arg::TENTATIVE
2404                            : tpu::TPUCompileMetadataProto::Arg::DISALLOWED);
2405     }
2406     *arg->mutable_sharding() = arg_sharding[i];
2407   }
2408 
2409   const int num_retvals = retval_sharding.size();
2410   for (int i = 0; i < num_retvals; ++i) {
2411     *proto.add_retvals()->mutable_sharding() = retval_sharding[i];
2412   }
2413   proto.set_session_handle(session_handle);
2414 
2415   DataTypeVector constant_arg_types;
2416   constant_arg_types.reserve(num_guaranteed_constants);
2417   for (int i = 0; i < num_guaranteed_constants; ++i) {
2418     constant_arg_types.push_back(arg_types[guaranteed_const_start_index + i]);
2419   }
2420   proto.set_xla_fusion_autotuner_thresh(autotuner_thresh);
2421 
2422   string metadata;
2423   proto.SerializeToString(&metadata);
2424 
2425   NodeDef def;
2426   def.set_name(UniqueNodeName("TPUReplicate/_compile", graph));
2427   def.set_op("TPUCompile");
2428   def.set_device(compile_device);
2429   if (replicate_node) {
2430     MergeDebugInfo(NodeDebugInfo(replicate_node->def()), &def);
2431   }
2432 
2433   AddNodeAttr("function", function, &def);
2434   AddNodeAttr("num_computations", num_cores_per_replica, &def);
2435   AddNodeAttr("NumDynamicShapes", static_cast<int>(dynamic_shape_nodes.size()),
2436               &def);
2437   AddNodeAttr("metadata", metadata, &def);
2438   AddNodeAttr("Tguaranteed_constants", constant_arg_types, &def);
2439 
2440   Status status;
2441   *compile_node = graph->AddNode(def, &status);
2442   TF_RETURN_IF_ERROR(status);
2443 
2444   (*compile_node)->set_assigned_device_name(compile_device);
2445 
2446   for (int i = 0; i < dynamic_shape_nodes.size(); ++i) {
2447     graph->AddEdge(dynamic_shape_nodes[i], 0, *compile_node, i);
2448   }
2449 
2450   for (int i = 0; i < num_guaranteed_constants; ++i) {
2451     graph->AddEdge(guaranteed_constant_nodes[i], 0, *compile_node,
2452                    dynamic_shape_nodes.size() + i);
2453   }
2454   VLOG(1) << "BuildCompileNode(): " << status;
2455   return status;
2456 }
2457 
FindGuaranteedConstantInputs(const Node & node,const NameRangeMap & input_range_map,std::vector<Node * > * guaranteed_constants)2458 Status DistributedTPURewritePass::FindGuaranteedConstantInputs(
2459     const Node& node, const NameRangeMap& input_range_map,
2460     std::vector<Node*>* guaranteed_constants) {
2461   std::vector<const Edge*> input_edges;
2462   TF_RETURN_IF_ERROR(node.input_edges(&input_edges));
2463   std::pair<int, int> variables_limits =
2464       input_range_map.at("guaranteed_constants");
2465   for (int i = variables_limits.first; i < variables_limits.second; ++i) {
2466     guaranteed_constants->push_back(input_edges[i]->src());
2467   }
2468   return Status::OK();
2469 }
2470 
FindVariableInputs(const Node & node,const NameRangeMap & input_range_map,std::vector<VariableInput> * variables)2471 Status DistributedTPURewritePass::FindVariableInputs(
2472     const Node& node, const NameRangeMap& input_range_map,
2473     std::vector<VariableInput>* variables) {
2474   std::vector<const Edge*> input_edges;
2475   TF_RETURN_IF_ERROR(node.input_edges(&input_edges));
2476   std::pair<int, int> variables_limits = input_range_map.at("variables");
2477   for (int i = variables_limits.first; i < variables_limits.second; ++i) {
2478     Node* node = input_edges[i]->src();
2479 
2480     // Find the type of the VarHandleOp that feeds this node, looking through
2481     // any wrapping Enter or Switch nodes.
2482     while (node->IsEnter() || node->IsSwitch()) {
2483       TF_RETURN_IF_ERROR(node->input_node(0, &node));
2484     }
2485     // Fix the variable device assignment if it is requested with a full name.
2486     if (!node->has_assigned_device_name() &&
2487         !node->requested_device().empty()) {
2488       DeviceNameUtils::ParsedName var_device;
2489       TF_RET_CHECK(DeviceNameUtils::ParseFullName(node->requested_device(),
2490                                                   &var_device));
2491       if (var_device.has_job && var_device.has_replica && var_device.has_task &&
2492           var_device.has_type && var_device.has_id) {
2493         node->set_assigned_device_name(node->requested_device());
2494         if (node != input_edges[i]->src() &&
2495             !input_edges[i]->src()->has_assigned_device_name()) {
2496           input_edges[i]->src()->set_assigned_device_name(
2497               node->requested_device());
2498         }
2499       }
2500     }
2501     if (node->type_string() == kVarHandleOp) {
2502       DataType dtype;
2503       TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "dtype", &dtype));
2504       variables->push_back(VariableInput{input_edges[i]->src(),
2505                                          input_edges[i]->src_output(), dtype});
2506     } else if (node->type_string() == "_Arg") {
2507       std::vector<DataType> dtypes;
2508       TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "_handle_dtypes", &dtypes));
2509       if (dtypes.empty()) {
2510         return errors::Internal(
2511             "_Arg node with resource output must have non-empty _handle_dtypes "
2512             "attribute: ",
2513             node->DebugString());
2514       }
2515       variables->push_back(VariableInput{
2516           input_edges[i]->src(), input_edges[i]->src_output(), dtypes[0]});
2517     } else {
2518       return errors::Internal(
2519           "Cannot handle variable input with node type other than VarHandleOp "
2520           "and _Arg: ",
2521           node->DebugString());
2522     }
2523   }
2524   return Status::OK();
2525 }
2526 
2527 // Builds a NoOp node, used for building control dependencies.
BuildNoopNode(const Node & source,StringPiece name,const string & device,Graph * graph,Node ** node)2528 static Status BuildNoopNode(const Node& source, StringPiece name,
2529                             const string& device, Graph* graph, Node** node) {
2530   NodeDefBuilder builder(name, "NoOp", NodeDebugInfo(source));
2531   if (!device.empty()) {
2532     builder.Device(device);
2533   }
2534   NodeDef def;
2535   TF_RETURN_IF_ERROR(builder.Finalize(&def));
2536 
2537   Status status;
2538   *node = graph->AddNode(def, &status);
2539   if (!device.empty()) {
2540     (*node)->set_assigned_device_name(device);
2541   }
2542   return status;
2543 }
2544 
ConnectHostComputeNodes(Node * compile_node,Node * key_placeholder_node,Graph * graph)2545 Status DistributedTPURewritePass::ConnectHostComputeNodes(
2546     Node* compile_node, Node* key_placeholder_node, Graph* graph) {
2547   // First find all the downstream nodes of the key placeholder node, since we
2548   // want to delete the connecting edges from key_placeholder_node which would
2549   // invalidate the out_nodes iterator.
2550   std::vector<Node*> host_transfer_nodes;
2551   for (Node* node : key_placeholder_node->out_nodes()) {
2552     host_transfer_nodes.push_back(node);
2553   }
2554   for (Node* node : host_transfer_nodes) {
2555     int input_index = -1;
2556     for (int i = 0; i < node->num_inputs(); i++) {
2557       const Edge* e;
2558       TF_RETURN_IF_ERROR(node->input_edge(i, &e));
2559       if (e->src() == key_placeholder_node) {
2560         if (input_index != -1) {
2561           return errors::Internal(
2562               "Node ", node->name(),
2563               " has multiple input edges from key placeholder node");
2564         }
2565         input_index = e->dst_input();
2566       }
2567     }
2568     if (input_index == -1) {
2569       return errors::Internal("Node ", node->name(),
2570                               " has no input edge from key placeholder node");
2571     }
2572     const Edge* key_edge;
2573     TF_RETURN_IF_ERROR(node->input_edge(input_index, &key_edge));
2574     graph->RemoveEdge(key_edge);
2575     graph->AddEdge(compile_node, 1, node, input_index);
2576   }
2577   graph->RemoveNode(key_placeholder_node);
2578   return Status::OK();
2579 }
2580 
BuildVariableReads(absl::Span<const VariableInput> variables,Node * control_predecessor,Graph * graph,std::vector<Node * > * variable_reads)2581 Status DistributedTPURewritePass::BuildVariableReads(
2582     absl::Span<const VariableInput> variables, Node* control_predecessor,
2583     Graph* graph, std::vector<Node*>* variable_reads) {
2584   variable_reads->resize(variables.size());
2585   for (int i = 0; i < variables.size(); ++i) {
2586     string name =
2587         graph->NewName(strings::StrCat(variables[i].node->name(), "/read"));
2588     NodeDefBuilder builder(name, "ReadVariableOp",
2589                            NodeDebugInfo(*variables[i].node));
2590 
2591     builder.Attr("dtype", variables[i].dtype);
2592     builder.Device(variables[i].node->assigned_device_name());
2593     builder.Input(variables[i].node->name(), 0, DT_RESOURCE);
2594     NodeDef def;
2595     TF_RETURN_IF_ERROR(builder.Finalize(&def));
2596 
2597     Status status;
2598     Node* read_node;
2599     (*variable_reads)[i] = read_node = graph->AddNode(def, &status);
2600     if (!status.ok()) return status;
2601 
2602     read_node->set_requested_device(variables[i].node->requested_device());
2603     read_node->set_assigned_device_name(
2604         variables[i].node->assigned_device_name());
2605     graph->AddEdge(variables[i].node, variables[i].index, read_node, 0);
2606 
2607     graph->AddControlEdge(control_predecessor, read_node);
2608   }
2609   return Status::OK();
2610 }
2611 
ContainsResourceWriteOp(const Graph & graph,const FunctionLibraryDefinition & fld)2612 bool DistributedTPURewritePass::ContainsResourceWriteOp(
2613     const Graph& graph, const FunctionLibraryDefinition& fld) {
2614   for (const Node* n : graph.nodes()) {
2615     const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n->type_string());
2616     if (op_info && op_info->kind() != XlaResourceOpKind::kRead) {
2617       VLOG(2) << "Found write resource op inside computation";
2618       return true;
2619     }
2620   }
2621   for (const string& func_name : fld.ListFunctionNames()) {
2622     const FunctionDef* func_def = fld.Find(func_name);
2623     for (const NodeDef& n : func_def->node_def()) {
2624       const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n.op());
2625       if (op_info && op_info->kind() != XlaResourceOpKind::kRead) {
2626         VLOG(2) << "Found write resource op inside " << func_name;
2627         return true;
2628       }
2629     }
2630   }
2631   return false;
2632 }
2633 
BuildVariableWrites(absl::Span<const VariableInput> variables,Node * control_successor,absl::Span<const VariableWrite> variable_writes,Graph * graph)2634 Status DistributedTPURewritePass::BuildVariableWrites(
2635     absl::Span<const VariableInput> variables, Node* control_successor,
2636     absl::Span<const VariableWrite> variable_writes, Graph* graph) {
2637   CHECK_EQ(variables.size(), variable_writes.size());
2638   for (int i = 0; i < variables.size(); ++i) {
2639     const VariableWrite& write = variable_writes[i];
2640     NodeDebugInfo debug_info(*variables[i].node);
2641 
2642     auto name = [&](string suffix) {
2643       return graph->NewName(
2644           strings::StrCat(variables[i].node->name(), "/", suffix));
2645     };
2646 
2647     Node* write_node;
2648     TF_RETURN_IF_ERROR(
2649         IncompleteNodeDefBuilder(name("assign"), "AssignVariableOp", debug_info)
2650             .AddAttr("dtype", variables[i].dtype)
2651             .Device(variables[i].node->assigned_device_name())
2652             .Build(graph, &write_node));
2653 
2654     // Colocate the control flow with the variable.
2655     CondBuilder cb(variables[i].node->name(),
2656                    variables[i].node->assigned_device_name(), debug_info,
2657                    graph);
2658 
2659     // Inputs to conditional.
2660     Node* switch_val;
2661     TF_RETURN_IF_ERROR(
2662         cb.AddInput("switch_val", variables[i].dtype,
2663                     /*device=*/write.value->assigned_device_name(), debug_info,
2664                     &switch_val));
2665     Node* switch_var;
2666     TF_RETURN_IF_ERROR(
2667         cb.AddInput("switch_var", DT_RESOURCE,
2668                     /*device=*/variables[i].node->assigned_device_name(),
2669                     debug_info, &switch_var));
2670     // Conditionally write the value back.
2671     graph->AddEdge(variables[i].node, variables[i].index, switch_var, 0);
2672     graph->AddEdge(switch_var, CondBuilder::kThenBranch, write_node, 0);
2673     graph->AddEdge(switch_val, CondBuilder::kThenBranch, write_node, 1);
2674     // Add control edge from the write to value that will be merged. There is no
2675     // output from the write so this control edge ensures the write completes.
2676     graph->AddControlEdge(write_node, cb.switch_t());
2677 
2678     graph->AddControlEdge(cb.control_successor(), control_successor);
2679 
2680     graph->AddEdge(write.predicate, write.predicate_output, cb.pred(), 0);
2681     graph->AddEdge(write.value, write.value_output, switch_val, 0);
2682   }
2683   return Status::OK();
2684 }
2685 
2686 namespace {
2687 
2688 // Creates nodes for zero-initialized dummy arguments for TPUExecute nodes.
CreatePerHostDummyArgs(const InferredShape & raw_var_shape,const string & host_cpu_device,Node * var_read,absl::string_view name_prefix,Graph * graph)2689 xla::StatusOr<Node*> CreatePerHostDummyArgs(const InferredShape& raw_var_shape,
2690                                             const string& host_cpu_device,
2691                                             Node* var_read,
2692                                             absl::string_view name_prefix,
2693                                             Graph* graph) {
2694   Status status;
2695   DataType dtype;
2696   TF_RETURN_IF_ERROR(GetNodeAttr(var_read->def(), "dtype", &dtype));
2697 
2698   if (!(dtype == DT_FLOAT || dtype == DT_BFLOAT16 || dtype == DT_INT32 ||
2699         dtype == DT_BOOL)) {
2700     return var_read;
2701   }
2702 
2703   TensorShape var_shape;
2704   if (!raw_var_shape.handle_shape.AsTensorShape(&var_shape) &&
2705       !raw_var_shape.shape.AsTensorShape(&var_shape)) {
2706     return Status(error::FAILED_PRECONDITION, "Failed to read arg shape.");
2707   }
2708 
2709   // Const - shape_as_tensor
2710   NodeDef shape_tensor_def;
2711   shape_tensor_def.set_op("Const");
2712   shape_tensor_def.set_name(graph->NewName(
2713       strings::StrCat(name_prefix, "/Initializer/zeros/shape_as_tensor")));
2714   AddNodeAttr("dtype", DT_INT32, &shape_tensor_def);
2715   TensorProto tensorshape_proto;
2716   tensorshape_proto.set_dtype(DT_INT32);
2717   for (int i = 0; i < var_shape.dims(); ++i) {
2718     tensorshape_proto.add_int_val(var_shape.dim_size(i));
2719   }
2720   TensorShape shape_shape({var_shape.dims()});
2721   shape_shape.AsProto(tensorshape_proto.mutable_tensor_shape());
2722   AddNodeAttr("value", tensorshape_proto, &shape_tensor_def);
2723   Node* shape_as_tensor_node = graph->AddNode(shape_tensor_def, &status);
2724   TF_RETURN_IF_ERROR(status);
2725 
2726   // Const - initializer value
2727   NodeDef init_val_def;
2728   init_val_def.set_op("Const");
2729   init_val_def.set_name(graph->NewName(
2730       strings::StrCat(name_prefix, "/Initializer/zeros/const_val")));
2731   TensorProto tensor_proto;
2732   tensor_proto.set_dtype(dtype);
2733   if (dtype == DT_FLOAT) {
2734     tensor_proto.add_float_val(0.0f);
2735   } else if (dtype == DT_BFLOAT16) {
2736     tensor_proto.add_half_val(0);
2737   } else if (dtype == DT_INT32) {
2738     tensor_proto.add_int_val(0);
2739   } else if (dtype == DT_BOOL) {
2740     tensor_proto.add_bool_val(false);
2741   } else {
2742     return errors::Internal(
2743         "Unable to create zero-init dummy arg tensor for type ", dtype);
2744   }
2745   TensorShape scalar_shape({});
2746   scalar_shape.AsProto(tensor_proto.mutable_tensor_shape());
2747   AddNodeAttr("value", tensor_proto, &init_val_def);
2748   AddNodeAttr("dtype", dtype, &init_val_def);
2749   Node* init_val_node = graph->AddNode(init_val_def, &status);
2750   TF_RETURN_IF_ERROR(status);
2751 
2752   // Fill node
2753   NodeDef fill_def;
2754   fill_def.set_op("Fill");
2755   fill_def.set_device(host_cpu_device);
2756   fill_def.set_name(
2757       graph->NewName(strings::StrCat(name_prefix, "/Initializer/zeros")));
2758   AddNodeAttr("T", dtype, &fill_def);
2759   AddNodeAttr("index_type", DT_INT32, &fill_def);
2760   Node* fill_node = graph->AddNode(fill_def, &status);
2761   TF_RETURN_IF_ERROR(status);
2762   graph->AddEdge(shape_as_tensor_node, 0, fill_node, 0);
2763   graph->AddEdge(init_val_node, 0, fill_node, 1);
2764 
2765   return fill_node;
2766 }
2767 
2768 // Helper that creates an IdentityN node containing all of the variables
2769 // values on CPU device 'device', except for those that will be split across
2770 // cores. (For split variables, this may cause additional cross-host data
2771 // transfers if more than 1 devices share the same variable partition on a
2772 // remote host.)
2773 //
2774 // A previous iteration of this code built one Identity node per TPU core per
2775 // variable, but this can rapidly become hundreds of thousands of nodes. This
2776 // formulation creates a single IdentityN node containing all of the variables
2777 // on each host. This may cause some unnecessary variable copies if only a
2778 // subset of hosts consume a given variable, but has the virtue of being
2779 // simple, and most models use pure replication where all cores want all the
2780 // variables.
2781 //
2782 // If enable_xla_param_broadcast is set to true, then per-host dummy
2783 // tensor args are created on all hosts except for the primary host. In this
2784 // scheme, the dummy args feed the IdentityN node on their local host. All
2785 // are zero-initialized.
2786 //
2787 // Returns the node and its output index to be consumed by TPUExecute for the
2788 // requested variable index.
CreateOrGetPerHostVariableCopy(const string & host_cpu_device,int64 var_index,const std::vector<Node * > & variable_reads,const DistributedTPURewritePass::ParameterInfo & params_info,const std::vector<xla::OpSharding> & arg_shardings,const Node & replicate_node,const bool enable_xla_param_broadcast,const int num_cores_per_replica,const std::vector<InferredShape> & arg_shapes,absl::flat_hash_map<string,std::vector<NodeOut>> * per_host_var_copies,Graph * graph)2789 xla::StatusOr<NodeOut> CreateOrGetPerHostVariableCopy(
2790     const string& host_cpu_device, int64 var_index,
2791     const std::vector<Node*>& variable_reads,
2792     const DistributedTPURewritePass::ParameterInfo& params_info,
2793     const std::vector<xla::OpSharding>& arg_shardings,
2794     const Node& replicate_node, const bool enable_xla_param_broadcast,
2795     const int num_cores_per_replica,
2796     const std::vector<InferredShape>& arg_shapes,
2797     absl::flat_hash_map<string, std::vector<NodeOut>>* per_host_var_copies,
2798     Graph* graph) {
2799   auto it = per_host_var_copies->find(host_cpu_device);
2800   if (it != per_host_var_copies->end()) {
2801     return it->second[var_index];
2802   }
2803 
2804   // Variable replication relies on identification of a master.
2805   DeviceNameUtils::ParsedName parsed_device;
2806   TF_RET_CHECK(DeviceNameUtils::ParseFullName(host_cpu_device, &parsed_device));
2807   TF_RET_CHECK(parsed_device.has_task);
2808   VLOG(1) << "Creating per-host IdentityN node for task " << parsed_device.task;
2809 
2810   DataTypeVector dtypes;
2811   // Per-variable data source for TPUExecute.
2812   std::vector<NodeOut> index_mapping;
2813   index_mapping.reserve(variable_reads.size());
2814   dtypes.reserve(variable_reads.size());
2815   for (int64 i = 0; i < variable_reads.size(); ++i) {
2816     Node* read = variable_reads[i];
2817     int64 orig_arg_num =
2818         i + params_info.NumPerReplicaArgs() + params_info.NumBroadcastArgs();
2819     if (arg_shardings[orig_arg_num].type() != xla::OpSharding::OTHER) {
2820       // We haven't built the IdentityN node yet, so temporarily use nullptr.
2821       index_mapping.push_back(
2822           NodeOut{nullptr, static_cast<int>(dtypes.size())});
2823       dtypes.push_back(read->output_type(0));
2824     } else {
2825       // Do not copy the full tensor of partitioned variables.
2826       index_mapping.push_back(NodeOut{read, 0});
2827     }
2828   }
2829   NodeDef ndef;
2830   ndef.set_name(
2831       graph->NewName(absl::StrCat(replicate_node.name(), "/_variable_copy")));
2832   ndef.set_op("IdentityN");
2833   ndef.set_device(host_cpu_device);
2834   AddNodeAttr("T", dtypes, &ndef);
2835   // TF meta-optimizer should skip this node for constant folding.
2836   AddNodeAttr("_tpu_avoid_constant_fold", "not_used", &ndef);
2837   Status s;
2838   Node* id_node = graph->AddNode(ndef, &s);
2839   TF_RETURN_IF_ERROR(s);
2840   id_node->set_assigned_device_name(host_cpu_device);
2841 
2842   for (int64 i = 0; i < variable_reads.size(); ++i) {
2843     if (index_mapping[i].node == nullptr) {
2844       // Fill index_mapping with the actual IdentityN node.
2845       index_mapping[i].node = id_node;
2846       if (parsed_device.task == 0 || !enable_xla_param_broadcast) {
2847         // XLA broadcast mode is not enabled, so use the variable reads as args
2848         // to TPUExecuteOp. For task 0, variable reads are always used
2849         // regardless of XLA broadcast.
2850 
2851         // Add the variable read edge to id_node.
2852         graph->AddEdge(variable_reads[i], 0, id_node, index_mapping[i].index);
2853       } else {
2854         // XLA broadcast mode is enabled. Create zero-valued dummy tensors to
2855         // use as variable args in the TPUExecuteOp.
2856         int64 orig_arg_num = i + params_info.NumPerReplicaArgs() +
2857                              params_info.NumBroadcastArgs();
2858         if (num_cores_per_replica > 1) {
2859           LOG(WARNING) << "XLA parameter broadcast is only supported for "
2860                           "replicated parameters. Falling back to "
2861                           "non-broadcast mode for the parameter associated "
2862                           "with the following variable read: "
2863                        << variable_reads[i]->name();
2864           graph->AddEdge(variable_reads[i], 0, id_node, index_mapping[i].index);
2865           continue;
2866         }
2867         string dummy_name =
2868             strings::StrCat(variable_reads[i]->name(),
2869                             absl::StrFormat("/dummy_%d", parsed_device.task));
2870         TF_ASSIGN_OR_RETURN(
2871             Node * var_read,
2872             CreatePerHostDummyArgs(arg_shapes[orig_arg_num], host_cpu_device,
2873                                    variable_reads[i], dummy_name, graph));
2874         graph->AddEdge(var_read, 0, id_node, index_mapping[i].index);
2875       }
2876     }
2877   }
2878 
2879   auto result = index_mapping[var_index];
2880   (*per_host_var_copies)[host_cpu_device] = std::move(index_mapping);
2881   return result;
2882 }
2883 
2884 }  // namespace
2885 
BuildExecuteNodes(const ParameterInfo & params_info,int num_tasks,int num_cores_per_replica,const Node & replicate_node,const std::vector<std::string> & arg_names,const DataTypeVector & arg_types,const std::vector<InferredShape> & arg_shapes,const DataTypeVector & retval_types,const std::vector<xla::OpSharding> & arg_shardings,const std::vector<xla::OpSharding> & retval_shardings,const std::vector<std::vector<string>> & tpu_device_names,Node * compile_node,const std::vector<Node * > & variable_reads,Node * control_predecessor,Node * control_successor,std::vector<VariableWrite> * variable_writes,Graph * graph)2886 Status DistributedTPURewritePass::BuildExecuteNodes(
2887     const ParameterInfo& params_info, int num_tasks, int num_cores_per_replica,
2888     const Node& replicate_node, const std::vector<std::string>& arg_names,
2889     const DataTypeVector& arg_types,
2890     const std::vector<InferredShape>& arg_shapes,
2891     const DataTypeVector& retval_types,
2892     const std::vector<xla::OpSharding>& arg_shardings,
2893     const std::vector<xla::OpSharding>& retval_shardings,
2894     const std::vector<std::vector<string>>& tpu_device_names,
2895     Node* compile_node, const std::vector<Node*>& variable_reads,
2896     Node* control_predecessor, Node* control_successor,
2897     std::vector<VariableWrite>* variable_writes, Graph* graph) {
2898   VLOG(1) << "BuildExecuteNodes " << replicate_node.DebugString();
2899   TF_RET_CHECK(params_info.NumReplicas() == tpu_device_names.size());
2900 
2901   const int num_variables = variable_reads.size();
2902   const int num_retvals_per_replica = retval_types.size();
2903 
2904   variable_writes->resize(num_variables);
2905 
2906   std::vector<const Edge*> replicate_input_edges;
2907   TF_RETURN_IF_ERROR(replicate_node.input_edges(&replicate_input_edges));
2908 
2909   // Map from replicate input index to the fan_in node;
2910   absl::flat_hash_map<int, std::vector<Node*>> replicate_input_fan_in_nodes;
2911   absl::flat_hash_map<int, std::vector<Node*>> replicate_output_fan_out_nodes;
2912   absl::flat_hash_map<int, std::vector<int>>
2913       replicate_output_fan_out_dst_inputs;
2914   std::vector<Node*> to_be_removed_nodes;
2915 
2916   for (const Edge* e : replicate_input_edges) {
2917     if (e->src()->type_string() == kTPUPartitionedInput) {
2918       int num_users = 0;
2919       for (const auto& ue : e->src()->out_edges()) {
2920         if (!ue->IsControlEdge()) ++num_users;
2921       }
2922       if (num_users != 1) {
2923         return tensorflow::errors::InvalidArgument(
2924             e->src()->name(), " must only have one user. Found ", num_users);
2925       }
2926       to_be_removed_nodes.push_back(e->src());
2927       std::vector<Node*>& nodes = replicate_input_fan_in_nodes[e->dst_input()];
2928       nodes.resize(num_cores_per_replica, nullptr);
2929       VLOG(2) << "allocate " << num_cores_per_replica
2930               << " for replicate_input_fan_in_nodes[" << e->dst_input() << "]";
2931       std::vector<const Edge*> fan_in_edges;
2932       TF_RETURN_IF_ERROR(e->src()->input_edges(&fan_in_edges));
2933       TF_RET_CHECK(fan_in_edges.size() == num_cores_per_replica);
2934 
2935       for (const Edge* fe : fan_in_edges) {
2936         nodes[fe->dst_input()] = fe->src();
2937         VLOG(2) << "replicate_input_fan_in_nodes[" << e->dst_input() << "]["
2938                 << fe->dst_input() << "] = " << fe->src()->name();
2939       }
2940     }
2941   }
2942 
2943   // Replicate output edges are sorted by replica id and then by outputs for
2944   // each replica. For example, if TPU Computation has outputs (output_1,
2945   // output_2, and output_3) and number of replicas is 2, then
2946   // replicate_output_edges order would be:
2947   // output_1_replica_1, output_2_replica_1, output_3_replica_1,
2948   // output_1_replica_2, output_2_replica_2, output_3_replica_2.
2949   std::vector<const Edge*> replicate_output_edges(replicate_node.num_outputs(),
2950                                                   nullptr);
2951   for (const Edge* edge : replicate_node.out_edges()) {
2952     if (edge->IsControlEdge()) continue;
2953 
2954     int num_partitioned_outputs = 0;
2955 
2956     for (const Edge* out_edge : edge->dst()->out_edges()) {
2957       if (out_edge->dst()->type_string() == kTPUPartitionedOutput) {
2958         num_partitioned_outputs++;
2959         // Paths between replicate_node and replicate_output_fan_out_nodes:
2960         // ReplicateNode->TpuOutIdenity->kTPUPartitionedOutput->fan-out-nodes
2961         TF_RET_CHECK(edge->dst()->out_edges().size() == 1);
2962         to_be_removed_nodes.push_back(edge->dst());
2963         to_be_removed_nodes.push_back(out_edge->dst());
2964         // Get the right replicated id from the replicate_output_edge.
2965         std::vector<Node*>& nodes =
2966             replicate_output_fan_out_nodes[edge->src_output()];
2967         std::vector<int>& dst_inputs =
2968             replicate_output_fan_out_dst_inputs[edge->src_output()];
2969         nodes.resize(num_cores_per_replica, nullptr);
2970         dst_inputs.resize(num_cores_per_replica, 0);
2971         TF_RET_CHECK(out_edge->dst()->out_edges().size() ==
2972                      num_cores_per_replica);
2973 
2974         for (const Edge* fe : out_edge->dst()->out_edges()) {
2975           nodes[fe->src_output()] = fe->dst();
2976           dst_inputs[fe->src_output()] = fe->dst_input();
2977           VLOG(2) << "replicate_output_fan_out_nodes[" << out_edge->src_output()
2978                   << "][" << fe->src_output()
2979                   << "] = " << fe->dst()->DebugString() << " with dst_input "
2980                   << fe->dst_input();
2981         }
2982       }
2983     }
2984     replicate_output_edges[edge->src_output()] = edge;
2985     if (num_partitioned_outputs > 1) {
2986       return errors::InvalidArgument(
2987           "More than one TPUPartitionedOutput per replciated output.");
2988     }
2989   }
2990 
2991   const int num_execute_args =
2992       arg_shardings.size() - params_info.NumGuaranteedConstants();
2993   // Inverts the arg_shardings and retval_shardings mappings to
2994   // form core -> {argument number} maps.
2995   std::vector<std::vector<int>> core_arg_nums(num_cores_per_replica);
2996   for (int i = 0; i < num_execute_args; ++i) {
2997     const auto& sharding = arg_shardings[i];
2998     if (sharding.type() == xla::OpSharding::MAXIMAL) {
2999       int core = sharding.tile_assignment_devices(0);
3000       TF_RETURN_IF_ERROR(ValidateCoreNumber(core, num_cores_per_replica));
3001       core_arg_nums[core].push_back(i);
3002     } else if (sharding.type() == xla::OpSharding::OTHER) {
3003       for (int64 core : sharding.tile_assignment_devices()) {
3004         core_arg_nums[core].push_back(i);
3005       }
3006     } else if (sharding.type() == xla::OpSharding::REPLICATED) {
3007       for (int core = 0; core < num_cores_per_replica; ++core) {
3008         core_arg_nums[core].push_back(i);
3009       }
3010     } else {
3011       return tensorflow::errors::InvalidArgument(
3012           "Unsupported argument sharding for arg=", arg_names[i],
3013           " shape=", arg_shapes[i].shape.DebugString(), ": ",
3014           sharding.DebugString());
3015     }
3016   }
3017   std::vector<std::vector<int>> core_retval_nums(num_cores_per_replica);
3018   for (int i = 0; i < retval_shardings.size(); ++i) {
3019     const auto& sharding = retval_shardings[i];
3020     if (sharding.type() == xla::OpSharding::MAXIMAL) {
3021       int core = sharding.tile_assignment_devices(0);
3022       TF_RETURN_IF_ERROR(ValidateCoreNumber(core, num_cores_per_replica));
3023       core_retval_nums[core].push_back(i);
3024     } else if (sharding.type() == xla::OpSharding::REPLICATED) {
3025       for (int core = 0; core < num_cores_per_replica; ++core) {
3026         core_retval_nums[core].push_back(i);
3027       }
3028     } else if (sharding.type() == xla::OpSharding::OTHER) {
3029       for (int64 core : sharding.tile_assignment_devices()) {
3030         core_retval_nums[core].push_back(i);
3031       }
3032     } else {
3033       return tensorflow::errors::InvalidArgument(
3034           "Unsupported argument sharding: ", sharding.DebugString());
3035     }
3036   }
3037 
3038   // Maps host device name to a list of per-variable pairs (variable_copy_node,
3039   // output_index_of_copy_node).
3040   absl::flat_hash_map<string, std::vector<NodeOut>> per_host_var_copies;
3041 
3042   // Mapping from original resource arg number to a second level map. Second
3043   // level map is from core id to output index of updated variable value.
3044   absl::flat_hash_map<int, absl::flat_hash_map<int, int>>
3045       orig_arg_num_to_output_index_mapping;
3046   // Mapping from retval index to a second level map. Second level map is from
3047   // core id to output index of sharded output value.
3048   std::unordered_map<int, std::unordered_map<int, int>>
3049       retval_index_to_output_index_mapping;
3050 
3051   // Represents mapping of argument index of sharded input to each
3052   // TPUExecute node to its corresponding Split node and its output index
3053   // from which sharded input will be fed into TPUExecute node.
3054   std::map<ShardedInputIndex, ShardedInputInfo> input_index_to_sharded_inputs;
3055 
3056   // Builds one TPUExecute node per core per replica.
3057   std::vector<std::vector<Node*>> execute_nodes(params_info.NumReplicas());
3058   for (int core = 0; core < num_cores_per_replica; ++core) {
3059     DataTypeVector core_retval_types;
3060     for (int output : core_retval_nums[core]) {
3061       core_retval_types.push_back(retval_types[output]);
3062     }
3063     DataTypeVector core_arg_types;
3064     std::vector<int> core_variable_writes;
3065     for (int input : core_arg_nums[core]) {
3066       // Resource variables can be passed either by reference (as a DT_RESOURCE)
3067       // tensor or by value (as the variable's current value). Per-replica or
3068       // distributed resource arguments are always passed by reference and
3069       // broadcast variables are always passed by value.
3070       if (arg_types[input] == DT_RESOURCE &&
3071           !params_info.IsPerReplicaArg(input) &&
3072           !params_info.IsDistributedArg(input)) {
3073         DataType handle_type = arg_shapes[input].handle_type;
3074         TF_RET_CHECK(handle_type != DT_INVALID) << DataTypeString(handle_type);
3075         core_arg_types.push_back(handle_type);
3076         int base = input - params_info.NumPerReplicaArgs() -
3077                    params_info.NumDistributedArgs() -
3078                    params_info.NumBroadcastArgs();
3079         // Variables passed by value will have a corresponding additional output
3080         // containing an updated value for the variable.
3081         core_variable_writes.push_back(base);
3082         core_retval_types.push_back(handle_type);
3083       } else {
3084         core_arg_types.push_back(arg_types[input]);
3085       }
3086     }
3087 
3088     NodeDef def;
3089     def.set_op("TPUExecute");
3090     MergeDebugInfo(NodeDebugInfo(replicate_node.def()), &def);
3091     AddNodeAttr("Targs", core_arg_types, &def);
3092     AddNodeAttr("Tresults", core_retval_types, &def);
3093 
3094     for (int64 replica = 0; replica < params_info.NumReplicas(); ++replica) {
3095       def.set_name(strings::StrCat(replicate_node.name(), "/_execute_", replica,
3096                                    "_", core));
3097 
3098       Status status;
3099       Node* node = graph->AddNode(def, &status);
3100       if (!status.ok()) return status;
3101       execute_nodes[replica].push_back(node);
3102 
3103       node->set_assigned_device_name(tpu_device_names[replica][core]);
3104 
3105       // Add control edges to ensure that execution happens after
3106       // `control_predecessor`, happens before `control_successor`, and is
3107       // triggered by evaluating any operator that depends on the original
3108       // TPUReplicate operator. See the comment at the top of the header file
3109       // for more details.
3110       graph->AddControlEdge(control_predecessor, node);
3111       graph->AddControlEdge(node, control_successor);
3112 
3113       // Add data input edges.
3114       for (int64 i = 0; i < core_arg_nums[core].size(); ++i) {
3115         int64 orig_arg_num = core_arg_nums[core][i];
3116         VLOG(2) << " replica " << replica << " core " << core << " i " << i
3117                 << " orig_arg_num " << orig_arg_num;
3118         if (params_info.IsPerReplicaArg(orig_arg_num) ||
3119             params_info.IsDistributedArg(orig_arg_num)) {
3120           // Per-replica input and distributed input
3121           int64 input_num = params_info.IsPerReplicaArg(orig_arg_num)
3122                                 ? replica * params_info.NumPerReplicaArgs() +
3123                                       core_arg_nums[core][i]
3124                                 : params_info.NumReplicas() *
3125                                           params_info.NumPerReplicaArgs() +
3126                                       core_arg_nums[core][i] -
3127                                       params_info.NumPerReplicaArgs();
3128 
3129           const Edge* edge = replicate_input_edges[input_num];
3130           VLOG(2) << "replicate_input_edges[" << input_num << "]";
3131           DataType dtype = edge->src()->output_type(edge->src_output());
3132           if (dtype == DT_RESOURCE) {
3133             DataType handle_dtype = arg_shapes[orig_arg_num].handle_type;
3134             if (std::find(kTpuAllTypes.begin(), kTpuAllTypes.end(),
3135                           handle_dtype) == kTpuAllTypes.end()) {
3136               return errors::InvalidArgument(
3137                   "Unsupported resource variable data type for TPU: ",
3138                   DataTypeString(handle_dtype), ", caused by output ",
3139                   edge->src()->name(), ":", edge->src_output());
3140             }
3141           } else {
3142             if (std::find(kTpuAllTypes.begin(), kTpuAllTypes.end(), dtype) ==
3143                 kTpuAllTypes.end()) {
3144               return errors::InvalidArgument(
3145                   "Unsupported data type for TPU: ", DataTypeString(dtype),
3146                   ", caused by output ", edge->src()->name(), ":",
3147                   edge->src_output());
3148             }
3149           }
3150           if (arg_shardings[orig_arg_num].type() == xla::OpSharding::OTHER) {
3151             // Don't automatically add a split node when input node is
3152             // kTPUPartitionedInput
3153             if (edge->src()->type_string() == kTPUPartitionedInput) {
3154               VLOG(2) << "Connect "
3155                       << replicate_input_fan_in_nodes[input_num][core]->name()
3156                       << " to " << node->name() << " at " << i;
3157               graph->AddEdge(replicate_input_fan_in_nodes[input_num][core], 0,
3158                              node, i);
3159             } else {
3160               if (dtype == DT_RESOURCE) {
3161                 return errors::InvalidArgument(
3162                     "Tiled sharding for per-replica DT_RESOURCE input must",
3163                     "be TPUPartitionedInput. Here got ",
3164                     edge->src()->type_string());
3165               }
3166               const xla::OpSharding& sharding = arg_shardings[orig_arg_num];
3167 
3168               // Create or get the Split node.
3169               TF_ASSIGN_OR_RETURN(
3170                   ShardedInputInfo sharded_input_info,
3171                   CreateOrGetSplitNodesForInputSharding(
3172                       sharding, orig_arg_num, dtype,
3173                       arg_shapes[orig_arg_num].handle_shape, replica,
3174                       edge->src_output(), edge->src(), control_predecessor,
3175                       graph, &input_index_to_sharded_inputs));
3176               NodeOut split_node_and_index =
3177                   sharded_input_info.sharded_inputs.at(core);
3178               // Connect with Split node output.
3179               graph->AddEdge(split_node_and_index.node,
3180                              split_node_and_index.index, node, i);
3181             }
3182           } else if (edge->src()->type_string() == kTPUPartitionedInput &&
3183                      arg_shardings[orig_arg_num].type() ==
3184                          xla::OpSharding::REPLICATED) {
3185             graph->AddEdge(replicate_input_fan_in_nodes[input_num][core], 0,
3186                            node, i);
3187           } else {
3188             graph->AddEdge(edge->src(), edge->src_output(), node, i);
3189           }
3190         } else if (params_info.IsBroadcastArg(orig_arg_num)) {
3191           // Broadcast input.
3192           int64 input_num = params_info.FirstBroadcastArgFromHost() +
3193                             core_arg_nums[core][i] -
3194                             params_info.NumPerReplicaArgs() -
3195                             params_info.NumDistributedArgs();
3196           const Edge* edge = replicate_input_edges[input_num];
3197           DataType dtype = edge->src()->output_type(edge->src_output());
3198           if (std::find(kTpuAllTypes.begin(), kTpuAllTypes.end(), dtype) ==
3199               kTpuAllTypes.end()) {
3200             return errors::InvalidArgument(
3201                 "Unsupported data type for TPU: ", DataTypeString(dtype),
3202                 ", caused by output ", edge->src()->name(), ":",
3203                 edge->src_output());
3204           }
3205           graph->AddEdge(edge->src(), edge->src_output(), node, i);
3206         } else {
3207           // Variable input.
3208           int64 variable_num = orig_arg_num - params_info.NumPerReplicaArgs() -
3209                                params_info.NumDistributedArgs() -
3210                                params_info.NumBroadcastArgs();
3211           TF_RET_CHECK(variable_num < num_variables);
3212 
3213           Node* variable_read = variable_reads[variable_num];
3214           DataType dtype = variable_read->output_type(0);
3215           if (std::find(kTpuAllTypes.begin(), kTpuAllTypes.end(), dtype) ==
3216               kTpuAllTypes.end()) {
3217             return errors::InvalidArgument(
3218                 "Unsupported resource variable data type for TPU: ",
3219                 DataTypeString(dtype), ", caused by ReadVariableOp ",
3220                 variable_read->DebugString());
3221           }
3222           DeviceNameUtils::ParsedName requested_device;
3223           string requested = variable_read->requested_device();
3224           TF_RET_CHECK(
3225               DeviceNameUtils::ParseFullName(requested, &requested_device));
3226           if (requested_device.type != "TPU") {
3227             // Stage the value via the CPU device on the remote host. The graph
3228             // partitioner will introduce an intermediate copy rather than
3229             // copying the same tensor multiple times across the network, and we
3230             // would prefer that intermediate copy to be in host memory to avoid
3231             // running out of memory if the TPUExecute op on the staging device
3232             // starts running before the _Send ops to the other TPU devices on
3233             // the same host complete. We don't do this if the variables are
3234             // already placed on TPU, otherwise it will cause an unnecessary
3235             // round trip copy.
3236             // TODO(b/79580121): give each replica its own on-device variable
3237             // replica and then delete this code.
3238             string device;
3239             TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
3240                 tpu_device_names[replica][core], &device));
3241             TF_ASSIGN_OR_RETURN(
3242                 auto var_data,
3243                 CreateOrGetPerHostVariableCopy(
3244                     device, variable_num, variable_reads, params_info,
3245                     arg_shardings, replicate_node, enable_xla_param_broadcast_,
3246                     num_cores_per_replica, arg_shapes, &per_host_var_copies,
3247                     graph));
3248 
3249             if (arg_shardings[orig_arg_num].type() == xla::OpSharding::OTHER) {
3250               const xla::OpSharding& sharding = arg_shardings[orig_arg_num];
3251               // Create or get the Split node.
3252               TF_ASSIGN_OR_RETURN(
3253                   ShardedInputInfo sharded_input_info,
3254                   CreateOrGetSplitNodesForInputSharding(
3255                       sharding, orig_arg_num,
3256                       arg_shapes[orig_arg_num].handle_type,
3257                       arg_shapes[orig_arg_num].handle_shape, replica,
3258                       var_data.index, var_data.node, control_predecessor, graph,
3259                       &input_index_to_sharded_inputs));
3260               NodeOut split_node_and_index =
3261                   sharded_input_info.sharded_inputs[core];
3262               // Connect with Split node output.
3263               graph->AddEdge(split_node_and_index.node,
3264                              split_node_and_index.index, node, i);
3265 
3266             } else {
3267               graph->AddEdge(var_data.node, var_data.index, node, i);
3268             }
3269           } else {
3270             graph->AddEdge(variable_reads[variable_num], 0, node, i);
3271           }
3272         }
3273       }
3274 
3275       // Adds a program input edge from the compiler.
3276       graph->AddEdge(compile_node, core + 1, node, node->num_inputs() - 1);
3277 
3278       // Add data output edges.
3279       int num_outputs = core_retval_nums[core].size();
3280       for (int i = 0; i < num_outputs; ++i) {
3281         int output_num =
3282             replica * num_retvals_per_replica + core_retval_nums[core][i];
3283         const auto& sharding = retval_shardings[core_retval_nums[core][i]];
3284         if (sharding.type() == xla::OpSharding::OTHER) {
3285           int retval_index = core_retval_nums[core][i];
3286           retval_index_to_output_index_mapping[retval_index][core] = i;
3287           bool is_last_core =
3288               core ==
3289               *std::max_element(sharding.tile_assignment_devices().begin(),
3290                                 sharding.tile_assignment_devices().end());
3291           bool isPartitionOutNode = false;
3292 
3293           const Edge* e = replicate_output_edges[output_num];
3294           const Edge* e_out;
3295           for (const Edge* out_edge : e->dst()->out_edges()) {
3296             if (out_edge->dst()->type_string() == kTPUPartitionedOutput) {
3297               isPartitionOutNode = true;
3298               e_out = out_edge;
3299             }
3300           }
3301           if (isPartitionOutNode) {
3302             graph->AddEdge(
3303                 node, i, replicate_output_fan_out_nodes[output_num][core],
3304                 replicate_output_fan_out_dst_inputs[output_num][core]);
3305             VLOG(2) << "Connect " << node->name() << " at " << i << " to "
3306                     << replicate_output_fan_out_nodes[output_num][core]->name()
3307                     << " at "
3308                     << replicate_output_fan_out_dst_inputs[output_num][core];
3309             if (is_last_core) {
3310               graph->RemoveEdge(e);
3311               graph->RemoveEdge(e_out);
3312             }
3313             continue;
3314           }
3315 
3316           // Do this in the iteration of last core in tile assignment, so all
3317           // TPUExecute nodes have been created.
3318           if (!is_last_core) {
3319             continue;
3320           }
3321 
3322           // Add a Concat node.
3323           std::vector<NodeOut> orig_inputs;
3324           for (int64 tile_index = 0;
3325                tile_index < sharding.tile_assignment_devices_size();
3326                ++tile_index) {
3327             int64 last_tile_dim_size =
3328                 *sharding.tile_assignment_dimensions().rbegin();
3329             if (sharding.replicate_on_last_tile_dim() &&
3330                 tile_index % last_tile_dim_size != 0) {
3331               continue;
3332             }
3333             int64 core_id = sharding.tile_assignment_devices(tile_index);
3334             int core_retval_index =
3335                 retval_index_to_output_index_mapping[retval_index][core_id];
3336             orig_inputs.push_back(
3337                 NodeOut{execute_nodes[replica][core_id],
3338                         static_cast<int>(
3339                             core_retval_nums[core_id][core_retval_index])});
3340           }
3341           DataType dtype = e->src()->output_type(e->src_output());
3342           TF_ASSIGN_OR_RETURN(
3343               Node * concat_node,
3344               CreateConcatNodesForRetval(
3345                   sharding, dtype, /*inferred_shape*/ PartialTensorShape(),
3346                   replica, orig_inputs, graph, /*device=*/""));
3347 
3348           const Edge* edge = replicate_output_edges[output_num];
3349           Node* dst = edge->dst();
3350           int dst_input = edge->dst_input();
3351           graph->RemoveEdge(edge);
3352           graph->AddEdge(concat_node, 0, dst, dst_input);
3353 
3354           continue;
3355         }
3356 
3357         // If this is a replicated output, outputs on all cores will be the
3358         // same, and we only take the output from core 0.
3359         if (sharding.type() == xla::OpSharding::REPLICATED && core != 0) {
3360           continue;
3361         }
3362 
3363         // If output has maximal sharding, make sure we only use output from
3364         // TPUExecute node with logical core id equal to core id defined by the
3365         // xla sharding.
3366         if (sharding.type() == xla::OpSharding::MAXIMAL &&
3367             core != sharding.tile_assignment_devices(0)) {
3368           continue;
3369         }
3370 
3371         const Edge* replicate_edge_to_replace =
3372             replicate_output_edges[output_num];
3373         Node* dst = replicate_edge_to_replace->dst();
3374         int dst_input = replicate_edge_to_replace->dst_input();
3375         graph->RemoveEdge(replicate_edge_to_replace);
3376         graph->AddEdge(node, i, dst, dst_input);
3377       }
3378 
3379       // Feed the updated variable values from the first replica to the
3380       // variable write nodes.
3381       if (replica == 0) {
3382         for (int i = 0; i < core_variable_writes.size(); ++i) {
3383           int orig_arg_num =
3384               core_variable_writes[i] + params_info.NumPerReplicaArgs() +
3385               params_info.NumDistributedArgs() + params_info.NumBroadcastArgs();
3386           const auto& sharding = arg_shardings[orig_arg_num];
3387           // If this is a tiling sharded variable, concat variable updates from
3388           // all cores.
3389           if (sharding.type() == xla::OpSharding::OTHER) {
3390             orig_arg_num_to_output_index_mapping[orig_arg_num][core] = i;
3391 
3392             // Do this in the iteration of last core in tile assignment, so all
3393             // TPUExecute nodes have been created.
3394             if (core !=
3395                 *std::max_element(sharding.tile_assignment_devices().begin(),
3396                                   sharding.tile_assignment_devices().end())) {
3397               continue;
3398             }
3399 
3400             // Add a Concat node.
3401             std::vector<NodeOut> orig_inputs;
3402             for (int64 tile_index = 0;
3403                  tile_index < sharding.tile_assignment_devices_size();
3404                  ++tile_index) {
3405               int64 last_tile_dim_size =
3406                   *sharding.tile_assignment_dimensions().rbegin();
3407               if (sharding.replicate_on_last_tile_dim() &&
3408                   tile_index % last_tile_dim_size != 0) {
3409                 continue;
3410               }
3411               int64 core_id = sharding.tile_assignment_devices(tile_index);
3412               int core_retval_num =
3413                   orig_arg_num_to_output_index_mapping[orig_arg_num][core_id];
3414               orig_inputs.push_back(
3415                   NodeOut{execute_nodes[0][core_id],
3416                           static_cast<int>(core_retval_nums[core_id].size() +
3417                                            core_retval_num)});
3418             }
3419 
3420             // Use the variable read's device for the concat. They should both
3421             // be collocated with the variable.
3422             absl::string_view device =
3423                 variable_reads[core_variable_writes[i]]->assigned_device_name();
3424             TF_ASSIGN_OR_RETURN(
3425                 Node * concat_node,
3426                 CreateConcatNodesForRetval(
3427                     sharding, arg_shapes[orig_arg_num].handle_type,
3428                     arg_shapes[orig_arg_num].handle_shape, replica, orig_inputs,
3429                     graph, device));
3430             // Populate VariableWrite.
3431             VariableWrite& write = variable_writes->at(core_variable_writes[i]);
3432             write.value = concat_node;
3433             write.value_output = 0;
3434             write.predicate = compile_node;
3435             write.predicate_output = num_cores_per_replica + core + 1;
3436 
3437             continue;
3438           }
3439 
3440           // If this is a replicated variable, outputs on all cores will be the
3441           // same, and we only take the output from core 0 for the varialbe
3442           // update.
3443           if (sharding.type() == xla::OpSharding::REPLICATED && core != 0) {
3444             continue;
3445           }
3446           VariableWrite& write = variable_writes->at(core_variable_writes[i]);
3447           write.value = node;
3448           write.value_output = num_outputs + i;
3449           write.predicate = compile_node;
3450           write.predicate_output = num_cores_per_replica + core + 1;
3451         }
3452       }
3453     }
3454   }
3455 
3456   for (Node* node : to_be_removed_nodes) {
3457     graph->RemoveNode(node);
3458   }
3459   return Status::OK();
3460 }
3461 
CopyOutsideCompilationNodes(int replica_index,const std::vector<Node * > & outside_compilation_nodes,const DeviceNameUtils::ParsedName & tpu_device,const DeviceNameUtils::ParsedName & partial_device,NodeToNodeReplicasMap * node_images,Graph * graph)3462 /* static */ Status DistributedTPURewritePass::CopyOutsideCompilationNodes(
3463     int replica_index, const std::vector<Node*>& outside_compilation_nodes,
3464     const DeviceNameUtils::ParsedName& tpu_device,
3465     const DeviceNameUtils::ParsedName& partial_device,
3466     NodeToNodeReplicasMap* node_images, Graph* graph) {
3467   for (Node* node : outside_compilation_nodes) {
3468     NodeDef image_def = node->def();
3469     MergeDebugInfo(NodeDebugInfo(node->def()), &image_def);
3470     const string suffix = strings::StrCat("/R", replica_index);
3471     // In addition to node name, make the frame name unique to avoid multiple
3472     // LoopCond nodes in one frame.
3473     TF_RETURN_IF_ERROR(
3474         AddPrefixAndSuffixToNode("" /* prefix */, suffix, &image_def));
3475     Status status;
3476     Node* image = graph->AddNode(image_def, &status);
3477     image->AddAttr(kXlaReplicaIdAttrName, replica_index);
3478     TF_RETURN_IF_ERROR(status);
3479     if (HasNodeAttr(image->def(), kXlaHasHostTransferAttrName)) {
3480       TF_RETURN_IF_ERROR(
3481           SetNodeDeviceForTPUCommunication(tpu_device, DEVICE_CPU, image));
3482     } else {
3483       const string& original_device_string =
3484           node->assigned_device_name().empty() ? node->requested_device()
3485                                                : node->assigned_device_name();
3486       DeviceNameUtils::ParsedName device;
3487       TF_RET_CHECK(
3488           DeviceNameUtils::ParseFullName(original_device_string, &device));
3489       // If the requested device can be merged with the replica's host device,
3490       // then do so. For example, if the requested device is "/CPU:0" or
3491       // "/GPU:0" then it will be placed on the CPU/GPU of the host where this
3492       // replica is running. But if the requested device is
3493       // "/task:3/replica:2/CPU:0" then it will be placed on that task/replica.
3494       if (DeviceNameUtils::IsSpecification(device, partial_device)) {
3495         TF_RETURN_IF_ERROR(
3496             DeviceNameUtils::MergeDevNames(&device, partial_device));
3497       }
3498       image->set_requested_device(DeviceNameUtils::ParsedNameToString(device));
3499     }
3500     std::vector<Node*>& node_image_vector = (*node_images)[node];
3501     node_image_vector.resize(replica_index + 1);
3502     node_image_vector[replica_index] = image;
3503   }
3504   return Status::OK();
3505 }
3506 
ReplicateOutsideCompilationNodes(const std::vector<std::vector<string>> & tf_device_assignment,const HostComputeCoreMap & host_compute_core,const OutsideCompilationNodeMap & outside_compilation_nodes,NodeToNodeReplicasMap * node_images,Graph * graph)3507 /* static */ Status DistributedTPURewritePass::ReplicateOutsideCompilationNodes(
3508     const std::vector<std::vector<string>>& tf_device_assignment,
3509     const HostComputeCoreMap& host_compute_core,
3510     const OutsideCompilationNodeMap& outside_compilation_nodes,
3511     NodeToNodeReplicasMap* node_images, Graph* graph) {
3512   // Iterate over replicas.
3513   for (int i = 0; i < tf_device_assignment.size(); ++i) {
3514     const auto& core_devices = tf_device_assignment[i];
3515     for (const auto& oc_cluster_iter : outside_compilation_nodes) {
3516       const string& oc_cluster_name = oc_cluster_iter.first;
3517       const auto& oc_cluster_nodes = oc_cluster_iter.second;
3518       // We previously validated that host_compute_core contains an entry for
3519       // each cluster.
3520       int core = host_compute_core.at(oc_cluster_name);
3521       TF_RET_CHECK(core >= 0 && core < core_devices.size());
3522       // tpu_device is the device the HostCompute XLA Op for this cluster runs
3523       // on.
3524       DeviceNameUtils::ParsedName tpu_device;
3525       TF_RET_CHECK(
3526           DeviceNameUtils::ParseFullName(core_devices[core], &tpu_device));
3527       // partial_device contains the replica and task but not the type.
3528       DeviceNameUtils::ParsedName partial_device = tpu_device;
3529       partial_device.has_type = false;
3530       partial_device.has_id = false;
3531 
3532       if (tf_device_assignment.size() == 1) {
3533         // With a single replica don't copy any nodes just put the original
3534         // nodes into the image map. We leave the device placement alone, except
3535         // that we have to fill in the correct core for the host send and
3536         // receive nodes.
3537         for (Node* node : oc_cluster_nodes) {
3538           (*node_images)[node] = {node};
3539           node->AddAttr(kXlaReplicaIdAttrName, 0);
3540           if (HasNodeAttr(node->def(), kXlaHasHostTransferAttrName)) {
3541             TF_RETURN_IF_ERROR(
3542                 SetNodeDeviceForTPUCommunication(tpu_device, DEVICE_CPU, node));
3543           }
3544         }
3545       } else {
3546         // Iterate over outside_compilation clusters in this computation, adding
3547         // all the nodes with appropriate device assignments.
3548         TF_RETURN_IF_ERROR(
3549             CopyOutsideCompilationNodes(i, oc_cluster_nodes, tpu_device,
3550                                         partial_device, node_images, graph));
3551       }
3552     }
3553   }
3554   return Status::OK();
3555 }
3556 
CopyOutsideCompilationEdges(const std::vector<Node * > & outside_compilation_nodes,const NodeToNodeReplicasMap & node_images,const std::unordered_map<string,Node * > outside_compilation_inputs,Graph * graph)3557 /* static */ Status DistributedTPURewritePass::CopyOutsideCompilationEdges(
3558     const std::vector<Node*>& outside_compilation_nodes,
3559     const NodeToNodeReplicasMap& node_images,
3560     const std::unordered_map<string, Node*> outside_compilation_inputs,
3561     Graph* graph) {
3562   for (Node* node : outside_compilation_nodes) {
3563     const auto& images = node_images.at(node);
3564     // Make a copy of all edges and iterate on "in_edges", because we might
3565     // remove edges when iteratating through them.
3566     std::vector<const Edge*> in_edges(node->in_edges().begin(),
3567                                       node->in_edges().end());
3568     for (const Edge* edge : in_edges) {
3569       Node* src = edge->src();
3570       const auto iter = node_images.find(src);
3571       if (iter == node_images.end()) {
3572         if (images.size() > 1) {
3573           // The source node is a 'normal' node not part of any
3574           // rewrite. Broadcast the value to all replicas. (If images.size() ==
3575           // 1 the cluster is not replicated and we can leave the original edge
3576           // in place.)
3577           for (Node* dst : images) {
3578             graph->AddEdge(src, edge->src_output(), dst, edge->dst_input());
3579           }
3580         }
3581         continue;
3582       }
3583 
3584       // The source node is a replicated outside_compilation node.
3585       const auto& src_images = iter->second;
3586       if (src_images.size() != images.size()) {
3587         return errors::InvalidArgument(
3588             "Graph contains an edge from node ", src->name(),
3589             " in an outside_compilation block replicated ", src_images.size(),
3590             " ways to node ", node->name(),
3591             " in an outside_compilation block replicated ", images.size(),
3592             " ways. Replication factors must match. Leave a comment on "
3593             "tracking bug b/76419636 if you need this to be supported.");
3594       }
3595       bool is_lifted_arg;
3596       string outside_compilation_cluster;
3597       if (GetNodeAttr(src->def(), kXlaIsLiftedArgAttrName, &is_lifted_arg)
3598               .ok() &&
3599           GetNodeAttr(src->def(), kOutsideCompilationAttr,
3600                       &outside_compilation_cluster)
3601               .ok()) {
3602         const auto input_iter =
3603             outside_compilation_inputs.find(outside_compilation_cluster);
3604         TF_RET_CHECK(input_iter != outside_compilation_inputs.end());
3605         TF_RET_CHECK(input_iter->second->type_string() == "IdentityN");
3606         int dst_input = edge->dst_input();
3607         if (src_images.size() == 1) {
3608           graph->RemoveEdge(edge);
3609         }
3610         for (int i = 0; i < src_images.size(); ++i) {
3611           graph->AddEdge(input_iter->second, i, images[i], dst_input);
3612         }
3613         continue;
3614       }
3615 
3616       bool is_placeholder_for_arg;
3617       string outside_compilation_input_attr;
3618       if (GetNodeAttr(src->def(), kXlaIsPlaceholderForArg,
3619                       &is_placeholder_for_arg)
3620               .ok() &&
3621           GetNodeAttr(src->def(), kXlaOutsideCompilationInputsAttrName,
3622                       &outside_compilation_input_attr)
3623               .ok()) {
3624         const auto input_iter =
3625             outside_compilation_inputs.find(outside_compilation_input_attr);
3626         TF_RET_CHECK(input_iter != outside_compilation_inputs.end());
3627         TF_RET_CHECK(input_iter->second->type_string() == "IdentityN");
3628         int dst_input = edge->dst_input();
3629         if (src_images.size() == 1) {
3630           graph->RemoveEdge(edge);
3631         }
3632         for (int i = 0; i < src_images.size(); ++i) {
3633           graph->AddEdge(input_iter->second, i, images[i], dst_input);
3634         }
3635         continue;
3636       }
3637 
3638       if (images.size() > 1) {
3639         // If images.size() == 1 neither cluster is replicated and we can
3640         // leave the original edges in place.
3641         for (int i = 0; i < src_images.size(); ++i) {
3642           graph->AddEdge(src_images[i], edge->src_output(), images[i],
3643                          edge->dst_input());
3644         }
3645       }
3646     }
3647     for (const Edge* edge : node->out_edges()) {
3648       Node* dst = edge->dst();
3649       const auto iter = node_images.find(dst);
3650       if (iter == node_images.end()) {
3651         // The source node is a 'normal' node not part of any rewrite.
3652         if (edge->IsControlEdge()) {
3653           // Make the dst node have a control dependency on every replica.
3654           if (images.size() > 1) {
3655             for (int i = 0; i < images.size(); ++i) {
3656               graph->AddControlEdge(images[i], dst);
3657             }
3658           }
3659           // else the cluster is not replicated so we can leave the original
3660           // edge in place.
3661         } else {
3662           // The edge
3663           // is only valid if the outside_compilation block is not replicated.
3664           if (images.size() > 1) {
3665             return errors::InvalidArgument(
3666                 "Graph contains an edge from node ", node->name(),
3667                 " in an outside_compilation block replicated ", images.size(),
3668                 " ways to node ", dst->name(),
3669                 " that is not part of an outside_compilation block. Edges from "
3670                 "outside_compilation to regular graph nodes are only supported "
3671                 "for replication factors of 1. Leave a comment on tracking bug "
3672                 "b/76419636 if you need this to be supported.");
3673           }
3674           // else the cluster is not replicated so we can leave the original
3675           // edge in place.
3676         }
3677       }
3678       // The case where src and dst are both in node_images is covered elsewhere
3679       // when iterating over in_edges of dst.
3680     }
3681   }
3682   return Status::OK();
3683 }
3684 
ReplicateOutsideCompilationEdges(const OutsideCompilationNodeMap & outside_compilation_nodes,const NodeToNodeReplicasMap & node_images,const std::unordered_map<string,Node * > outside_compilation_inputs,Graph * graph)3685 /* static */ Status DistributedTPURewritePass::ReplicateOutsideCompilationEdges(
3686     const OutsideCompilationNodeMap& outside_compilation_nodes,
3687     const NodeToNodeReplicasMap& node_images,
3688     const std::unordered_map<string, Node*> outside_compilation_inputs,
3689     Graph* graph) {
3690   for (const auto& oc_cluster_iter : outside_compilation_nodes) {
3691     TF_RETURN_IF_ERROR(
3692         CopyOutsideCompilationEdges(oc_cluster_iter.second, node_images,
3693                                     outside_compilation_inputs, graph));
3694   }
3695   return Status::OK();
3696 }
3697 
RemoveOutsideCompilationNodes(const NodeToNodeReplicasMap & node_images,Graph * graph)3698 /* static */ Status DistributedTPURewritePass::RemoveOutsideCompilationNodes(
3699     const NodeToNodeReplicasMap& node_images, Graph* graph) {
3700   for (const auto& iter : node_images) {
3701     if (iter.second.size() > 1) {
3702       // The cluster was replicated so remove the original node.
3703       Node* node = iter.first;
3704       graph->RemoveNode(node);
3705     }
3706   }
3707   return Status::OK();
3708 }
3709 
3710 /* static */ Status
LowerOutsideCompilationFunctionalNodes(Graph * g,const FunctionLibraryDefinition & flib_def,const TPUReplicateDeviceNamesMapping & tpu_replicate_device_names_mapping)3711 DistributedTPURewritePass::LowerOutsideCompilationFunctionalNodes(
3712     Graph* g, const FunctionLibraryDefinition& flib_def,
3713     const TPUReplicateDeviceNamesMapping& tpu_replicate_device_names_mapping) {
3714   bool modified = false;
3715   do {
3716     std::vector<Node*> nodes_to_lower;
3717     for (Node* n : g->op_nodes()) {
3718       if (!HasNodeAttr(n->def(), kOutsideCompilationAttr)) {
3719         continue;
3720       }
3721 
3722       if (n->IsWhileNode() || n->IsIfNode() || IsFunctionCall(flib_def, *n)) {
3723         // Only lower functional ops with DT_RESOURCE input, because otherwise
3724         // placer will complain. For normal cases, lowering will cause slowdown
3725         // when related functions are huge (b/139037679).
3726         bool has_resource_input = false;
3727         for (const Edge* e : n->in_edges()) {
3728           if (!e->IsControlEdge() &&
3729               e->src()->output_type(e->src_output()) == DT_RESOURCE) {
3730             has_resource_input = true;
3731             break;
3732           }
3733         }
3734         if (has_resource_input) {
3735           nodes_to_lower.push_back(n);
3736         }
3737       }
3738     }
3739 
3740     modified = !nodes_to_lower.empty();
3741 
3742     auto lower_functional_node = [&flib_def, &g](Node* n) -> Status {
3743       // Clear device assignment. Otherwise all lowered nodes will have
3744       // device assignment, which is not what we want.
3745       n->set_requested_device("");
3746 
3747       int replica_id;
3748       TF_RETURN_IF_ERROR(
3749           GetNodeAttr(n->def(), kXlaReplicaIdAttrName, &replica_id));
3750 
3751       string outside_compilation_attr;
3752       TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kOutsideCompilationAttr,
3753                                      &outside_compilation_attr));
3754 
3755       // There are two different kinds of functional outside compilation nodes:
3756       // 1. Nodes that are in outside compilation blocks already. They are
3757       //    generated by FunctionalizeControlFlowForXlaPass, and only have
3758       //    attribute kOutsideCompilationAttr.
3759       // 2. Mirrored control flow built for outside compilation in functional
3760       //    nodes. They are generated by ExtractOutsideCompilationPass, and have
3761       //    both kOutsideCompilationAttr and kXlaHasHostTransferAttrName.
3762       // When lowering them, they need to be treated differently.
3763       // For 1), their body functions are always V1 functions written by users,
3764       // and their "control outputs" are control inputs of _Retval nodes. They
3765       // should be lowered as V1 functions.
3766       // For 2), we always add necessary "control outputs"
3767       // (_XlaRecvAtHost/_XlaSendAtHost nodes) to "control_ret" field in their
3768       // FunctionDef's. They should be lowered as V2 functions.
3769       bool is_host_side_mirrored_control_flow =
3770           HasNodeAttr(n->def(), kXlaHasHostTransferAttrName);
3771 
3772       int num_node_ids = g->num_node_ids();
3773       bool is_call_node = IsFunctionCall(flib_def, *n);
3774       if (n->IsWhileNode()) {
3775         TF_RETURN_IF_ERROR(RewriteWhileNode(n, g,
3776                                             /*keep_node_fetchable=*/false));
3777       } else if (n->IsIfNode()) {
3778         TF_RETURN_IF_ERROR(RewriteIfNode(n, g, /*keep_node_fetchable=*/false));
3779       } else {
3780         TF_RET_CHECK(is_call_node);
3781         // See comments for "is_host_side_mirrored_control_flow" above.
3782         // If this is a node that's in outside compilation block, lower it as
3783         // V1 function. This is controlled by removing
3784         // kLowerAsMultiDeviceFunctionAttr from the node.
3785         if (!is_host_side_mirrored_control_flow) {
3786           n->ClearAttr(LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr);
3787         } else {
3788           n->ClearAttr(LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr);
3789           n->AddAttr(LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr,
3790                      true);
3791         }
3792         TF_RETURN_IF_ERROR(
3793             RewriteFunctionCallNode(n, g, flib_def,
3794                                     /*keep_caller_fetchable=*/false));
3795       }
3796 
3797       for (int i = num_node_ids; i < g->num_node_ids(); i++) {
3798         Node* node = g->FindNodeId(i);
3799         if (!node) {
3800           continue;
3801         }
3802 
3803         if (!is_call_node && is_host_side_mirrored_control_flow &&
3804             IsFunctionCall(flib_def, *node)) {
3805           // For If/While nodes, if they are host side mirrored control flow,
3806           // mark their body function calls with kXlaHasHostTransferAttrName
3807           // attribute to make sure we lower them as V2 function.
3808           node->AddAttr(kXlaHasHostTransferAttrName, true);
3809         }
3810 
3811         if (IsFunctionCall(flib_def, *node) || node->IsWhileNode() ||
3812             node->IsIfNode()) {
3813           // Set kOutsideCompilationAttr attribute so we lower these
3814           // nested function call nodes later.
3815           node->AddAttr(kOutsideCompilationAttr, outside_compilation_attr);
3816           // Set kXlaReplicaIdAttrName attribute so we know replica id when we
3817           // lower this function call node.
3818           node->AddAttr(kXlaReplicaIdAttrName, replica_id);
3819         } else if (node->type_string() == "_XlaRecvAtHost" ||
3820                    node->type_string() == "_XlaSendFromHost") {
3821           // For "_XlaRecvAtHost" and "_XlaSendFromHost" nodes, make sure they
3822           // have kXlaReplicaIdAttrName attribute so later we know which host
3823           // device to assign.
3824           node->AddAttr(kXlaReplicaIdAttrName, replica_id);
3825         }
3826       }
3827       return Status::OK();
3828     };
3829 
3830     for (Node* n : nodes_to_lower) {
3831       TF_RETURN_IF_ERROR(lower_functional_node(n));
3832     }
3833   } while (modified);
3834 
3835   // Set device for all _XlaRecvAtHost and _XlaSendFromHost nodes.
3836   for (Node* n : g->op_nodes()) {
3837     if (n->type_string() != "_XlaRecvAtHost" &&
3838         n->type_string() != "_XlaSendFromHost") {
3839       continue;
3840     }
3841 
3842     string replicate;
3843     TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kTPUReplicateAttr, &replicate));
3844     auto iter = tpu_replicate_device_names_mapping.find(replicate);
3845     TF_RET_CHECK(iter != tpu_replicate_device_names_mapping.end());
3846     const auto& tpu_device_names = iter->second;
3847 
3848     int replica_id;
3849     TF_RETURN_IF_ERROR(
3850         GetNodeAttr(n->def(), kXlaReplicaIdAttrName, &replica_id));
3851     TF_RET_CHECK(replica_id < tpu_device_names.size());
3852     const string& tpu_device_name = tpu_device_names[replica_id][0];
3853     string host_device_name;
3854     TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
3855         tpu_device_name, &host_device_name));
3856     n->set_assigned_device_name(host_device_name);
3857     // We may run TPU rewrite passes again on the subgraphs of the resulting
3858     // graph. Clear kTPUReplicateAttr and kOutsideCompilationAttr for
3859     // "_XlaRecvAtHost" nodes and "_XlaSendFromHost" nodes, in order to make
3860     // sure that TPU rewrite passes take no effect on host-side subgraphs for
3861     // outside compilation.
3862     n->ClearAttr(kTPUReplicateAttr);
3863     n->ClearAttr(kOutsideCompilationAttr);
3864   }
3865 
3866   // Remove IdentityN nodes generated for outside compilation. IdentityN is
3867   // exempt from resource edge colocation, but here we do need input and output
3868   // for these IdentityN nodes to be colocated.
3869   std::vector<Node*> identityn_nodes;
3870   for (Node* n : g->op_nodes()) {
3871     if (n->type_string() == "IdentityN" &&
3872         HasNodeAttr(n->def(), kXlaOutsideCompilationInputsAttrName)) {
3873       identityn_nodes.push_back(n);
3874     }
3875   }
3876   for (Node* n : identityn_nodes) {
3877     std::vector<const Edge*> out_edges(n->out_edges().begin(),
3878                                        n->out_edges().end());
3879     for (const Edge* e : out_edges) {
3880       if (e->IsControlEdge()) {
3881         continue;
3882       }
3883 
3884       int src_output = e->src_output();
3885       const Edge* input_edge;
3886       TF_RETURN_IF_ERROR(n->input_edge(src_output, &input_edge));
3887       Node* dst = e->dst();
3888       int dst_input = e->dst_input();
3889       g->RemoveEdge(e);
3890       g->AddEdge(input_edge->src(), input_edge->src_output(), dst, dst_input);
3891     }
3892     g->RemoveNode(n);
3893   }
3894 
3895   return Status::OK();
3896 }
3897 
ParseHostComputeCores(const Node & replicate_node,const OutsideCompilationNodeMap & outside_compilation_nodes,HostComputeCoreMap * host_compute_core)3898 /* static */ Status DistributedTPURewritePass::ParseHostComputeCores(
3899     const Node& replicate_node,
3900     const OutsideCompilationNodeMap& outside_compilation_nodes,
3901     HostComputeCoreMap* host_compute_core) {
3902   std::vector<string> hc_core_string;
3903   TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), "host_compute_core",
3904                                  &hc_core_string));
3905   TF_RETURN_IF_ERROR(
3906       ParseHostComputeCoreList(hc_core_string, host_compute_core));
3907   for (const auto& iter : outside_compilation_nodes) {
3908     const string& oc_cluster_name = iter.first;
3909     if (host_compute_core->find(oc_cluster_name) == host_compute_core->end()) {
3910       // By default put host compute Ops on replicated core 0.
3911       (*host_compute_core)[oc_cluster_name] = 0;
3912     }
3913   }
3914   return Status::OK();
3915 }
3916 
GetDeviceTopology(const DeviceSet & device_set,const Node & replicate_node,int * num_replicas,int * num_cores_per_replica,int * num_tasks,std::vector<std::vector<string>> * tf_device_assignment,std::unique_ptr<xla::DeviceAssignment> * xla_device_assignment,string * tpu_compilation_device)3917 /* static */ Status DistributedTPURewritePass::GetDeviceTopology(
3918     const DeviceSet& device_set, const Node& replicate_node, int* num_replicas,
3919     int* num_cores_per_replica, int* num_tasks,
3920     std::vector<std::vector<string>>* tf_device_assignment,
3921     std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment,
3922     string* tpu_compilation_device) {
3923   TF_RETURN_IF_ERROR(
3924       GetNodeAttr(replicate_node.attrs(), "num_replicas", num_replicas));
3925   if (*num_replicas < 1) {
3926     return errors::InvalidArgument("num_replicas must be >= 1, got ",
3927                                    *num_replicas);
3928   }
3929 
3930   // Find the set of TPU devices in the TF job.
3931   // Indexed by [task number][tpu device number].
3932   std::vector<std::vector<Device*>> tpu_devices;
3933   int num_tpus_per_task;
3934   TF_RETURN_IF_ERROR(GetTPUDeviceNames(replicate_node.requested_device(),
3935                                        device_set, tpu_compilation_device,
3936                                        &num_tpus_per_task, &tpu_devices));
3937   *num_tasks = tpu_devices.size();
3938 
3939   string topology;
3940   TF_RETURN_IF_ERROR(
3941       GetNodeAttr(replicate_node.attrs(), "topology", &topology));
3942   TF_RETURN_IF_ERROR(GetNodeAttr(
3943       replicate_node.attrs(), "num_cores_per_replica", num_cores_per_replica));
3944   std::vector<int> device_assignment;
3945   TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), "device_assignment",
3946                                  &device_assignment));
3947 
3948   // TODO(cwhipkey): since we can control multiple pods of different shapes
3949   // from a single worker, it may be desirable to propagate the remote device
3950   // information around (e.g., in DeviceAttributes). This can lead to the mesh
3951   // topology proto being leaked to cloud TPU users (e.g. through GetStatus
3952   // calls); this may be okay, but to be conservative, just assume that the
3953   // master session has the proper flags set.
3954 
3955   // We do not initialize platform right now, but we can still retrieve the
3956   // TPU topology even with an uninitialized platform.
3957   auto* tpu_platform = tpu::TpuPlatformInterface::GetRegisteredPlatform(
3958       /*initialize_platform=*/false);
3959   TF_RET_CHECK(tpu_platform);
3960   tpu::TpuTopologyExternal tpu_topology(tpu_platform->GetTopologyPtr());
3961   TF_RET_CHECK(num_tpus_per_task ==
3962                tpu_topology.LogicalDevicesPerHost(kTensorCore));
3963   TF_RETURN_IF_ERROR(BuildDeviceAssignment(
3964       tpu_topology, num_tpus_per_task, tpu_devices, *num_replicas,
3965       *num_cores_per_replica, topology, device_assignment, tf_device_assignment,
3966       xla_device_assignment));
3967 
3968   return Status::OK();
3969 }
3970 
GetIOTypes(int num_replicas,const Node & replicate_node,FunctionLibraryRuntime * flr,Graph * graph,NameRangeMap * input_name_map,const NameAttrList ** function,std::unique_ptr<Graph> * computation,DataTypeVector * arg_types,DataTypeVector * retval_types,ParameterInfo * params_info)3971 /* static */ Status DistributedTPURewritePass::GetIOTypes(
3972     int num_replicas, const Node& replicate_node, FunctionLibraryRuntime* flr,
3973     Graph* graph, NameRangeMap* input_name_map, const NameAttrList** function,
3974     std::unique_ptr<Graph>* computation, DataTypeVector* arg_types,
3975     DataTypeVector* retval_types, ParameterInfo* params_info) {
3976   DataTypeVector input_types, broadcast_input_types, guaranteed_constant_types;
3977   TF_RETURN_IF_ERROR(
3978       GetNodeAttr(replicate_node.attrs(), "Tinputs", &input_types));
3979   TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), "Tbroadcast_inputs",
3980                                  &broadcast_input_types));
3981   TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(),
3982                                  "Tguaranteed_constants",
3983                                  &guaranteed_constant_types));
3984   int num_distributed_vars;
3985   TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(),
3986                                  "num_distributed_variables",
3987                                  &num_distributed_vars));
3988   const int num_per_replica_inputs = input_types.size() - num_distributed_vars;
3989 
3990   if (num_per_replica_inputs % num_replicas != 0) {
3991     return errors::InvalidArgument(
3992         "Number of inputs to TPUReplicate (", num_per_replica_inputs,
3993         ") is not divisible by the number of replicas (", num_replicas, ").");
3994   }
3995 
3996   int num_variables;
3997   TF_RETURN_IF_ERROR(
3998       GetNodeAttr(replicate_node.attrs(), "NumVariables", &num_variables));
3999 
4000   NameRangeMap output_name_map;
4001   TF_RETURN_IF_ERROR(NameRangesForNode(replicate_node, replicate_node.op_def(),
4002                                        input_name_map, &output_name_map));
4003 
4004   TF_RETURN_IF_ERROR(
4005       GetNodeAttr(replicate_node.attrs(), "computation", function));
4006 
4007   *computation = absl::make_unique<Graph>(graph->op_registry());
4008   TF_RETURN_IF_ERROR(GetComputationForTPUReplicateOp(
4009       **function, flr, computation->get(), arg_types, retval_types));
4010 
4011   *params_info = ParameterInfo(
4012       num_replicas, num_per_replica_inputs / num_replicas, num_distributed_vars,
4013       broadcast_input_types.size(), num_variables,
4014       guaranteed_constant_types.size(), retval_types->size());
4015 
4016   if (arg_types->size() != params_info->NumInputsToEachReplica()) {
4017     return errors::InvalidArgument(
4018         "Computation argument to TPUReplicate has wrong number of "
4019         "arguments. Expected ",
4020         params_info->NumInputsToEachReplica(), " inputs, got ",
4021         arg_types->size());
4022   }
4023   if (replicate_node.num_outputs() != params_info->NumOutputsToHost()) {
4024     return errors::InvalidArgument(
4025         "Wrong number of outputs from TPUReplicate. Expected ",
4026         params_info->NumOutputsToHost(), " outputs, got ",
4027         replicate_node.num_outputs());
4028   }
4029   if (enable_cross_replica_sharding_mirrored_variables_) {
4030     std::vector<int> mirrored_variable_indices;
4031     TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(),
4032                                    TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR,
4033                                    &mirrored_variable_indices));
4034     for (int index : mirrored_variable_indices) {
4035       TF_RET_CHECK(params_info->IsPerReplicaArg(index) ||
4036                    params_info->IsDistributedArg(index))
4037           << "Mirrored variables not categorized as per-replica arguments, "
4038              "index: "
4039           << index;
4040       params_info->mutable_mirrored_variable_indices()->insert(index);
4041     }
4042   }
4043   return Status::OK();
4044 }
4045 
BuildSequencingNodes(const string & tpu_compilation_device,const Node & replicate_node,Graph * graph,Node ** host_transfer_sequencer,Node ** control_before,Node ** control_after)4046 /* static */ Status DistributedTPURewritePass::BuildSequencingNodes(
4047     const string& tpu_compilation_device, const Node& replicate_node,
4048     Graph* graph, Node** host_transfer_sequencer, Node** control_before,
4049     Node** control_after) {
4050   *host_transfer_sequencer = nullptr;
4051 
4052   TF_RETURN_IF_ERROR(
4053       BuildNoopNode(replicate_node,
4054                     graph->NewName(strings::StrCat(replicate_node.name(), "/",
4055                                                    "control_before")),
4056                     /*device=*/"", graph, control_before));
4057   for (const Edge* e : replicate_node.in_edges()) {
4058     if (!e->IsControlEdge()) {
4059       continue;
4060     }
4061     Node* predecessor = e->src();
4062     if (predecessor->IsSource()) continue;
4063     if (predecessor->type_string() == "NoOp" &&
4064         predecessor->attrs().Find("_xla_host_transfer_sequencer") != nullptr) {
4065       // The node is the sequencer for host transfer operations. Its control
4066       // dependency needs to be placed after the execute node, not before.
4067       if (*host_transfer_sequencer != nullptr) {
4068         return errors::Internal("Replicate node ", replicate_node.name(),
4069                                 " has two transfer sequencer nodes: ",
4070                                 (*host_transfer_sequencer)->name(), " and ",
4071                                 predecessor->name());
4072       }
4073       // Set the correct device to match the other sequencing nodes.
4074       predecessor->set_assigned_device_name(tpu_compilation_device);
4075       *host_transfer_sequencer = predecessor;
4076     } else {
4077       graph->AddControlEdge(predecessor, *control_before);
4078     }
4079   }
4080 
4081   TF_RETURN_IF_ERROR(
4082       BuildNoopNode(replicate_node,
4083                     graph->NewName(strings::StrCat(replicate_node.name(), "/",
4084                                                    "control_after")),
4085                     /*device=*/tpu_compilation_device, graph, control_after));
4086   for (Node* successor : replicate_node.out_nodes()) {
4087     if (successor->attrs().Find("_xla_tail_outside_compilation") != nullptr) {
4088       graph->AddControlEdge(successor, *control_after);
4089     } else {
4090       graph->AddControlEdge(*control_after, successor);
4091     }
4092   }
4093   return Status::OK();
4094 }
4095 
DealWithConstantsAndVariables(const Node & replicate_node,const NameRangeMap & input_name_map,Graph * graph,Node * host_transfer_sequencer,Node * control_before,Node * control_after,absl::Span<const VariableInput> variable_nodes,std::vector<Node * > * guaranteed_constant_nodes,std::vector<Node * > * variable_reads)4096 /* static */ Status DistributedTPURewritePass::DealWithConstantsAndVariables(
4097     const Node& replicate_node, const NameRangeMap& input_name_map,
4098     Graph* graph, Node* host_transfer_sequencer, Node* control_before,
4099     Node* control_after, absl::Span<const VariableInput> variable_nodes,
4100     std::vector<Node*>* guaranteed_constant_nodes,
4101     std::vector<Node*>* variable_reads) {
4102   TF_RETURN_IF_ERROR(FindGuaranteedConstantInputs(
4103       replicate_node, input_name_map, guaranteed_constant_nodes));
4104 
4105   TF_RETURN_IF_ERROR(BuildVariableReads(variable_nodes, control_before, graph,
4106                                         variable_reads));
4107   // Add the control dependency from host transfer nodes.
4108   if (host_transfer_sequencer != nullptr) {
4109     graph->AddControlEdge(host_transfer_sequencer, control_after);
4110   }
4111   return Status::OK();
4112 }
4113 
4114 /* static */ Status
BuildCompilationStatusReturnNodes(Node * replicate_node,Node * compile_node,Node ** control_after_compilation,Graph * graph)4115 DistributedTPURewritePass::BuildCompilationStatusReturnNodes(
4116     Node* replicate_node, Node* compile_node, Node** control_after_compilation,
4117     Graph* graph) {
4118   const Edge* compilation_edge = nullptr;
4119   for (const auto* e : replicate_node->out_edges()) {
4120     if (e->IsControlEdge() &&
4121         e->dst()->type_string() == "TPUCompilationResult") {
4122       TF_RET_CHECK(compilation_edge == nullptr)
4123           << "Multiple compilation result nodes attached to the same replicate "
4124              "cluster.";
4125       compilation_edge = e;
4126     }
4127   }
4128 
4129   // TODO(jpienaar): This should be checked by default, current tests not using
4130   // this are ones that use the "abort upon successful compile flag" which will
4131   // be removed. Leaving this in until then.
4132   if (compilation_edge != nullptr) {
4133     Node* compilation_status = compilation_edge->dst();
4134     const AttrValue* compile_status_cluster_attr =
4135         compilation_status->attrs().Find(kTPUCompilationResultAttr);
4136     TF_RET_CHECK(compile_status_cluster_attr != nullptr);
4137     const string& compile_status_cluster = compile_status_cluster_attr->s();
4138     TF_RET_CHECK(!compile_status_cluster.empty());
4139     const AttrValue* replicate_cluster_attr =
4140         replicate_node->attrs().Find(kTPUReplicateAttr);
4141     TF_RET_CHECK(replicate_cluster_attr != nullptr);
4142     const string& replicate_cluster = replicate_cluster_attr->s();
4143     TF_RET_CHECK(!replicate_cluster.empty());
4144     TF_RET_CHECK(compile_status_cluster == replicate_cluster);
4145 
4146     TF_RETURN_IF_ERROR(
4147         ReplaceCompilationResultNodeWithIdentity(graph, &compilation_status));
4148     graph->AddEdge(compile_node, 0, compilation_status, 0);
4149   }
4150 
4151   NodeDef def;
4152   def.set_name(UniqueNodeName("tpu_compile_succeeded_assert", graph));
4153   // Create an op to assert that compilation succeeded. The alternative would
4154   // have been to have each execute op check and return an error.
4155   def.set_op("TPUCompileSucceededAssert");
4156   MergeDebugInfo(NodeDebugInfo(replicate_node->def()), &def);
4157   Status status;
4158   Node* compile_succeeded = graph->AddNode(def, &status);
4159   compile_succeeded->set_assigned_device_name(
4160       compile_node->assigned_device_name());
4161   TF_RETURN_IF_ERROR(status);
4162   graph->AddEdge(compile_node, 0, compile_succeeded, 0);
4163 
4164   // Build a sequencing node for when compilation has completed.
4165   TF_RETURN_IF_ERROR(
4166       BuildNoopNode(*replicate_node,
4167                     graph->NewName(strings::StrCat(compile_node->name(), "/",
4168                                                    "after_compilation")),
4169                     /*device=*/"", graph, control_after_compilation));
4170   graph->AddControlEdge(compile_succeeded, *control_after_compilation);
4171 
4172   return Status::OK();
4173 }
4174 
4175 // Updates the head and tail outside compiled nodes so that nodes have the
4176 // correct device and removes the replication and outside compilation attributes
4177 // so that these nodes do not trigger further graph optimization passes.
UpdateHeadTailOutsideCompilation(const std::vector<std::vector<string>> & tf_device_assignment,const std::vector<Node * > & head_tail_outside_compilation_nodes)4178 /* static */ Status DistributedTPURewritePass::UpdateHeadTailOutsideCompilation(
4179     const std::vector<std::vector<string>>& tf_device_assignment,
4180     const std::vector<Node*>& head_tail_outside_compilation_nodes) {
4181   for (Node* node : head_tail_outside_compilation_nodes) {
4182     int replica_id;
4183     TF_RETURN_IF_ERROR(
4184         GetNodeAttr(node->def(), kXlaReplicaIdAttrName, &replica_id));
4185     // Since we set the device, this will now run on a task other than 0. We
4186     // clear the two following attributes so that we don't trigger encapsulation
4187     // again on the remote host (which will fail due to a missing
4188     // _TPUReplicateMetadata node for the cluster).
4189     for (const Edge* e : node->in_edges()) {
4190       // Resource consuming ops should colocate with its resource input.
4191       if (e->src()->IsArg() &&
4192           e->src()->output_type(e->src_output()) == DT_RESOURCE) {
4193         node->set_requested_device(tf_device_assignment[replica_id][0]);
4194       }
4195     }
4196     if (node->requested_device().empty()) {
4197       string cpu_device;
4198       TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
4199           tf_device_assignment[replica_id][0], &cpu_device));
4200       node->set_requested_device(cpu_device);
4201     }
4202     node->ClearAttr(kTPUReplicateAttr);
4203     node->ClearAttr(kOutsideCompilationAttr);
4204   }
4205   return Status::OK();
4206 }
4207 
4208 // Performs the rewrite on a single TPUReplicate node.
RewriteTPUReplicateNode(const string & session_handle,const DeviceSet & device_set,Node * replicate_node,FunctionLibraryDefinition * flib_def,FunctionLibraryRuntime * flr,Node * host_compute_key_placeholder_node,const OutsideCompilationNodeMap & outside_compilation_nodes,const std::vector<Node * > & head_tail_outside_compilation_nodes,NodeToNodeReplicasMap * outside_compilation_node_images,Graph * graph,const GraphShapeInfo & shape_info,TPUReplicateDeviceNamesMapping * tpu_replicate_device_names_mapping,int64 autotuner_thresh)4209 /* static */ Status DistributedTPURewritePass::RewriteTPUReplicateNode(
4210     const string& session_handle, const DeviceSet& device_set,
4211     Node* replicate_node, FunctionLibraryDefinition* flib_def,
4212     FunctionLibraryRuntime* flr, Node* host_compute_key_placeholder_node,
4213     const OutsideCompilationNodeMap& outside_compilation_nodes,
4214     const std::vector<Node*>& head_tail_outside_compilation_nodes,
4215     NodeToNodeReplicasMap* outside_compilation_node_images, Graph* graph,
4216     const GraphShapeInfo& shape_info,
4217     TPUReplicateDeviceNamesMapping* tpu_replicate_device_names_mapping,
4218     int64 autotuner_thresh) {
4219   VLOG(2) << "Rewriting node " << replicate_node->name();
4220 
4221   // num_replicas and num_cores_per_replica are the 'virtual' replicas (copies
4222   // of the computation) and cores (virtual cores within computations) specified
4223   // by the user. They will be mapped to physical TPU cores below.
4224   int num_replicas;
4225   int num_cores_per_replica;
4226   int num_tasks;  // Number of tasks.
4227   std::vector<std::vector<string>> tf_device_assignment;
4228   std::unique_ptr<xla::DeviceAssignment> xla_device_assignment;
4229   string tpu_compilation_device;
4230   TF_RETURN_IF_ERROR(GetDeviceTopology(
4231       device_set, *replicate_node, &num_replicas, &num_cores_per_replica,
4232       &num_tasks, &tf_device_assignment, &xla_device_assignment,
4233       &tpu_compilation_device));
4234 
4235   TF_RETURN_IF_ERROR(UpdateHeadTailOutsideCompilation(
4236       tf_device_assignment, head_tail_outside_compilation_nodes));
4237 
4238   string replicate;
4239   TF_RETURN_IF_ERROR(
4240       GetNodeAttr(replicate_node->def(), kTPUReplicateAttr, &replicate));
4241   tpu_replicate_device_names_mapping->emplace(replicate, tf_device_assignment);
4242 
4243   NameRangeMap input_name_map;
4244   const NameAttrList* function;
4245   std::unique_ptr<Graph> computation;
4246   DataTypeVector arg_types, retval_types;
4247   ParameterInfo params_info;
4248   TF_RETURN_IF_ERROR(GetIOTypes(num_replicas, *replicate_node, flr, graph,
4249                                 &input_name_map, &function, &computation,
4250                                 &arg_types, &retval_types, &params_info));
4251 
4252   std::vector<InferredShape> arg_shapes, retval_shapes;
4253   TF_RETURN_IF_ERROR(GetArgAndRetvalShapes(
4254       shape_info, *replicate_node, params_info, &arg_shapes, &retval_shapes));
4255 
4256   TF_RETURN_IF_ERROR(ValidateCoreNumbers(*computation, num_cores_per_replica));
4257 
4258   std::vector<xla::OpSharding> arg_sharding;
4259   std::vector<bool> arg_fast_mem;
4260   std::vector<std::string> arg_names;
4261   std::vector<xla::OpSharding> retval_sharding;
4262   TF_RETURN_IF_ERROR(AssignArgsAndRetvalsToCores(
4263       num_cores_per_replica, params_info, arg_types, arg_shapes, retval_types,
4264       retval_shapes, *computation, replicate_node, flr,
4265       allow_xla_spmd_partition_, &arg_sharding, &arg_fast_mem, &retval_sharding,
4266       &arg_names));
4267 
4268   VLOG(1) << DumpGraphToFile("distributed_tpu_graph_to_replicate", *computation,
4269                              flib_def);
4270 
4271   GraphDef graph_def;
4272   graph->ToGraphDef(&graph_def);
4273   FunctionLibraryDefinition reachable_functions =
4274       flib_def->ReachableDefinitions(graph_def);
4275   uint64 library_fingerprint;
4276 
4277   TF_RETURN_IF_ERROR(
4278       FingerprintFunctionLibrary(reachable_functions, &library_fingerprint));
4279   VLOG(1) << "Fingerprint functions: "
4280           << absl::StrJoin(reachable_functions.ListFunctionNames(), ", ");
4281   VLOG(1) << "library_fingerprint: " << library_fingerprint;
4282 
4283   // Builds trigger nodes that put barriers around the expansion of
4284   // TPUReplicate. In particular, we must guarantee:
4285   // a) variable reads happen after all predecessors of the original
4286   //    TPUReplicate.
4287   // b) variable writes happen before all successors of the original
4288   //    TPUReplicate.
4289   // c) all replicas execute, even if output tensors are only requested from
4290   //    a subset of replicas. This is necessary both to ensure that variable
4291   //    updates happen, but also Send/Recv will deadlock if only one half of
4292   //    the communicating pair runs.
4293   Node* host_transfer_sequencer;
4294   Node* control_before;
4295   Node* control_after;
4296   TF_RETURN_IF_ERROR(BuildSequencingNodes(
4297       tpu_compilation_device, *replicate_node, graph, &host_transfer_sequencer,
4298       &control_before, &control_after));
4299 
4300   // Build a vector of variable nodes that are inputs.
4301   std::vector<VariableInput> variable_inputs;
4302   TF_RETURN_IF_ERROR(
4303       FindVariableInputs(*replicate_node, input_name_map, &variable_inputs));
4304 
4305   std::vector<Node*> guaranteed_constant_nodes;
4306   std::vector<Node*> variable_reads;
4307   TF_RETURN_IF_ERROR(DealWithConstantsAndVariables(
4308       *replicate_node, input_name_map, graph, host_transfer_sequencer,
4309       control_before, control_after, variable_inputs,
4310       &guaranteed_constant_nodes, &variable_reads));
4311 
4312   // Builds Shape nodes that compute the dynamic shapes of arguments whose
4313   // shapes are not statically known.
4314   std::vector<Node*> dynamic_shape_nodes;
4315   TF_RETURN_IF_ERROR(BuildDynamicShapeNodes(*replicate_node, arg_shapes,
4316                                             params_info, variable_reads, graph,
4317                                             &dynamic_shape_nodes));
4318 
4319   // Builds a TPUCompile node that compiles `clusters` on `compile_device`.
4320   Node* compile_node;
4321   TF_RETURN_IF_ERROR(BuildCompileNode(
4322       replicate_node, *function, library_fingerprint, params_info, arg_shapes,
4323       arg_types, guaranteed_constant_nodes, session_handle, arg_sharding,
4324       arg_fast_mem, arg_names, retval_sharding, num_cores_per_replica,
4325       /*compile_device=*/tpu_compilation_device, xla_device_assignment.get(),
4326       dynamic_shape_nodes, graph, &compile_node, autotuner_thresh));
4327 
4328   // Compilation must be sequenced after the control node if the TPU computation
4329   // in a control-flow construct, such as a loop.
4330   graph->AddControlEdge(control_before, compile_node);
4331 
4332   Node* control_after_compilation;
4333   TF_RETURN_IF_ERROR(BuildCompilationStatusReturnNodes(
4334       replicate_node, compile_node, &control_after_compilation, graph));
4335 
4336   std::vector<VariableWrite> variable_writes;
4337   TF_RETURN_IF_ERROR(BuildExecuteNodes(
4338       params_info, num_tasks, num_cores_per_replica, *replicate_node, arg_names,
4339       arg_types, arg_shapes, retval_types, arg_sharding, retval_sharding,
4340       tf_device_assignment, compile_node, variable_reads,
4341       control_after_compilation, control_after, &variable_writes, graph));
4342   bool contains_resource_write_op =
4343       ContainsResourceWriteOp(*graph, reachable_functions);
4344 
4345   VLOG(2) << "contains_resource_write_op: " << contains_resource_write_op;
4346   // Skip conditional write if there is no resource writing op inside TPU
4347   // computation.
4348   if (contains_resource_write_op) {
4349     TF_RETURN_IF_ERROR(BuildVariableWrites(variable_inputs, control_after,
4350                                            variable_writes, graph));
4351   }
4352 
4353   if (host_compute_key_placeholder_node != nullptr) {
4354     TF_RETURN_IF_ERROR(ConnectHostComputeNodes(
4355         compile_node, host_compute_key_placeholder_node, graph));
4356   }
4357 
4358   HostComputeCoreMap host_compute_core;
4359   TF_RETURN_IF_ERROR(ParseHostComputeCores(
4360       *replicate_node, outside_compilation_nodes, &host_compute_core));
4361   TF_RETURN_IF_ERROR(ReplicateOutsideCompilationNodes(
4362       tf_device_assignment, host_compute_core, outside_compilation_nodes,
4363       outside_compilation_node_images, graph));
4364 
4365   graph->RemoveNode(replicate_node);
4366   return Status::OK();
4367 }
4368 
4369 // Adds sharded weight update optimization for each host training loop.
4370 //
4371 // For any host training loop found in the graph, TPUVariableReshard ops
4372 // are inserted to match the best layout chosen by the XLA.
4373 /* static */ Status
PerformHostTrainingLoopOptimization(Graph * graph,FunctionLibraryDefinition * flib_def,FunctionLibraryRuntime * flr)4374 DistributedTPURewritePass::PerformHostTrainingLoopOptimization(
4375     Graph* graph, FunctionLibraryDefinition* flib_def,
4376     FunctionLibraryRuntime* flr) {
4377   std::vector<tpu::HostTrainingLoopInfo> host_training_loops_info;
4378   Status s = tpu::DetectHostTrainingLoop(
4379       /*current_function_name=*/nullptr,
4380       /*current_function_attr=*/nullptr, flib_def, graph, flr,
4381       &host_training_loops_info);
4382   if (!s.ok()) {
4383     VLOG(2) << "No valid host training loop found. Skipping sharded weight "
4384             << "update optimization.";
4385     return Status::OK();
4386   }
4387 
4388   for (const auto& host_loop : host_training_loops_info) {
4389     const auto& function_name = host_loop.encapsulating_function_name;
4390     // `function_name` has value when host training loop is inside a
4391     // function call node. When host training loop is found inside a function
4392     // call node, then, in addition to adding TPUVariableReshard ops, function
4393     // library definition needs to be updated as well.
4394     if (function_name.has_value()) {
4395       const auto& function_attr = host_loop.encapsulating_function_attrs;
4396       TF_RET_CHECK(function_attr.has_value())
4397           << "Unable to find function attribute for function: "
4398           << *function_name;
4399 
4400       const FunctionDef* function_def = flib_def->Find(*function_name);
4401       TF_RET_CHECK(function_def)
4402           << "Unable to find function : " << *function_name;
4403 
4404       std::unique_ptr<FunctionBody> fbody;
4405       TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
4406           *function_def, AttrSlice(&function_attr.value()), flib_def, &fbody));
4407       Graph* function_graph = fbody->graph;
4408       TF_RETURN_IF_ERROR(tpu::AddReshardOp(function_graph, host_loop));
4409       TF_RETURN_IF_ERROR(UpdateFunctionLibDefinition(*function_graph,
4410                                                      *function_name, flib_def));
4411     } else {
4412       TF_RETURN_IF_ERROR(tpu::AddReshardOp(graph, host_loop));
4413     }
4414   }
4415   return Status::OK();
4416 }
4417 
PlaceUnassignedDeviceNodesOnTPUIfPossible(Graph * graph)4418 Status DistributedTPURewritePass::PlaceUnassignedDeviceNodesOnTPUIfPossible(
4419     Graph* graph) {
4420   ReverseDFS(*graph, {}, PlaceOpsOnTPU);
4421   return Status::OK();
4422 }
4423 
Run(const GraphOptimizationPassOptions & options)4424 Status DistributedTPURewritePass::Run(
4425     const GraphOptimizationPassOptions& options) {
4426   VLOG(1) << "DistributedTPURewritePass::Run";
4427 
4428   Graph* graph = options.graph->get();
4429 
4430   VLOG(1) << DumpGraphToFile("distributed_tpu_compilation_before", *graph,
4431                              options.flib_def);
4432 
4433   const auto* config = &options.session_options->config;
4434   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
4435       new ProcessFunctionLibraryRuntime(
4436           nullptr, options.session_options->env, config,
4437           graph->versions().producer(), options.flib_def,
4438           config ? config->graph_options().optimizer_options()
4439                  : OptimizerOptions()));
4440 
4441   FunctionLibraryRuntime* flr =
4442       pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
4443 
4444   // This pass can only run in the session master, which should fill
4445   // in the device_set field to the options.
4446   TF_RET_CHECK(options.device_set != nullptr);
4447 
4448   // Find all the replicate nodes before mutating the graph.
4449   std::vector<Node*> replicate_nodes;
4450   // Map from compiled subgraph cluster name to the outside_compilation nodes in
4451   // that cluster.
4452   std::map<string, OutsideCompilationNodeMap> outside_compilation_nodes;
4453   std::map<string, std::vector<Node*>> head_tail_outside_compilation_nodes;
4454   TF_RETURN_IF_ERROR(FindTaggedNodes(graph, &replicate_nodes,
4455                                      &outside_compilation_nodes,
4456                                      &head_tail_outside_compilation_nodes));
4457 
4458   if (replicate_nodes.empty()) {
4459     // Remove unused TPUPartitionedInput nodes.
4460     for (Node* n : graph->nodes()) {
4461       if (n->type_string() == kTPUPartitionedInput) graph->RemoveNode(n);
4462     }
4463     return Status::OK();
4464   }
4465 
4466   std::unordered_map<string, Node*> host_compute_key_placeholder_map;
4467   TF_RETURN_IF_ERROR(FindHostComputeKeyPlaceholderNodes(
4468       graph, replicate_nodes, &host_compute_key_placeholder_map));
4469 
4470   GraphShapeInfo shape_info;
4471   TF_RETURN_IF_ERROR(InferShapes(graph, /*arg_shapes=*/{},
4472                                  flr->GetFunctionLibraryDefinition(),
4473                                  &shape_info));
4474   int64 autotuner_thresh = options.session_options->config.experimental()
4475                                .xla_fusion_autotuner_thresh();
4476 
4477   NodeToNodeReplicasMap outside_compilation_node_images;
4478   TPUReplicateDeviceNamesMapping tpu_replicate_device_names_mapping;
4479   for (Node* node : replicate_nodes) {
4480     TF_RETURN_IF_ERROR(RewriteTPUReplicateNode(
4481         options.session_handle, *options.device_set, node, options.flib_def,
4482         flr, host_compute_key_placeholder_map[node->name()],
4483         outside_compilation_nodes[node->name()],
4484         head_tail_outside_compilation_nodes[node->name()],
4485         &outside_compilation_node_images, graph, shape_info,
4486         &tpu_replicate_device_names_mapping, autotuner_thresh));
4487   }
4488 
4489   // Place the padding nodes generated by dynamic padder on the correct devices.
4490   // TODO(rxsang): Place padding ops on TPUs in
4491   // PlaceUnassignedDeviceNodesOnTPUIfPossible function.
4492   TF_RETURN_IF_ERROR(SetPaddingNodesDevices(graph));
4493 
4494   std::unordered_map<string, Node*> outside_compilation_inputs;
4495   for (Node* n : graph->op_nodes()) {
4496     string lifted_arg_inputs_attr;
4497     if (n->type_string() == "IdentityN" &&
4498         GetNodeAttr(n->def(), kXlaOutsideCompilationInputsAttrName,
4499                     &lifted_arg_inputs_attr)
4500             .ok()) {
4501       outside_compilation_inputs[lifted_arg_inputs_attr] = n;
4502     }
4503   }
4504   for (const auto& iter : outside_compilation_nodes) {
4505     TF_RETURN_IF_ERROR(ReplicateOutsideCompilationEdges(
4506         iter.second, outside_compilation_node_images,
4507         outside_compilation_inputs, graph));
4508   }
4509   TF_RETURN_IF_ERROR(
4510       RemoveOutsideCompilationNodes(outside_compilation_node_images, graph));
4511   TF_RETURN_IF_ERROR(LowerOutsideCompilationFunctionalNodes(
4512       graph, *options.flib_def, tpu_replicate_device_names_mapping));
4513 
4514   TF_RETURN_IF_ERROR(PlaceUnassignedDeviceNodesOnTPUIfPossible(graph));
4515   VLOG(1) << DumpGraphToFile("distributed_tpu_compilation_after", *graph,
4516                              options.flib_def);
4517   VLOG(1) << "DistributedTPURewritePass::Run() finished";
4518 
4519   if (enable_cross_replica_sharding_mirrored_variables_) {
4520     VLOG(1) << "Starting host training loop optimization.";
4521     VLOG(1) << DumpGraphToFile("host_loop_optimization_before", *graph,
4522                                options.flib_def);
4523     TF_RETURN_IF_ERROR(
4524         PerformHostTrainingLoopOptimization(graph, options.flib_def, flr));
4525     VLOG(1) << DumpGraphToFile("host_loop_optimization_after", *graph,
4526                                options.flib_def);
4527     VLOG(1) << "Host training loop optimization finished.";
4528   }
4529 
4530   return Status::OK();
4531 }
4532 
4533 bool DistributedTPURewritePass::distribute_vars_ = false;
4534 bool DistributedTPURewritePass::allow_xla_spmd_partition_ = true;
4535 bool DistributedTPURewritePass::
4536     replicate_inputs_outputs_by_default_for_xla_spmd_ = false;
4537 bool DistributedTPURewritePass::
4538     enable_cross_replica_sharding_mirrored_variables_ = true;
4539 bool DistributedTPURewritePass::enable_automatic_model_parallelism_ = false;
4540 bool DistributedTPURewritePass::enable_xla_param_broadcast_ = false;
4541 
SetDistributedTpuRewritePassOptions(bool distribute_vars,bool allow_xla_spmd_partition,bool replicate_inputs_outputs_by_default_for_xla_spmd,bool enable_cross_replica_sharding_mirrored_variables,bool enable_automatic_model_parallelism,bool enable_xla_param_broadcast)4542 /*static*/ void DistributedTPURewritePass::SetDistributedTpuRewritePassOptions(
4543     bool distribute_vars, bool allow_xla_spmd_partition,
4544     bool replicate_inputs_outputs_by_default_for_xla_spmd,
4545     bool enable_cross_replica_sharding_mirrored_variables,
4546     bool enable_automatic_model_parallelism, bool enable_xla_param_broadcast) {
4547   distribute_vars_ = distribute_vars;
4548   allow_xla_spmd_partition_ = allow_xla_spmd_partition;
4549   replicate_inputs_outputs_by_default_for_xla_spmd_ =
4550       replicate_inputs_outputs_by_default_for_xla_spmd;
4551   enable_cross_replica_sharding_mirrored_variables_ =
4552       enable_cross_replica_sharding_mirrored_variables;
4553   enable_automatic_model_parallelism_ = enable_automatic_model_parallelism;
4554   enable_xla_param_broadcast_ = enable_xla_param_broadcast;
4555 }
4556 
4557 }  // namespace tensorflow
4558