1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Compilation for distributed TPU (TPU_REPLICATED_CORE devices).
17
18 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h"
19
20 #include <queue>
21 #include <vector>
22
23 #include "absl/algorithm/container.h"
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/strings/escaping.h"
26 #include "tensorflow/compiler/jit/encapsulate_util.h"
27 #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
28 #include "tensorflow/compiler/tf2xla/sharding_util.h"
29 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
30 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
31 #include "tensorflow/compiler/xla/array3d.h"
32 #include "tensorflow/compiler/xla/array4d.h"
33 #include "tensorflow/compiler/xla/client/sharding_builder.h"
34 #include "tensorflow/compiler/xla/service/computation_placer.h"
35 #include "tensorflow/compiler/xla/xla.pb.h"
36 #include "tensorflow/core/common_runtime/function.h"
37 #include "tensorflow/core/common_runtime/graph_constructor.h"
38 #include "tensorflow/core/common_runtime/lower_function_call_op.h"
39 #include "tensorflow/core/common_runtime/lower_functional_ops.h"
40 #include "tensorflow/core/common_runtime/lower_if_op.h"
41 #include "tensorflow/core/common_runtime/lower_while_op.h"
42 #include "tensorflow/core/common_runtime/optimization_registry.h"
43 #include "tensorflow/core/framework/function.h"
44 #include "tensorflow/core/framework/graph_to_functiondef.h"
45 #include "tensorflow/core/framework/node_def_builder.h"
46 #include "tensorflow/core/framework/node_def_util.h"
47 #include "tensorflow/core/framework/partial_tensor_shape.h"
48 #include "tensorflow/core/framework/tensor.pb.h"
49 #include "tensorflow/core/framework/types.pb.h"
50 #include "tensorflow/core/framework/versions.pb.h"
51 #include "tensorflow/core/graph/algorithm.h"
52 #include "tensorflow/core/graph/graph.h"
53 #include "tensorflow/core/lib/core/errors.h"
54 #include "tensorflow/core/lib/core/status.h"
55 #include "tensorflow/core/lib/gtl/cleanup.h"
56 #include "tensorflow/core/lib/strings/proto_serialization.h"
57 #include "tensorflow/core/lib/strings/str_util.h"
58 #include "tensorflow/core/platform/fingerprint.h"
59 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
60 #include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h"
61 #include "tensorflow/core/protobuf/tpu/topology.pb.h"
62 #include "tensorflow/core/public/session_options.h"
63 #include "tensorflow/core/tpu/graph_rewrite/cond_builder.h"
64 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h"
65 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h"
66 #include "tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.h"
67 #include "tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.h"
68 #include "tensorflow/core/tpu/tpu_compile_interface.h"
69 #include "tensorflow/core/tpu/tpu_defs.h"
70 #include "tensorflow/core/tpu/tpu_fingerprint_utils.h"
71 #include "tensorflow/core/tpu/tpu_ops_c_api.h"
72 #include "tensorflow/core/util/device_name_utils.h"
73 #include "tensorflow/core/util/dump_graph.h"
74 #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
75
76 namespace tensorflow {
77
78 namespace {
79
80 // Device coordinates are defined as (x, y, z, core), thus resulting in a rank 4
81 // topology.
82 constexpr int kTPUTopologyRank = 4;
83
84 // An upper bound on how many cores may be present in the topology.
85 static constexpr int kTPUMaxTopologySize = 4096;
86
87 // Attribute containing the serialized xla::OpSharding to be passed to the
88 // corresponding XLA HLO operation, which represents how a shape is distributed
89 // across logical cores, e.g., replication, single-device, or partitioning.
90 const char kShardingAttribute[] = "_XlaSharding";
91
92 const char kTPUPartitionedInput[] = "TPUPartitionedInput";
93 const char kTPUPartitionedOutput[] = "TPUPartitionedOutput";
94
95 const char kVarHandleOp[] = "VarHandleOp";
96
97 static const char* const kTPUCompilationResultAttr = "_tpu_compilation_status";
98 static const char* const kPostDeviceRewriteAttr = "_post_device_rewrite";
99
100 class IntrusiveHeapLink {
101 public:
102 using size_type = size_t;
103 static constexpr size_type kNotMember = -1;
104
105 IntrusiveHeapLink() = default;
106
107 // Only IntrusiveHeap and LinkAccess objects should make these objects.
IntrusiveHeapLink(size_type pos)108 explicit IntrusiveHeapLink(size_type pos) : pos_{pos} {}
109
110 // Only IntrusiveHeap and LinkAccess should get the value.
get() const111 size_type get() const { return pos_; }
112
113 private:
114 size_type pos_{kNotMember};
115 };
116
117 template <typename T, IntrusiveHeapLink T::*M>
118 struct IntrusiveHeapDataMemberLinkAccess {
Gettensorflow::__anonf0ad67bd0111::IntrusiveHeapDataMemberLinkAccess119 IntrusiveHeapLink Get(const T* elem) const { return elem->*M; }
Settensorflow::__anonf0ad67bd0111::IntrusiveHeapDataMemberLinkAccess120 void Set(T* elem, IntrusiveHeapLink link) const { elem->*M = link; }
121 };
122
123 template <typename T>
124 struct DefaultIntrusiveHeapLinkAccess {
Gettensorflow::__anonf0ad67bd0111::DefaultIntrusiveHeapLinkAccess125 IntrusiveHeapLink Get(const T* elem) const { return elem->heap; }
Settensorflow::__anonf0ad67bd0111::DefaultIntrusiveHeapLinkAccess126 void Set(T* elem, IntrusiveHeapLink link) const { elem->heap = link; }
127 };
128
129 template <typename T, typename PtrCompare,
130 typename LinkAccess = DefaultIntrusiveHeapLinkAccess<T>,
131 typename Alloc = std::allocator<T*>>
132 class IntrusiveHeap {
133 public:
134 typedef typename IntrusiveHeapLink::size_type size_type;
135 typedef T value_type;
136 typedef T* pointer;
137 typedef const T* const_pointer;
138 typedef PtrCompare pointer_compare_type;
139 typedef LinkAccess link_access_type;
140 typedef Alloc allocator_type;
141
IntrusiveHeap(const pointer_compare_type & comp=pointer_compare_type (),const link_access_type & link_access=link_access_type (),const allocator_type & alloc=allocator_type ())142 explicit IntrusiveHeap(
143 const pointer_compare_type& comp = pointer_compare_type(),
144 const link_access_type& link_access = link_access_type(),
145 const allocator_type& alloc = allocator_type())
146 : rep_(comp, link_access, alloc) {}
147
size() const148 size_type size() const { return heap().size(); }
149
empty() const150 bool empty() const { return heap().empty(); }
151
152 // Return the top element, but don't remove it.
top() const153 pointer top() const {
154 DCHECK(!empty());
155 return heap()[0];
156 }
157
158 // Remove the top() pointer from the heap and return it.
Pop()159 pointer Pop() {
160 pointer t = top();
161 Remove(t);
162 return t;
163 }
164
165 // Insert 't' into the heap.
Push(pointer t)166 void Push(pointer t) {
167 SetPositionOf(t, heap().size());
168 heap().push_back(t);
169 FixHeapUp(t);
170 }
171
172 // Adjust the heap to accommodate changes in '*t'.
Adjust(pointer t)173 void Adjust(pointer t) {
174 DCHECK(Contains(t));
175 size_type h = GetPositionOf(t);
176 if (h != 0 && compare()(t, heap()[(h - 1) >> 1])) {
177 FixHeapUp(t);
178 } else {
179 FixHeapDown(t);
180 }
181 }
182
183 // Remove the specified pointer from the heap.
Remove(pointer t)184 void Remove(pointer t) {
185 DCHECK(Contains(t));
186 size_type h = GetPositionOf(t);
187 SetPositionOf(t, IntrusiveHeapLink::kNotMember);
188 if (h == heap().size() - 1) {
189 // Fast path for removing from back of heap.
190 heap().pop_back();
191 return;
192 }
193 // Move the element from the back of the heap to overwrite 't'.
194 pointer& elem = heap()[h];
195 elem = heap().back();
196 SetPositionOf(elem, h); // Element has moved, so update its link.
197 heap().pop_back();
198 Adjust(elem); // Restore the heap invariant.
199 }
200
Clear()201 void Clear() { heap().clear(); }
202
Contains(const_pointer t) const203 bool Contains(const_pointer t) const {
204 size_type h = GetPositionOf(t);
205 return (h != IntrusiveHeapLink::kNotMember) && (h < size()) &&
206 heap()[h] == t;
207 }
208
reserve(size_type n)209 void reserve(size_type n) { heap().reserve(n); }
210
capacity() const211 size_type capacity() const { return heap().capacity(); }
212
get_allocator() const213 allocator_type get_allocator() const { return rep_.heap_.get_allocator(); }
214
215 private:
216 typedef std::vector<pointer, allocator_type> heap_type;
217
218 // Empty base class optimization for pointer_compare and link_access.
219 // The heap_ data member retains a copy of the allocator, so it is not
220 // stored explicitly.
221 struct Rep : pointer_compare_type, link_access_type {
Reptensorflow::__anonf0ad67bd0111::IntrusiveHeap::Rep222 explicit Rep(const pointer_compare_type& cmp,
223 const link_access_type& link_access,
224 const allocator_type& alloc)
225 : pointer_compare_type(cmp),
226 link_access_type(link_access),
227 heap_(alloc) {}
228 heap_type heap_; // NOLINT
229 };
230
compare() const231 const pointer_compare_type& compare() const { return rep_; }
232
link_access() const233 const link_access_type& link_access() const { return rep_; }
234
heap() const235 const heap_type& heap() const { return rep_.heap_; }
heap()236 heap_type& heap() { return rep_.heap_; }
237
GetPositionOf(const_pointer t) const238 size_type GetPositionOf(const_pointer t) const {
239 return link_access().Get(t).get();
240 }
241
SetPositionOf(pointer t,size_type pos) const242 void SetPositionOf(pointer t, size_type pos) const {
243 return link_access().Set(t, IntrusiveHeapLink(pos));
244 }
245
FixHeapUp(pointer t)246 void FixHeapUp(pointer t) {
247 size_type h = GetPositionOf(t);
248 while (h != 0) {
249 size_type parent = (h - 1) >> 1;
250 if (compare()(heap()[parent], t)) {
251 break;
252 }
253 heap()[h] = heap()[parent];
254 SetPositionOf(heap()[h], h);
255 h = parent;
256 }
257 heap()[h] = t;
258 SetPositionOf(t, h);
259 }
260
FixHeapDown(pointer t)261 void FixHeapDown(pointer t) {
262 size_type h = GetPositionOf(t);
263 for (;;) {
264 size_type kid = (h << 1) + 1;
265 if (kid >= heap().size()) {
266 break;
267 }
268 if (kid + 1 < heap().size() && compare()(heap()[kid + 1], heap()[kid])) {
269 ++kid;
270 }
271 if (compare()(t, heap()[kid])) {
272 break;
273 }
274 heap()[h] = heap()[kid];
275 SetPositionOf(heap()[h], h);
276 h = kid;
277 }
278
279 heap()[h] = t;
280 SetPositionOf(t, h);
281 }
282
283 Rep rep_;
284 };
285
CoreDeviceLabel(int core)286 string CoreDeviceLabel(int core) {
287 return strings::StrCat("/device:", DEVICE_TPU_REPLICATED_CORE, ":", core);
288 }
289
290 // Creates a unique node name with a particular prefix.
UniqueNodeName(const StringPiece prefix,Graph * graph)291 string UniqueNodeName(const StringPiece prefix, Graph* graph) {
292 return graph->NewName(strings::StrCat(prefix, "/_", internal::GetNodeId()));
293 }
294
SetNodeDeviceForTPUCommunication(DeviceNameUtils::ParsedName device,const string & target_device_type,Node * node)295 Status SetNodeDeviceForTPUCommunication(DeviceNameUtils::ParsedName device,
296 const string& target_device_type,
297 Node* node) {
298 TF_RET_CHECK(device.has_type && device.type == DEVICE_TPU_NODE);
299 TF_RET_CHECK(device.has_id);
300 TF_RET_CHECK(HasNodeAttr(node->def(), kXlaHasHostTransferAttrName));
301
302 // Store the device instance as an attr on the Node.
303 TF_RETURN_IF_ERROR(SetDeviceOrdinalAttributeForNode(node, device.id));
304
305 // Place the execute Op on the TPU_SYSTEM device so it can access the cache of
306 // compiled protos in the resource manager.
307 device.type = target_device_type;
308 device.id = 0;
309
310 node->set_assigned_device_name(DeviceNameUtils::ParsedNameToString(device));
311 return Status::OK();
312 }
313
314 // Iterate over the nodes in the original graph and find all the TPUReplicate
315 // nodes, and all the nodes that are part of outside_compilation clusters.
FindTaggedNodes(Graph * graph,std::vector<Node * > * replicate_nodes,std::map<string,DistributedTPURewritePass::OutsideCompilationNodeMap> * outside_compilation_nodes,std::map<string,std::vector<Node * >> * head_tail_outside_compilation_nodes)316 Status FindTaggedNodes(
317 Graph* graph, std::vector<Node*>* replicate_nodes,
318 std::map<string, DistributedTPURewritePass::OutsideCompilationNodeMap>*
319 outside_compilation_nodes,
320 std::map<string, std::vector<Node*>>* head_tail_outside_compilation_nodes) {
321 for (Node* node : graph->op_nodes()) {
322 if (node->type_string() == "_TPUReplicate") {
323 replicate_nodes->push_back(node);
324 const AttrValue* cluster_attr = node->attrs().Find(kTPUReplicateAttr);
325 if (cluster_attr == nullptr) {
326 return errors::Internal("TPUReplicate node ", node->name(), " has no ",
327 kTPUReplicateAttr, " attr.");
328 } else {
329 const string& cluster = cluster_attr->s();
330 if (cluster.empty()) {
331 return errors::Internal("Attr ", kTPUReplicateAttr, " on node ",
332 node->name(), " has no string value.");
333 }
334 if (outside_compilation_nodes->find(cluster) !=
335 outside_compilation_nodes->end()) {
336 return errors::Internal(
337 "TPUReplicate node ", node->name(), " has ", kTPUReplicateAttr,
338 " attr value '", cluster,
339 "' which is a duplicate of another TPUReplicate node in the "
340 "graph.");
341 }
342 (*outside_compilation_nodes)[cluster] =
343 DistributedTPURewritePass::OutsideCompilationNodeMap();
344 (*head_tail_outside_compilation_nodes)[cluster] = std::vector<Node*>();
345 }
346 }
347 }
348 for (Node* node : graph->op_nodes()) {
349 if (node->type_string() != "_TPUReplicate") {
350 const AttrValue* cluster_attr = node->attrs().Find(kTPUReplicateAttr);
351 const AttrValue* outside_compilation_attr =
352 node->attrs().Find(kOutsideCompilationAttr);
353 if (cluster_attr == nullptr) {
354 if (outside_compilation_attr != nullptr) {
355 return errors::Internal("Node ", node->name(), " has ",
356 kOutsideCompilationAttr, " attr but no ",
357 kTPUReplicateAttr, " attr.");
358 }
359 } else {
360 const string& cluster = cluster_attr->s();
361 if (cluster.empty()) {
362 return errors::Internal("Attr ", kTPUReplicateAttr, " on node ",
363 node->name(), " has no string value.");
364 }
365 const auto iter = outside_compilation_nodes->find(cluster);
366 if (iter == outside_compilation_nodes->end()) {
367 return errors::Internal(
368 "Attr ", kTPUReplicateAttr, " on node ", node->name(),
369 " does not correspond to a TPUReplicate node.");
370 }
371 if (outside_compilation_attr == nullptr) {
372 return errors::Internal("Node ", node->name(), " has ",
373 kTPUReplicateAttr, " attr but no ",
374 kOutsideCompilationAttr, " attr.");
375 }
376 const string& oc_cluster = outside_compilation_attr->s();
377 if (oc_cluster.empty()) {
378 return errors::Internal("Attr ", kOutsideCompilationAttr, " on node ",
379 node->name(), " has no string value.");
380 }
381
382 // Outside compilation cluster at head and tail of TPU computation has
383 // already been moved to host and is already replicated. As so, do not
384 // replicate outside compilation nodes with replica id attribute.
385 int replica_id;
386 if (TryGetNodeAttr(node->def(), kXlaReplicaIdAttrName, &replica_id)) {
387 const AttrValue* head_attr =
388 node->attrs().Find("_xla_only_arg_or_oc_input");
389 const AttrValue* tail_attr =
390 node->attrs().Find("_xla_only_ret_or_oc_output");
391 if (((head_attr != nullptr) && (head_attr->b())) ||
392 ((tail_attr != nullptr) && (tail_attr->b()))) {
393 // This is safe as this has the same keys as
394 // outside_compilation_nodes which we already know has this key.
395 (*head_tail_outside_compilation_nodes)[cluster].push_back(node);
396 }
397 continue;
398 }
399 iter->second[oc_cluster].push_back(node);
400 }
401 }
402 }
403 return Status::OK();
404 }
405
406 // Helper class to spread TPU computation arguments and return values
407 // across cores.
408 // If all shapes are fully defined, balance by their size.
409 // If some of them are not fully defined, the undefined shapes size will
410 // be estimated with the average size of the fully defined ones.
411 // If none are defined, fall back to round-robin.
412 class TensorDevicePlacer {
413 public:
414 // Creates a TensorDevicePlacer object to distribute arguments or
415 // return values to a set of num_devices devices, where the types and
416 // the inferred shapes of the inputs (arguments or return values) are
417 // passed in types and shapes.
TensorDevicePlacer(int64 num_devices,const DataTypeVector & types,const std::vector<InferredShape> & shapes)418 TensorDevicePlacer(int64 num_devices, const DataTypeVector& types,
419 const std::vector<InferredShape>& shapes)
420 : index_nodes_(num_devices), sizes_(types.size()) {
421 int64 total_size = 0;
422 int64 num_defined = 0;
423 for (int64 i = 0; i < types.size(); ++i) {
424 sizes_[i] = GetInferredShapeSize(shapes[i], types[i]);
425 if (sizes_[i] >= 0) {
426 total_size += sizes_[i];
427 ++num_defined;
428 }
429 }
430 // If a shape is undefined, select a size for it which is the average
431 // of the defined shapes. If no shapes are defined, assign 1 so that we
432 // get round-robin behavior.
433 int64 undefined_shape_size =
434 (num_defined > 0) ? total_size / num_defined : 1;
435 for (int64 i = 0; i < sizes_.size(); ++i) {
436 if (sizes_[i] < 0) {
437 sizes_[i] = undefined_shape_size;
438 }
439 }
440
441 for (int64 i = 0; i < num_devices; ++i) {
442 heap_.Push(&index_nodes_[i]);
443 }
444 }
445
446 // Reports that the argument/return-value at index has been assigned
447 // by the user to a given device.
ReportDeviceAssigned(int64 device,int64 index)448 void ReportDeviceAssigned(int64 device, int64 index) {
449 DeviceNode* node = &index_nodes_.at(device);
450 node->size += sizes_.at(index);
451 heap_.Adjust(node);
452 }
453
454 // Retrieves the device at which the argument/return-value at index
455 // should be assigned to.
RetrieveAssignment(int64 index)456 int64 RetrieveAssignment(int64 index) {
457 DeviceNode* node = heap_.top();
458 int64 device = node - index_nodes_.data();
459 node->size += sizes_.at(index);
460 heap_.Adjust(node);
461 return device;
462 }
463
464 private:
465 struct DeviceNode {
466 struct Compare {
467 // Compare functor to implement a min heap using the ::gtl::IntrusiveHeap
468 // infrastructure.
operator ()tensorflow::__anonf0ad67bd0111::TensorDevicePlacer::DeviceNode::Compare469 bool operator()(const DeviceNode* lhs, const DeviceNode* rhs) const {
470 return lhs->size < rhs->size;
471 }
472 };
473
474 IntrusiveHeapLink heap;
475 int64 size = 0;
476 };
477
GetInferredShapeSize(const InferredShape & ishape,DataType dtype)478 static int64 GetInferredShapeSize(const InferredShape& ishape,
479 DataType dtype) {
480 return ishape.shape.IsFullyDefined()
481 ? ishape.shape.num_elements() * DataTypeSize(dtype)
482 : -1;
483 }
484
485 std::vector<DeviceNode> index_nodes_;
486 IntrusiveHeap<DeviceNode, typename DeviceNode::Compare> heap_;
487 std::vector<int64> sizes_;
488 };
489
ValidateCoreNumber(int64 core,int64 num_cores_per_replica)490 Status ValidateCoreNumber(int64 core, int64 num_cores_per_replica) {
491 if (core < 0 || core >= num_cores_per_replica) {
492 return tensorflow::errors::InvalidArgument("Invalid core ID: ", core,
493 ". The valid core IDs are [0..",
494 num_cores_per_replica, ")");
495 }
496 return Status::OK();
497 }
498
FindHostComputeKeyPlaceholderNodes(const Graph * graph,const std::vector<Node * > & replicate_nodes,std::unordered_map<string,Node * > * host_compute_key_placeholder_map)499 Status FindHostComputeKeyPlaceholderNodes(
500 const Graph* graph, const std::vector<Node*>& replicate_nodes,
501 std::unordered_map<string, Node*>* host_compute_key_placeholder_map) {
502 host_compute_key_placeholder_map->clear();
503 for (const auto node : replicate_nodes) {
504 (*host_compute_key_placeholder_map)[node->name()] = nullptr;
505 }
506
507 for (Node* node : graph->op_nodes()) {
508 if (node->type_string() == "Placeholder" &&
509 str_util::EndsWith(node->name(), "_key_placeholder")) {
510 const AttrValue* call_node_attr =
511 node->attrs().Find("_host_compute_call_node");
512 if (call_node_attr != nullptr) {
513 auto iter = host_compute_key_placeholder_map->find(call_node_attr->s());
514 if (iter == host_compute_key_placeholder_map->end()) {
515 return errors::InvalidArgument(
516 "Node ", node->name(), " has _host_compute_call_node attribute '",
517 call_node_attr->s(), "' that doesn't correspond to a call node");
518 }
519 if (iter->second != nullptr) {
520 return errors::InvalidArgument(
521 "Key placeholder node ", iter->second->name(), " for call node ",
522 call_node_attr->s(), " previously found as ",
523 iter->second->name());
524 }
525 iter->second = node;
526 }
527 }
528 }
529
530 return Status::OK();
531 }
532
ReplaceCompilationResultNodeWithIdentity(Graph * graph,Node ** node)533 Status ReplaceCompilationResultNodeWithIdentity(Graph* graph, Node** node) {
534 Node* old_node = *node;
535 // We want to replace the node with an identity node with the same name.
536 const string& node_name = old_node->name();
537
538 // Create identity node.
539 TF_ASSIGN_OR_RETURN(
540 Node * id_node,
541 BuildIdentityNode(graph, node_name, DT_STRING,
542 /*input=*/nullptr, /*requested_device=*/""));
543
544 // No incoming edges are copied as a new one will be added from compile node
545 // to id_node.
546
547 // Copy outgoing edges to the id node.
548 std::vector<const Edge*> out_edges(old_node->out_edges().begin(),
549 old_node->out_edges().end());
550 for (const Edge* edge : out_edges) {
551 Node* dst = edge->dst();
552 int src_output = edge->src_output();
553 int dst_input = edge->dst_input();
554
555 if (src_output == Graph::kControlSlot) {
556 graph->AddControlEdge(id_node, dst);
557 } else {
558 graph->AddEdge(id_node, src_output, dst, dst_input);
559 }
560 graph->RemoveEdge(edge);
561 }
562 graph->RemoveNode(old_node);
563
564 *node = id_node;
565 return Status::OK();
566 }
567
FillPaddingMap(const Node & replicate_node,protobuf::RepeatedPtrField<tpu::PaddingMap> * padding_maps)568 Status FillPaddingMap(
569 const Node& replicate_node,
570 protobuf::RepeatedPtrField<tpu::PaddingMap>* padding_maps) {
571 std::vector<string> padding_map_strs;
572 TF_RETURN_IF_ERROR(
573 GetNodeAttr(replicate_node.attrs(), "padding_map", &padding_map_strs));
574 padding_maps->Reserve(padding_map_strs.size());
575 for (const string& padding_map_str : padding_map_strs) {
576 tpu::PaddingMap* padding_map = padding_maps->Add();
577 if (!padding_map->ParseFromString(padding_map_str)) {
578 return errors::InvalidArgument(
579 "Malformed padding_map serialized string: ", padding_map_str);
580 }
581 }
582 return Status::OK();
583 }
584
GetStepMarkerLocation(const Node & replicate_node,xla::DebugOptions::StepMarkerLocation * location)585 Status GetStepMarkerLocation(const Node& replicate_node,
586 xla::DebugOptions::StepMarkerLocation* location) {
587 string step_marker_location_attr;
588 TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), "step_marker_location",
589 &step_marker_location_attr));
590 if (step_marker_location_attr.empty()) {
591 *location = xla::DebugOptions::STEP_MARK_AT_ENTRY;
592 } else {
593 if (!xla::DebugOptions::StepMarkerLocation_Parse(step_marker_location_attr,
594 location)) {
595 return errors::InvalidArgument("Malformed step_marker_location: ",
596 step_marker_location_attr);
597 }
598 }
599 return Status::OK();
600 }
601
602 // Extracts a map of dimension and number of splits for tiled input from xla
603 // sharding attribute.
GetDimensionIndicesAndNumSplitsFromSharding(const xla::OpSharding & sharding,std::map<int,int> * split_dimension_map)604 Status GetDimensionIndicesAndNumSplitsFromSharding(
605 const xla::OpSharding& sharding, std::map<int, int>* split_dimension_map) {
606 int64 tensor_tile_rank = sharding.tile_assignment_dimensions_size();
607 if (sharding.replicate_on_last_tile_dim()) {
608 tensor_tile_rank--;
609 }
610 for (int dim_index = 0; dim_index < tensor_tile_rank; dim_index++) {
611 if (sharding.tile_assignment_dimensions(dim_index) > 1) {
612 split_dimension_map->emplace(
613 dim_index, sharding.tile_assignment_dimensions(dim_index));
614 }
615 }
616
617 if (split_dimension_map->empty()) {
618 return errors::InvalidArgument("Arg has unnecessary tiled sharding: ",
619 sharding.DebugString());
620 }
621 return Status::OK();
622 }
623
624 // Updates contents of the function with `function_name` in function library
625 // definition `flib_def` to `new_graph`. This is required when graph
626 // transformation happens inside a function call body.
UpdateFunctionLibDefinition(const Graph & new_graph,const std::string & function_name,FunctionLibraryDefinition * flib_def)627 Status UpdateFunctionLibDefinition(const Graph& new_graph,
628 const std::string& function_name,
629 FunctionLibraryDefinition* flib_def) {
630 FunctionDef graph_fdef;
631 TF_RETURN_IF_ERROR(GraphToFunctionDef(new_graph, function_name, &graph_fdef));
632 TF_RETURN_IF_ERROR(flib_def->ReplaceFunction(function_name, graph_fdef));
633 return Status::OK();
634 }
635
636 struct NodeOut {
637 Node* node;
638 int index;
639 };
640
641 struct ShardedInputIndex {
642 int replica_id;
643 int argument_index;
644
operator <tensorflow::__anonf0ad67bd0111::ShardedInputIndex645 bool operator<(const ShardedInputIndex& rhs) const {
646 return std::tie(replica_id, argument_index) <
647 std::tie(rhs.replica_id, rhs.argument_index);
648 }
649 };
650
651 struct ShardedInputInfo {
652 // Split node that would be connected to tiled input Node.
653 Node* split_node;
654 // List of splits nodes and output index of the split node from which sharded
655 // input will be connected to the TPUExecute node. The inputs are ordered by
656 // logical core ids.
657 std::vector<NodeOut> sharded_inputs;
658 };
659
660 // Adds pad node after split node to graph for uneven sharding tiled inputs.
661 // |graph| owns the returned Node* instance.
CreatePadNode(const int padding,const int num_dims,const int split_dim,DataType dtype,Node * control_predecessor,Node * split_node,const int split_index,Graph * graph)662 xla::StatusOr<Node*> CreatePadNode(const int padding, const int num_dims,
663 const int split_dim, DataType dtype,
664 Node* control_predecessor, Node* split_node,
665 const int split_index, Graph* graph) {
666 // Add paddings node.
667 Status s;
668 NodeDef paddings_def;
669 paddings_def.set_name(
670 graph->NewName(absl::StrCat(split_node->name(), "/paddings")));
671 paddings_def.set_op("Const");
672 AddNodeAttr("dtype", DT_INT32, &paddings_def);
673 paddings_def.set_device(split_node->assigned_device_name());
674 TensorProto sizes_tensor_proto;
675 sizes_tensor_proto.set_dtype(DT_INT32);
676 for (int i = 0; i < num_dims; ++i) {
677 sizes_tensor_proto.add_int_val(0);
678 if (i == split_dim) {
679 sizes_tensor_proto.add_int_val(padding);
680 } else {
681 sizes_tensor_proto.add_int_val(0);
682 }
683 }
684 TensorShape sizes_shape({num_dims, 2});
685 sizes_shape.AsProto(sizes_tensor_proto.mutable_tensor_shape());
686 AddNodeAttr("value", sizes_tensor_proto, &paddings_def);
687 Node* paddings_node = graph->AddNode(paddings_def, &s);
688 TF_RETURN_IF_ERROR(s);
689
690 // Add Pad node.
691 NodeDef pad_def;
692 pad_def.set_name(graph->NewName(
693 absl::StrCat(split_node->name(), "/pad_shard_", split_index)));
694 pad_def.set_op("Pad");
695 pad_def.set_device(split_node->assigned_device_name());
696 AddNodeAttr("T", dtype, &pad_def);
697 AddNodeAttr("Tpaddings", DT_INT32, &pad_def);
698 pad_def.add_input(absl::StrCat(split_node->name(), ":", split_index));
699 pad_def.add_input(absl::StrCat(paddings_node->name(), ":0"));
700 Node* pad_node = graph->AddNode(pad_def, &s);
701 pad_node->set_assigned_device_name(split_node->assigned_device_name());
702 TF_RETURN_IF_ERROR(s);
703 // Add edges for pad node.
704 graph->AddEdge(split_node, split_index, pad_node, 0);
705 graph->AddEdge(paddings_node, 0, pad_node, 1);
706 graph->AddControlEdge(control_predecessor, pad_node);
707 return pad_node;
708 }
709
710 // Adds split node and split dimension node to graph for sharding tiled inputs.
711 // |graph| owns the returned Node* instance.
CreateSplitNode(const int num_splits,const int dim,const int num_dims,const int64 padding,const int orig_src_output,DataType dtype,absl::string_view name_prefix,Node * control_predecessor,Node * orig_src,Graph * graph)712 xla::StatusOr<Node*> CreateSplitNode(const int num_splits, const int dim,
713 const int num_dims, const int64 padding,
714 const int orig_src_output, DataType dtype,
715 absl::string_view name_prefix,
716 Node* control_predecessor, Node* orig_src,
717 Graph* graph) {
718 const std::string input_assigned_device = orig_src->assigned_device_name();
719 Node* to_split_node = orig_src;
720 int to_split_index = orig_src_output;
721 if (padding > 0) {
722 TF_ASSIGN_OR_RETURN(
723 Node * pad_node,
724 CreatePadNode(padding, num_dims, dim, dtype, control_predecessor,
725 orig_src, orig_src_output, graph));
726 to_split_node = pad_node;
727 to_split_index = 0;
728 }
729
730 // Add a split dimension node.
731 NodeDef split_dim_def;
732 split_dim_def.set_name(
733 graph->NewName(absl::StrCat(name_prefix, "/split_dim")));
734 split_dim_def.set_op("Const");
735 split_dim_def.set_device(input_assigned_device);
736 AddNodeAttr("dtype", DT_INT32, &split_dim_def);
737 TensorProto tensor_proto;
738 tensor_proto.set_dtype(DT_INT32);
739 tensor_proto.add_int_val(dim);
740 TensorShape shape({});
741 shape.AsProto(tensor_proto.mutable_tensor_shape());
742 AddNodeAttr("value", tensor_proto, &split_dim_def);
743 Status s;
744 Node* split_dim_node = graph->AddNode(split_dim_def, &s);
745 TF_RETURN_IF_ERROR(s);
746 // Add a split node.
747 NodeDef split_def;
748 split_def.set_name(graph->NewName(absl::StrCat(name_prefix, "/split")));
749 split_def.set_op("Split");
750 split_def.set_device(input_assigned_device);
751 AddNodeAttr("num_split", num_splits, &split_def);
752 AddNodeAttr("T", dtype, &split_def);
753 split_def.add_input(absl::StrCat(split_dim_node->name(), ":0"));
754 split_def.add_input(absl::StrCat(to_split_node->name(), ":", to_split_index));
755 Node* split_node = graph->AddNode(split_def, &s);
756 TF_RETURN_IF_ERROR(s);
757
758 split_node->set_assigned_device_name(input_assigned_device);
759
760 // If colocate the newly created split op to source node of input to TPU
761 // computation.
762 split_node->AddAttr(kColocationAttrName,
763 std::vector<string>{absl::StrCat(kColocationGroupPrefix,
764 orig_src->name())});
765
766 graph->AddEdge(split_dim_node, 0, split_node, 0);
767 graph->AddEdge(to_split_node, to_split_index, split_node, 1);
768
769 // Add a control dependency from `control_predecessor` to newly created
770 // constant node. This ensures that newly added split/split dim
771 // nodes are placed inside correct while loop frames when TPUExecute
772 // node is inside a host training loop.
773 graph->AddControlEdge(control_predecessor, split_dim_node);
774 return split_node;
775 }
776
GetPadding(const int split_dim,const int num_splits,const PartialTensorShape & partial_tensor_shape)777 int64 GetPadding(const int split_dim, const int num_splits,
778 const PartialTensorShape& partial_tensor_shape) {
779 // If dim dimension is not defined, no uneven sharding support.
780 if (partial_tensor_shape.dim_size(split_dim) <= 0) {
781 return 0;
782 }
783 int64 per_split_size = tensorflow::MathUtil::CeilOfRatio<int64>(
784 partial_tensor_shape.dim_size(split_dim), num_splits);
785 int64 total_padding =
786 per_split_size * num_splits - partial_tensor_shape.dim_size(split_dim);
787 return total_padding;
788 }
789
790 // Creates a set of splits nodes that shards tiled input node in graph.
CreateOrGetSplitNodesForInputSharding(const xla::OpSharding & sharding,int orig_arg_num,DataType dtype,const PartialTensorShape & partial_tensor_shape,int replica_id,int orig_src_output,Node * orig_src,Node * control_predecessor,Graph * graph,std::map<ShardedInputIndex,ShardedInputInfo> * arg_index_to_sharded_input_map)791 xla::StatusOr<ShardedInputInfo> CreateOrGetSplitNodesForInputSharding(
792 const xla::OpSharding& sharding, int orig_arg_num, DataType dtype,
793 const PartialTensorShape& partial_tensor_shape, int replica_id,
794 int orig_src_output, Node* orig_src, Node* control_predecessor,
795 Graph* graph,
796 std::map<ShardedInputIndex, ShardedInputInfo>*
797 arg_index_to_sharded_input_map) {
798 ShardedInputIndex input_index{replica_id, orig_arg_num};
799 auto iter = arg_index_to_sharded_input_map->find(input_index);
800 if (iter != arg_index_to_sharded_input_map->end()) {
801 return iter->second;
802 }
803 // Maps input dimension and number of splits with which the
804 // dimension sharded.
805 std::map<int, int> split_dimension_map;
806 TF_RETURN_IF_ERROR(GetDimensionIndicesAndNumSplitsFromSharding(
807 sharding, &split_dimension_map));
808 TF_RET_CHECK(!split_dimension_map.empty())
809 << "Unnecessary sharding attribute found.";
810
811 // For v1 while loop, nodes inside the loop body must either
812 // 1) Have data edges from while loop input node.
813 // or
814 // 2) Have direct control dependency from while loop input control
815 // node.
816 //
817 // As so, if we are adding Split node inside, while loop body,
818 // we must manually add a control dependency to a node inside
819 // a while loop (i.e. `control_predecessor`) to constant nodes
820 // without data in-edges to make sure that added split nodes
821 // have correct frame name. Else, placer will complain when
822 // `BuildControlFlow()` is invoked.
823
824 auto sharding_it = split_dimension_map.begin();
825 std::queue<Node*> split_nodes_for_dimension;
826 absl::flat_hash_map<Node*, int> node_to_split_dim;
827 int split_dimension = sharding_it->first;
828 int num_split = sharding_it->second;
829
830 // Creates a tree of split nodes for sharding tiled inputs. Splits nodes
831 // are created such that input data is sharded in row major order.
832 // Split nodes at ith depth from the original input node represent nodes
833 // that split the input data at ith dimension.
834 TF_ASSIGN_OR_RETURN(
835 Node * root_split_node,
836 CreateSplitNode(
837 num_split, split_dimension, partial_tensor_shape.dims(),
838 GetPadding(split_dimension, num_split, partial_tensor_shape),
839 orig_src_output, dtype,
840 absl::StrCat("sharded_input/replica_", replica_id, "_dim_",
841 split_dimension),
842 control_predecessor, orig_src, graph));
843 sharding_it++;
844
845 split_nodes_for_dimension.emplace(root_split_node);
846 node_to_split_dim[root_split_node] = split_dimension;
847
848 while (sharding_it != split_dimension_map.end()) {
849 split_dimension = sharding_it->first;
850 num_split = sharding_it->second;
851 int num_split_nodes_in_dimension = split_nodes_for_dimension.size();
852 for (int i = 0; i < num_split_nodes_in_dimension; ++i) {
853 Node* input_split_node = split_nodes_for_dimension.front();
854 split_nodes_for_dimension.pop();
855 for (int src_output_index = 0;
856 src_output_index < input_split_node->num_outputs();
857 ++src_output_index) {
858 TF_ASSIGN_OR_RETURN(
859 Node * split_node,
860 CreateSplitNode(
861 num_split, split_dimension, partial_tensor_shape.dims(),
862 GetPadding(split_dimension, num_split, partial_tensor_shape),
863 src_output_index, dtype,
864 absl::StrCat("sharded_input/replica_", replica_id, "_dim_",
865 split_dimension),
866 control_predecessor, input_split_node, graph));
867 split_nodes_for_dimension.emplace(split_node);
868 node_to_split_dim[split_node] = split_dimension;
869 }
870 }
871 sharding_it++;
872 }
873
874 // `split_nodes_for_dimension` now includes final split nodes
875 // from which sharded data will be fed into TPUExcute nodes -- sorted by
876 // row major order.
877 std::vector<NodeOut> sharded_inputs_list(
878 sharding.tile_assignment_devices_size());
879 int64 next_core_tile_index = 0;
880 while (!split_nodes_for_dimension.empty()) {
881 Node* split_node = split_nodes_for_dimension.front();
882 split_nodes_for_dimension.pop();
883 int num_splits;
884 TF_RETURN_IF_ERROR(
885 GetNodeAttr(split_node->def(), "num_split", &num_splits));
886 for (int out_index = 0; out_index < num_splits; ++out_index) {
887 int64 repeat_count = sharding.replicate_on_last_tile_dim()
888 ? *sharding.tile_assignment_dimensions().rbegin()
889 : 1;
890 for (int64 i = 0; i < repeat_count; ++i) {
891 int64 next_core =
892 sharding.tile_assignment_devices(next_core_tile_index++);
893 sharded_inputs_list[next_core] = NodeOut{split_node, out_index};
894 }
895 }
896 }
897
898 ShardedInputInfo sharded_input_info{root_split_node,
899 std::move(sharded_inputs_list)};
900 (*arg_index_to_sharded_input_map)[input_index] = sharded_input_info;
901 return sharded_input_info;
902 }
903
904 // Creates a concat node to be used for aggregating sharded retvals across
905 // logical cores.
CreateConcatNode(int dim,int num_splits,DataType dtype,absl::string_view name_prefix,const std::vector<NodeOut> & inputs,Graph * graph,absl::string_view device)906 xla::StatusOr<Node*> CreateConcatNode(int dim, int num_splits, DataType dtype,
907 absl::string_view name_prefix,
908 const std::vector<NodeOut>& inputs,
909 Graph* graph, absl::string_view device) {
910 // Add a Concat dim node.
911 NodeDef concat_dim_def;
912 concat_dim_def.set_name(
913 graph->NewName(absl::StrCat(name_prefix, "/concat_dim")));
914 concat_dim_def.set_op("Const");
915 AddNodeAttr("dtype", DT_INT32, &concat_dim_def);
916 concat_dim_def.set_device(std::string(device));
917 TensorProto tensor_proto;
918 tensor_proto.set_dtype(DT_INT32);
919 tensor_proto.add_int_val(dim);
920 TensorShape shape({});
921 shape.AsProto(tensor_proto.mutable_tensor_shape());
922 AddNodeAttr("value", tensor_proto, &concat_dim_def);
923 Status s;
924 Node* concat_dim_node = graph->AddNode(concat_dim_def, &s);
925 TF_RETURN_IF_ERROR(s);
926
927 // Add a Concat node.
928 NodeDef concat_def;
929 concat_def.set_name(graph->NewName(absl::StrCat(name_prefix, "/concat")));
930 concat_def.set_op("Concat");
931 AddNodeAttr("N", num_splits, &concat_def);
932 AddNodeAttr("T", dtype, &concat_def);
933 concat_def.add_input(absl::StrCat(concat_dim_node->name(), ":0"));
934 concat_def.set_device(std::string(device));
935 for (const auto& i : inputs) {
936 concat_def.add_input(absl::StrCat(i.node->name(), ":", i.index));
937 }
938 Node* concat_node = graph->AddNode(concat_def, &s);
939 TF_RETURN_IF_ERROR(s);
940
941 graph->AddEdge(concat_dim_node, 0, concat_node, 0);
942
943 // 0th input to concat node is a concat dim node. So we start from 1st input
944 // and add all input edges.
945 int dst_input = 1;
946 for (const auto& i : inputs) {
947 graph->AddEdge(i.node, i.index, concat_node, dst_input);
948 ++dst_input;
949 }
950 return concat_node;
951 }
952
953 // Adds slice node after concat node to graph for uneven sharding tiled inputs.
CreateSliceNode(DataType dtype,const PartialTensorShape & shape,Node * concat_node,const int concat_out_index,Graph * graph,absl::string_view device)954 xla::StatusOr<Node*> CreateSliceNode(DataType dtype,
955 const PartialTensorShape& shape,
956 Node* concat_node,
957 const int concat_out_index, Graph* graph,
958 absl::string_view device) {
959 Status s;
960 // Add begin node for concat.
961 NodeDef begin_def;
962 begin_def.set_name(
963 graph->NewName(absl::StrCat(concat_node->name(), "/slice_begin")));
964 begin_def.set_op("Const");
965 AddNodeAttr("dtype", DT_INT32, &begin_def);
966 begin_def.set_device(std::string(device));
967 TensorProto begin_tensor_proto;
968 begin_tensor_proto.set_dtype(DT_INT32);
969 for (int i = 0; i < shape.dims(); ++i) {
970 begin_tensor_proto.add_int_val(0);
971 }
972 TensorShape begin_shape({shape.dims()});
973 begin_shape.AsProto(begin_tensor_proto.mutable_tensor_shape());
974 AddNodeAttr("value", begin_tensor_proto, &begin_def);
975 Node* begin_node = graph->AddNode(begin_def, &s);
976 TF_RETURN_IF_ERROR(s);
977
978 // Add size node.
979 NodeDef size_def;
980 size_def.set_name(
981 graph->NewName(absl::StrCat(concat_node->name(), "/slice_size")));
982 size_def.set_op("Const");
983 AddNodeAttr("dtype", DT_INT32, &size_def);
984 size_def.set_device(std::string(device));
985 TensorProto sizes_tensor_proto;
986 sizes_tensor_proto.set_dtype(DT_INT32);
987 for (int i = 0; i < shape.dims(); ++i) {
988 sizes_tensor_proto.add_int_val(shape.dim_size(i));
989 }
990 TensorShape sizes_shape({shape.dims()});
991 sizes_shape.AsProto(sizes_tensor_proto.mutable_tensor_shape());
992 AddNodeAttr("value", sizes_tensor_proto, &size_def);
993 Node* size_node = graph->AddNode(size_def, &s);
994 TF_RETURN_IF_ERROR(s);
995
996 // Add Slice node.
997 NodeDef slice_def;
998 slice_def.set_name(
999 graph->NewName(absl::StrCat(concat_node->name(), "/slice")));
1000 slice_def.set_op("Slice");
1001 slice_def.set_device(std::string(device));
1002 AddNodeAttr("T", dtype, &slice_def);
1003 AddNodeAttr("Index", DT_INT32, &slice_def);
1004 slice_def.add_input(absl::StrCat(concat_node->name(), ":", concat_out_index));
1005 slice_def.add_input(absl::StrCat(begin_node->name(), ":0"));
1006 slice_def.add_input(absl::StrCat(size_node->name(), ":0"));
1007 Node* slice_node = graph->AddNode(slice_def, &s);
1008 TF_RETURN_IF_ERROR(s);
1009 // Add edges for slice node.
1010 graph->AddEdge(concat_node, concat_out_index, slice_node, 0);
1011 graph->AddEdge(begin_node, 0, slice_node, 1);
1012 graph->AddEdge(size_node, 0, slice_node, 2);
1013 return slice_node;
1014 }
1015
1016 // Creates a set of Concat nodes that aggregates sharded outputs from TPUExecute
1017 // nodes into a single output. Sharded outputs are concatenated along row major
1018 // order. That is, tiled output along 0th dimension will be concatenated last.
CreateConcatNodesForRetval(const xla::OpSharding & sharding,DataType dtype,const PartialTensorShape & inferred_shape,int replica_id,const std::vector<NodeOut> & orig_inputs,Graph * graph,absl::string_view device)1019 xla::StatusOr<Node*> CreateConcatNodesForRetval(
1020 const xla::OpSharding& sharding, DataType dtype,
1021 const PartialTensorShape& inferred_shape, int replica_id,
1022 const std::vector<NodeOut>& orig_inputs, Graph* graph,
1023 absl::string_view device) {
1024 std::map<int, int> split_dimension_map;
1025 TF_RETURN_IF_ERROR(GetDimensionIndicesAndNumSplitsFromSharding(
1026 sharding, &split_dimension_map));
1027 std::vector<NodeOut> inputs_to_sharded_retval = orig_inputs;
1028 bool has_paddings = false;
1029
1030 for (auto it = split_dimension_map.rbegin(); it != split_dimension_map.rend();
1031 it++) {
1032 auto dim = it->first;
1033 auto num_splits = it->second;
1034
1035 int num_concat_nodes = inputs_to_sharded_retval.size() / num_splits;
1036 int input_index_to_concat_node = 0;
1037
1038 std::vector<NodeOut> new_concat_nodes;
1039 for (int i = 0; i < num_concat_nodes; ++i) {
1040 auto concat_input_it =
1041 inputs_to_sharded_retval.begin() + input_index_to_concat_node;
1042 std::vector<NodeOut> inputs(concat_input_it,
1043 concat_input_it + num_splits);
1044 input_index_to_concat_node += num_splits;
1045
1046 TF_ASSIGN_OR_RETURN(
1047 Node * concat_node,
1048 CreateConcatNode(
1049 dim, num_splits, dtype,
1050 absl::StrCat("sharded_output/replica_", replica_id, "_dim_", dim),
1051 inputs, graph, device));
1052 int64 paddings = GetPadding(dim, num_splits, inferred_shape);
1053 has_paddings |= paddings > 0;
1054 new_concat_nodes.emplace_back(NodeOut{concat_node, 0});
1055 }
1056 inputs_to_sharded_retval = new_concat_nodes;
1057 }
1058
1059 TF_RET_CHECK(inputs_to_sharded_retval.size() == 1);
1060 if (has_paddings) {
1061 TF_ASSIGN_OR_RETURN(Node * slice_node,
1062 CreateSliceNode(dtype, inferred_shape,
1063 inputs_to_sharded_retval.at(0).node,
1064 /*concat_out_index*/ 0, graph, device));
1065 return slice_node;
1066 }
1067 return inputs_to_sharded_retval.at(0).node;
1068 }
1069
1070 // Set the padding ops the same devices as the original inputs. If the original
1071 // inputs are on TPUs, the padding ops will be placed on TPUs and XLA on demand
1072 // mode will be triggered, so we don't need to copy the data back to the host
1073 // to do the padding.
SetPaddingNodesDevices(Graph * graph)1074 Status SetPaddingNodesDevices(Graph* graph) {
1075 for (Node* n : graph->op_nodes()) {
1076 bool tpu_padding_attr;
1077 if (n->type_string() == "Pad" &&
1078 GetNodeAttr(n->attrs(), kPostDeviceRewriteAttr, &tpu_padding_attr)
1079 .ok()) {
1080 Node* unpadded_input;
1081 TF_RETURN_IF_ERROR(n->input_node(0, &unpadded_input));
1082
1083 const string& requested_device = unpadded_input->requested_device();
1084 const string& assigned_device = unpadded_input->assigned_device_name();
1085 if (!requested_device.empty() || !assigned_device.empty()) {
1086 // The output nodes of the original unpadded inputs include the padded
1087 // inputs and real shapes of inputs, we assign those to the same device
1088 // as the original inputs.
1089 for (Node* out : unpadded_input->out_nodes()) {
1090 if (GetNodeAttr(out->attrs(), kPostDeviceRewriteAttr,
1091 &tpu_padding_attr)
1092 .ok()) {
1093 out->set_requested_device(requested_device);
1094 out->set_assigned_device_name(assigned_device);
1095 }
1096 }
1097 // There might be a tf.shape node added before TPUCompileOp, we need to
1098 // set its device as well.
1099 for (Node* out : n->out_nodes()) {
1100 if (n->type_string() == "Shape") {
1101 out->set_requested_device(requested_device);
1102 out->set_assigned_device_name(assigned_device);
1103 }
1104 }
1105 }
1106 }
1107 }
1108 return Status::OK();
1109 }
1110
AssignedOrRequestedDevice(const Node * node)1111 const string& AssignedOrRequestedDevice(const Node* node) {
1112 if (!node->assigned_device_name().empty()) {
1113 return node->assigned_device_name();
1114 }
1115 return node->requested_device();
1116 }
1117
IsTpuDevice(const string & device_string)1118 bool IsTpuDevice(const string& device_string) {
1119 DeviceNameUtils::ParsedName device;
1120 return DeviceNameUtils::ParseFullName(device_string, &device) &&
1121 device.type == DEVICE_TPU_NODE;
1122 }
1123
1124 // Returns a set of device ops can be placed on TPU. There is no strict rule of
1125 // thumb to decide which ops should be in the list, but empirically they are
1126 // mostly dummy ops like Identity-like ops or control flow related ops. However
1127 // people can add also add other ops like Pad to allow data stay on TPU.
PlaceOnTPUOpList()1128 const absl::flat_hash_set<std::string>& PlaceOnTPUOpList() {
1129 static const auto place_on_tpu_ops = new absl::flat_hash_set<std::string>(
1130 {"Identity", "IdentityN", "Enter", "Exit", "Switch", "Merge",
1131 "NextIteration", "Shape", "_Retval"});
1132 return *place_on_tpu_ops;
1133 }
1134
1135 // If an op satisfies the following conditions, it will be placed on the same
1136 // TPU device as its inputs:
1137 // (1) The op can be placed on TPU (in the PlaceOnTPUOpList)
1138 // (2) The op itself has no requested or assigned devices.
1139 // (3) All the data inputs of this op are placed on the same device on TPUs.
1140 // There are exceptions like the NextIterations input of Switch node can
1141 // be placed on CPU as it is just a boolean.
1142 //
1143 // Returns true if the node device has been changed, otherwise returns false.
PlaceOpsOnTPU(Node * node)1144 bool PlaceOpsOnTPU(Node* node) {
1145 if (!AssignedOrRequestedDevice(node).empty() ||
1146 !PlaceOnTPUOpList().contains(node->type_string())) {
1147 return false;
1148 }
1149 string src_tpu_device = "";
1150 Node* src_node;
1151 for (const Edge* e : node->in_edges()) {
1152 if (e->IsControlEdge()) {
1153 continue;
1154 }
1155 Node* src = e->src();
1156 const string& src_device = AssignedOrRequestedDevice(src);
1157
1158 // Make exceptions that we don't force the some inputs to place on TPUs.
1159 if (node->IsSwitch() && src->IsLoopCond()) {
1160 continue;
1161 }
1162
1163 if (!IsTpuDevice(src_device) ||
1164 (!src_tpu_device.empty() && src_device != src_tpu_device)) {
1165 return false;
1166 }
1167 if (src_tpu_device.empty()) {
1168 src_tpu_device = src_device;
1169 src_node = src;
1170 }
1171 }
1172 node->set_assigned_device_name(src_node->assigned_device_name());
1173 node->set_requested_device(src_node->requested_device());
1174 return true;
1175 }
1176
CreateOpMetadataFromNode(const Node & node)1177 xla::OpMetadata CreateOpMetadataFromNode(const Node& node) {
1178 xla::OpMetadata metadata;
1179 metadata.set_op_type(node.type_string());
1180 metadata.set_op_name(node.name());
1181 return metadata;
1182 }
1183
1184 // Helper struct holding node (nullable) and associated sharding.
1185 struct NodeAndSharding {
NodeAndShardingtensorflow::__anonf0ad67bd0111::NodeAndSharding1186 explicit NodeAndSharding(const Node* node, const xla::OpSharding& sharding)
1187 : node(node), sharding(sharding) {}
1188
1189 const Node* node;
1190 xla::OpSharding sharding;
1191 };
1192
1193 // Validate sharding configuration derived from XlaSharding attribute.
1194 // Infer the core id from the OpSharding, if necessary.
ParseAndValidateSharding(const NodeAndSharding & node_and_sharding,const int num_cores_per_replica,int64 * inferred_core_id,absl::optional<NodeAndSharding> * result)1195 Status ParseAndValidateSharding(const NodeAndSharding& node_and_sharding,
1196 const int num_cores_per_replica,
1197 int64* inferred_core_id,
1198 absl::optional<NodeAndSharding>* result) {
1199 if (node_and_sharding.sharding.type() == xla::OpSharding::MAXIMAL) {
1200 int64 core_annotation =
1201 node_and_sharding.sharding.tile_assignment_devices(0);
1202 TF_RETURN_IF_ERROR(
1203 ValidateCoreNumber(core_annotation, num_cores_per_replica));
1204 if (*inferred_core_id == -1 || *inferred_core_id > core_annotation) {
1205 *inferred_core_id = core_annotation;
1206 result->emplace(node_and_sharding);
1207 }
1208 } else {
1209 if (node_and_sharding.sharding.type() == xla::OpSharding::OTHER) {
1210 for (int64 core : node_and_sharding.sharding.tile_assignment_devices()) {
1211 TF_RETURN_IF_ERROR(ValidateCoreNumber(core, num_cores_per_replica));
1212 }
1213 }
1214
1215 if (!result->has_value()) {
1216 *result = node_and_sharding;
1217 } else {
1218 std::string result_value_serialized;
1219 xla::OpSharding result_value = result->value().sharding;
1220 result_value.clear_metadata();
1221 SerializeToStringDeterministic(result_value, &result_value_serialized);
1222
1223 std::string sharding_serialized;
1224 xla::OpSharding sharding = node_and_sharding.sharding;
1225 sharding.clear_metadata();
1226 SerializeToStringDeterministic(sharding, &sharding_serialized);
1227
1228 // TODO(lyandy): Choose the more granular sharding instead of always
1229 // assigning to core 0 (maximal).
1230 if (result_value_serialized != sharding_serialized) {
1231 // We see different shardings, assign to core 0.
1232 auto core_zero_sharding = xla::sharding_builder::AssignDevice(0);
1233 DCHECK_NE(node_and_sharding.node, nullptr);
1234 *core_zero_sharding.add_metadata() =
1235 CreateOpMetadataFromNode(*node_and_sharding.node);
1236 result->emplace(
1237 NodeAndSharding(node_and_sharding.node, core_zero_sharding));
1238 }
1239 }
1240 }
1241 return Status::OK();
1242 }
1243
1244 // As XlaSharding node may be followed by Cast op or an Identity op,
1245 // recursively walk the graph and aggregate nodes connectd to
1246 // |input_node| or Cast/Identity op following the |input_node|.
FindNodesMaybeContainingShardingInfo(const Node & input_node,std::vector<const Node * > * nodes)1247 void FindNodesMaybeContainingShardingInfo(const Node& input_node,
1248 std::vector<const Node*>* nodes) {
1249 if (input_node.IsIdentity() || input_node.type_string() == "Cast") {
1250 for (const Node* connected_node : input_node.out_nodes())
1251 FindNodesMaybeContainingShardingInfo(*connected_node, nodes);
1252 }
1253 nodes->emplace_back(&input_node);
1254 }
1255
1256 // Parse sharding configuration from |node| or it's adjacent nodes.
1257 // XlaSharding configuration may be derived from
1258 // a) Connected Identity op node.
1259 // b) Connected Cast op node.
1260 xla::StatusOr<absl::optional<NodeAndSharding>>
ParseInputShardingFromAdjacentNode(const int num_cores_per_replica,const Node & node)1261 ParseInputShardingFromAdjacentNode(const int num_cores_per_replica,
1262 const Node& node) {
1263 // If |node| has `device` attribute or is a XlaSharding op,
1264 // return the parsed OpSharding.
1265 TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
1266 ParseShardingFromDevice(node, num_cores_per_replica,
1267 /*add_metadata=*/true));
1268 if (sharding.has_value()) {
1269 return absl::optional<NodeAndSharding>(NodeAndSharding(&node, *sharding));
1270 }
1271
1272 // XlaShardingOp may be followed by an identity or followed by identity
1273 // and a Cast op.
1274 std::vector<const Node*> potential_nodes_with_input_sharding;
1275 FindNodesMaybeContainingShardingInfo(node,
1276 &potential_nodes_with_input_sharding);
1277 for (const Node* maybe_node_with_sharding_info :
1278 potential_nodes_with_input_sharding) {
1279 if (maybe_node_with_sharding_info->type_string() != "XlaSharding") continue;
1280
1281 TF_ASSIGN_OR_RETURN(
1282 absl::optional<xla::OpSharding> sharding_config,
1283 ParseShardingFromDevice(*maybe_node_with_sharding_info,
1284 num_cores_per_replica, /*add_metadata=*/true));
1285 if (sharding_config.has_value()) {
1286 return absl::optional<NodeAndSharding>(
1287 NodeAndSharding(maybe_node_with_sharding_info, *sharding_config));
1288 }
1289 }
1290 return absl::optional<NodeAndSharding>();
1291 }
1292
1293 // Walk the graph from an argument node to find OpSharding configuration
1294 // from its neighbor nodes. Sharding configuration may be inferred from
1295 // 1) Parsing XlaSharding attribute from neighboring node.
1296 // 2) If argument node is a resource, then by parsing adjacent nodes
1297 // of the connected ReadVariable op.
ParseAndValidateShardingFromNeighbors(const int num_cores_per_replica,const std::string & arg_node_name,const Node & neighbor_node,int64 * inferred_core_id,bool * is_fast_mem,absl::optional<NodeAndSharding> * result)1298 Status ParseAndValidateShardingFromNeighbors(
1299 const int num_cores_per_replica, const std::string& arg_node_name,
1300 const Node& neighbor_node, int64* inferred_core_id, bool* is_fast_mem,
1301 absl::optional<NodeAndSharding>* result) {
1302 if (neighbor_node.attrs().Find(TPU_FAST_MEM_ATTR) != nullptr) {
1303 *is_fast_mem = true;
1304 VLOG(2) << "place " << neighbor_node.name() << " on fast memory because "
1305 << arg_node_name << " has " << TPU_FAST_MEM_ATTR << " attribute";
1306 }
1307
1308 // XlaSharding information may be encoded on node directly connected to the
1309 // argument node.
1310 TF_ASSIGN_OR_RETURN(
1311 absl::optional<NodeAndSharding> node_and_sharding,
1312 ParseInputShardingFromAdjacentNode(num_cores_per_replica, neighbor_node));
1313 if (node_and_sharding.has_value()) {
1314 TF_RETURN_IF_ERROR(ParseAndValidateSharding(
1315 *node_and_sharding, num_cores_per_replica, inferred_core_id, result));
1316 return Status::OK();
1317 }
1318
1319 // When we use variable in TPU computation, we always have a
1320 // XlaSharding op followed by a ReadVariableOp. As so, correctly parse
1321 // the users of ReadVariableOp for potential sharding configuration.
1322 if (neighbor_node.type_string() == "ReadVariableOp") {
1323 for (const Edge* e : neighbor_node.out_edges()) {
1324 if (e->IsControlEdge()) continue;
1325
1326 if (e->dst()->attrs().Find(TPU_FAST_MEM_ATTR) != nullptr) {
1327 *is_fast_mem = true;
1328 VLOG(2) << "place " << arg_node_name << " on fast memory because "
1329 << e->dst()->name() << TPU_FAST_MEM_ATTR << " attribute";
1330 }
1331
1332 TF_ASSIGN_OR_RETURN(
1333 absl::optional<NodeAndSharding> node_and_sharding,
1334 ParseInputShardingFromAdjacentNode(num_cores_per_replica, *e->dst()));
1335 if (node_and_sharding.has_value()) {
1336 TF_RETURN_IF_ERROR(ParseAndValidateSharding(*node_and_sharding,
1337 num_cores_per_replica,
1338 inferred_core_id, result));
1339 return Status::OK();
1340 }
1341 }
1342 }
1343 return Status::OK();
1344 }
1345
1346 } // namespace
1347
1348 // Inputs:
1349 // replication_spec_string: the device to which the TPUReplicate node was
1350 // assigned.
1351 // device_set: the set of TF devices.
1352 // Outputs:
1353 // tpu_compilation_device: the name of the TPU compilation device.
1354 // num_tpus_per_task: the number of TPUs in each task. Verifies that all tasks
1355 // have the same number of TPU devices.
1356 // tpu_devices: the TPU devices, indexed by [task][device].
GetTPUDeviceNames(const string & replication_spec_string,const DeviceSet & device_set,string * tpu_compilation_device,int * num_tpus_per_task,std::vector<std::vector<Device * >> * tpu_devices)1357 static Status GetTPUDeviceNames(
1358 const string& replication_spec_string, const DeviceSet& device_set,
1359 string* tpu_compilation_device, int* num_tpus_per_task,
1360 std::vector<std::vector<Device*>>* tpu_devices) {
1361 // TODO(b/110910013) GetSystemDevice parses the spec and returns the name of
1362 // the tpu_system device, which we replace by the cpu device. We do this
1363 // replacement because we want to place the TPUCompileOp (and the compile
1364 // assert op) explicitly on cpu devices on the same job as the tpu_system
1365 // device.
1366 DeviceNameUtils::ParsedName replication_spec;
1367 Device* replication_device;
1368 TF_RETURN_IF_ERROR(DistributedTPURewriteHelpers::GetSystemDevice(
1369 replication_spec_string, device_set, &replication_spec,
1370 &replication_device));
1371 *tpu_compilation_device =
1372 str_util::StringReplace(replication_device->name(), DEVICE_TPU_SYSTEM,
1373 DEVICE_CPU, /*replace_all=*/true);
1374
1375 // Finds the set of TPU devices attached to the tasks in the job.
1376 TF_RETURN_IF_ERROR(DistributedTPURewriteHelpers::GetTPUDevices(
1377 replication_spec, device_set, num_tpus_per_task, tpu_devices));
1378
1379 return Status::OK();
1380 }
1381
1382 // Parses the topology attribute of TPUReplicate, and populates *topology with
1383 // a physical mesh coordinate to (task, device) mapping.
ParseTopologyAttr(const string & topology_attr,const tpu::TpuTopologyExternal & tpu_topology,int num_tasks,int num_tpus_per_task,xla::Array4D<std::pair<int,int>> * topology)1384 static Status ParseTopologyAttr(const string& topology_attr,
1385 const tpu::TpuTopologyExternal& tpu_topology,
1386 int num_tasks, int num_tpus_per_task,
1387 xla::Array4D<std::pair<int, int>>* topology) {
1388 static_assert(4 == kTPUTopologyRank, "Assumes the topology rank is 4");
1389 tpu::TopologyProto proto;
1390 proto.ParseFromString(topology_attr);
1391 if (proto.mesh_shape_size() != kTPUTopologyRank) {
1392 return errors::InvalidArgument("TPU topology must be rank ",
1393 kTPUTopologyRank);
1394 }
1395 if (proto.num_tasks() != num_tasks) {
1396 return errors::InvalidArgument("Mismatched number of TPU tasks");
1397 }
1398 if (proto.num_tpu_devices_per_task() != num_tpus_per_task) {
1399 return errors::InvalidArgument("Mismatched number of TPUs per task (",
1400 proto.num_tpu_devices_per_task(),
1401 " != ", num_tpus_per_task, ").");
1402 }
1403 if (proto.device_coordinates_size() !=
1404 num_tasks * num_tpus_per_task * kTPUTopologyRank) {
1405 return errors::InvalidArgument(
1406 "device coordinates should be ", num_tasks, "x", num_tpus_per_task, "x",
1407 kTPUTopologyRank, "; got ", proto.device_coordinates_size());
1408 }
1409
1410 int devices_per_chip = tpu_topology.LogicalDevicesPerChip(kTensorCore);
1411 *topology = xla::Array4D<std::pair<int, int>>(
1412 tpu_topology.chip_bounds().x, tpu_topology.chip_bounds().y,
1413 tpu_topology.chip_bounds().z, devices_per_chip, {-1, -1});
1414 int pos = 0;
1415 for (int task = 0; task < num_tasks; ++task) {
1416 for (int device = 0; device < num_tpus_per_task; ++device) {
1417 int32 x = proto.device_coordinates(pos++);
1418 int32 y = proto.device_coordinates(pos++);
1419 int32 z = proto.device_coordinates(pos++);
1420 int32 core = proto.device_coordinates(pos++);
1421
1422 if (!tpu_topology.HasChip(x, y, z) || core < 0 ||
1423 core >= devices_per_chip) {
1424 return errors::InvalidArgument(
1425 "Mesh coordinates (", x, ",", y, ",", z, ",", core,
1426 ") are not valid for the current TPU topology");
1427 }
1428 if ((*topology)(x, y, z, core).first != -1) {
1429 return errors::InvalidArgument("Duplicate coordinates (", x, ",", y,
1430 ",", z, ",", core, ") in TPU topology");
1431 }
1432 (*topology)(x, y, z, core) = {task, device};
1433 }
1434 }
1435 return Status::OK();
1436 }
1437
1438 // Parses the value of the device_assignment attribute to TPUReplicate.
1439 // Populates *device_assignment; *device_assignment must be a 2D array with
1440 // shape (num_replicas, num_cores_per_replica).
ParseDeviceAssignmentAttr(absl::Span<const int> device_assignment_attr,const tpu::TpuTopologyExternal & tpu_topology,int num_replicas,int num_cores_per_replica,xla::Array2D<tpu::TpuCoreLocationExternal> * device_assignment)1441 static Status ParseDeviceAssignmentAttr(
1442 absl::Span<const int> device_assignment_attr,
1443 const tpu::TpuTopologyExternal& tpu_topology, int num_replicas,
1444 int num_cores_per_replica,
1445 xla::Array2D<tpu::TpuCoreLocationExternal>* device_assignment) {
1446 static_assert(4 == kTPUTopologyRank, "Assumes the topology rank is 4");
1447
1448 const int64 device_assignment_attr_size =
1449 num_replicas * num_cores_per_replica * kTPUTopologyRank;
1450 if (device_assignment_attr.size() != device_assignment_attr_size) {
1451 return errors::InvalidArgument(
1452 "Length of device_assignment attribute must be equal to num_replicas (",
1453 num_replicas, ") * num_cores_per_replica (", num_cores_per_replica,
1454 ") * ", kTPUTopologyRank, " got ", device_assignment_attr.size());
1455 }
1456 for (int core : device_assignment_attr) {
1457 if (core < 0 || core >= kTPUMaxTopologySize) {
1458 return errors::InvalidArgument(
1459 "Invalid core number in device assignment: ", core);
1460 }
1461 }
1462
1463 *device_assignment = xla::Array2D<tpu::TpuCoreLocationExternal>(
1464 num_replicas, num_cores_per_replica);
1465 int devices_per_chip = tpu_topology.LogicalDevicesPerChip(kTensorCore);
1466 xla::Array4D<int> replica_assignment(
1467 tpu_topology.chip_bounds().x, tpu_topology.chip_bounds().y,
1468 tpu_topology.chip_bounds().z, devices_per_chip, -1);
1469 int pos = 0;
1470 for (int replica = 0; replica < num_replicas; ++replica) {
1471 for (int logical_core = 0; logical_core < num_cores_per_replica;
1472 ++logical_core) {
1473 int32 x = device_assignment_attr[pos++];
1474 int32 y = device_assignment_attr[pos++];
1475 int32 z = device_assignment_attr[pos++];
1476 int32 core = device_assignment_attr[pos++];
1477
1478 if (!tpu_topology.HasChip(x, y, z) || core < 0 ||
1479 core >= devices_per_chip) {
1480 return errors::InvalidArgument(
1481 "Mesh coordinates (", x, ",", y, ",", core,
1482 ") are not valid for the current TPU topology");
1483 }
1484 tpu::TpuCoreLocationExternal core_location =
1485 tpu_topology.Core(kTensorCore, x, y, z, core);
1486
1487 if (replica_assignment(x, y, z, core) != -1) {
1488 return errors::InvalidArgument("Duplicate coordinates (", x, ",", y,
1489 ",", z, ",", core,
1490 ") in TPU device assignment");
1491 }
1492 replica_assignment(x, y, z, core) = replica;
1493 (*device_assignment)(replica, logical_core) = core_location;
1494 }
1495 }
1496 return Status::OK();
1497 }
1498
1499 // Builds TensorFlow device assignments for the special case of a single core
1500 // computation that is replicated to every core in the mesh.
1501 // LINT.IfChange
BuildFullMeshDeviceAssignment(int num_replicas,const std::vector<std::vector<Device * >> & tpu_devices,int num_tasks,int num_tpus_per_task,std::vector<std::vector<string>> * tf_device_assignment)1502 static Status BuildFullMeshDeviceAssignment(
1503 int num_replicas, const std::vector<std::vector<Device*>>& tpu_devices,
1504 int num_tasks, int num_tpus_per_task,
1505 std::vector<std::vector<string>>* tf_device_assignment) {
1506 // Assign TensorFlow devices to replicas arbitrarily.
1507 for (int i = 0; i < num_replicas; ++i) {
1508 int task = i / num_tpus_per_task;
1509 int device = i % num_tpus_per_task;
1510 TF_RET_CHECK(task >= 0 && task < num_tasks);
1511 TF_RET_CHECK(device >= 0 && device < num_tpus_per_task);
1512
1513 // We don't actually know which TF device corresponds to which physical
1514 // device, but it doesn't matter—they're all identical.
1515 (*tf_device_assignment)[i] = {tpu_devices[task][device]->name()};
1516 }
1517 return Status::OK();
1518 }
1519 // LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc)
1520
1521 // Builds TensorFlow device assignments for a replicated computation and convert
1522 // device_assignment into xla_device_assignment.
BuildGeneralDeviceAssignment(int num_replicas,int num_cores_per_replica,const std::vector<std::vector<Device * >> & tpu_devices,const xla::Array2D<tpu::TpuCoreLocationExternal> & device_assignment,const xla::Array4D<std::pair<int,int>> & topology,std::vector<std::vector<string>> * tf_device_assignment,std::unique_ptr<xla::DeviceAssignment> * xla_device_assignment)1523 static Status BuildGeneralDeviceAssignment(
1524 int num_replicas, int num_cores_per_replica,
1525 const std::vector<std::vector<Device*>>& tpu_devices,
1526 const xla::Array2D<tpu::TpuCoreLocationExternal>& device_assignment,
1527 const xla::Array4D<std::pair<int, int>>& topology,
1528 std::vector<std::vector<string>>* tf_device_assignment,
1529 std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment) {
1530 // Assign TensorFlow devices to each computation's replicas according to
1531 // device_assignment and 'topology'.
1532 *xla_device_assignment = absl::make_unique<xla::DeviceAssignment>(
1533 num_replicas, num_cores_per_replica);
1534 for (int replica = 0; replica < num_replicas; ++replica) {
1535 for (int computation = 0; computation < num_cores_per_replica;
1536 ++computation) {
1537 const tpu::TpuCoreLocationExternal& core_location =
1538 device_assignment(replica, computation);
1539
1540 int task;
1541 int device;
1542 std::tie(task, device) =
1543 topology(core_location.chip_coordinates().x,
1544 core_location.chip_coordinates().y,
1545 core_location.chip_coordinates().z, core_location.index());
1546
1547 CHECK_LT(computation, num_cores_per_replica);
1548 (**xla_device_assignment)(replica, computation) = core_location.Id();
1549
1550 // The communication pattern between replicas will be determined later by
1551 // BuildAllReduceRing.
1552 TF_RET_CHECK(task >= 0 && task < tpu_devices.size());
1553 TF_RET_CHECK(device >= 0 && device < tpu_devices[task].size());
1554 (*tf_device_assignment)[replica].push_back(
1555 tpu_devices[task][device]->name());
1556 }
1557 }
1558 return Status::OK();
1559 }
1560
BuildDeviceAssignment(const tpu::TpuTopologyExternal & tpu_topology,int num_tpus_per_task,const std::vector<std::vector<Device * >> & tpu_devices,int num_replicas,int num_cores_per_replica,const string & topology_attr,absl::Span<const int> device_assignment_attr,std::vector<std::vector<string>> * tf_device_assignment,std::unique_ptr<xla::DeviceAssignment> * xla_device_assignment)1561 /*static*/ Status DistributedTPURewritePass::BuildDeviceAssignment(
1562 const tpu::TpuTopologyExternal& tpu_topology, int num_tpus_per_task,
1563 const std::vector<std::vector<Device*>>& tpu_devices, int num_replicas,
1564 int num_cores_per_replica, const string& topology_attr,
1565 absl::Span<const int> device_assignment_attr,
1566 std::vector<std::vector<string>>* tf_device_assignment,
1567 std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment) {
1568 const int num_tasks = tpu_devices.size();
1569 const int num_tpu_devices = num_tasks * num_tpus_per_task;
1570 VLOG(2) << "num_tasks=" << num_tasks
1571 << " num_tpus_per_task=" << num_tpus_per_task;
1572
1573 // Checks num_replicas is sane first to avoid integer overflow.
1574 if (num_replicas > num_tpu_devices) {
1575 #ifdef PLATFORM_CLOUD_TPU
1576 return errors::InvalidArgument("Requested num_replicas=", num_replicas,
1577 " but there are only ", num_tpu_devices,
1578 " cores in the TPU topology.");
1579 #else
1580 return errors::InvalidArgument("Requested num_replicas=", num_replicas,
1581 " but there are only ", num_tpu_devices,
1582 " cores in the TPU topology.");
1583 #endif
1584 }
1585 if (num_replicas * num_cores_per_replica > num_tpu_devices) {
1586 return errors::InvalidArgument(
1587 "Requested num_replicas=", num_replicas, " with ",
1588 num_cores_per_replica, " cores per replica, but there are only ",
1589 num_tpu_devices, " cores in the TPU topology");
1590 }
1591
1592 tf_device_assignment->clear();
1593 tf_device_assignment->resize(num_replicas);
1594
1595 // Special case: we allow the user to omit the topology and device assignment
1596 // information in two cases:
1597 // * there is only one replica and one core per replica. In this case, we
1598 // don't need to know topology information because we don't communicate with
1599 // other cores.
1600 // * the number of replicas is equal to the number of cores in the slice. In
1601 // this case, all cores are running the same program so we don't need to
1602 // know which is which.
1603 if (topology_attr.empty()) {
1604 // LINT.IfChange
1605 if (num_replicas != 1 && num_replicas != num_tpu_devices) {
1606 return errors::InvalidArgument(
1607 "TPUReplicate asked to create ", num_replicas,
1608 " replicas, but the number of cores in the TPU topology is ",
1609 num_tpu_devices,
1610 " and no TPU device assignment was supplied. "
1611 "A TPU device assignment is required if the number of replicas is "
1612 "not 1 or the number of cores in the topology (",
1613 num_tpu_devices, ")");
1614 }
1615
1616 if (num_cores_per_replica != 1) {
1617 return errors::InvalidArgument(
1618 "A TPU topology must be provided if num_cores_per_replica != 1");
1619 }
1620
1621 if (!device_assignment_attr.empty()) {
1622 return errors::InvalidArgument(
1623 "A TPU topology must be provided if device_assignment_attr is "
1624 "non-empty");
1625 }
1626
1627 // If there is only one replica, assign the Tensorflow computation to task 0
1628 // device 0, and leave the XLA device assignment empty. We don't know which
1629 // core this is in the TPU topology, but it doesn't matter—we don't need to
1630 // communicate with any other cores.
1631 if (num_replicas == 1) {
1632 (*tf_device_assignment)[0] = {tpu_devices[0][0]->name()};
1633 return Status::OK();
1634 }
1635
1636 // Otherwise, num_replicas is equal to the number of cores, and we build a
1637 // device assignment that covers the entire mesh. We do not need to know
1638 // the topology to do so because all cores are identical.
1639 return BuildFullMeshDeviceAssignment(num_replicas, tpu_devices, num_tasks,
1640 num_tpus_per_task,
1641 tf_device_assignment);
1642 // LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc)
1643 }
1644
1645 // Array that maps mesh coordinates to {TF task, TF TPU device #} pairs.
1646 xla::Array4D<std::pair<int, int>> topology;
1647 TF_RETURN_IF_ERROR(ParseTopologyAttr(topology_attr, tpu_topology, num_tasks,
1648 num_tpus_per_task, &topology));
1649
1650 // Array that maps logical (replica, core) pairs to physical mesh coordinates.
1651 xla::Array2D<tpu::TpuCoreLocationExternal> device_assignment;
1652 TF_RETURN_IF_ERROR(ParseDeviceAssignmentAttr(
1653 device_assignment_attr, tpu_topology, num_replicas, num_cores_per_replica,
1654 &device_assignment));
1655
1656 return BuildGeneralDeviceAssignment(
1657 num_replicas, num_cores_per_replica, tpu_devices, device_assignment,
1658 topology, tf_device_assignment, xla_device_assignment);
1659 }
1660
GetComputationForTPUReplicateOp(const NameAttrList & function,FunctionLibraryRuntime * flr,Graph * computation,DataTypeVector * arg_types,DataTypeVector * retval_types)1661 Status DistributedTPURewritePass::GetComputationForTPUReplicateOp(
1662 const NameAttrList& function, FunctionLibraryRuntime* flr,
1663 Graph* computation, DataTypeVector* arg_types,
1664 DataTypeVector* retval_types) {
1665 FunctionLibraryRuntime::Handle handle;
1666
1667 TF_RETURN_IF_ERROR(
1668 flr->Instantiate(function.name(), AttrSlice(&function.attr()), &handle));
1669
1670 const FunctionBody* fbody = flr->GetFunctionBody(handle);
1671
1672 CopyGraph(*fbody->graph, computation);
1673 *arg_types = fbody->arg_types;
1674 *retval_types = fbody->ret_types;
1675 return Status::OK();
1676 }
1677
1678 // Grab the InferredShape corresponding to an edge input.
GetEdgeShape(const GraphShapeInfo & shape_info,const Edge & edge,const InferredShape ** info)1679 static Status GetEdgeShape(const GraphShapeInfo& shape_info, const Edge& edge,
1680 const InferredShape** info) {
1681 auto it = shape_info.find(edge.src()->name());
1682 if (it == shape_info.end()) {
1683 return errors::InvalidArgument(
1684 "Input to replicated TPU computation is missing InferredShape: ",
1685 edge.src()->name());
1686 }
1687 TF_RET_CHECK(it->second.size() > edge.src_output());
1688 *info = &it->second[edge.src_output()];
1689 return Status::OK();
1690 }
1691
GetArgAndRetvalShapes(const GraphShapeInfo & shape_info,const Node & node,const ParameterInfo & params_info,std::vector<InferredShape> * arg_shapes,std::vector<InferredShape> * retval_shapes)1692 Status DistributedTPURewritePass::GetArgAndRetvalShapes(
1693 const GraphShapeInfo& shape_info, const Node& node,
1694 const ParameterInfo& params_info, std::vector<InferredShape>* arg_shapes,
1695 std::vector<InferredShape>* retval_shapes) {
1696 std::vector<const Edge*> input_edges;
1697 TF_RETURN_IF_ERROR(node.input_edges(&input_edges));
1698
1699 // If any replica's arg shape is unknown, we will mark the computation's arg
1700 // shape as being unknown. If the shapes differ the TpuExecute Op will raise a
1701 // runtime error.
1702 std::vector<bool> any_replica_shape_unknown(
1703 params_info.NumInputsToEachReplica());
1704 arg_shapes->clear();
1705 arg_shapes->resize(params_info.NumInputsToEachReplica());
1706 TF_RET_CHECK(input_edges.size() == params_info.NumInputsFromHost());
1707 // Determines the shapes of the per-replica arguments and checks that all
1708 // replicas have identical shapes.
1709 int64 edge_pos = 0;
1710 auto check_shape = [&](int input_index) -> Status {
1711 const InferredShape* info;
1712 TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info));
1713 ++edge_pos;
1714
1715 if ((info->handle_type == DT_INVALID && !info->shape.IsFullyDefined()) ||
1716 (info->handle_type != DT_INVALID &&
1717 !info->handle_shape.IsFullyDefined())) {
1718 any_replica_shape_unknown[input_index] = true;
1719 }
1720 xla::StatusOr<InferredShape> status =
1721 MergeInferredShapes((*arg_shapes)[input_index], *info);
1722 if (!status.ok()) {
1723 return errors::InvalidArgument(
1724 "Mismatched shapes for input ", input_index, ": ",
1725 (*arg_shapes)[input_index].shape.DebugString(), " vs. ",
1726 info->shape.DebugString());
1727 }
1728 (*arg_shapes)[input_index] = status.ValueOrDie();
1729 return Status::OK();
1730 };
1731
1732 for (int64 i = 0; i < params_info.NumReplicas(); ++i) {
1733 for (int64 j = 0; j < params_info.NumPerReplicaArgs(); ++j) {
1734 TF_RETURN_IF_ERROR(check_shape(j));
1735 }
1736 }
1737
1738 for (int64 i = 0; i < params_info.NumDistributedArgs(); ++i) {
1739 TF_RETURN_IF_ERROR(check_shape(params_info.NumPerReplicaArgs() + i));
1740 }
1741
1742 for (int64 i = 0;
1743 i < params_info.NumPerReplicaArgs() + params_info.NumDistributedArgs();
1744 ++i) {
1745 if (any_replica_shape_unknown[i]) {
1746 (*arg_shapes)[i].shape = PartialTensorShape();
1747 (*arg_shapes)[i].handle_shape = PartialTensorShape();
1748 }
1749 }
1750
1751 // Determines the shape of the broadcast arguments.
1752 for (int64 i = 0; i < params_info.NumBroadcastArgs(); ++i) {
1753 TF_RET_CHECK(node.input_type(edge_pos) != DT_RESOURCE);
1754 const InferredShape* info;
1755 TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info));
1756 (*arg_shapes)[i + params_info.NumPerReplicaArgs() +
1757 params_info.NumDistributedArgs()]
1758 .shape = info->shape;
1759 ++edge_pos;
1760 }
1761
1762 // Determines the handle shape and handle type of the resource variable
1763 // arguments.
1764 for (int64 i = 0; i < params_info.NumVariables(); ++i) {
1765 TF_RET_CHECK(node.input_type(edge_pos) == DT_RESOURCE);
1766 const InferredShape* info;
1767 TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info));
1768 InferredShape& arg_shape =
1769 (*arg_shapes)[i + params_info.NumPerReplicaArgs() +
1770 params_info.NumDistributedArgs() +
1771 params_info.NumBroadcastArgs()];
1772 arg_shape.shape = TensorShape(); // Variables are always scalars.
1773 arg_shape.handle_shape = info->handle_shape;
1774 arg_shape.handle_type = info->handle_type;
1775 TF_RET_CHECK(arg_shape.handle_type != DT_INVALID)
1776 << " input edge: " << input_edges[edge_pos]->DebugString();
1777 ++edge_pos;
1778 }
1779
1780 // Determines the shape of the guaranteed constants.
1781 // TODO(vinuraja): Can be removed because they are not required for any
1782 // calculations. Leaving them here for symmetry with other structures like
1783 // arg_types, arg_sharding, etc.
1784 for (int64 i = 0; i < params_info.NumGuaranteedConstants(); ++i) {
1785 TF_RET_CHECK(node.input_type(edge_pos) != DT_RESOURCE);
1786 const InferredShape* info;
1787 TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info));
1788 (*arg_shapes)[i + params_info.NumPerReplicaArgs() +
1789 params_info.NumDistributedArgs() +
1790 params_info.NumBroadcastArgs() + params_info.NumVariables()]
1791 .shape = info->shape;
1792 ++edge_pos;
1793 }
1794
1795 // Extract the return value shapes.
1796 auto it = shape_info.find(node.name());
1797 retval_shapes->clear();
1798 if (it != shape_info.end()) {
1799 TF_RET_CHECK(it->second.size() >= node.num_outputs());
1800 retval_shapes->resize(node.num_outputs());
1801 for (int i = 0; i < node.num_outputs(); ++i) {
1802 (*retval_shapes)[i].shape = it->second[i].shape;
1803 }
1804 } else if (node.num_outputs() > 0) {
1805 return errors::InvalidArgument(
1806 "Replicated TPU computation is missing InferredShape: ",
1807 FormatNodeForError(node));
1808 }
1809 return Status::OK();
1810 }
1811
1812 // Verifies that all nodes have legal sharding.
ValidateCoreNumbers(const Graph & graph,int num_cores_per_replica)1813 static Status ValidateCoreNumbers(const Graph& graph,
1814 int num_cores_per_replica) {
1815 for (Node* n : graph.nodes()) {
1816 TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
1817 ParseShardingFromDevice(*n, num_cores_per_replica,
1818 /*add_metadata=*/true));
1819 }
1820 return Status::OK();
1821 }
1822
InferXlaShardingFromNeighbors(const Node & n,int num_cores_per_replica,FunctionLibraryRuntime * flr,CachedFunctionHandles * cached_function_handles,absl::optional<NodeAndSharding> * output_node_and_sharding,bool * is_fast_mem)1823 static Status InferXlaShardingFromNeighbors(
1824 const Node& n, int num_cores_per_replica, FunctionLibraryRuntime* flr,
1825 CachedFunctionHandles* cached_function_handles,
1826 absl::optional<NodeAndSharding>* output_node_and_sharding,
1827 bool* is_fast_mem) {
1828 int64 core = -1;
1829 absl::optional<NodeAndSharding> result;
1830 // We assume the variable has been allocated on fast memory if any consuming
1831 // op has TPU_FAST_MEM_ATTR attribute. This is a protocol between runtime and
1832 // compiler.
1833 *is_fast_mem = false;
1834 for (const Edge* edge : n.out_edges()) {
1835 if (edge->IsControlEdge()) continue;
1836
1837 TF_RETURN_IF_ERROR(ParseAndValidateShardingFromNeighbors(
1838 num_cores_per_replica, n.name(), *edge->dst(), &core, is_fast_mem,
1839 &result));
1840
1841 if (!flr) continue;
1842
1843 // The nodes deciding this arg's device assignment might be in
1844 // FunctionDef. Instantiate FunctionDefs associated with this node
1845 // and check nodes using this arg.
1846 std::function<Status(const Edge* call_edge)> parse_sharding_from_function =
1847 [&](const Edge* call_edge) {
1848 auto associated_functions = GetAssociatedFunctions(
1849 *call_edge->dst(), flr->GetFunctionLibraryDefinition());
1850 for (auto& associated_function : associated_functions) {
1851 FunctionLibraryRuntime::Handle handle;
1852 TF_RETURN_IF_ERROR(cached_function_handles->GetOrInstantiate(
1853 associated_function.func_name(),
1854 AttrSlice(&associated_function.attrs()), &handle));
1855 const FunctionBody* body = flr->GetFunctionBody(handle);
1856 Graph* g = body->graph;
1857
1858 for (Node* body_node : g->nodes()) {
1859 if (!body_node->IsArg()) continue;
1860
1861 int index;
1862 TF_RETURN_IF_ERROR(
1863 GetNodeAttr(body_node->attrs(), "index", &index));
1864 if (index != call_edge->dst_input()) continue;
1865
1866 for (const Edge* out_edge : body_node->out_edges()) {
1867 if (out_edge->IsControlEdge()) continue;
1868
1869 TF_RETURN_IF_ERROR(ParseAndValidateShardingFromNeighbors(
1870 num_cores_per_replica, n.name(), *out_edge->dst(), &core,
1871 is_fast_mem, &result));
1872
1873 TF_RETURN_IF_ERROR(parse_sharding_from_function(out_edge));
1874 }
1875 }
1876 }
1877 return Status::OK();
1878 };
1879 TF_RETURN_IF_ERROR(parse_sharding_from_function(edge));
1880 }
1881 *output_node_and_sharding = result;
1882 return Status::OK();
1883 }
1884
UseSpmdForXlaPartitioning(const Node * replicate_node)1885 bool UseSpmdForXlaPartitioning(const Node* replicate_node) {
1886 bool spmd_attr;
1887 if (!replicate_node ||
1888 !TryGetNodeAttr(replicate_node->attrs(), "use_spmd_for_xla_partitioning",
1889 &spmd_attr)) {
1890 spmd_attr = false;
1891 }
1892 return spmd_attr;
1893 }
1894
FormatNodeAndShardingMsg(const absl::optional<NodeAndSharding> & node_and_sharding)1895 std::string FormatNodeAndShardingMsg(
1896 const absl::optional<NodeAndSharding>& node_and_sharding) {
1897 DCHECK(node_and_sharding.has_value());
1898
1899 xla::OpSharding sharding_no_metadata = node_and_sharding->sharding;
1900 sharding_no_metadata.clear_metadata();
1901 std::string escaped_sharding_str =
1902 absl::CEscape(sharding_no_metadata.SerializeAsString());
1903 if (node_and_sharding->node == nullptr) {
1904 return absl::StrCat(" via default sharding '", escaped_sharding_str, "'");
1905 }
1906
1907 return absl::StrCat(" via node ", node_and_sharding->node->DebugString(),
1908 " sharding '", escaped_sharding_str, "'");
1909 }
1910
AssignArgsAndRetvalsToCores(int num_cores_per_replica,const ParameterInfo & params_info,const DataTypeVector & arg_types,const std::vector<InferredShape> & arg_shapes,const DataTypeVector & retval_types,const std::vector<InferredShape> & retval_shapes,const Graph & graph,const Node * replicate_node,FunctionLibraryRuntime * flr,bool allow_parameter_replication_for_spmd,std::vector<xla::OpSharding> * arg_sharding,std::vector<bool> * arg_fast_mem,std::vector<xla::OpSharding> * retval_sharding,std::vector<std::string> * arg_names)1911 Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores(
1912 int num_cores_per_replica, const ParameterInfo& params_info,
1913 const DataTypeVector& arg_types,
1914 const std::vector<InferredShape>& arg_shapes,
1915 const DataTypeVector& retval_types,
1916 const std::vector<InferredShape>& retval_shapes, const Graph& graph,
1917 const Node* replicate_node, FunctionLibraryRuntime* flr,
1918 bool allow_parameter_replication_for_spmd,
1919 std::vector<xla::OpSharding>* arg_sharding, std::vector<bool>* arg_fast_mem,
1920 std::vector<xla::OpSharding>* retval_sharding,
1921 std::vector<std::string>* arg_names) {
1922 // Builds vectors of the argument and return nodes.
1923 std::vector<Node*> args(arg_types.size());
1924 std::vector<Node*> retvals(retval_types.size());
1925 absl::flat_hash_map<int, Node*> partitioned_output_nodes;
1926 for (Node* node : graph.op_nodes()) {
1927 if (node->IsArg()) {
1928 int index;
1929 TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index));
1930 TF_RET_CHECK(index >= 0 && index < args.size());
1931 args[index] = node;
1932 } else if (node->IsRetval()) {
1933 int index;
1934 TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index));
1935 TF_RET_CHECK(index >= 0 && index < retvals.size());
1936 retvals[index] = node;
1937 }
1938 }
1939 for (const Edge* edge : replicate_node->out_edges()) {
1940 int num_partitioned_outputs = 0;
1941 for (const Edge* out_edge : edge->dst()->out_edges()) {
1942 if (out_edge->dst()->type_string() == kTPUPartitionedOutput) {
1943 partitioned_output_nodes[edge->src_output()] = out_edge->dst();
1944 num_partitioned_outputs++;
1945 }
1946 }
1947 if (num_partitioned_outputs > 1) {
1948 return errors::InvalidArgument(
1949 "More than one TPUPartitionedOutput per replciated output.");
1950 }
1951 }
1952
1953 // Verifies there are no missing arguments/return values.
1954 for (int i = 0; i < args.size(); ++i) {
1955 if (args[i] == nullptr) {
1956 return errors::Internal("Missing function argument: ", i);
1957 }
1958 }
1959 for (int i = 0; i < retvals.size(); ++i) {
1960 if (retvals[i] == nullptr) {
1961 return errors::Internal("Missing function return value: ", i);
1962 }
1963 }
1964
1965 // Assigns a core to each _Arg. Chooses the lowest-numbered core that
1966 // consumes the argument. We choose the lowest-numbered core so the
1967 // assignment is deterministic.
1968 TensorDevicePlacer args_device_selector(num_cores_per_replica, arg_types,
1969 arg_shapes);
1970 arg_sharding->resize(args.size());
1971 arg_names->resize(args.size());
1972 arg_fast_mem->resize(args.size());
1973 CachedFunctionHandles cached_function_handles(flr);
1974 const bool use_spmd = (UseSpmdForXlaPartitioning(replicate_node) ||
1975 replicate_inputs_outputs_by_default_for_xla_spmd_) &&
1976 allow_parameter_replication_for_spmd;
1977
1978 // Offset _TPUReplicate non per replica argument indices by
1979 // (num_replicas - 1) * num_per_replica_args as _TPUReplicate nodes are
1980 // constructed with all per replica args across all replicas while the
1981 // encapsulated function only has 1 replica's per replica args. Per replica
1982 // args are ordered by replica first, so the index here does not require an
1983 // offset and the first replica's input nodes is sufficient for determining
1984 // argument sharding.
1985 const int index_offset =
1986 (params_info.NumReplicas() - 1) * params_info.NumPerReplicaArgs();
1987 for (int i = 0; i < args.size(); ++i) {
1988 const Node* n = args[i];
1989 absl::optional<int64> assigned_core;
1990 absl::optional<NodeAndSharding> node_and_sharding;
1991 bool is_fast_mem;
1992 TF_RETURN_IF_ERROR(InferXlaShardingFromNeighbors(
1993 *n, num_cores_per_replica, flr, &cached_function_handles,
1994 &node_and_sharding, &is_fast_mem));
1995
1996 const bool is_per_replica_arg = params_info.IsPerReplicaArg(i);
1997 if (is_per_replica_arg || params_info.IsDistributedArg(i)) {
1998 Node* input_node;
1999 TF_RETURN_IF_ERROR(replicate_node->input_node(
2000 i + (is_per_replica_arg ? 0 : index_offset), &input_node));
2001 if (input_node->type_string() == kTPUPartitionedInput) {
2002 TF_ASSIGN_OR_RETURN(
2003 absl::optional<xla::OpSharding> parsed_sharding,
2004 GetShardingFromNodeDef(input_node->def(), /*add_metadata=*/true));
2005 if (!parsed_sharding.has_value())
2006 return errors::InvalidArgument("Missing _XlaSharding attr from: ",
2007 input_node->DebugString());
2008 node_and_sharding = NodeAndSharding(input_node, *parsed_sharding);
2009 VLOG(1) << "Arg " << i << " parsed sharding information from "
2010 << input_node->DebugString() << " : "
2011 << parsed_sharding->DebugString();
2012 }
2013 }
2014
2015 if (params_info.IsVariableArg(i)) {
2016 Node* input_node;
2017 TF_RETURN_IF_ERROR(
2018 replicate_node->input_node(i + index_offset, &input_node));
2019 if (input_node->type_string() == kVarHandleOp) {
2020 TF_ASSIGN_OR_RETURN(
2021 absl::optional<xla::OpSharding> parsed_sharding,
2022 GetShardingFromNodeDef(input_node->def(), /*add_metadata=*/true));
2023 if (parsed_sharding.has_value()) {
2024 node_and_sharding = NodeAndSharding(input_node, *parsed_sharding);
2025 VLOG(1) << "Arg " << i << " parsed sharding information from "
2026 << input_node->DebugString() << " : "
2027 << parsed_sharding->DebugString();
2028 }
2029 }
2030 }
2031
2032 if (node_and_sharding.has_value() && enable_automatic_model_parallelism_) {
2033 return tensorflow::errors::InvalidArgument(
2034 "Specifying manual sharding is not allowed when automatic "
2035 "model parallelism is enabled.",
2036 node_and_sharding->sharding.DebugString());
2037 }
2038
2039 if (!node_and_sharding.has_value()) {
2040 if (use_spmd &&
2041 (params_info.IsVariableArg(i) || params_info.IsBroadcastArg(i) ||
2042 ((params_info.IsPerReplicaArg(i) ||
2043 params_info.IsDistributedArg(i)) &&
2044 arg_types[i] != DT_RESOURCE))) {
2045 // Use replication for host variables or non-variable per-replica
2046 // inputs.
2047 node_and_sharding = NodeAndSharding(/*node=*/nullptr,
2048 xla::sharding_builder::Replicate());
2049 } else {
2050 // TODO(dlibenzi): Distributing variables to cores other than 0 makes
2051 // learning/brain/research/babelfish/trainer:trainer_tpu_test fail.
2052 // For now distribute only per replica arguments, unless
2053 // tf_jf_distribute_vars is set, to allow debugging the issue.
2054 if (((params_info.IsPerReplicaArg(i) ||
2055 params_info.IsDistributedArg(i)) &&
2056 arg_types[i] != DT_RESOURCE) ||
2057 (distribute_vars_ && params_info.IsVariableArg(i))) {
2058 assigned_core = args_device_selector.RetrieveAssignment(i);
2059 } else {
2060 assigned_core = 0;
2061 }
2062 node_and_sharding = NodeAndSharding(
2063 /*node=*/nullptr,
2064 xla::sharding_builder::AssignDevice(*assigned_core));
2065 }
2066 *node_and_sharding->sharding.add_metadata() =
2067 CreateOpMetadataFromNode(*replicate_node);
2068 } else if (node_and_sharding->sharding.type() == xla::OpSharding::MAXIMAL) {
2069 assigned_core = node_and_sharding->sharding.tile_assignment_devices(0);
2070 } else if (node_and_sharding->sharding.type() !=
2071 xla::OpSharding::REPLICATED &&
2072 node_and_sharding->sharding.type() != xla::OpSharding::OTHER) {
2073 return tensorflow::errors::InvalidArgument(
2074 "Unsupported argument sharding (for arg ", n->DebugString(),
2075 "): ", node_and_sharding->sharding.DebugString());
2076 }
2077 if (assigned_core.has_value()) {
2078 args_device_selector.ReportDeviceAssigned(*assigned_core, i);
2079 VLOG(3) << "Assigning argument " << i << " (" << n->DebugString()
2080 << ") to core " << *assigned_core
2081 << FormatNodeAndShardingMsg(node_and_sharding);
2082 args[i]->set_assigned_device_name(CoreDeviceLabel(*assigned_core));
2083 } else if (node_and_sharding->sharding.type() == xla::OpSharding::OTHER) {
2084 for (int64 core : node_and_sharding->sharding.tile_assignment_devices()) {
2085 args_device_selector.ReportDeviceAssigned(core, i);
2086 }
2087 VLOG(3) << "Assigning argument " << i << " (" << n->DebugString()
2088 << ") with tiled sharding to cores "
2089 << absl::StrJoin(
2090 node_and_sharding->sharding.tile_assignment_devices(), ",")
2091 << " " << FormatNodeAndShardingMsg(node_and_sharding);
2092 } else {
2093 DCHECK_EQ(node_and_sharding->sharding.type(),
2094 xla::OpSharding::REPLICATED);
2095 for (int64 core = 0; core < num_cores_per_replica; ++core) {
2096 args_device_selector.ReportDeviceAssigned(core, i);
2097 }
2098 VLOG(3) << "Assigning argument " << i << " (" << n->DebugString()
2099 << ") to all cores"
2100 << FormatNodeAndShardingMsg(node_and_sharding);
2101 }
2102 (*arg_sharding)[i] = node_and_sharding->sharding;
2103 (*arg_fast_mem)[i] = is_fast_mem;
2104 (*arg_names)[i] = n->name();
2105 if (is_fast_mem) {
2106 VLOG(3) << "Add " << TPU_FAST_MEM_ATTR << " attribute to "
2107 << args[i]->name();
2108 }
2109 args[i]->AddAttr(kShardingAttribute,
2110 node_and_sharding->sharding.SerializeAsString());
2111 }
2112 TF_RETURN_IF_ERROR(cached_function_handles.ReleaseAllHandles());
2113
2114 // Assigns each _Retval node to the core that produces its value.
2115 TensorDevicePlacer retvals_device_selector(num_cores_per_replica,
2116 retval_types, retval_shapes);
2117 retval_sharding->resize(retvals.size());
2118 for (int i = 0; i < retvals.size(); ++i) {
2119 const Edge* edge;
2120 TF_RETURN_IF_ERROR(retvals[i]->input_edge(0, &edge));
2121
2122 TF_ASSIGN_OR_RETURN(
2123 absl::optional<xla::OpSharding> edge_sharding,
2124 ParseShardingFromEdgeSource(*edge, num_cores_per_replica,
2125 /*add_metadata=*/true));
2126
2127 absl::optional<NodeAndSharding> node_and_sharding;
2128 if (edge_sharding.has_value()) {
2129 node_and_sharding.emplace(NodeAndSharding(edge->src(), *edge_sharding));
2130 }
2131
2132 if (partitioned_output_nodes.contains(i)) {
2133 Node* output_node = partitioned_output_nodes[i];
2134 TF_ASSIGN_OR_RETURN(
2135 absl::optional<xla::OpSharding> parsed_sharding,
2136 GetShardingFromNodeDef(output_node->def(), /*add_metadata=*/true));
2137 if (parsed_sharding.has_value()) {
2138 node_and_sharding = NodeAndSharding(output_node, *parsed_sharding);
2139 VLOG(1) << "Retval " << i << " parsed sharding information from "
2140 << output_node->DebugString() << " : "
2141 << parsed_sharding->DebugString();
2142 }
2143 }
2144 absl::optional<int64> assigned_core;
2145 if (node_and_sharding.has_value()) {
2146 if (enable_automatic_model_parallelism_) {
2147 return tensorflow::errors::InvalidArgument(
2148 "Specifying manual sharding is not allowed when automatic "
2149 "model parallelism is enabled.",
2150 node_and_sharding->sharding.DebugString());
2151 }
2152
2153 if (node_and_sharding->sharding.type() == xla::OpSharding::MAXIMAL) {
2154 assigned_core = node_and_sharding->sharding.tile_assignment_devices(0);
2155 TF_RETURN_IF_ERROR(
2156 ValidateCoreNumber(*assigned_core, num_cores_per_replica));
2157 } else if (node_and_sharding->sharding.type() !=
2158 xla::OpSharding::REPLICATED &&
2159 node_and_sharding->sharding.type() != xla::OpSharding::OTHER) {
2160 return tensorflow::errors::InvalidArgument(
2161 "Unsupported argument sharding for retval ",
2162 retvals[i]->DebugString(), " edge=", edge->DebugString(), ": ",
2163 node_and_sharding->sharding.DebugString());
2164 }
2165 } else {
2166 if (use_spmd) {
2167 node_and_sharding = NodeAndSharding(/*node=*/nullptr,
2168 xla::sharding_builder::Replicate());
2169 } else {
2170 if (distribute_vars_) {
2171 assigned_core = retvals_device_selector.RetrieveAssignment(i);
2172 } else {
2173 assigned_core = 0;
2174 }
2175 node_and_sharding = NodeAndSharding(
2176 /*node=*/nullptr,
2177 xla::sharding_builder::AssignDevice(*assigned_core));
2178 }
2179 *node_and_sharding->sharding.add_metadata() =
2180 CreateOpMetadataFromNode(*replicate_node);
2181 }
2182 if (assigned_core.has_value()) {
2183 retvals[i]->set_assigned_device_name(CoreDeviceLabel(*assigned_core));
2184 retvals_device_selector.ReportDeviceAssigned(*assigned_core, i);
2185 VLOG(3) << "Assigning return value " << i << " ("
2186 << retvals[i]->DebugString() << ") to core " << *assigned_core
2187 << FormatNodeAndShardingMsg(node_and_sharding);
2188 } else if (node_and_sharding->sharding.type() == xla::OpSharding::OTHER) {
2189 for (int64 core : node_and_sharding->sharding.tile_assignment_devices()) {
2190 retvals_device_selector.ReportDeviceAssigned(core, i);
2191 }
2192 VLOG(3) << "Assigning return value " << i << " ("
2193 << retvals[i]->DebugString() << ") with tiled sharding to cores "
2194 << absl::StrJoin(
2195 node_and_sharding->sharding.tile_assignment_devices(), ",")
2196 << " " << FormatNodeAndShardingMsg(node_and_sharding);
2197 } else {
2198 DCHECK_EQ(node_and_sharding->sharding.type(),
2199 xla::OpSharding::REPLICATED);
2200 for (int64 core = 0; core < num_cores_per_replica; ++core) {
2201 retvals_device_selector.ReportDeviceAssigned(core, i);
2202 }
2203 VLOG(3) << "Assigning return value " << i << " ("
2204 << retvals[i]->DebugString() << ") to all cores"
2205 << FormatNodeAndShardingMsg(node_and_sharding);
2206 }
2207 retvals[i]->AddAttr(kShardingAttribute,
2208 node_and_sharding->sharding.SerializeAsString());
2209 (*retval_sharding)[i] = node_and_sharding->sharding;
2210 }
2211 if (use_spmd &&
2212 (absl::c_any_of(*arg_sharding,
2213 [](const xla::OpSharding& s) {
2214 return s.type() == xla::OpSharding::MAXIMAL;
2215 }) ||
2216 absl::c_any_of(*retval_sharding, [](const xla::OpSharding& s) {
2217 return s.type() == xla::OpSharding::MAXIMAL;
2218 }))) {
2219 LOG(WARNING) << "XLA SPMD only supports cases where all inputs/outputs "
2220 "exist on every partition (sharded or replicated). Fall "
2221 "back to MPMD.";
2222 return AssignArgsAndRetvalsToCores(
2223 num_cores_per_replica, params_info, arg_types, arg_shapes, retval_types,
2224 retval_shapes, graph, replicate_node, flr,
2225 /*allow_parameter_replication_for_spmd=*/false, arg_sharding,
2226 arg_fast_mem, retval_sharding, arg_names);
2227 }
2228 return Status::OK();
2229 }
2230
2231 // Builds Shape nodes that compute the shapes of arguments whose shapes are not
2232 // statically known.
BuildDynamicShapeNodes(const Node & replicate_node,const std::vector<InferredShape> & arg_shapes,const ParameterInfo & params_info,const std::vector<Node * > & variable_reads,Graph * graph,std::vector<Node * > * dynamic_shape_nodes)2233 /* static */ Status DistributedTPURewritePass::BuildDynamicShapeNodes(
2234 const Node& replicate_node, const std::vector<InferredShape>& arg_shapes,
2235 const ParameterInfo& params_info, const std::vector<Node*>& variable_reads,
2236 Graph* graph, std::vector<Node*>* dynamic_shape_nodes) {
2237 dynamic_shape_nodes->clear();
2238
2239 std::vector<const Edge*> replicate_input_edges;
2240 TF_RETURN_IF_ERROR(replicate_node.input_edges(&replicate_input_edges));
2241
2242 // The compiler determines the shape of each constant by inspecting the value
2243 // of its corresponding host-memory tensor; this happens when a step is run.
2244 // As a result, the shapes of constants are not needed at graph rewrite time.
2245 const int num_args = arg_shapes.size() - params_info.NumGuaranteedConstants();
2246 TF_RET_CHECK(num_args == params_info.NumPerReplicaArgs() +
2247 params_info.NumDistributedArgs() +
2248 params_info.NumBroadcastArgs() +
2249 params_info.NumVariables());
2250
2251 for (int i = 0; i < num_args; ++i) {
2252 const PartialTensorShape* shape = arg_shapes[i].handle_type == DT_INVALID
2253 ? &arg_shapes[i].shape
2254 : &arg_shapes[i].handle_shape;
2255 if (!shape->IsFullyDefined()) {
2256 Node* src;
2257 int src_output;
2258 if (params_info.IsPerReplicaArg(i)) {
2259 TF_RET_CHECK(i < replicate_input_edges.size());
2260 // All replicas must have the same input shapes. Uses the shape of the
2261 // inputs from the first replica.
2262 src = replicate_input_edges[i]->src();
2263 src_output = replicate_input_edges[i]->src_output();
2264 } else if (params_info.IsDistributedArg(i) ||
2265 params_info.IsBroadcastArg(i)) {
2266 int64 input_num =
2267 params_info.NumPerReplicaArgs() * params_info.NumReplicas() + i -
2268 params_info.NumPerReplicaArgs();
2269 TF_RET_CHECK(0 <= input_num &&
2270 input_num < replicate_input_edges.size());
2271 src = replicate_input_edges[input_num]->src();
2272 src_output = replicate_input_edges[input_num]->src_output();
2273 } else {
2274 int64 var_num = i - params_info.NumPerReplicaArgs() -
2275 params_info.NumDistributedArgs() -
2276 params_info.NumBroadcastArgs();
2277 TF_RET_CHECK(0 <= var_num && var_num < variable_reads.size());
2278 src = variable_reads[var_num];
2279 src_output = 0;
2280 }
2281
2282 NodeDef def;
2283 def.set_name(graph->NewName(strings::StrCat(src->name(), "/shape")));
2284 def.set_op("Shape");
2285 def.set_device(src->assigned_device_name());
2286 AddNodeAttr("T", src->output_type(src_output), &def);
2287 AddNodeAttr("out_type", DT_INT64, &def);
2288 MergeDebugInfo(NodeDebugInfo(replicate_node.def()), &def);
2289
2290 Status status;
2291 Node* shape_node = graph->AddNode(def, &status);
2292 if (!status.ok()) return status;
2293 dynamic_shape_nodes->push_back(shape_node);
2294
2295 shape_node->set_assigned_device_name(src->assigned_device_name());
2296 graph->AddEdge(src, src_output, shape_node, 0);
2297 }
2298 }
2299 return Status::OK();
2300 }
2301
2302 // Builds a TPUCompile node that compiles the bodies of the function call
2303 // `nodes`.
BuildCompileNode(const Node * replicate_node,const NameAttrList & function,uint64 library_fingerprint,const ParameterInfo & params_info,const std::vector<InferredShape> & arg_shapes,const DataTypeVector & arg_types,const std::vector<Node * > & guaranteed_constant_nodes,const string & session_handle,const std::vector<xla::OpSharding> & arg_sharding,const std::vector<bool> & arg_fast_mem,const std::vector<std::string> & arg_names,const std::vector<xla::OpSharding> & retval_sharding,int num_cores_per_replica,const string & compile_device,const xla::DeviceAssignment * xla_device_assignment,const std::vector<Node * > & dynamic_shape_nodes,Graph * graph,Node ** compile_node,int64 autotuner_thresh)2304 Status DistributedTPURewritePass::BuildCompileNode(
2305 const Node* replicate_node, const NameAttrList& function,
2306 uint64 library_fingerprint, const ParameterInfo& params_info,
2307 const std::vector<InferredShape>& arg_shapes,
2308 const DataTypeVector& arg_types,
2309 const std::vector<Node*>& guaranteed_constant_nodes,
2310 const string& session_handle,
2311 const std::vector<xla::OpSharding>& arg_sharding,
2312 const std::vector<bool>& arg_fast_mem,
2313 const std::vector<std::string>& arg_names,
2314 const std::vector<xla::OpSharding>& retval_sharding,
2315 int num_cores_per_replica, const string& compile_device,
2316 const xla::DeviceAssignment* xla_device_assignment,
2317 const std::vector<Node*>& dynamic_shape_nodes, Graph* graph,
2318 Node** compile_node, int64 autotuner_thresh) {
2319 VLOG(1) << "BuildCompileNode";
2320
2321 tpu::TPUCompileMetadataProto proto;
2322 proto.set_num_replicas(params_info.NumReplicas());
2323 proto.set_num_cores_per_replica(num_cores_per_replica);
2324 proto.set_function_library_fingerprint(library_fingerprint);
2325 proto.set_enable_automatic_model_parallelism(
2326 enable_cross_replica_sharding_mirrored_variables_);
2327 const bool use_spmd =
2328 UseSpmdForXlaPartitioning(replicate_node) && allow_xla_spmd_partition_ &&
2329 !absl::c_any_of(arg_sharding,
2330 [](const xla::OpSharding& s) {
2331 return s.type() == xla::OpSharding::MAXIMAL;
2332 }) &&
2333 !absl::c_any_of(retval_sharding, [](const xla::OpSharding& s) {
2334 return s.type() == xla::OpSharding::MAXIMAL;
2335 });
2336 proto.set_use_spmd_for_xla_partitioning(use_spmd);
2337 proto.set_broadcast_replicated_parameters_via_collectives(
2338 enable_xla_param_broadcast_);
2339
2340 // Get and fill padding map.
2341 if (replicate_node != nullptr) {
2342 TF_RETURN_IF_ERROR(
2343 FillPaddingMap(*replicate_node, proto.mutable_padding_maps()));
2344 xla::DebugOptions::StepMarkerLocation location;
2345 TF_RETURN_IF_ERROR(GetStepMarkerLocation(*replicate_node, &location));
2346 proto.set_step_marker_location(location);
2347 }
2348
2349 if (xla_device_assignment != nullptr) {
2350 TF_RETURN_IF_ERROR(
2351 xla_device_assignment->Serialize(proto.mutable_device_assignment()));
2352 }
2353
2354 const int num_args = arg_types.size();
2355 const int num_guaranteed_constants = guaranteed_constant_nodes.size();
2356 const int guaranteed_const_start_index = num_args - num_guaranteed_constants;
2357 TF_RET_CHECK(num_args == arg_shapes.size());
2358 TF_RET_CHECK(num_args == arg_sharding.size())
2359 << num_args << " != " << arg_sharding.size();
2360
2361 for (int i = 0; i < num_args; ++i) {
2362 tpu::TPUCompileMetadataProto::Arg* arg = proto.add_args();
2363 DataType type = arg_types[i];
2364 const InferredShape& arg_shape = arg_shapes[i];
2365 arg->set_name(arg_names[i]);
2366 if (type == DT_RESOURCE) {
2367 TF_RET_CHECK(arg_shape.handle_type != DT_INVALID) << i;
2368 arg->set_dtype(arg_shape.handle_type);
2369 arg_shape.handle_shape.AsProto(arg->mutable_shape());
2370 arg->set_kind(tpu::TPUCompileMetadataProto::Arg::VARIABLE);
2371 arg->set_fast_mem(arg_fast_mem[i]);
2372 } else {
2373 arg->set_dtype(type);
2374 arg_shape.shape.AsProto(arg->mutable_shape());
2375 if (i >= guaranteed_const_start_index) {
2376 const DataType edge_type =
2377 guaranteed_constant_nodes[i - guaranteed_const_start_index]
2378 ->output_type(0);
2379 TF_RET_CHECK(type == edge_type)
2380 << "Arg type: " << type << " but edge type: " << edge_type;
2381 arg->set_kind(tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT);
2382 } else {
2383 arg->set_kind(tpu::TPUCompileMetadataProto::Arg::PARAMETER);
2384 }
2385 }
2386 // As long as the argument is not a per-replica one, it should have the same
2387 // value for all replicas. For clarity, we keep the (redundant) checks for
2388 // variable, broadcast and constant types, to prevent bugs in case new types
2389 // with different semantics are introduced in the future.
2390 arg->set_is_same_data_across_replicas(
2391 !params_info.IsPerReplicaArg(i) && !params_info.IsDistributedArg(i) &&
2392 (params_info.IsVariableArg(i) || params_info.IsBroadcastArg(i) ||
2393 params_info.IsConstantArg(i)));
2394 if (params_info.mirrored_variable_indices().count(i) > 0) {
2395 CHECK_EQ(type, DT_RESOURCE);
2396 arg->set_is_same_data_across_replicas(true);
2397 // 64-bit type is not shardable by XLA:TPU yet.
2398 bool sharding_enabled = (arg_shape.handle_type != DT_COMPLEX64 &&
2399 arg_shape.handle_type != DT_INT64 &&
2400 arg_shape.handle_type != DT_UINT64 &&
2401 arg_shape.handle_type != DT_DOUBLE);
2402 arg->set_enable_xla_sharding(
2403 sharding_enabled ? tpu::TPUCompileMetadataProto::Arg::TENTATIVE
2404 : tpu::TPUCompileMetadataProto::Arg::DISALLOWED);
2405 }
2406 *arg->mutable_sharding() = arg_sharding[i];
2407 }
2408
2409 const int num_retvals = retval_sharding.size();
2410 for (int i = 0; i < num_retvals; ++i) {
2411 *proto.add_retvals()->mutable_sharding() = retval_sharding[i];
2412 }
2413 proto.set_session_handle(session_handle);
2414
2415 DataTypeVector constant_arg_types;
2416 constant_arg_types.reserve(num_guaranteed_constants);
2417 for (int i = 0; i < num_guaranteed_constants; ++i) {
2418 constant_arg_types.push_back(arg_types[guaranteed_const_start_index + i]);
2419 }
2420 proto.set_xla_fusion_autotuner_thresh(autotuner_thresh);
2421
2422 string metadata;
2423 proto.SerializeToString(&metadata);
2424
2425 NodeDef def;
2426 def.set_name(UniqueNodeName("TPUReplicate/_compile", graph));
2427 def.set_op("TPUCompile");
2428 def.set_device(compile_device);
2429 if (replicate_node) {
2430 MergeDebugInfo(NodeDebugInfo(replicate_node->def()), &def);
2431 }
2432
2433 AddNodeAttr("function", function, &def);
2434 AddNodeAttr("num_computations", num_cores_per_replica, &def);
2435 AddNodeAttr("NumDynamicShapes", static_cast<int>(dynamic_shape_nodes.size()),
2436 &def);
2437 AddNodeAttr("metadata", metadata, &def);
2438 AddNodeAttr("Tguaranteed_constants", constant_arg_types, &def);
2439
2440 Status status;
2441 *compile_node = graph->AddNode(def, &status);
2442 TF_RETURN_IF_ERROR(status);
2443
2444 (*compile_node)->set_assigned_device_name(compile_device);
2445
2446 for (int i = 0; i < dynamic_shape_nodes.size(); ++i) {
2447 graph->AddEdge(dynamic_shape_nodes[i], 0, *compile_node, i);
2448 }
2449
2450 for (int i = 0; i < num_guaranteed_constants; ++i) {
2451 graph->AddEdge(guaranteed_constant_nodes[i], 0, *compile_node,
2452 dynamic_shape_nodes.size() + i);
2453 }
2454 VLOG(1) << "BuildCompileNode(): " << status;
2455 return status;
2456 }
2457
FindGuaranteedConstantInputs(const Node & node,const NameRangeMap & input_range_map,std::vector<Node * > * guaranteed_constants)2458 Status DistributedTPURewritePass::FindGuaranteedConstantInputs(
2459 const Node& node, const NameRangeMap& input_range_map,
2460 std::vector<Node*>* guaranteed_constants) {
2461 std::vector<const Edge*> input_edges;
2462 TF_RETURN_IF_ERROR(node.input_edges(&input_edges));
2463 std::pair<int, int> variables_limits =
2464 input_range_map.at("guaranteed_constants");
2465 for (int i = variables_limits.first; i < variables_limits.second; ++i) {
2466 guaranteed_constants->push_back(input_edges[i]->src());
2467 }
2468 return Status::OK();
2469 }
2470
FindVariableInputs(const Node & node,const NameRangeMap & input_range_map,std::vector<VariableInput> * variables)2471 Status DistributedTPURewritePass::FindVariableInputs(
2472 const Node& node, const NameRangeMap& input_range_map,
2473 std::vector<VariableInput>* variables) {
2474 std::vector<const Edge*> input_edges;
2475 TF_RETURN_IF_ERROR(node.input_edges(&input_edges));
2476 std::pair<int, int> variables_limits = input_range_map.at("variables");
2477 for (int i = variables_limits.first; i < variables_limits.second; ++i) {
2478 Node* node = input_edges[i]->src();
2479
2480 // Find the type of the VarHandleOp that feeds this node, looking through
2481 // any wrapping Enter or Switch nodes.
2482 while (node->IsEnter() || node->IsSwitch()) {
2483 TF_RETURN_IF_ERROR(node->input_node(0, &node));
2484 }
2485 // Fix the variable device assignment if it is requested with a full name.
2486 if (!node->has_assigned_device_name() &&
2487 !node->requested_device().empty()) {
2488 DeviceNameUtils::ParsedName var_device;
2489 TF_RET_CHECK(DeviceNameUtils::ParseFullName(node->requested_device(),
2490 &var_device));
2491 if (var_device.has_job && var_device.has_replica && var_device.has_task &&
2492 var_device.has_type && var_device.has_id) {
2493 node->set_assigned_device_name(node->requested_device());
2494 if (node != input_edges[i]->src() &&
2495 !input_edges[i]->src()->has_assigned_device_name()) {
2496 input_edges[i]->src()->set_assigned_device_name(
2497 node->requested_device());
2498 }
2499 }
2500 }
2501 if (node->type_string() == kVarHandleOp) {
2502 DataType dtype;
2503 TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "dtype", &dtype));
2504 variables->push_back(VariableInput{input_edges[i]->src(),
2505 input_edges[i]->src_output(), dtype});
2506 } else if (node->type_string() == "_Arg") {
2507 std::vector<DataType> dtypes;
2508 TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "_handle_dtypes", &dtypes));
2509 if (dtypes.empty()) {
2510 return errors::Internal(
2511 "_Arg node with resource output must have non-empty _handle_dtypes "
2512 "attribute: ",
2513 node->DebugString());
2514 }
2515 variables->push_back(VariableInput{
2516 input_edges[i]->src(), input_edges[i]->src_output(), dtypes[0]});
2517 } else {
2518 return errors::Internal(
2519 "Cannot handle variable input with node type other than VarHandleOp "
2520 "and _Arg: ",
2521 node->DebugString());
2522 }
2523 }
2524 return Status::OK();
2525 }
2526
2527 // Builds a NoOp node, used for building control dependencies.
BuildNoopNode(const Node & source,StringPiece name,const string & device,Graph * graph,Node ** node)2528 static Status BuildNoopNode(const Node& source, StringPiece name,
2529 const string& device, Graph* graph, Node** node) {
2530 NodeDefBuilder builder(name, "NoOp", NodeDebugInfo(source));
2531 if (!device.empty()) {
2532 builder.Device(device);
2533 }
2534 NodeDef def;
2535 TF_RETURN_IF_ERROR(builder.Finalize(&def));
2536
2537 Status status;
2538 *node = graph->AddNode(def, &status);
2539 if (!device.empty()) {
2540 (*node)->set_assigned_device_name(device);
2541 }
2542 return status;
2543 }
2544
ConnectHostComputeNodes(Node * compile_node,Node * key_placeholder_node,Graph * graph)2545 Status DistributedTPURewritePass::ConnectHostComputeNodes(
2546 Node* compile_node, Node* key_placeholder_node, Graph* graph) {
2547 // First find all the downstream nodes of the key placeholder node, since we
2548 // want to delete the connecting edges from key_placeholder_node which would
2549 // invalidate the out_nodes iterator.
2550 std::vector<Node*> host_transfer_nodes;
2551 for (Node* node : key_placeholder_node->out_nodes()) {
2552 host_transfer_nodes.push_back(node);
2553 }
2554 for (Node* node : host_transfer_nodes) {
2555 int input_index = -1;
2556 for (int i = 0; i < node->num_inputs(); i++) {
2557 const Edge* e;
2558 TF_RETURN_IF_ERROR(node->input_edge(i, &e));
2559 if (e->src() == key_placeholder_node) {
2560 if (input_index != -1) {
2561 return errors::Internal(
2562 "Node ", node->name(),
2563 " has multiple input edges from key placeholder node");
2564 }
2565 input_index = e->dst_input();
2566 }
2567 }
2568 if (input_index == -1) {
2569 return errors::Internal("Node ", node->name(),
2570 " has no input edge from key placeholder node");
2571 }
2572 const Edge* key_edge;
2573 TF_RETURN_IF_ERROR(node->input_edge(input_index, &key_edge));
2574 graph->RemoveEdge(key_edge);
2575 graph->AddEdge(compile_node, 1, node, input_index);
2576 }
2577 graph->RemoveNode(key_placeholder_node);
2578 return Status::OK();
2579 }
2580
BuildVariableReads(absl::Span<const VariableInput> variables,Node * control_predecessor,Graph * graph,std::vector<Node * > * variable_reads)2581 Status DistributedTPURewritePass::BuildVariableReads(
2582 absl::Span<const VariableInput> variables, Node* control_predecessor,
2583 Graph* graph, std::vector<Node*>* variable_reads) {
2584 variable_reads->resize(variables.size());
2585 for (int i = 0; i < variables.size(); ++i) {
2586 string name =
2587 graph->NewName(strings::StrCat(variables[i].node->name(), "/read"));
2588 NodeDefBuilder builder(name, "ReadVariableOp",
2589 NodeDebugInfo(*variables[i].node));
2590
2591 builder.Attr("dtype", variables[i].dtype);
2592 builder.Device(variables[i].node->assigned_device_name());
2593 builder.Input(variables[i].node->name(), 0, DT_RESOURCE);
2594 NodeDef def;
2595 TF_RETURN_IF_ERROR(builder.Finalize(&def));
2596
2597 Status status;
2598 Node* read_node;
2599 (*variable_reads)[i] = read_node = graph->AddNode(def, &status);
2600 if (!status.ok()) return status;
2601
2602 read_node->set_requested_device(variables[i].node->requested_device());
2603 read_node->set_assigned_device_name(
2604 variables[i].node->assigned_device_name());
2605 graph->AddEdge(variables[i].node, variables[i].index, read_node, 0);
2606
2607 graph->AddControlEdge(control_predecessor, read_node);
2608 }
2609 return Status::OK();
2610 }
2611
ContainsResourceWriteOp(const Graph & graph,const FunctionLibraryDefinition & fld)2612 bool DistributedTPURewritePass::ContainsResourceWriteOp(
2613 const Graph& graph, const FunctionLibraryDefinition& fld) {
2614 for (const Node* n : graph.nodes()) {
2615 const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n->type_string());
2616 if (op_info && op_info->kind() != XlaResourceOpKind::kRead) {
2617 VLOG(2) << "Found write resource op inside computation";
2618 return true;
2619 }
2620 }
2621 for (const string& func_name : fld.ListFunctionNames()) {
2622 const FunctionDef* func_def = fld.Find(func_name);
2623 for (const NodeDef& n : func_def->node_def()) {
2624 const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n.op());
2625 if (op_info && op_info->kind() != XlaResourceOpKind::kRead) {
2626 VLOG(2) << "Found write resource op inside " << func_name;
2627 return true;
2628 }
2629 }
2630 }
2631 return false;
2632 }
2633
BuildVariableWrites(absl::Span<const VariableInput> variables,Node * control_successor,absl::Span<const VariableWrite> variable_writes,Graph * graph)2634 Status DistributedTPURewritePass::BuildVariableWrites(
2635 absl::Span<const VariableInput> variables, Node* control_successor,
2636 absl::Span<const VariableWrite> variable_writes, Graph* graph) {
2637 CHECK_EQ(variables.size(), variable_writes.size());
2638 for (int i = 0; i < variables.size(); ++i) {
2639 const VariableWrite& write = variable_writes[i];
2640 NodeDebugInfo debug_info(*variables[i].node);
2641
2642 auto name = [&](string suffix) {
2643 return graph->NewName(
2644 strings::StrCat(variables[i].node->name(), "/", suffix));
2645 };
2646
2647 Node* write_node;
2648 TF_RETURN_IF_ERROR(
2649 IncompleteNodeDefBuilder(name("assign"), "AssignVariableOp", debug_info)
2650 .AddAttr("dtype", variables[i].dtype)
2651 .Device(variables[i].node->assigned_device_name())
2652 .Build(graph, &write_node));
2653
2654 // Colocate the control flow with the variable.
2655 CondBuilder cb(variables[i].node->name(),
2656 variables[i].node->assigned_device_name(), debug_info,
2657 graph);
2658
2659 // Inputs to conditional.
2660 Node* switch_val;
2661 TF_RETURN_IF_ERROR(
2662 cb.AddInput("switch_val", variables[i].dtype,
2663 /*device=*/write.value->assigned_device_name(), debug_info,
2664 &switch_val));
2665 Node* switch_var;
2666 TF_RETURN_IF_ERROR(
2667 cb.AddInput("switch_var", DT_RESOURCE,
2668 /*device=*/variables[i].node->assigned_device_name(),
2669 debug_info, &switch_var));
2670 // Conditionally write the value back.
2671 graph->AddEdge(variables[i].node, variables[i].index, switch_var, 0);
2672 graph->AddEdge(switch_var, CondBuilder::kThenBranch, write_node, 0);
2673 graph->AddEdge(switch_val, CondBuilder::kThenBranch, write_node, 1);
2674 // Add control edge from the write to value that will be merged. There is no
2675 // output from the write so this control edge ensures the write completes.
2676 graph->AddControlEdge(write_node, cb.switch_t());
2677
2678 graph->AddControlEdge(cb.control_successor(), control_successor);
2679
2680 graph->AddEdge(write.predicate, write.predicate_output, cb.pred(), 0);
2681 graph->AddEdge(write.value, write.value_output, switch_val, 0);
2682 }
2683 return Status::OK();
2684 }
2685
2686 namespace {
2687
2688 // Creates nodes for zero-initialized dummy arguments for TPUExecute nodes.
CreatePerHostDummyArgs(const InferredShape & raw_var_shape,const string & host_cpu_device,Node * var_read,absl::string_view name_prefix,Graph * graph)2689 xla::StatusOr<Node*> CreatePerHostDummyArgs(const InferredShape& raw_var_shape,
2690 const string& host_cpu_device,
2691 Node* var_read,
2692 absl::string_view name_prefix,
2693 Graph* graph) {
2694 Status status;
2695 DataType dtype;
2696 TF_RETURN_IF_ERROR(GetNodeAttr(var_read->def(), "dtype", &dtype));
2697
2698 if (!(dtype == DT_FLOAT || dtype == DT_BFLOAT16 || dtype == DT_INT32 ||
2699 dtype == DT_BOOL)) {
2700 return var_read;
2701 }
2702
2703 TensorShape var_shape;
2704 if (!raw_var_shape.handle_shape.AsTensorShape(&var_shape) &&
2705 !raw_var_shape.shape.AsTensorShape(&var_shape)) {
2706 return Status(error::FAILED_PRECONDITION, "Failed to read arg shape.");
2707 }
2708
2709 // Const - shape_as_tensor
2710 NodeDef shape_tensor_def;
2711 shape_tensor_def.set_op("Const");
2712 shape_tensor_def.set_name(graph->NewName(
2713 strings::StrCat(name_prefix, "/Initializer/zeros/shape_as_tensor")));
2714 AddNodeAttr("dtype", DT_INT32, &shape_tensor_def);
2715 TensorProto tensorshape_proto;
2716 tensorshape_proto.set_dtype(DT_INT32);
2717 for (int i = 0; i < var_shape.dims(); ++i) {
2718 tensorshape_proto.add_int_val(var_shape.dim_size(i));
2719 }
2720 TensorShape shape_shape({var_shape.dims()});
2721 shape_shape.AsProto(tensorshape_proto.mutable_tensor_shape());
2722 AddNodeAttr("value", tensorshape_proto, &shape_tensor_def);
2723 Node* shape_as_tensor_node = graph->AddNode(shape_tensor_def, &status);
2724 TF_RETURN_IF_ERROR(status);
2725
2726 // Const - initializer value
2727 NodeDef init_val_def;
2728 init_val_def.set_op("Const");
2729 init_val_def.set_name(graph->NewName(
2730 strings::StrCat(name_prefix, "/Initializer/zeros/const_val")));
2731 TensorProto tensor_proto;
2732 tensor_proto.set_dtype(dtype);
2733 if (dtype == DT_FLOAT) {
2734 tensor_proto.add_float_val(0.0f);
2735 } else if (dtype == DT_BFLOAT16) {
2736 tensor_proto.add_half_val(0);
2737 } else if (dtype == DT_INT32) {
2738 tensor_proto.add_int_val(0);
2739 } else if (dtype == DT_BOOL) {
2740 tensor_proto.add_bool_val(false);
2741 } else {
2742 return errors::Internal(
2743 "Unable to create zero-init dummy arg tensor for type ", dtype);
2744 }
2745 TensorShape scalar_shape({});
2746 scalar_shape.AsProto(tensor_proto.mutable_tensor_shape());
2747 AddNodeAttr("value", tensor_proto, &init_val_def);
2748 AddNodeAttr("dtype", dtype, &init_val_def);
2749 Node* init_val_node = graph->AddNode(init_val_def, &status);
2750 TF_RETURN_IF_ERROR(status);
2751
2752 // Fill node
2753 NodeDef fill_def;
2754 fill_def.set_op("Fill");
2755 fill_def.set_device(host_cpu_device);
2756 fill_def.set_name(
2757 graph->NewName(strings::StrCat(name_prefix, "/Initializer/zeros")));
2758 AddNodeAttr("T", dtype, &fill_def);
2759 AddNodeAttr("index_type", DT_INT32, &fill_def);
2760 Node* fill_node = graph->AddNode(fill_def, &status);
2761 TF_RETURN_IF_ERROR(status);
2762 graph->AddEdge(shape_as_tensor_node, 0, fill_node, 0);
2763 graph->AddEdge(init_val_node, 0, fill_node, 1);
2764
2765 return fill_node;
2766 }
2767
2768 // Helper that creates an IdentityN node containing all of the variables
2769 // values on CPU device 'device', except for those that will be split across
2770 // cores. (For split variables, this may cause additional cross-host data
2771 // transfers if more than 1 devices share the same variable partition on a
2772 // remote host.)
2773 //
2774 // A previous iteration of this code built one Identity node per TPU core per
2775 // variable, but this can rapidly become hundreds of thousands of nodes. This
2776 // formulation creates a single IdentityN node containing all of the variables
2777 // on each host. This may cause some unnecessary variable copies if only a
2778 // subset of hosts consume a given variable, but has the virtue of being
2779 // simple, and most models use pure replication where all cores want all the
2780 // variables.
2781 //
2782 // If enable_xla_param_broadcast is set to true, then per-host dummy
2783 // tensor args are created on all hosts except for the primary host. In this
2784 // scheme, the dummy args feed the IdentityN node on their local host. All
2785 // are zero-initialized.
2786 //
2787 // Returns the node and its output index to be consumed by TPUExecute for the
2788 // requested variable index.
CreateOrGetPerHostVariableCopy(const string & host_cpu_device,int64 var_index,const std::vector<Node * > & variable_reads,const DistributedTPURewritePass::ParameterInfo & params_info,const std::vector<xla::OpSharding> & arg_shardings,const Node & replicate_node,const bool enable_xla_param_broadcast,const int num_cores_per_replica,const std::vector<InferredShape> & arg_shapes,absl::flat_hash_map<string,std::vector<NodeOut>> * per_host_var_copies,Graph * graph)2789 xla::StatusOr<NodeOut> CreateOrGetPerHostVariableCopy(
2790 const string& host_cpu_device, int64 var_index,
2791 const std::vector<Node*>& variable_reads,
2792 const DistributedTPURewritePass::ParameterInfo& params_info,
2793 const std::vector<xla::OpSharding>& arg_shardings,
2794 const Node& replicate_node, const bool enable_xla_param_broadcast,
2795 const int num_cores_per_replica,
2796 const std::vector<InferredShape>& arg_shapes,
2797 absl::flat_hash_map<string, std::vector<NodeOut>>* per_host_var_copies,
2798 Graph* graph) {
2799 auto it = per_host_var_copies->find(host_cpu_device);
2800 if (it != per_host_var_copies->end()) {
2801 return it->second[var_index];
2802 }
2803
2804 // Variable replication relies on identification of a master.
2805 DeviceNameUtils::ParsedName parsed_device;
2806 TF_RET_CHECK(DeviceNameUtils::ParseFullName(host_cpu_device, &parsed_device));
2807 TF_RET_CHECK(parsed_device.has_task);
2808 VLOG(1) << "Creating per-host IdentityN node for task " << parsed_device.task;
2809
2810 DataTypeVector dtypes;
2811 // Per-variable data source for TPUExecute.
2812 std::vector<NodeOut> index_mapping;
2813 index_mapping.reserve(variable_reads.size());
2814 dtypes.reserve(variable_reads.size());
2815 for (int64 i = 0; i < variable_reads.size(); ++i) {
2816 Node* read = variable_reads[i];
2817 int64 orig_arg_num =
2818 i + params_info.NumPerReplicaArgs() + params_info.NumBroadcastArgs();
2819 if (arg_shardings[orig_arg_num].type() != xla::OpSharding::OTHER) {
2820 // We haven't built the IdentityN node yet, so temporarily use nullptr.
2821 index_mapping.push_back(
2822 NodeOut{nullptr, static_cast<int>(dtypes.size())});
2823 dtypes.push_back(read->output_type(0));
2824 } else {
2825 // Do not copy the full tensor of partitioned variables.
2826 index_mapping.push_back(NodeOut{read, 0});
2827 }
2828 }
2829 NodeDef ndef;
2830 ndef.set_name(
2831 graph->NewName(absl::StrCat(replicate_node.name(), "/_variable_copy")));
2832 ndef.set_op("IdentityN");
2833 ndef.set_device(host_cpu_device);
2834 AddNodeAttr("T", dtypes, &ndef);
2835 // TF meta-optimizer should skip this node for constant folding.
2836 AddNodeAttr("_tpu_avoid_constant_fold", "not_used", &ndef);
2837 Status s;
2838 Node* id_node = graph->AddNode(ndef, &s);
2839 TF_RETURN_IF_ERROR(s);
2840 id_node->set_assigned_device_name(host_cpu_device);
2841
2842 for (int64 i = 0; i < variable_reads.size(); ++i) {
2843 if (index_mapping[i].node == nullptr) {
2844 // Fill index_mapping with the actual IdentityN node.
2845 index_mapping[i].node = id_node;
2846 if (parsed_device.task == 0 || !enable_xla_param_broadcast) {
2847 // XLA broadcast mode is not enabled, so use the variable reads as args
2848 // to TPUExecuteOp. For task 0, variable reads are always used
2849 // regardless of XLA broadcast.
2850
2851 // Add the variable read edge to id_node.
2852 graph->AddEdge(variable_reads[i], 0, id_node, index_mapping[i].index);
2853 } else {
2854 // XLA broadcast mode is enabled. Create zero-valued dummy tensors to
2855 // use as variable args in the TPUExecuteOp.
2856 int64 orig_arg_num = i + params_info.NumPerReplicaArgs() +
2857 params_info.NumBroadcastArgs();
2858 if (num_cores_per_replica > 1) {
2859 LOG(WARNING) << "XLA parameter broadcast is only supported for "
2860 "replicated parameters. Falling back to "
2861 "non-broadcast mode for the parameter associated "
2862 "with the following variable read: "
2863 << variable_reads[i]->name();
2864 graph->AddEdge(variable_reads[i], 0, id_node, index_mapping[i].index);
2865 continue;
2866 }
2867 string dummy_name =
2868 strings::StrCat(variable_reads[i]->name(),
2869 absl::StrFormat("/dummy_%d", parsed_device.task));
2870 TF_ASSIGN_OR_RETURN(
2871 Node * var_read,
2872 CreatePerHostDummyArgs(arg_shapes[orig_arg_num], host_cpu_device,
2873 variable_reads[i], dummy_name, graph));
2874 graph->AddEdge(var_read, 0, id_node, index_mapping[i].index);
2875 }
2876 }
2877 }
2878
2879 auto result = index_mapping[var_index];
2880 (*per_host_var_copies)[host_cpu_device] = std::move(index_mapping);
2881 return result;
2882 }
2883
2884 } // namespace
2885
BuildExecuteNodes(const ParameterInfo & params_info,int num_tasks,int num_cores_per_replica,const Node & replicate_node,const std::vector<std::string> & arg_names,const DataTypeVector & arg_types,const std::vector<InferredShape> & arg_shapes,const DataTypeVector & retval_types,const std::vector<xla::OpSharding> & arg_shardings,const std::vector<xla::OpSharding> & retval_shardings,const std::vector<std::vector<string>> & tpu_device_names,Node * compile_node,const std::vector<Node * > & variable_reads,Node * control_predecessor,Node * control_successor,std::vector<VariableWrite> * variable_writes,Graph * graph)2886 Status DistributedTPURewritePass::BuildExecuteNodes(
2887 const ParameterInfo& params_info, int num_tasks, int num_cores_per_replica,
2888 const Node& replicate_node, const std::vector<std::string>& arg_names,
2889 const DataTypeVector& arg_types,
2890 const std::vector<InferredShape>& arg_shapes,
2891 const DataTypeVector& retval_types,
2892 const std::vector<xla::OpSharding>& arg_shardings,
2893 const std::vector<xla::OpSharding>& retval_shardings,
2894 const std::vector<std::vector<string>>& tpu_device_names,
2895 Node* compile_node, const std::vector<Node*>& variable_reads,
2896 Node* control_predecessor, Node* control_successor,
2897 std::vector<VariableWrite>* variable_writes, Graph* graph) {
2898 VLOG(1) << "BuildExecuteNodes " << replicate_node.DebugString();
2899 TF_RET_CHECK(params_info.NumReplicas() == tpu_device_names.size());
2900
2901 const int num_variables = variable_reads.size();
2902 const int num_retvals_per_replica = retval_types.size();
2903
2904 variable_writes->resize(num_variables);
2905
2906 std::vector<const Edge*> replicate_input_edges;
2907 TF_RETURN_IF_ERROR(replicate_node.input_edges(&replicate_input_edges));
2908
2909 // Map from replicate input index to the fan_in node;
2910 absl::flat_hash_map<int, std::vector<Node*>> replicate_input_fan_in_nodes;
2911 absl::flat_hash_map<int, std::vector<Node*>> replicate_output_fan_out_nodes;
2912 absl::flat_hash_map<int, std::vector<int>>
2913 replicate_output_fan_out_dst_inputs;
2914 std::vector<Node*> to_be_removed_nodes;
2915
2916 for (const Edge* e : replicate_input_edges) {
2917 if (e->src()->type_string() == kTPUPartitionedInput) {
2918 int num_users = 0;
2919 for (const auto& ue : e->src()->out_edges()) {
2920 if (!ue->IsControlEdge()) ++num_users;
2921 }
2922 if (num_users != 1) {
2923 return tensorflow::errors::InvalidArgument(
2924 e->src()->name(), " must only have one user. Found ", num_users);
2925 }
2926 to_be_removed_nodes.push_back(e->src());
2927 std::vector<Node*>& nodes = replicate_input_fan_in_nodes[e->dst_input()];
2928 nodes.resize(num_cores_per_replica, nullptr);
2929 VLOG(2) << "allocate " << num_cores_per_replica
2930 << " for replicate_input_fan_in_nodes[" << e->dst_input() << "]";
2931 std::vector<const Edge*> fan_in_edges;
2932 TF_RETURN_IF_ERROR(e->src()->input_edges(&fan_in_edges));
2933 TF_RET_CHECK(fan_in_edges.size() == num_cores_per_replica);
2934
2935 for (const Edge* fe : fan_in_edges) {
2936 nodes[fe->dst_input()] = fe->src();
2937 VLOG(2) << "replicate_input_fan_in_nodes[" << e->dst_input() << "]["
2938 << fe->dst_input() << "] = " << fe->src()->name();
2939 }
2940 }
2941 }
2942
2943 // Replicate output edges are sorted by replica id and then by outputs for
2944 // each replica. For example, if TPU Computation has outputs (output_1,
2945 // output_2, and output_3) and number of replicas is 2, then
2946 // replicate_output_edges order would be:
2947 // output_1_replica_1, output_2_replica_1, output_3_replica_1,
2948 // output_1_replica_2, output_2_replica_2, output_3_replica_2.
2949 std::vector<const Edge*> replicate_output_edges(replicate_node.num_outputs(),
2950 nullptr);
2951 for (const Edge* edge : replicate_node.out_edges()) {
2952 if (edge->IsControlEdge()) continue;
2953
2954 int num_partitioned_outputs = 0;
2955
2956 for (const Edge* out_edge : edge->dst()->out_edges()) {
2957 if (out_edge->dst()->type_string() == kTPUPartitionedOutput) {
2958 num_partitioned_outputs++;
2959 // Paths between replicate_node and replicate_output_fan_out_nodes:
2960 // ReplicateNode->TpuOutIdenity->kTPUPartitionedOutput->fan-out-nodes
2961 TF_RET_CHECK(edge->dst()->out_edges().size() == 1);
2962 to_be_removed_nodes.push_back(edge->dst());
2963 to_be_removed_nodes.push_back(out_edge->dst());
2964 // Get the right replicated id from the replicate_output_edge.
2965 std::vector<Node*>& nodes =
2966 replicate_output_fan_out_nodes[edge->src_output()];
2967 std::vector<int>& dst_inputs =
2968 replicate_output_fan_out_dst_inputs[edge->src_output()];
2969 nodes.resize(num_cores_per_replica, nullptr);
2970 dst_inputs.resize(num_cores_per_replica, 0);
2971 TF_RET_CHECK(out_edge->dst()->out_edges().size() ==
2972 num_cores_per_replica);
2973
2974 for (const Edge* fe : out_edge->dst()->out_edges()) {
2975 nodes[fe->src_output()] = fe->dst();
2976 dst_inputs[fe->src_output()] = fe->dst_input();
2977 VLOG(2) << "replicate_output_fan_out_nodes[" << out_edge->src_output()
2978 << "][" << fe->src_output()
2979 << "] = " << fe->dst()->DebugString() << " with dst_input "
2980 << fe->dst_input();
2981 }
2982 }
2983 }
2984 replicate_output_edges[edge->src_output()] = edge;
2985 if (num_partitioned_outputs > 1) {
2986 return errors::InvalidArgument(
2987 "More than one TPUPartitionedOutput per replciated output.");
2988 }
2989 }
2990
2991 const int num_execute_args =
2992 arg_shardings.size() - params_info.NumGuaranteedConstants();
2993 // Inverts the arg_shardings and retval_shardings mappings to
2994 // form core -> {argument number} maps.
2995 std::vector<std::vector<int>> core_arg_nums(num_cores_per_replica);
2996 for (int i = 0; i < num_execute_args; ++i) {
2997 const auto& sharding = arg_shardings[i];
2998 if (sharding.type() == xla::OpSharding::MAXIMAL) {
2999 int core = sharding.tile_assignment_devices(0);
3000 TF_RETURN_IF_ERROR(ValidateCoreNumber(core, num_cores_per_replica));
3001 core_arg_nums[core].push_back(i);
3002 } else if (sharding.type() == xla::OpSharding::OTHER) {
3003 for (int64 core : sharding.tile_assignment_devices()) {
3004 core_arg_nums[core].push_back(i);
3005 }
3006 } else if (sharding.type() == xla::OpSharding::REPLICATED) {
3007 for (int core = 0; core < num_cores_per_replica; ++core) {
3008 core_arg_nums[core].push_back(i);
3009 }
3010 } else {
3011 return tensorflow::errors::InvalidArgument(
3012 "Unsupported argument sharding for arg=", arg_names[i],
3013 " shape=", arg_shapes[i].shape.DebugString(), ": ",
3014 sharding.DebugString());
3015 }
3016 }
3017 std::vector<std::vector<int>> core_retval_nums(num_cores_per_replica);
3018 for (int i = 0; i < retval_shardings.size(); ++i) {
3019 const auto& sharding = retval_shardings[i];
3020 if (sharding.type() == xla::OpSharding::MAXIMAL) {
3021 int core = sharding.tile_assignment_devices(0);
3022 TF_RETURN_IF_ERROR(ValidateCoreNumber(core, num_cores_per_replica));
3023 core_retval_nums[core].push_back(i);
3024 } else if (sharding.type() == xla::OpSharding::REPLICATED) {
3025 for (int core = 0; core < num_cores_per_replica; ++core) {
3026 core_retval_nums[core].push_back(i);
3027 }
3028 } else if (sharding.type() == xla::OpSharding::OTHER) {
3029 for (int64 core : sharding.tile_assignment_devices()) {
3030 core_retval_nums[core].push_back(i);
3031 }
3032 } else {
3033 return tensorflow::errors::InvalidArgument(
3034 "Unsupported argument sharding: ", sharding.DebugString());
3035 }
3036 }
3037
3038 // Maps host device name to a list of per-variable pairs (variable_copy_node,
3039 // output_index_of_copy_node).
3040 absl::flat_hash_map<string, std::vector<NodeOut>> per_host_var_copies;
3041
3042 // Mapping from original resource arg number to a second level map. Second
3043 // level map is from core id to output index of updated variable value.
3044 absl::flat_hash_map<int, absl::flat_hash_map<int, int>>
3045 orig_arg_num_to_output_index_mapping;
3046 // Mapping from retval index to a second level map. Second level map is from
3047 // core id to output index of sharded output value.
3048 std::unordered_map<int, std::unordered_map<int, int>>
3049 retval_index_to_output_index_mapping;
3050
3051 // Represents mapping of argument index of sharded input to each
3052 // TPUExecute node to its corresponding Split node and its output index
3053 // from which sharded input will be fed into TPUExecute node.
3054 std::map<ShardedInputIndex, ShardedInputInfo> input_index_to_sharded_inputs;
3055
3056 // Builds one TPUExecute node per core per replica.
3057 std::vector<std::vector<Node*>> execute_nodes(params_info.NumReplicas());
3058 for (int core = 0; core < num_cores_per_replica; ++core) {
3059 DataTypeVector core_retval_types;
3060 for (int output : core_retval_nums[core]) {
3061 core_retval_types.push_back(retval_types[output]);
3062 }
3063 DataTypeVector core_arg_types;
3064 std::vector<int> core_variable_writes;
3065 for (int input : core_arg_nums[core]) {
3066 // Resource variables can be passed either by reference (as a DT_RESOURCE)
3067 // tensor or by value (as the variable's current value). Per-replica or
3068 // distributed resource arguments are always passed by reference and
3069 // broadcast variables are always passed by value.
3070 if (arg_types[input] == DT_RESOURCE &&
3071 !params_info.IsPerReplicaArg(input) &&
3072 !params_info.IsDistributedArg(input)) {
3073 DataType handle_type = arg_shapes[input].handle_type;
3074 TF_RET_CHECK(handle_type != DT_INVALID) << DataTypeString(handle_type);
3075 core_arg_types.push_back(handle_type);
3076 int base = input - params_info.NumPerReplicaArgs() -
3077 params_info.NumDistributedArgs() -
3078 params_info.NumBroadcastArgs();
3079 // Variables passed by value will have a corresponding additional output
3080 // containing an updated value for the variable.
3081 core_variable_writes.push_back(base);
3082 core_retval_types.push_back(handle_type);
3083 } else {
3084 core_arg_types.push_back(arg_types[input]);
3085 }
3086 }
3087
3088 NodeDef def;
3089 def.set_op("TPUExecute");
3090 MergeDebugInfo(NodeDebugInfo(replicate_node.def()), &def);
3091 AddNodeAttr("Targs", core_arg_types, &def);
3092 AddNodeAttr("Tresults", core_retval_types, &def);
3093
3094 for (int64 replica = 0; replica < params_info.NumReplicas(); ++replica) {
3095 def.set_name(strings::StrCat(replicate_node.name(), "/_execute_", replica,
3096 "_", core));
3097
3098 Status status;
3099 Node* node = graph->AddNode(def, &status);
3100 if (!status.ok()) return status;
3101 execute_nodes[replica].push_back(node);
3102
3103 node->set_assigned_device_name(tpu_device_names[replica][core]);
3104
3105 // Add control edges to ensure that execution happens after
3106 // `control_predecessor`, happens before `control_successor`, and is
3107 // triggered by evaluating any operator that depends on the original
3108 // TPUReplicate operator. See the comment at the top of the header file
3109 // for more details.
3110 graph->AddControlEdge(control_predecessor, node);
3111 graph->AddControlEdge(node, control_successor);
3112
3113 // Add data input edges.
3114 for (int64 i = 0; i < core_arg_nums[core].size(); ++i) {
3115 int64 orig_arg_num = core_arg_nums[core][i];
3116 VLOG(2) << " replica " << replica << " core " << core << " i " << i
3117 << " orig_arg_num " << orig_arg_num;
3118 if (params_info.IsPerReplicaArg(orig_arg_num) ||
3119 params_info.IsDistributedArg(orig_arg_num)) {
3120 // Per-replica input and distributed input
3121 int64 input_num = params_info.IsPerReplicaArg(orig_arg_num)
3122 ? replica * params_info.NumPerReplicaArgs() +
3123 core_arg_nums[core][i]
3124 : params_info.NumReplicas() *
3125 params_info.NumPerReplicaArgs() +
3126 core_arg_nums[core][i] -
3127 params_info.NumPerReplicaArgs();
3128
3129 const Edge* edge = replicate_input_edges[input_num];
3130 VLOG(2) << "replicate_input_edges[" << input_num << "]";
3131 DataType dtype = edge->src()->output_type(edge->src_output());
3132 if (dtype == DT_RESOURCE) {
3133 DataType handle_dtype = arg_shapes[orig_arg_num].handle_type;
3134 if (std::find(kTpuAllTypes.begin(), kTpuAllTypes.end(),
3135 handle_dtype) == kTpuAllTypes.end()) {
3136 return errors::InvalidArgument(
3137 "Unsupported resource variable data type for TPU: ",
3138 DataTypeString(handle_dtype), ", caused by output ",
3139 edge->src()->name(), ":", edge->src_output());
3140 }
3141 } else {
3142 if (std::find(kTpuAllTypes.begin(), kTpuAllTypes.end(), dtype) ==
3143 kTpuAllTypes.end()) {
3144 return errors::InvalidArgument(
3145 "Unsupported data type for TPU: ", DataTypeString(dtype),
3146 ", caused by output ", edge->src()->name(), ":",
3147 edge->src_output());
3148 }
3149 }
3150 if (arg_shardings[orig_arg_num].type() == xla::OpSharding::OTHER) {
3151 // Don't automatically add a split node when input node is
3152 // kTPUPartitionedInput
3153 if (edge->src()->type_string() == kTPUPartitionedInput) {
3154 VLOG(2) << "Connect "
3155 << replicate_input_fan_in_nodes[input_num][core]->name()
3156 << " to " << node->name() << " at " << i;
3157 graph->AddEdge(replicate_input_fan_in_nodes[input_num][core], 0,
3158 node, i);
3159 } else {
3160 if (dtype == DT_RESOURCE) {
3161 return errors::InvalidArgument(
3162 "Tiled sharding for per-replica DT_RESOURCE input must",
3163 "be TPUPartitionedInput. Here got ",
3164 edge->src()->type_string());
3165 }
3166 const xla::OpSharding& sharding = arg_shardings[orig_arg_num];
3167
3168 // Create or get the Split node.
3169 TF_ASSIGN_OR_RETURN(
3170 ShardedInputInfo sharded_input_info,
3171 CreateOrGetSplitNodesForInputSharding(
3172 sharding, orig_arg_num, dtype,
3173 arg_shapes[orig_arg_num].handle_shape, replica,
3174 edge->src_output(), edge->src(), control_predecessor,
3175 graph, &input_index_to_sharded_inputs));
3176 NodeOut split_node_and_index =
3177 sharded_input_info.sharded_inputs.at(core);
3178 // Connect with Split node output.
3179 graph->AddEdge(split_node_and_index.node,
3180 split_node_and_index.index, node, i);
3181 }
3182 } else if (edge->src()->type_string() == kTPUPartitionedInput &&
3183 arg_shardings[orig_arg_num].type() ==
3184 xla::OpSharding::REPLICATED) {
3185 graph->AddEdge(replicate_input_fan_in_nodes[input_num][core], 0,
3186 node, i);
3187 } else {
3188 graph->AddEdge(edge->src(), edge->src_output(), node, i);
3189 }
3190 } else if (params_info.IsBroadcastArg(orig_arg_num)) {
3191 // Broadcast input.
3192 int64 input_num = params_info.FirstBroadcastArgFromHost() +
3193 core_arg_nums[core][i] -
3194 params_info.NumPerReplicaArgs() -
3195 params_info.NumDistributedArgs();
3196 const Edge* edge = replicate_input_edges[input_num];
3197 DataType dtype = edge->src()->output_type(edge->src_output());
3198 if (std::find(kTpuAllTypes.begin(), kTpuAllTypes.end(), dtype) ==
3199 kTpuAllTypes.end()) {
3200 return errors::InvalidArgument(
3201 "Unsupported data type for TPU: ", DataTypeString(dtype),
3202 ", caused by output ", edge->src()->name(), ":",
3203 edge->src_output());
3204 }
3205 graph->AddEdge(edge->src(), edge->src_output(), node, i);
3206 } else {
3207 // Variable input.
3208 int64 variable_num = orig_arg_num - params_info.NumPerReplicaArgs() -
3209 params_info.NumDistributedArgs() -
3210 params_info.NumBroadcastArgs();
3211 TF_RET_CHECK(variable_num < num_variables);
3212
3213 Node* variable_read = variable_reads[variable_num];
3214 DataType dtype = variable_read->output_type(0);
3215 if (std::find(kTpuAllTypes.begin(), kTpuAllTypes.end(), dtype) ==
3216 kTpuAllTypes.end()) {
3217 return errors::InvalidArgument(
3218 "Unsupported resource variable data type for TPU: ",
3219 DataTypeString(dtype), ", caused by ReadVariableOp ",
3220 variable_read->DebugString());
3221 }
3222 DeviceNameUtils::ParsedName requested_device;
3223 string requested = variable_read->requested_device();
3224 TF_RET_CHECK(
3225 DeviceNameUtils::ParseFullName(requested, &requested_device));
3226 if (requested_device.type != "TPU") {
3227 // Stage the value via the CPU device on the remote host. The graph
3228 // partitioner will introduce an intermediate copy rather than
3229 // copying the same tensor multiple times across the network, and we
3230 // would prefer that intermediate copy to be in host memory to avoid
3231 // running out of memory if the TPUExecute op on the staging device
3232 // starts running before the _Send ops to the other TPU devices on
3233 // the same host complete. We don't do this if the variables are
3234 // already placed on TPU, otherwise it will cause an unnecessary
3235 // round trip copy.
3236 // TODO(b/79580121): give each replica its own on-device variable
3237 // replica and then delete this code.
3238 string device;
3239 TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
3240 tpu_device_names[replica][core], &device));
3241 TF_ASSIGN_OR_RETURN(
3242 auto var_data,
3243 CreateOrGetPerHostVariableCopy(
3244 device, variable_num, variable_reads, params_info,
3245 arg_shardings, replicate_node, enable_xla_param_broadcast_,
3246 num_cores_per_replica, arg_shapes, &per_host_var_copies,
3247 graph));
3248
3249 if (arg_shardings[orig_arg_num].type() == xla::OpSharding::OTHER) {
3250 const xla::OpSharding& sharding = arg_shardings[orig_arg_num];
3251 // Create or get the Split node.
3252 TF_ASSIGN_OR_RETURN(
3253 ShardedInputInfo sharded_input_info,
3254 CreateOrGetSplitNodesForInputSharding(
3255 sharding, orig_arg_num,
3256 arg_shapes[orig_arg_num].handle_type,
3257 arg_shapes[orig_arg_num].handle_shape, replica,
3258 var_data.index, var_data.node, control_predecessor, graph,
3259 &input_index_to_sharded_inputs));
3260 NodeOut split_node_and_index =
3261 sharded_input_info.sharded_inputs[core];
3262 // Connect with Split node output.
3263 graph->AddEdge(split_node_and_index.node,
3264 split_node_and_index.index, node, i);
3265
3266 } else {
3267 graph->AddEdge(var_data.node, var_data.index, node, i);
3268 }
3269 } else {
3270 graph->AddEdge(variable_reads[variable_num], 0, node, i);
3271 }
3272 }
3273 }
3274
3275 // Adds a program input edge from the compiler.
3276 graph->AddEdge(compile_node, core + 1, node, node->num_inputs() - 1);
3277
3278 // Add data output edges.
3279 int num_outputs = core_retval_nums[core].size();
3280 for (int i = 0; i < num_outputs; ++i) {
3281 int output_num =
3282 replica * num_retvals_per_replica + core_retval_nums[core][i];
3283 const auto& sharding = retval_shardings[core_retval_nums[core][i]];
3284 if (sharding.type() == xla::OpSharding::OTHER) {
3285 int retval_index = core_retval_nums[core][i];
3286 retval_index_to_output_index_mapping[retval_index][core] = i;
3287 bool is_last_core =
3288 core ==
3289 *std::max_element(sharding.tile_assignment_devices().begin(),
3290 sharding.tile_assignment_devices().end());
3291 bool isPartitionOutNode = false;
3292
3293 const Edge* e = replicate_output_edges[output_num];
3294 const Edge* e_out;
3295 for (const Edge* out_edge : e->dst()->out_edges()) {
3296 if (out_edge->dst()->type_string() == kTPUPartitionedOutput) {
3297 isPartitionOutNode = true;
3298 e_out = out_edge;
3299 }
3300 }
3301 if (isPartitionOutNode) {
3302 graph->AddEdge(
3303 node, i, replicate_output_fan_out_nodes[output_num][core],
3304 replicate_output_fan_out_dst_inputs[output_num][core]);
3305 VLOG(2) << "Connect " << node->name() << " at " << i << " to "
3306 << replicate_output_fan_out_nodes[output_num][core]->name()
3307 << " at "
3308 << replicate_output_fan_out_dst_inputs[output_num][core];
3309 if (is_last_core) {
3310 graph->RemoveEdge(e);
3311 graph->RemoveEdge(e_out);
3312 }
3313 continue;
3314 }
3315
3316 // Do this in the iteration of last core in tile assignment, so all
3317 // TPUExecute nodes have been created.
3318 if (!is_last_core) {
3319 continue;
3320 }
3321
3322 // Add a Concat node.
3323 std::vector<NodeOut> orig_inputs;
3324 for (int64 tile_index = 0;
3325 tile_index < sharding.tile_assignment_devices_size();
3326 ++tile_index) {
3327 int64 last_tile_dim_size =
3328 *sharding.tile_assignment_dimensions().rbegin();
3329 if (sharding.replicate_on_last_tile_dim() &&
3330 tile_index % last_tile_dim_size != 0) {
3331 continue;
3332 }
3333 int64 core_id = sharding.tile_assignment_devices(tile_index);
3334 int core_retval_index =
3335 retval_index_to_output_index_mapping[retval_index][core_id];
3336 orig_inputs.push_back(
3337 NodeOut{execute_nodes[replica][core_id],
3338 static_cast<int>(
3339 core_retval_nums[core_id][core_retval_index])});
3340 }
3341 DataType dtype = e->src()->output_type(e->src_output());
3342 TF_ASSIGN_OR_RETURN(
3343 Node * concat_node,
3344 CreateConcatNodesForRetval(
3345 sharding, dtype, /*inferred_shape*/ PartialTensorShape(),
3346 replica, orig_inputs, graph, /*device=*/""));
3347
3348 const Edge* edge = replicate_output_edges[output_num];
3349 Node* dst = edge->dst();
3350 int dst_input = edge->dst_input();
3351 graph->RemoveEdge(edge);
3352 graph->AddEdge(concat_node, 0, dst, dst_input);
3353
3354 continue;
3355 }
3356
3357 // If this is a replicated output, outputs on all cores will be the
3358 // same, and we only take the output from core 0.
3359 if (sharding.type() == xla::OpSharding::REPLICATED && core != 0) {
3360 continue;
3361 }
3362
3363 // If output has maximal sharding, make sure we only use output from
3364 // TPUExecute node with logical core id equal to core id defined by the
3365 // xla sharding.
3366 if (sharding.type() == xla::OpSharding::MAXIMAL &&
3367 core != sharding.tile_assignment_devices(0)) {
3368 continue;
3369 }
3370
3371 const Edge* replicate_edge_to_replace =
3372 replicate_output_edges[output_num];
3373 Node* dst = replicate_edge_to_replace->dst();
3374 int dst_input = replicate_edge_to_replace->dst_input();
3375 graph->RemoveEdge(replicate_edge_to_replace);
3376 graph->AddEdge(node, i, dst, dst_input);
3377 }
3378
3379 // Feed the updated variable values from the first replica to the
3380 // variable write nodes.
3381 if (replica == 0) {
3382 for (int i = 0; i < core_variable_writes.size(); ++i) {
3383 int orig_arg_num =
3384 core_variable_writes[i] + params_info.NumPerReplicaArgs() +
3385 params_info.NumDistributedArgs() + params_info.NumBroadcastArgs();
3386 const auto& sharding = arg_shardings[orig_arg_num];
3387 // If this is a tiling sharded variable, concat variable updates from
3388 // all cores.
3389 if (sharding.type() == xla::OpSharding::OTHER) {
3390 orig_arg_num_to_output_index_mapping[orig_arg_num][core] = i;
3391
3392 // Do this in the iteration of last core in tile assignment, so all
3393 // TPUExecute nodes have been created.
3394 if (core !=
3395 *std::max_element(sharding.tile_assignment_devices().begin(),
3396 sharding.tile_assignment_devices().end())) {
3397 continue;
3398 }
3399
3400 // Add a Concat node.
3401 std::vector<NodeOut> orig_inputs;
3402 for (int64 tile_index = 0;
3403 tile_index < sharding.tile_assignment_devices_size();
3404 ++tile_index) {
3405 int64 last_tile_dim_size =
3406 *sharding.tile_assignment_dimensions().rbegin();
3407 if (sharding.replicate_on_last_tile_dim() &&
3408 tile_index % last_tile_dim_size != 0) {
3409 continue;
3410 }
3411 int64 core_id = sharding.tile_assignment_devices(tile_index);
3412 int core_retval_num =
3413 orig_arg_num_to_output_index_mapping[orig_arg_num][core_id];
3414 orig_inputs.push_back(
3415 NodeOut{execute_nodes[0][core_id],
3416 static_cast<int>(core_retval_nums[core_id].size() +
3417 core_retval_num)});
3418 }
3419
3420 // Use the variable read's device for the concat. They should both
3421 // be collocated with the variable.
3422 absl::string_view device =
3423 variable_reads[core_variable_writes[i]]->assigned_device_name();
3424 TF_ASSIGN_OR_RETURN(
3425 Node * concat_node,
3426 CreateConcatNodesForRetval(
3427 sharding, arg_shapes[orig_arg_num].handle_type,
3428 arg_shapes[orig_arg_num].handle_shape, replica, orig_inputs,
3429 graph, device));
3430 // Populate VariableWrite.
3431 VariableWrite& write = variable_writes->at(core_variable_writes[i]);
3432 write.value = concat_node;
3433 write.value_output = 0;
3434 write.predicate = compile_node;
3435 write.predicate_output = num_cores_per_replica + core + 1;
3436
3437 continue;
3438 }
3439
3440 // If this is a replicated variable, outputs on all cores will be the
3441 // same, and we only take the output from core 0 for the varialbe
3442 // update.
3443 if (sharding.type() == xla::OpSharding::REPLICATED && core != 0) {
3444 continue;
3445 }
3446 VariableWrite& write = variable_writes->at(core_variable_writes[i]);
3447 write.value = node;
3448 write.value_output = num_outputs + i;
3449 write.predicate = compile_node;
3450 write.predicate_output = num_cores_per_replica + core + 1;
3451 }
3452 }
3453 }
3454 }
3455
3456 for (Node* node : to_be_removed_nodes) {
3457 graph->RemoveNode(node);
3458 }
3459 return Status::OK();
3460 }
3461
CopyOutsideCompilationNodes(int replica_index,const std::vector<Node * > & outside_compilation_nodes,const DeviceNameUtils::ParsedName & tpu_device,const DeviceNameUtils::ParsedName & partial_device,NodeToNodeReplicasMap * node_images,Graph * graph)3462 /* static */ Status DistributedTPURewritePass::CopyOutsideCompilationNodes(
3463 int replica_index, const std::vector<Node*>& outside_compilation_nodes,
3464 const DeviceNameUtils::ParsedName& tpu_device,
3465 const DeviceNameUtils::ParsedName& partial_device,
3466 NodeToNodeReplicasMap* node_images, Graph* graph) {
3467 for (Node* node : outside_compilation_nodes) {
3468 NodeDef image_def = node->def();
3469 MergeDebugInfo(NodeDebugInfo(node->def()), &image_def);
3470 const string suffix = strings::StrCat("/R", replica_index);
3471 // In addition to node name, make the frame name unique to avoid multiple
3472 // LoopCond nodes in one frame.
3473 TF_RETURN_IF_ERROR(
3474 AddPrefixAndSuffixToNode("" /* prefix */, suffix, &image_def));
3475 Status status;
3476 Node* image = graph->AddNode(image_def, &status);
3477 image->AddAttr(kXlaReplicaIdAttrName, replica_index);
3478 TF_RETURN_IF_ERROR(status);
3479 if (HasNodeAttr(image->def(), kXlaHasHostTransferAttrName)) {
3480 TF_RETURN_IF_ERROR(
3481 SetNodeDeviceForTPUCommunication(tpu_device, DEVICE_CPU, image));
3482 } else {
3483 const string& original_device_string =
3484 node->assigned_device_name().empty() ? node->requested_device()
3485 : node->assigned_device_name();
3486 DeviceNameUtils::ParsedName device;
3487 TF_RET_CHECK(
3488 DeviceNameUtils::ParseFullName(original_device_string, &device));
3489 // If the requested device can be merged with the replica's host device,
3490 // then do so. For example, if the requested device is "/CPU:0" or
3491 // "/GPU:0" then it will be placed on the CPU/GPU of the host where this
3492 // replica is running. But if the requested device is
3493 // "/task:3/replica:2/CPU:0" then it will be placed on that task/replica.
3494 if (DeviceNameUtils::IsSpecification(device, partial_device)) {
3495 TF_RETURN_IF_ERROR(
3496 DeviceNameUtils::MergeDevNames(&device, partial_device));
3497 }
3498 image->set_requested_device(DeviceNameUtils::ParsedNameToString(device));
3499 }
3500 std::vector<Node*>& node_image_vector = (*node_images)[node];
3501 node_image_vector.resize(replica_index + 1);
3502 node_image_vector[replica_index] = image;
3503 }
3504 return Status::OK();
3505 }
3506
ReplicateOutsideCompilationNodes(const std::vector<std::vector<string>> & tf_device_assignment,const HostComputeCoreMap & host_compute_core,const OutsideCompilationNodeMap & outside_compilation_nodes,NodeToNodeReplicasMap * node_images,Graph * graph)3507 /* static */ Status DistributedTPURewritePass::ReplicateOutsideCompilationNodes(
3508 const std::vector<std::vector<string>>& tf_device_assignment,
3509 const HostComputeCoreMap& host_compute_core,
3510 const OutsideCompilationNodeMap& outside_compilation_nodes,
3511 NodeToNodeReplicasMap* node_images, Graph* graph) {
3512 // Iterate over replicas.
3513 for (int i = 0; i < tf_device_assignment.size(); ++i) {
3514 const auto& core_devices = tf_device_assignment[i];
3515 for (const auto& oc_cluster_iter : outside_compilation_nodes) {
3516 const string& oc_cluster_name = oc_cluster_iter.first;
3517 const auto& oc_cluster_nodes = oc_cluster_iter.second;
3518 // We previously validated that host_compute_core contains an entry for
3519 // each cluster.
3520 int core = host_compute_core.at(oc_cluster_name);
3521 TF_RET_CHECK(core >= 0 && core < core_devices.size());
3522 // tpu_device is the device the HostCompute XLA Op for this cluster runs
3523 // on.
3524 DeviceNameUtils::ParsedName tpu_device;
3525 TF_RET_CHECK(
3526 DeviceNameUtils::ParseFullName(core_devices[core], &tpu_device));
3527 // partial_device contains the replica and task but not the type.
3528 DeviceNameUtils::ParsedName partial_device = tpu_device;
3529 partial_device.has_type = false;
3530 partial_device.has_id = false;
3531
3532 if (tf_device_assignment.size() == 1) {
3533 // With a single replica don't copy any nodes just put the original
3534 // nodes into the image map. We leave the device placement alone, except
3535 // that we have to fill in the correct core for the host send and
3536 // receive nodes.
3537 for (Node* node : oc_cluster_nodes) {
3538 (*node_images)[node] = {node};
3539 node->AddAttr(kXlaReplicaIdAttrName, 0);
3540 if (HasNodeAttr(node->def(), kXlaHasHostTransferAttrName)) {
3541 TF_RETURN_IF_ERROR(
3542 SetNodeDeviceForTPUCommunication(tpu_device, DEVICE_CPU, node));
3543 }
3544 }
3545 } else {
3546 // Iterate over outside_compilation clusters in this computation, adding
3547 // all the nodes with appropriate device assignments.
3548 TF_RETURN_IF_ERROR(
3549 CopyOutsideCompilationNodes(i, oc_cluster_nodes, tpu_device,
3550 partial_device, node_images, graph));
3551 }
3552 }
3553 }
3554 return Status::OK();
3555 }
3556
CopyOutsideCompilationEdges(const std::vector<Node * > & outside_compilation_nodes,const NodeToNodeReplicasMap & node_images,const std::unordered_map<string,Node * > outside_compilation_inputs,Graph * graph)3557 /* static */ Status DistributedTPURewritePass::CopyOutsideCompilationEdges(
3558 const std::vector<Node*>& outside_compilation_nodes,
3559 const NodeToNodeReplicasMap& node_images,
3560 const std::unordered_map<string, Node*> outside_compilation_inputs,
3561 Graph* graph) {
3562 for (Node* node : outside_compilation_nodes) {
3563 const auto& images = node_images.at(node);
3564 // Make a copy of all edges and iterate on "in_edges", because we might
3565 // remove edges when iteratating through them.
3566 std::vector<const Edge*> in_edges(node->in_edges().begin(),
3567 node->in_edges().end());
3568 for (const Edge* edge : in_edges) {
3569 Node* src = edge->src();
3570 const auto iter = node_images.find(src);
3571 if (iter == node_images.end()) {
3572 if (images.size() > 1) {
3573 // The source node is a 'normal' node not part of any
3574 // rewrite. Broadcast the value to all replicas. (If images.size() ==
3575 // 1 the cluster is not replicated and we can leave the original edge
3576 // in place.)
3577 for (Node* dst : images) {
3578 graph->AddEdge(src, edge->src_output(), dst, edge->dst_input());
3579 }
3580 }
3581 continue;
3582 }
3583
3584 // The source node is a replicated outside_compilation node.
3585 const auto& src_images = iter->second;
3586 if (src_images.size() != images.size()) {
3587 return errors::InvalidArgument(
3588 "Graph contains an edge from node ", src->name(),
3589 " in an outside_compilation block replicated ", src_images.size(),
3590 " ways to node ", node->name(),
3591 " in an outside_compilation block replicated ", images.size(),
3592 " ways. Replication factors must match. Leave a comment on "
3593 "tracking bug b/76419636 if you need this to be supported.");
3594 }
3595 bool is_lifted_arg;
3596 string outside_compilation_cluster;
3597 if (GetNodeAttr(src->def(), kXlaIsLiftedArgAttrName, &is_lifted_arg)
3598 .ok() &&
3599 GetNodeAttr(src->def(), kOutsideCompilationAttr,
3600 &outside_compilation_cluster)
3601 .ok()) {
3602 const auto input_iter =
3603 outside_compilation_inputs.find(outside_compilation_cluster);
3604 TF_RET_CHECK(input_iter != outside_compilation_inputs.end());
3605 TF_RET_CHECK(input_iter->second->type_string() == "IdentityN");
3606 int dst_input = edge->dst_input();
3607 if (src_images.size() == 1) {
3608 graph->RemoveEdge(edge);
3609 }
3610 for (int i = 0; i < src_images.size(); ++i) {
3611 graph->AddEdge(input_iter->second, i, images[i], dst_input);
3612 }
3613 continue;
3614 }
3615
3616 bool is_placeholder_for_arg;
3617 string outside_compilation_input_attr;
3618 if (GetNodeAttr(src->def(), kXlaIsPlaceholderForArg,
3619 &is_placeholder_for_arg)
3620 .ok() &&
3621 GetNodeAttr(src->def(), kXlaOutsideCompilationInputsAttrName,
3622 &outside_compilation_input_attr)
3623 .ok()) {
3624 const auto input_iter =
3625 outside_compilation_inputs.find(outside_compilation_input_attr);
3626 TF_RET_CHECK(input_iter != outside_compilation_inputs.end());
3627 TF_RET_CHECK(input_iter->second->type_string() == "IdentityN");
3628 int dst_input = edge->dst_input();
3629 if (src_images.size() == 1) {
3630 graph->RemoveEdge(edge);
3631 }
3632 for (int i = 0; i < src_images.size(); ++i) {
3633 graph->AddEdge(input_iter->second, i, images[i], dst_input);
3634 }
3635 continue;
3636 }
3637
3638 if (images.size() > 1) {
3639 // If images.size() == 1 neither cluster is replicated and we can
3640 // leave the original edges in place.
3641 for (int i = 0; i < src_images.size(); ++i) {
3642 graph->AddEdge(src_images[i], edge->src_output(), images[i],
3643 edge->dst_input());
3644 }
3645 }
3646 }
3647 for (const Edge* edge : node->out_edges()) {
3648 Node* dst = edge->dst();
3649 const auto iter = node_images.find(dst);
3650 if (iter == node_images.end()) {
3651 // The source node is a 'normal' node not part of any rewrite.
3652 if (edge->IsControlEdge()) {
3653 // Make the dst node have a control dependency on every replica.
3654 if (images.size() > 1) {
3655 for (int i = 0; i < images.size(); ++i) {
3656 graph->AddControlEdge(images[i], dst);
3657 }
3658 }
3659 // else the cluster is not replicated so we can leave the original
3660 // edge in place.
3661 } else {
3662 // The edge
3663 // is only valid if the outside_compilation block is not replicated.
3664 if (images.size() > 1) {
3665 return errors::InvalidArgument(
3666 "Graph contains an edge from node ", node->name(),
3667 " in an outside_compilation block replicated ", images.size(),
3668 " ways to node ", dst->name(),
3669 " that is not part of an outside_compilation block. Edges from "
3670 "outside_compilation to regular graph nodes are only supported "
3671 "for replication factors of 1. Leave a comment on tracking bug "
3672 "b/76419636 if you need this to be supported.");
3673 }
3674 // else the cluster is not replicated so we can leave the original
3675 // edge in place.
3676 }
3677 }
3678 // The case where src and dst are both in node_images is covered elsewhere
3679 // when iterating over in_edges of dst.
3680 }
3681 }
3682 return Status::OK();
3683 }
3684
ReplicateOutsideCompilationEdges(const OutsideCompilationNodeMap & outside_compilation_nodes,const NodeToNodeReplicasMap & node_images,const std::unordered_map<string,Node * > outside_compilation_inputs,Graph * graph)3685 /* static */ Status DistributedTPURewritePass::ReplicateOutsideCompilationEdges(
3686 const OutsideCompilationNodeMap& outside_compilation_nodes,
3687 const NodeToNodeReplicasMap& node_images,
3688 const std::unordered_map<string, Node*> outside_compilation_inputs,
3689 Graph* graph) {
3690 for (const auto& oc_cluster_iter : outside_compilation_nodes) {
3691 TF_RETURN_IF_ERROR(
3692 CopyOutsideCompilationEdges(oc_cluster_iter.second, node_images,
3693 outside_compilation_inputs, graph));
3694 }
3695 return Status::OK();
3696 }
3697
RemoveOutsideCompilationNodes(const NodeToNodeReplicasMap & node_images,Graph * graph)3698 /* static */ Status DistributedTPURewritePass::RemoveOutsideCompilationNodes(
3699 const NodeToNodeReplicasMap& node_images, Graph* graph) {
3700 for (const auto& iter : node_images) {
3701 if (iter.second.size() > 1) {
3702 // The cluster was replicated so remove the original node.
3703 Node* node = iter.first;
3704 graph->RemoveNode(node);
3705 }
3706 }
3707 return Status::OK();
3708 }
3709
3710 /* static */ Status
LowerOutsideCompilationFunctionalNodes(Graph * g,const FunctionLibraryDefinition & flib_def,const TPUReplicateDeviceNamesMapping & tpu_replicate_device_names_mapping)3711 DistributedTPURewritePass::LowerOutsideCompilationFunctionalNodes(
3712 Graph* g, const FunctionLibraryDefinition& flib_def,
3713 const TPUReplicateDeviceNamesMapping& tpu_replicate_device_names_mapping) {
3714 bool modified = false;
3715 do {
3716 std::vector<Node*> nodes_to_lower;
3717 for (Node* n : g->op_nodes()) {
3718 if (!HasNodeAttr(n->def(), kOutsideCompilationAttr)) {
3719 continue;
3720 }
3721
3722 if (n->IsWhileNode() || n->IsIfNode() || IsFunctionCall(flib_def, *n)) {
3723 // Only lower functional ops with DT_RESOURCE input, because otherwise
3724 // placer will complain. For normal cases, lowering will cause slowdown
3725 // when related functions are huge (b/139037679).
3726 bool has_resource_input = false;
3727 for (const Edge* e : n->in_edges()) {
3728 if (!e->IsControlEdge() &&
3729 e->src()->output_type(e->src_output()) == DT_RESOURCE) {
3730 has_resource_input = true;
3731 break;
3732 }
3733 }
3734 if (has_resource_input) {
3735 nodes_to_lower.push_back(n);
3736 }
3737 }
3738 }
3739
3740 modified = !nodes_to_lower.empty();
3741
3742 auto lower_functional_node = [&flib_def, &g](Node* n) -> Status {
3743 // Clear device assignment. Otherwise all lowered nodes will have
3744 // device assignment, which is not what we want.
3745 n->set_requested_device("");
3746
3747 int replica_id;
3748 TF_RETURN_IF_ERROR(
3749 GetNodeAttr(n->def(), kXlaReplicaIdAttrName, &replica_id));
3750
3751 string outside_compilation_attr;
3752 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kOutsideCompilationAttr,
3753 &outside_compilation_attr));
3754
3755 // There are two different kinds of functional outside compilation nodes:
3756 // 1. Nodes that are in outside compilation blocks already. They are
3757 // generated by FunctionalizeControlFlowForXlaPass, and only have
3758 // attribute kOutsideCompilationAttr.
3759 // 2. Mirrored control flow built for outside compilation in functional
3760 // nodes. They are generated by ExtractOutsideCompilationPass, and have
3761 // both kOutsideCompilationAttr and kXlaHasHostTransferAttrName.
3762 // When lowering them, they need to be treated differently.
3763 // For 1), their body functions are always V1 functions written by users,
3764 // and their "control outputs" are control inputs of _Retval nodes. They
3765 // should be lowered as V1 functions.
3766 // For 2), we always add necessary "control outputs"
3767 // (_XlaRecvAtHost/_XlaSendAtHost nodes) to "control_ret" field in their
3768 // FunctionDef's. They should be lowered as V2 functions.
3769 bool is_host_side_mirrored_control_flow =
3770 HasNodeAttr(n->def(), kXlaHasHostTransferAttrName);
3771
3772 int num_node_ids = g->num_node_ids();
3773 bool is_call_node = IsFunctionCall(flib_def, *n);
3774 if (n->IsWhileNode()) {
3775 TF_RETURN_IF_ERROR(RewriteWhileNode(n, g,
3776 /*keep_node_fetchable=*/false));
3777 } else if (n->IsIfNode()) {
3778 TF_RETURN_IF_ERROR(RewriteIfNode(n, g, /*keep_node_fetchable=*/false));
3779 } else {
3780 TF_RET_CHECK(is_call_node);
3781 // See comments for "is_host_side_mirrored_control_flow" above.
3782 // If this is a node that's in outside compilation block, lower it as
3783 // V1 function. This is controlled by removing
3784 // kLowerAsMultiDeviceFunctionAttr from the node.
3785 if (!is_host_side_mirrored_control_flow) {
3786 n->ClearAttr(LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr);
3787 } else {
3788 n->ClearAttr(LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr);
3789 n->AddAttr(LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr,
3790 true);
3791 }
3792 TF_RETURN_IF_ERROR(
3793 RewriteFunctionCallNode(n, g, flib_def,
3794 /*keep_caller_fetchable=*/false));
3795 }
3796
3797 for (int i = num_node_ids; i < g->num_node_ids(); i++) {
3798 Node* node = g->FindNodeId(i);
3799 if (!node) {
3800 continue;
3801 }
3802
3803 if (!is_call_node && is_host_side_mirrored_control_flow &&
3804 IsFunctionCall(flib_def, *node)) {
3805 // For If/While nodes, if they are host side mirrored control flow,
3806 // mark their body function calls with kXlaHasHostTransferAttrName
3807 // attribute to make sure we lower them as V2 function.
3808 node->AddAttr(kXlaHasHostTransferAttrName, true);
3809 }
3810
3811 if (IsFunctionCall(flib_def, *node) || node->IsWhileNode() ||
3812 node->IsIfNode()) {
3813 // Set kOutsideCompilationAttr attribute so we lower these
3814 // nested function call nodes later.
3815 node->AddAttr(kOutsideCompilationAttr, outside_compilation_attr);
3816 // Set kXlaReplicaIdAttrName attribute so we know replica id when we
3817 // lower this function call node.
3818 node->AddAttr(kXlaReplicaIdAttrName, replica_id);
3819 } else if (node->type_string() == "_XlaRecvAtHost" ||
3820 node->type_string() == "_XlaSendFromHost") {
3821 // For "_XlaRecvAtHost" and "_XlaSendFromHost" nodes, make sure they
3822 // have kXlaReplicaIdAttrName attribute so later we know which host
3823 // device to assign.
3824 node->AddAttr(kXlaReplicaIdAttrName, replica_id);
3825 }
3826 }
3827 return Status::OK();
3828 };
3829
3830 for (Node* n : nodes_to_lower) {
3831 TF_RETURN_IF_ERROR(lower_functional_node(n));
3832 }
3833 } while (modified);
3834
3835 // Set device for all _XlaRecvAtHost and _XlaSendFromHost nodes.
3836 for (Node* n : g->op_nodes()) {
3837 if (n->type_string() != "_XlaRecvAtHost" &&
3838 n->type_string() != "_XlaSendFromHost") {
3839 continue;
3840 }
3841
3842 string replicate;
3843 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kTPUReplicateAttr, &replicate));
3844 auto iter = tpu_replicate_device_names_mapping.find(replicate);
3845 TF_RET_CHECK(iter != tpu_replicate_device_names_mapping.end());
3846 const auto& tpu_device_names = iter->second;
3847
3848 int replica_id;
3849 TF_RETURN_IF_ERROR(
3850 GetNodeAttr(n->def(), kXlaReplicaIdAttrName, &replica_id));
3851 TF_RET_CHECK(replica_id < tpu_device_names.size());
3852 const string& tpu_device_name = tpu_device_names[replica_id][0];
3853 string host_device_name;
3854 TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
3855 tpu_device_name, &host_device_name));
3856 n->set_assigned_device_name(host_device_name);
3857 // We may run TPU rewrite passes again on the subgraphs of the resulting
3858 // graph. Clear kTPUReplicateAttr and kOutsideCompilationAttr for
3859 // "_XlaRecvAtHost" nodes and "_XlaSendFromHost" nodes, in order to make
3860 // sure that TPU rewrite passes take no effect on host-side subgraphs for
3861 // outside compilation.
3862 n->ClearAttr(kTPUReplicateAttr);
3863 n->ClearAttr(kOutsideCompilationAttr);
3864 }
3865
3866 // Remove IdentityN nodes generated for outside compilation. IdentityN is
3867 // exempt from resource edge colocation, but here we do need input and output
3868 // for these IdentityN nodes to be colocated.
3869 std::vector<Node*> identityn_nodes;
3870 for (Node* n : g->op_nodes()) {
3871 if (n->type_string() == "IdentityN" &&
3872 HasNodeAttr(n->def(), kXlaOutsideCompilationInputsAttrName)) {
3873 identityn_nodes.push_back(n);
3874 }
3875 }
3876 for (Node* n : identityn_nodes) {
3877 std::vector<const Edge*> out_edges(n->out_edges().begin(),
3878 n->out_edges().end());
3879 for (const Edge* e : out_edges) {
3880 if (e->IsControlEdge()) {
3881 continue;
3882 }
3883
3884 int src_output = e->src_output();
3885 const Edge* input_edge;
3886 TF_RETURN_IF_ERROR(n->input_edge(src_output, &input_edge));
3887 Node* dst = e->dst();
3888 int dst_input = e->dst_input();
3889 g->RemoveEdge(e);
3890 g->AddEdge(input_edge->src(), input_edge->src_output(), dst, dst_input);
3891 }
3892 g->RemoveNode(n);
3893 }
3894
3895 return Status::OK();
3896 }
3897
ParseHostComputeCores(const Node & replicate_node,const OutsideCompilationNodeMap & outside_compilation_nodes,HostComputeCoreMap * host_compute_core)3898 /* static */ Status DistributedTPURewritePass::ParseHostComputeCores(
3899 const Node& replicate_node,
3900 const OutsideCompilationNodeMap& outside_compilation_nodes,
3901 HostComputeCoreMap* host_compute_core) {
3902 std::vector<string> hc_core_string;
3903 TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), "host_compute_core",
3904 &hc_core_string));
3905 TF_RETURN_IF_ERROR(
3906 ParseHostComputeCoreList(hc_core_string, host_compute_core));
3907 for (const auto& iter : outside_compilation_nodes) {
3908 const string& oc_cluster_name = iter.first;
3909 if (host_compute_core->find(oc_cluster_name) == host_compute_core->end()) {
3910 // By default put host compute Ops on replicated core 0.
3911 (*host_compute_core)[oc_cluster_name] = 0;
3912 }
3913 }
3914 return Status::OK();
3915 }
3916
GetDeviceTopology(const DeviceSet & device_set,const Node & replicate_node,int * num_replicas,int * num_cores_per_replica,int * num_tasks,std::vector<std::vector<string>> * tf_device_assignment,std::unique_ptr<xla::DeviceAssignment> * xla_device_assignment,string * tpu_compilation_device)3917 /* static */ Status DistributedTPURewritePass::GetDeviceTopology(
3918 const DeviceSet& device_set, const Node& replicate_node, int* num_replicas,
3919 int* num_cores_per_replica, int* num_tasks,
3920 std::vector<std::vector<string>>* tf_device_assignment,
3921 std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment,
3922 string* tpu_compilation_device) {
3923 TF_RETURN_IF_ERROR(
3924 GetNodeAttr(replicate_node.attrs(), "num_replicas", num_replicas));
3925 if (*num_replicas < 1) {
3926 return errors::InvalidArgument("num_replicas must be >= 1, got ",
3927 *num_replicas);
3928 }
3929
3930 // Find the set of TPU devices in the TF job.
3931 // Indexed by [task number][tpu device number].
3932 std::vector<std::vector<Device*>> tpu_devices;
3933 int num_tpus_per_task;
3934 TF_RETURN_IF_ERROR(GetTPUDeviceNames(replicate_node.requested_device(),
3935 device_set, tpu_compilation_device,
3936 &num_tpus_per_task, &tpu_devices));
3937 *num_tasks = tpu_devices.size();
3938
3939 string topology;
3940 TF_RETURN_IF_ERROR(
3941 GetNodeAttr(replicate_node.attrs(), "topology", &topology));
3942 TF_RETURN_IF_ERROR(GetNodeAttr(
3943 replicate_node.attrs(), "num_cores_per_replica", num_cores_per_replica));
3944 std::vector<int> device_assignment;
3945 TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), "device_assignment",
3946 &device_assignment));
3947
3948 // TODO(cwhipkey): since we can control multiple pods of different shapes
3949 // from a single worker, it may be desirable to propagate the remote device
3950 // information around (e.g., in DeviceAttributes). This can lead to the mesh
3951 // topology proto being leaked to cloud TPU users (e.g. through GetStatus
3952 // calls); this may be okay, but to be conservative, just assume that the
3953 // master session has the proper flags set.
3954
3955 // We do not initialize platform right now, but we can still retrieve the
3956 // TPU topology even with an uninitialized platform.
3957 auto* tpu_platform = tpu::TpuPlatformInterface::GetRegisteredPlatform(
3958 /*initialize_platform=*/false);
3959 TF_RET_CHECK(tpu_platform);
3960 tpu::TpuTopologyExternal tpu_topology(tpu_platform->GetTopologyPtr());
3961 TF_RET_CHECK(num_tpus_per_task ==
3962 tpu_topology.LogicalDevicesPerHost(kTensorCore));
3963 TF_RETURN_IF_ERROR(BuildDeviceAssignment(
3964 tpu_topology, num_tpus_per_task, tpu_devices, *num_replicas,
3965 *num_cores_per_replica, topology, device_assignment, tf_device_assignment,
3966 xla_device_assignment));
3967
3968 return Status::OK();
3969 }
3970
GetIOTypes(int num_replicas,const Node & replicate_node,FunctionLibraryRuntime * flr,Graph * graph,NameRangeMap * input_name_map,const NameAttrList ** function,std::unique_ptr<Graph> * computation,DataTypeVector * arg_types,DataTypeVector * retval_types,ParameterInfo * params_info)3971 /* static */ Status DistributedTPURewritePass::GetIOTypes(
3972 int num_replicas, const Node& replicate_node, FunctionLibraryRuntime* flr,
3973 Graph* graph, NameRangeMap* input_name_map, const NameAttrList** function,
3974 std::unique_ptr<Graph>* computation, DataTypeVector* arg_types,
3975 DataTypeVector* retval_types, ParameterInfo* params_info) {
3976 DataTypeVector input_types, broadcast_input_types, guaranteed_constant_types;
3977 TF_RETURN_IF_ERROR(
3978 GetNodeAttr(replicate_node.attrs(), "Tinputs", &input_types));
3979 TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), "Tbroadcast_inputs",
3980 &broadcast_input_types));
3981 TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(),
3982 "Tguaranteed_constants",
3983 &guaranteed_constant_types));
3984 int num_distributed_vars;
3985 TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(),
3986 "num_distributed_variables",
3987 &num_distributed_vars));
3988 const int num_per_replica_inputs = input_types.size() - num_distributed_vars;
3989
3990 if (num_per_replica_inputs % num_replicas != 0) {
3991 return errors::InvalidArgument(
3992 "Number of inputs to TPUReplicate (", num_per_replica_inputs,
3993 ") is not divisible by the number of replicas (", num_replicas, ").");
3994 }
3995
3996 int num_variables;
3997 TF_RETURN_IF_ERROR(
3998 GetNodeAttr(replicate_node.attrs(), "NumVariables", &num_variables));
3999
4000 NameRangeMap output_name_map;
4001 TF_RETURN_IF_ERROR(NameRangesForNode(replicate_node, replicate_node.op_def(),
4002 input_name_map, &output_name_map));
4003
4004 TF_RETURN_IF_ERROR(
4005 GetNodeAttr(replicate_node.attrs(), "computation", function));
4006
4007 *computation = absl::make_unique<Graph>(graph->op_registry());
4008 TF_RETURN_IF_ERROR(GetComputationForTPUReplicateOp(
4009 **function, flr, computation->get(), arg_types, retval_types));
4010
4011 *params_info = ParameterInfo(
4012 num_replicas, num_per_replica_inputs / num_replicas, num_distributed_vars,
4013 broadcast_input_types.size(), num_variables,
4014 guaranteed_constant_types.size(), retval_types->size());
4015
4016 if (arg_types->size() != params_info->NumInputsToEachReplica()) {
4017 return errors::InvalidArgument(
4018 "Computation argument to TPUReplicate has wrong number of "
4019 "arguments. Expected ",
4020 params_info->NumInputsToEachReplica(), " inputs, got ",
4021 arg_types->size());
4022 }
4023 if (replicate_node.num_outputs() != params_info->NumOutputsToHost()) {
4024 return errors::InvalidArgument(
4025 "Wrong number of outputs from TPUReplicate. Expected ",
4026 params_info->NumOutputsToHost(), " outputs, got ",
4027 replicate_node.num_outputs());
4028 }
4029 if (enable_cross_replica_sharding_mirrored_variables_) {
4030 std::vector<int> mirrored_variable_indices;
4031 TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(),
4032 TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR,
4033 &mirrored_variable_indices));
4034 for (int index : mirrored_variable_indices) {
4035 TF_RET_CHECK(params_info->IsPerReplicaArg(index) ||
4036 params_info->IsDistributedArg(index))
4037 << "Mirrored variables not categorized as per-replica arguments, "
4038 "index: "
4039 << index;
4040 params_info->mutable_mirrored_variable_indices()->insert(index);
4041 }
4042 }
4043 return Status::OK();
4044 }
4045
BuildSequencingNodes(const string & tpu_compilation_device,const Node & replicate_node,Graph * graph,Node ** host_transfer_sequencer,Node ** control_before,Node ** control_after)4046 /* static */ Status DistributedTPURewritePass::BuildSequencingNodes(
4047 const string& tpu_compilation_device, const Node& replicate_node,
4048 Graph* graph, Node** host_transfer_sequencer, Node** control_before,
4049 Node** control_after) {
4050 *host_transfer_sequencer = nullptr;
4051
4052 TF_RETURN_IF_ERROR(
4053 BuildNoopNode(replicate_node,
4054 graph->NewName(strings::StrCat(replicate_node.name(), "/",
4055 "control_before")),
4056 /*device=*/"", graph, control_before));
4057 for (const Edge* e : replicate_node.in_edges()) {
4058 if (!e->IsControlEdge()) {
4059 continue;
4060 }
4061 Node* predecessor = e->src();
4062 if (predecessor->IsSource()) continue;
4063 if (predecessor->type_string() == "NoOp" &&
4064 predecessor->attrs().Find("_xla_host_transfer_sequencer") != nullptr) {
4065 // The node is the sequencer for host transfer operations. Its control
4066 // dependency needs to be placed after the execute node, not before.
4067 if (*host_transfer_sequencer != nullptr) {
4068 return errors::Internal("Replicate node ", replicate_node.name(),
4069 " has two transfer sequencer nodes: ",
4070 (*host_transfer_sequencer)->name(), " and ",
4071 predecessor->name());
4072 }
4073 // Set the correct device to match the other sequencing nodes.
4074 predecessor->set_assigned_device_name(tpu_compilation_device);
4075 *host_transfer_sequencer = predecessor;
4076 } else {
4077 graph->AddControlEdge(predecessor, *control_before);
4078 }
4079 }
4080
4081 TF_RETURN_IF_ERROR(
4082 BuildNoopNode(replicate_node,
4083 graph->NewName(strings::StrCat(replicate_node.name(), "/",
4084 "control_after")),
4085 /*device=*/tpu_compilation_device, graph, control_after));
4086 for (Node* successor : replicate_node.out_nodes()) {
4087 if (successor->attrs().Find("_xla_tail_outside_compilation") != nullptr) {
4088 graph->AddControlEdge(successor, *control_after);
4089 } else {
4090 graph->AddControlEdge(*control_after, successor);
4091 }
4092 }
4093 return Status::OK();
4094 }
4095
DealWithConstantsAndVariables(const Node & replicate_node,const NameRangeMap & input_name_map,Graph * graph,Node * host_transfer_sequencer,Node * control_before,Node * control_after,absl::Span<const VariableInput> variable_nodes,std::vector<Node * > * guaranteed_constant_nodes,std::vector<Node * > * variable_reads)4096 /* static */ Status DistributedTPURewritePass::DealWithConstantsAndVariables(
4097 const Node& replicate_node, const NameRangeMap& input_name_map,
4098 Graph* graph, Node* host_transfer_sequencer, Node* control_before,
4099 Node* control_after, absl::Span<const VariableInput> variable_nodes,
4100 std::vector<Node*>* guaranteed_constant_nodes,
4101 std::vector<Node*>* variable_reads) {
4102 TF_RETURN_IF_ERROR(FindGuaranteedConstantInputs(
4103 replicate_node, input_name_map, guaranteed_constant_nodes));
4104
4105 TF_RETURN_IF_ERROR(BuildVariableReads(variable_nodes, control_before, graph,
4106 variable_reads));
4107 // Add the control dependency from host transfer nodes.
4108 if (host_transfer_sequencer != nullptr) {
4109 graph->AddControlEdge(host_transfer_sequencer, control_after);
4110 }
4111 return Status::OK();
4112 }
4113
4114 /* static */ Status
BuildCompilationStatusReturnNodes(Node * replicate_node,Node * compile_node,Node ** control_after_compilation,Graph * graph)4115 DistributedTPURewritePass::BuildCompilationStatusReturnNodes(
4116 Node* replicate_node, Node* compile_node, Node** control_after_compilation,
4117 Graph* graph) {
4118 const Edge* compilation_edge = nullptr;
4119 for (const auto* e : replicate_node->out_edges()) {
4120 if (e->IsControlEdge() &&
4121 e->dst()->type_string() == "TPUCompilationResult") {
4122 TF_RET_CHECK(compilation_edge == nullptr)
4123 << "Multiple compilation result nodes attached to the same replicate "
4124 "cluster.";
4125 compilation_edge = e;
4126 }
4127 }
4128
4129 // TODO(jpienaar): This should be checked by default, current tests not using
4130 // this are ones that use the "abort upon successful compile flag" which will
4131 // be removed. Leaving this in until then.
4132 if (compilation_edge != nullptr) {
4133 Node* compilation_status = compilation_edge->dst();
4134 const AttrValue* compile_status_cluster_attr =
4135 compilation_status->attrs().Find(kTPUCompilationResultAttr);
4136 TF_RET_CHECK(compile_status_cluster_attr != nullptr);
4137 const string& compile_status_cluster = compile_status_cluster_attr->s();
4138 TF_RET_CHECK(!compile_status_cluster.empty());
4139 const AttrValue* replicate_cluster_attr =
4140 replicate_node->attrs().Find(kTPUReplicateAttr);
4141 TF_RET_CHECK(replicate_cluster_attr != nullptr);
4142 const string& replicate_cluster = replicate_cluster_attr->s();
4143 TF_RET_CHECK(!replicate_cluster.empty());
4144 TF_RET_CHECK(compile_status_cluster == replicate_cluster);
4145
4146 TF_RETURN_IF_ERROR(
4147 ReplaceCompilationResultNodeWithIdentity(graph, &compilation_status));
4148 graph->AddEdge(compile_node, 0, compilation_status, 0);
4149 }
4150
4151 NodeDef def;
4152 def.set_name(UniqueNodeName("tpu_compile_succeeded_assert", graph));
4153 // Create an op to assert that compilation succeeded. The alternative would
4154 // have been to have each execute op check and return an error.
4155 def.set_op("TPUCompileSucceededAssert");
4156 MergeDebugInfo(NodeDebugInfo(replicate_node->def()), &def);
4157 Status status;
4158 Node* compile_succeeded = graph->AddNode(def, &status);
4159 compile_succeeded->set_assigned_device_name(
4160 compile_node->assigned_device_name());
4161 TF_RETURN_IF_ERROR(status);
4162 graph->AddEdge(compile_node, 0, compile_succeeded, 0);
4163
4164 // Build a sequencing node for when compilation has completed.
4165 TF_RETURN_IF_ERROR(
4166 BuildNoopNode(*replicate_node,
4167 graph->NewName(strings::StrCat(compile_node->name(), "/",
4168 "after_compilation")),
4169 /*device=*/"", graph, control_after_compilation));
4170 graph->AddControlEdge(compile_succeeded, *control_after_compilation);
4171
4172 return Status::OK();
4173 }
4174
4175 // Updates the head and tail outside compiled nodes so that nodes have the
4176 // correct device and removes the replication and outside compilation attributes
4177 // so that these nodes do not trigger further graph optimization passes.
UpdateHeadTailOutsideCompilation(const std::vector<std::vector<string>> & tf_device_assignment,const std::vector<Node * > & head_tail_outside_compilation_nodes)4178 /* static */ Status DistributedTPURewritePass::UpdateHeadTailOutsideCompilation(
4179 const std::vector<std::vector<string>>& tf_device_assignment,
4180 const std::vector<Node*>& head_tail_outside_compilation_nodes) {
4181 for (Node* node : head_tail_outside_compilation_nodes) {
4182 int replica_id;
4183 TF_RETURN_IF_ERROR(
4184 GetNodeAttr(node->def(), kXlaReplicaIdAttrName, &replica_id));
4185 // Since we set the device, this will now run on a task other than 0. We
4186 // clear the two following attributes so that we don't trigger encapsulation
4187 // again on the remote host (which will fail due to a missing
4188 // _TPUReplicateMetadata node for the cluster).
4189 for (const Edge* e : node->in_edges()) {
4190 // Resource consuming ops should colocate with its resource input.
4191 if (e->src()->IsArg() &&
4192 e->src()->output_type(e->src_output()) == DT_RESOURCE) {
4193 node->set_requested_device(tf_device_assignment[replica_id][0]);
4194 }
4195 }
4196 if (node->requested_device().empty()) {
4197 string cpu_device;
4198 TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
4199 tf_device_assignment[replica_id][0], &cpu_device));
4200 node->set_requested_device(cpu_device);
4201 }
4202 node->ClearAttr(kTPUReplicateAttr);
4203 node->ClearAttr(kOutsideCompilationAttr);
4204 }
4205 return Status::OK();
4206 }
4207
4208 // Performs the rewrite on a single TPUReplicate node.
RewriteTPUReplicateNode(const string & session_handle,const DeviceSet & device_set,Node * replicate_node,FunctionLibraryDefinition * flib_def,FunctionLibraryRuntime * flr,Node * host_compute_key_placeholder_node,const OutsideCompilationNodeMap & outside_compilation_nodes,const std::vector<Node * > & head_tail_outside_compilation_nodes,NodeToNodeReplicasMap * outside_compilation_node_images,Graph * graph,const GraphShapeInfo & shape_info,TPUReplicateDeviceNamesMapping * tpu_replicate_device_names_mapping,int64 autotuner_thresh)4209 /* static */ Status DistributedTPURewritePass::RewriteTPUReplicateNode(
4210 const string& session_handle, const DeviceSet& device_set,
4211 Node* replicate_node, FunctionLibraryDefinition* flib_def,
4212 FunctionLibraryRuntime* flr, Node* host_compute_key_placeholder_node,
4213 const OutsideCompilationNodeMap& outside_compilation_nodes,
4214 const std::vector<Node*>& head_tail_outside_compilation_nodes,
4215 NodeToNodeReplicasMap* outside_compilation_node_images, Graph* graph,
4216 const GraphShapeInfo& shape_info,
4217 TPUReplicateDeviceNamesMapping* tpu_replicate_device_names_mapping,
4218 int64 autotuner_thresh) {
4219 VLOG(2) << "Rewriting node " << replicate_node->name();
4220
4221 // num_replicas and num_cores_per_replica are the 'virtual' replicas (copies
4222 // of the computation) and cores (virtual cores within computations) specified
4223 // by the user. They will be mapped to physical TPU cores below.
4224 int num_replicas;
4225 int num_cores_per_replica;
4226 int num_tasks; // Number of tasks.
4227 std::vector<std::vector<string>> tf_device_assignment;
4228 std::unique_ptr<xla::DeviceAssignment> xla_device_assignment;
4229 string tpu_compilation_device;
4230 TF_RETURN_IF_ERROR(GetDeviceTopology(
4231 device_set, *replicate_node, &num_replicas, &num_cores_per_replica,
4232 &num_tasks, &tf_device_assignment, &xla_device_assignment,
4233 &tpu_compilation_device));
4234
4235 TF_RETURN_IF_ERROR(UpdateHeadTailOutsideCompilation(
4236 tf_device_assignment, head_tail_outside_compilation_nodes));
4237
4238 string replicate;
4239 TF_RETURN_IF_ERROR(
4240 GetNodeAttr(replicate_node->def(), kTPUReplicateAttr, &replicate));
4241 tpu_replicate_device_names_mapping->emplace(replicate, tf_device_assignment);
4242
4243 NameRangeMap input_name_map;
4244 const NameAttrList* function;
4245 std::unique_ptr<Graph> computation;
4246 DataTypeVector arg_types, retval_types;
4247 ParameterInfo params_info;
4248 TF_RETURN_IF_ERROR(GetIOTypes(num_replicas, *replicate_node, flr, graph,
4249 &input_name_map, &function, &computation,
4250 &arg_types, &retval_types, ¶ms_info));
4251
4252 std::vector<InferredShape> arg_shapes, retval_shapes;
4253 TF_RETURN_IF_ERROR(GetArgAndRetvalShapes(
4254 shape_info, *replicate_node, params_info, &arg_shapes, &retval_shapes));
4255
4256 TF_RETURN_IF_ERROR(ValidateCoreNumbers(*computation, num_cores_per_replica));
4257
4258 std::vector<xla::OpSharding> arg_sharding;
4259 std::vector<bool> arg_fast_mem;
4260 std::vector<std::string> arg_names;
4261 std::vector<xla::OpSharding> retval_sharding;
4262 TF_RETURN_IF_ERROR(AssignArgsAndRetvalsToCores(
4263 num_cores_per_replica, params_info, arg_types, arg_shapes, retval_types,
4264 retval_shapes, *computation, replicate_node, flr,
4265 allow_xla_spmd_partition_, &arg_sharding, &arg_fast_mem, &retval_sharding,
4266 &arg_names));
4267
4268 VLOG(1) << DumpGraphToFile("distributed_tpu_graph_to_replicate", *computation,
4269 flib_def);
4270
4271 GraphDef graph_def;
4272 graph->ToGraphDef(&graph_def);
4273 FunctionLibraryDefinition reachable_functions =
4274 flib_def->ReachableDefinitions(graph_def);
4275 uint64 library_fingerprint;
4276
4277 TF_RETURN_IF_ERROR(
4278 FingerprintFunctionLibrary(reachable_functions, &library_fingerprint));
4279 VLOG(1) << "Fingerprint functions: "
4280 << absl::StrJoin(reachable_functions.ListFunctionNames(), ", ");
4281 VLOG(1) << "library_fingerprint: " << library_fingerprint;
4282
4283 // Builds trigger nodes that put barriers around the expansion of
4284 // TPUReplicate. In particular, we must guarantee:
4285 // a) variable reads happen after all predecessors of the original
4286 // TPUReplicate.
4287 // b) variable writes happen before all successors of the original
4288 // TPUReplicate.
4289 // c) all replicas execute, even if output tensors are only requested from
4290 // a subset of replicas. This is necessary both to ensure that variable
4291 // updates happen, but also Send/Recv will deadlock if only one half of
4292 // the communicating pair runs.
4293 Node* host_transfer_sequencer;
4294 Node* control_before;
4295 Node* control_after;
4296 TF_RETURN_IF_ERROR(BuildSequencingNodes(
4297 tpu_compilation_device, *replicate_node, graph, &host_transfer_sequencer,
4298 &control_before, &control_after));
4299
4300 // Build a vector of variable nodes that are inputs.
4301 std::vector<VariableInput> variable_inputs;
4302 TF_RETURN_IF_ERROR(
4303 FindVariableInputs(*replicate_node, input_name_map, &variable_inputs));
4304
4305 std::vector<Node*> guaranteed_constant_nodes;
4306 std::vector<Node*> variable_reads;
4307 TF_RETURN_IF_ERROR(DealWithConstantsAndVariables(
4308 *replicate_node, input_name_map, graph, host_transfer_sequencer,
4309 control_before, control_after, variable_inputs,
4310 &guaranteed_constant_nodes, &variable_reads));
4311
4312 // Builds Shape nodes that compute the dynamic shapes of arguments whose
4313 // shapes are not statically known.
4314 std::vector<Node*> dynamic_shape_nodes;
4315 TF_RETURN_IF_ERROR(BuildDynamicShapeNodes(*replicate_node, arg_shapes,
4316 params_info, variable_reads, graph,
4317 &dynamic_shape_nodes));
4318
4319 // Builds a TPUCompile node that compiles `clusters` on `compile_device`.
4320 Node* compile_node;
4321 TF_RETURN_IF_ERROR(BuildCompileNode(
4322 replicate_node, *function, library_fingerprint, params_info, arg_shapes,
4323 arg_types, guaranteed_constant_nodes, session_handle, arg_sharding,
4324 arg_fast_mem, arg_names, retval_sharding, num_cores_per_replica,
4325 /*compile_device=*/tpu_compilation_device, xla_device_assignment.get(),
4326 dynamic_shape_nodes, graph, &compile_node, autotuner_thresh));
4327
4328 // Compilation must be sequenced after the control node if the TPU computation
4329 // in a control-flow construct, such as a loop.
4330 graph->AddControlEdge(control_before, compile_node);
4331
4332 Node* control_after_compilation;
4333 TF_RETURN_IF_ERROR(BuildCompilationStatusReturnNodes(
4334 replicate_node, compile_node, &control_after_compilation, graph));
4335
4336 std::vector<VariableWrite> variable_writes;
4337 TF_RETURN_IF_ERROR(BuildExecuteNodes(
4338 params_info, num_tasks, num_cores_per_replica, *replicate_node, arg_names,
4339 arg_types, arg_shapes, retval_types, arg_sharding, retval_sharding,
4340 tf_device_assignment, compile_node, variable_reads,
4341 control_after_compilation, control_after, &variable_writes, graph));
4342 bool contains_resource_write_op =
4343 ContainsResourceWriteOp(*graph, reachable_functions);
4344
4345 VLOG(2) << "contains_resource_write_op: " << contains_resource_write_op;
4346 // Skip conditional write if there is no resource writing op inside TPU
4347 // computation.
4348 if (contains_resource_write_op) {
4349 TF_RETURN_IF_ERROR(BuildVariableWrites(variable_inputs, control_after,
4350 variable_writes, graph));
4351 }
4352
4353 if (host_compute_key_placeholder_node != nullptr) {
4354 TF_RETURN_IF_ERROR(ConnectHostComputeNodes(
4355 compile_node, host_compute_key_placeholder_node, graph));
4356 }
4357
4358 HostComputeCoreMap host_compute_core;
4359 TF_RETURN_IF_ERROR(ParseHostComputeCores(
4360 *replicate_node, outside_compilation_nodes, &host_compute_core));
4361 TF_RETURN_IF_ERROR(ReplicateOutsideCompilationNodes(
4362 tf_device_assignment, host_compute_core, outside_compilation_nodes,
4363 outside_compilation_node_images, graph));
4364
4365 graph->RemoveNode(replicate_node);
4366 return Status::OK();
4367 }
4368
4369 // Adds sharded weight update optimization for each host training loop.
4370 //
4371 // For any host training loop found in the graph, TPUVariableReshard ops
4372 // are inserted to match the best layout chosen by the XLA.
4373 /* static */ Status
PerformHostTrainingLoopOptimization(Graph * graph,FunctionLibraryDefinition * flib_def,FunctionLibraryRuntime * flr)4374 DistributedTPURewritePass::PerformHostTrainingLoopOptimization(
4375 Graph* graph, FunctionLibraryDefinition* flib_def,
4376 FunctionLibraryRuntime* flr) {
4377 std::vector<tpu::HostTrainingLoopInfo> host_training_loops_info;
4378 Status s = tpu::DetectHostTrainingLoop(
4379 /*current_function_name=*/nullptr,
4380 /*current_function_attr=*/nullptr, flib_def, graph, flr,
4381 &host_training_loops_info);
4382 if (!s.ok()) {
4383 VLOG(2) << "No valid host training loop found. Skipping sharded weight "
4384 << "update optimization.";
4385 return Status::OK();
4386 }
4387
4388 for (const auto& host_loop : host_training_loops_info) {
4389 const auto& function_name = host_loop.encapsulating_function_name;
4390 // `function_name` has value when host training loop is inside a
4391 // function call node. When host training loop is found inside a function
4392 // call node, then, in addition to adding TPUVariableReshard ops, function
4393 // library definition needs to be updated as well.
4394 if (function_name.has_value()) {
4395 const auto& function_attr = host_loop.encapsulating_function_attrs;
4396 TF_RET_CHECK(function_attr.has_value())
4397 << "Unable to find function attribute for function: "
4398 << *function_name;
4399
4400 const FunctionDef* function_def = flib_def->Find(*function_name);
4401 TF_RET_CHECK(function_def)
4402 << "Unable to find function : " << *function_name;
4403
4404 std::unique_ptr<FunctionBody> fbody;
4405 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
4406 *function_def, AttrSlice(&function_attr.value()), flib_def, &fbody));
4407 Graph* function_graph = fbody->graph;
4408 TF_RETURN_IF_ERROR(tpu::AddReshardOp(function_graph, host_loop));
4409 TF_RETURN_IF_ERROR(UpdateFunctionLibDefinition(*function_graph,
4410 *function_name, flib_def));
4411 } else {
4412 TF_RETURN_IF_ERROR(tpu::AddReshardOp(graph, host_loop));
4413 }
4414 }
4415 return Status::OK();
4416 }
4417
PlaceUnassignedDeviceNodesOnTPUIfPossible(Graph * graph)4418 Status DistributedTPURewritePass::PlaceUnassignedDeviceNodesOnTPUIfPossible(
4419 Graph* graph) {
4420 ReverseDFS(*graph, {}, PlaceOpsOnTPU);
4421 return Status::OK();
4422 }
4423
Run(const GraphOptimizationPassOptions & options)4424 Status DistributedTPURewritePass::Run(
4425 const GraphOptimizationPassOptions& options) {
4426 VLOG(1) << "DistributedTPURewritePass::Run";
4427
4428 Graph* graph = options.graph->get();
4429
4430 VLOG(1) << DumpGraphToFile("distributed_tpu_compilation_before", *graph,
4431 options.flib_def);
4432
4433 const auto* config = &options.session_options->config;
4434 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
4435 new ProcessFunctionLibraryRuntime(
4436 nullptr, options.session_options->env, config,
4437 graph->versions().producer(), options.flib_def,
4438 config ? config->graph_options().optimizer_options()
4439 : OptimizerOptions()));
4440
4441 FunctionLibraryRuntime* flr =
4442 pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
4443
4444 // This pass can only run in the session master, which should fill
4445 // in the device_set field to the options.
4446 TF_RET_CHECK(options.device_set != nullptr);
4447
4448 // Find all the replicate nodes before mutating the graph.
4449 std::vector<Node*> replicate_nodes;
4450 // Map from compiled subgraph cluster name to the outside_compilation nodes in
4451 // that cluster.
4452 std::map<string, OutsideCompilationNodeMap> outside_compilation_nodes;
4453 std::map<string, std::vector<Node*>> head_tail_outside_compilation_nodes;
4454 TF_RETURN_IF_ERROR(FindTaggedNodes(graph, &replicate_nodes,
4455 &outside_compilation_nodes,
4456 &head_tail_outside_compilation_nodes));
4457
4458 if (replicate_nodes.empty()) {
4459 // Remove unused TPUPartitionedInput nodes.
4460 for (Node* n : graph->nodes()) {
4461 if (n->type_string() == kTPUPartitionedInput) graph->RemoveNode(n);
4462 }
4463 return Status::OK();
4464 }
4465
4466 std::unordered_map<string, Node*> host_compute_key_placeholder_map;
4467 TF_RETURN_IF_ERROR(FindHostComputeKeyPlaceholderNodes(
4468 graph, replicate_nodes, &host_compute_key_placeholder_map));
4469
4470 GraphShapeInfo shape_info;
4471 TF_RETURN_IF_ERROR(InferShapes(graph, /*arg_shapes=*/{},
4472 flr->GetFunctionLibraryDefinition(),
4473 &shape_info));
4474 int64 autotuner_thresh = options.session_options->config.experimental()
4475 .xla_fusion_autotuner_thresh();
4476
4477 NodeToNodeReplicasMap outside_compilation_node_images;
4478 TPUReplicateDeviceNamesMapping tpu_replicate_device_names_mapping;
4479 for (Node* node : replicate_nodes) {
4480 TF_RETURN_IF_ERROR(RewriteTPUReplicateNode(
4481 options.session_handle, *options.device_set, node, options.flib_def,
4482 flr, host_compute_key_placeholder_map[node->name()],
4483 outside_compilation_nodes[node->name()],
4484 head_tail_outside_compilation_nodes[node->name()],
4485 &outside_compilation_node_images, graph, shape_info,
4486 &tpu_replicate_device_names_mapping, autotuner_thresh));
4487 }
4488
4489 // Place the padding nodes generated by dynamic padder on the correct devices.
4490 // TODO(rxsang): Place padding ops on TPUs in
4491 // PlaceUnassignedDeviceNodesOnTPUIfPossible function.
4492 TF_RETURN_IF_ERROR(SetPaddingNodesDevices(graph));
4493
4494 std::unordered_map<string, Node*> outside_compilation_inputs;
4495 for (Node* n : graph->op_nodes()) {
4496 string lifted_arg_inputs_attr;
4497 if (n->type_string() == "IdentityN" &&
4498 GetNodeAttr(n->def(), kXlaOutsideCompilationInputsAttrName,
4499 &lifted_arg_inputs_attr)
4500 .ok()) {
4501 outside_compilation_inputs[lifted_arg_inputs_attr] = n;
4502 }
4503 }
4504 for (const auto& iter : outside_compilation_nodes) {
4505 TF_RETURN_IF_ERROR(ReplicateOutsideCompilationEdges(
4506 iter.second, outside_compilation_node_images,
4507 outside_compilation_inputs, graph));
4508 }
4509 TF_RETURN_IF_ERROR(
4510 RemoveOutsideCompilationNodes(outside_compilation_node_images, graph));
4511 TF_RETURN_IF_ERROR(LowerOutsideCompilationFunctionalNodes(
4512 graph, *options.flib_def, tpu_replicate_device_names_mapping));
4513
4514 TF_RETURN_IF_ERROR(PlaceUnassignedDeviceNodesOnTPUIfPossible(graph));
4515 VLOG(1) << DumpGraphToFile("distributed_tpu_compilation_after", *graph,
4516 options.flib_def);
4517 VLOG(1) << "DistributedTPURewritePass::Run() finished";
4518
4519 if (enable_cross_replica_sharding_mirrored_variables_) {
4520 VLOG(1) << "Starting host training loop optimization.";
4521 VLOG(1) << DumpGraphToFile("host_loop_optimization_before", *graph,
4522 options.flib_def);
4523 TF_RETURN_IF_ERROR(
4524 PerformHostTrainingLoopOptimization(graph, options.flib_def, flr));
4525 VLOG(1) << DumpGraphToFile("host_loop_optimization_after", *graph,
4526 options.flib_def);
4527 VLOG(1) << "Host training loop optimization finished.";
4528 }
4529
4530 return Status::OK();
4531 }
4532
4533 bool DistributedTPURewritePass::distribute_vars_ = false;
4534 bool DistributedTPURewritePass::allow_xla_spmd_partition_ = true;
4535 bool DistributedTPURewritePass::
4536 replicate_inputs_outputs_by_default_for_xla_spmd_ = false;
4537 bool DistributedTPURewritePass::
4538 enable_cross_replica_sharding_mirrored_variables_ = true;
4539 bool DistributedTPURewritePass::enable_automatic_model_parallelism_ = false;
4540 bool DistributedTPURewritePass::enable_xla_param_broadcast_ = false;
4541
SetDistributedTpuRewritePassOptions(bool distribute_vars,bool allow_xla_spmd_partition,bool replicate_inputs_outputs_by_default_for_xla_spmd,bool enable_cross_replica_sharding_mirrored_variables,bool enable_automatic_model_parallelism,bool enable_xla_param_broadcast)4542 /*static*/ void DistributedTPURewritePass::SetDistributedTpuRewritePassOptions(
4543 bool distribute_vars, bool allow_xla_spmd_partition,
4544 bool replicate_inputs_outputs_by_default_for_xla_spmd,
4545 bool enable_cross_replica_sharding_mirrored_variables,
4546 bool enable_automatic_model_parallelism, bool enable_xla_param_broadcast) {
4547 distribute_vars_ = distribute_vars;
4548 allow_xla_spmd_partition_ = allow_xla_spmd_partition;
4549 replicate_inputs_outputs_by_default_for_xla_spmd_ =
4550 replicate_inputs_outputs_by_default_for_xla_spmd;
4551 enable_cross_replica_sharding_mirrored_variables_ =
4552 enable_cross_replica_sharding_mirrored_variables;
4553 enable_automatic_model_parallelism_ = enable_automatic_model_parallelism;
4554 enable_xla_param_broadcast_ = enable_xla_param_broadcast;
4555 }
4556
4557 } // namespace tensorflow
4558