1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.h"
17 
18 #include <deque>
19 #include <map>
20 #include <unordered_map>
21 
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/container/node_hash_set.h"
24 #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
25 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
26 #include "tensorflow/core/graph/algorithm.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/lib/gtl/cleanup.h"
30 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
31 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h"
32 
33 namespace tensorflow {
34 namespace tpu {
35 
36 namespace {
37 
38 constexpr char kDefaultShardingValue[] = "";
39 
FindEdgeConnecting(const Node * src,const Node * dst)40 const Edge* FindEdgeConnecting(const Node* src, const Node* dst) {
41   for (const auto e : src->out_edges()) {
42     if (e->dst()->name() == dst->name()) return &(*e);
43   }
44   return nullptr;
45 }
46 
47 // Contains TPUExecute node and its DT_RESOURCE input nodes that
48 // correspond to model weights.
49 struct ExecuteNodeInfo {
50   Node* execute_node;
51   std::vector<const Edge*> var_inputs;
52 };
53 
54 // Returns whether `node` is in `execute_nodes` or `(identity) -> execute`.
IsExecuteNodeOrIdentityToExecuteNode(const Graph & graph,const std::unordered_set<Node * > & loop_nodes,const absl::flat_hash_set<Node * > & execute_nodes,Node * node)55 bool IsExecuteNodeOrIdentityToExecuteNode(
56     const Graph& graph, const std::unordered_set<Node*>& loop_nodes,  // NOLINT
57     const absl::flat_hash_set<Node*>& execute_nodes, Node* node) {
58   if (execute_nodes.find(node) != execute_nodes.end()) return true;
59   if (loop_nodes.find(node) == loop_nodes.end()) return false;
60   if (node->IsNextIteration()) return true;
61   if (!node->IsIdentity()) return false;
62 
63   for (const Edge* e : node->out_edges()) {
64     if (e->IsControlEdge()) continue;
65 
66     Node* node = e->dst();
67     if (!IsExecuteNodeOrIdentityToExecuteNode(graph, loop_nodes, execute_nodes,
68                                               node)) {
69       return false;
70     }
71   }
72 
73   return true;
74 }
75 
76 // From input node to the TPUExecute op, finds the corresponding Enter node
77 // by searching/traversing nodes in below pattern of nodes:
78 // Enter ----> (identity) --->  While body input
79 // Returns nullptr if the Enter node is not found.
FindEnterNodeFromTPUExecuteNodeInput(Node * input_node)80 xla::StatusOr<Node*> FindEnterNodeFromTPUExecuteNodeInput(Node* input_node) {
81   Node* node = input_node;
82   while (node->IsIdentity()) {
83     TF_RETURN_IF_ERROR(node->input_node(0, &node));
84   }
85 
86   if (node->IsEnter()) {
87     return node;
88   }
89   return nullptr;
90 }
91 
ResourceOnlyUsedForTPUExecuteInLoop(const Graph & graph,const std::unordered_set<Node * > & loop_nodes,const Node * enter_node,const absl::flat_hash_set<Node * > execute_nodes)92 xla::StatusOr<bool> ResourceOnlyUsedForTPUExecuteInLoop(
93     const Graph& graph, const std::unordered_set<Node*>& loop_nodes,  // NOLINT
94     const Node* enter_node, const absl::flat_hash_set<Node*> execute_nodes) {
95   for (const Edge* output_edge : enter_node->out_edges()) {
96     Node* output_node = output_edge->dst();
97     if (output_edge->IsControlEdge() || output_node->IsExit()) continue;
98 
99     // If output node is not execute node, it must be output node
100     // to the while loop body.
101     if (!IsExecuteNodeOrIdentityToExecuteNode(graph, loop_nodes, execute_nodes,
102                                               output_node)) {
103       return false;
104     }
105   }
106   return true;
107 }
108 
109 // Given a TPUCompile node, find all TPUExecute nodes that executes the compiled
110 // program and its model weight variable inputs as well.
111 // TPUCompileMetadataProto of TPUCompile node must be reset to `new_metadata`
112 // if new reshard ops are added.
ExtractExecuteNodeInfo(const Node * compile_node,const Graph & graph,const std::unordered_set<Node * > & loop_nodes,std::vector<ExecuteNodeInfo> * execute_node_info,TPUCompileMetadataProto * new_metadata)113 Status ExtractExecuteNodeInfo(const Node* compile_node, const Graph& graph,
114                               const std::unordered_set<Node*>& loop_nodes,  // NOLINT
115                               std::vector<ExecuteNodeInfo>* execute_node_info,
116                               TPUCompileMetadataProto* new_metadata) {
117   string metadata_string;
118   TF_RETURN_IF_ERROR(
119       GetNodeAttr(compile_node->attrs(), "metadata", &metadata_string));
120   new_metadata->ParsePartialFromString(metadata_string);
121   if (new_metadata->num_cores_per_replica() != 1) {
122     // We do not support model parallelism yet.
123     return Status::OK();
124   }
125 
126   execute_node_info->clear();
127   for (Node* node : compile_node->out_nodes()) {
128     if (node->type_string() == "TPUExecute") {
129       execute_node_info->push_back({node});
130     }
131   }
132   if (execute_node_info->empty()) {
133     return Status::OK();
134   }
135   TF_RET_CHECK(execute_node_info->size() == new_metadata->num_replicas())
136       << "Number of replicas does not equal number of execute nodes: "
137       << new_metadata->num_replicas() << " vs " << execute_node_info->size();
138   DataTypeVector arg_types;
139   TF_RETURN_IF_ERROR(GetNodeAttr((*execute_node_info)[0].execute_node->attrs(),
140                                  "Targs", &arg_types));
141   for (int64 i = 0; i < arg_types.size(); ++i) {
142     if (arg_types[i] != DT_RESOURCE) {
143       continue;
144     }
145     const auto sharding_config = new_metadata->args(i).enable_xla_sharding();
146     if (sharding_config != TPUCompileMetadataProto::Arg::TENTATIVE &&
147         sharding_config != TPUCompileMetadataProto::Arg::ALLOWED) {
148       continue;
149     }
150     std::vector<const Edge*> edges(execute_node_info->size());
151     bool is_supported = true;
152     std::unordered_map<Node*, absl::flat_hash_set<Node*>>
153         enter_to_execute_nodes;
154     for (int64 j = 0; j < edges.size(); ++j) {
155       auto execute = (*execute_node_info)[j].execute_node;
156       TF_RETURN_IF_ERROR(execute->input_edge(i, &edges[j]));
157       TF_RET_CHECK(edges[j]->src()->output_type(edges[j]->src_output()) ==
158                    arg_types[i])
159           << "Execute op has an unexpected input type.";
160       // Traverse backwards to find the Enter node from which the input is
161       // passed.
162       // This makes sure that we are checking the usages of all potential
163       // aliases of the input node as well.
164       TF_ASSIGN_OR_RETURN(auto enter_node, FindEnterNodeFromTPUExecuteNodeInput(
165                                                edges[j]->src()));
166       if (enter_node == nullptr) {
167         is_supported = false;
168         enter_to_execute_nodes.clear();
169         break;
170       }
171       enter_to_execute_nodes[enter_node].insert(edges[j]->dst());
172     }
173 
174     for (const auto& it : enter_to_execute_nodes) {
175       // Size of execute nodes should be either 1 (per-replica variables) or
176       // num_replicas (distributed variables).
177       if ((it.second.size() != 1) &&
178           (it.second.size() != new_metadata->num_replicas())) {
179         is_supported = false;
180         break;
181       }
182       TF_ASSIGN_OR_RETURN(bool no_other_use,
183                           ResourceOnlyUsedForTPUExecuteInLoop(
184                               graph, loop_nodes, it.first, it.second));
185       if (!no_other_use) {
186         is_supported = false;
187         break;
188       }
189     }
190 
191     // Add the variable input edges only when they are supported for all
192     // executes.
193     if (is_supported) {
194       for (int64 j = 0; j < edges.size(); ++j) {
195         (*execute_node_info)[j].var_inputs.push_back(edges[j]);
196       }
197       new_metadata->mutable_args(i)->set_enable_xla_sharding(
198           TPUCompileMetadataProto::Arg::ALLOWED);
199     }
200   }
201 
202   int64 total = 0;
203   for (const auto& a : new_metadata->args()) {
204     if (a.enable_xla_sharding() == TPUCompileMetadataProto::Arg::ALLOWED) {
205       total++;
206     }
207   }
208   TF_RET_CHECK(total == (*execute_node_info)[0].var_inputs.size())
209       << " total " << total << " var_inputs "
210       << (*execute_node_info)[0].var_inputs.size();
211   if (total == 0) {
212     // We don't need to process anything if no input is added.
213     execute_node_info->clear();
214   }
215   return Status::OK();
216 }
217 
IsTPUCompileOp(const Node & n)218 bool IsTPUCompileOp(const Node& n) { return n.type_string() == "TPUCompile"; }
219 
FindTPUCompileNodes(const std::string * current_function_name,const AttrValueMap * current_function_attr,const std::unordered_map<string,WhileLoopFrame> & frames,std::vector<HostTrainingLoopInfo> * host_training_loops_info)220 void FindTPUCompileNodes(
221     const std::string* current_function_name,
222     const AttrValueMap* current_function_attr,
223     const std::unordered_map<string, WhileLoopFrame>& frames,
224     std::vector<HostTrainingLoopInfo>* host_training_loops_info) {
225   // Adds frames with no children (i.e., the innermost frames) to a worklist.
226   std::deque<const WhileLoopFrame*> worklist;
227 
228   for (auto& frame : frames) {
229     if (frame.second.num_children == 0) {
230       worklist.push_back(&frame.second);
231     }
232   }
233 
234   // Check TPUCompile node from the innermost while loop to the outermost
235   // while loop.
236   while (!worklist.empty()) {
237     const WhileLoopFrame* frame = worklist.front();
238     worklist.pop_front();
239 
240     for (const auto& n : frame->nodes) {
241       if (!IsTPUCompileOp(*n)) continue;
242 
243       HostTrainingLoopInfo host_training_loop_info;
244       host_training_loop_info.compile_node_name = n->name();
245       host_training_loop_info.loop_cond_node_name = frame->loop_cond->name();
246       host_training_loop_info.while_loop_name = frame->name;
247 
248       for (const auto arg : frame->args) {
249         LoopArgInfo arg_info;
250         arg_info.enter_node_name = arg.enter->name();
251         if (arg.exit) arg_info.exit_node_name = arg.exit->name();
252 
253         host_training_loop_info.loop_arguments.push_back(std::move(arg_info));
254       }
255       host_training_loop_info.loop_nodes = frame->nodes;
256 
257       if (current_function_name) {
258         host_training_loop_info.encapsulating_function_name =
259             *current_function_name;
260       }
261       if (current_function_attr) {
262         host_training_loop_info.encapsulating_function_attrs =
263             *current_function_attr;
264       }
265 
266       host_training_loops_info->emplace_back(
267           std::move(host_training_loop_info));
268     }
269 
270     // If the parent has no remaining children, add it to the worklist.
271     --frame->parent->num_children;
272     if (frame->parent->num_children == 0) {
273       worklist.push_back(frame->parent);
274     }
275   }
276 }
277 
278 // From while loop cond node, finds all loop exit nodes by searching/traversing
279 // nodes in below pattern of nodes:
280 // LoopCond -----> Switch -----> Exit
FindLoopExitNodes(const Node & loop_cond)281 std::vector<Node*> FindLoopExitNodes(const Node& loop_cond) {
282   std::vector<Node*> loop_exit_nodes;
283   for (const auto e_cond : loop_cond.out_edges()) {
284     if (e_cond->IsControlEdge() || !e_cond->dst()->IsSwitch()) continue;
285     auto switch_node = e_cond->dst();
286 
287     for (const auto e_switch : switch_node->out_edges()) {
288       if (e_switch->IsControlEdge() || !e_switch->dst()->IsExit()) continue;
289 
290       loop_exit_nodes.push_back(e_switch->dst());
291     }
292   }
293   return loop_exit_nodes;
294 }
295 
296 // Find any one of switch nodes in the while loop by traversing the graph
297 // from while loop condition node.
GetLoopSwitchNode(const Node & loop_cond_node)298 xla::StatusOr<Node*> GetLoopSwitchNode(const Node& loop_cond_node) {
299   Node* loop_switch_node;
300   for (auto n : loop_cond_node.out_nodes()) {
301     if (n->IsSwitch()) {
302       loop_switch_node = n;
303       break;
304     }
305   }
306 
307   TF_RET_CHECK(loop_switch_node->IsSwitch())
308       << "Unable to find any switch nodes.";
309   return loop_switch_node;
310 }
311 
312 // Returns or creates a node in that is executed before each loop iteration
313 // in the while loop.
GetOrCreateBeforeEachIterationNode(Graph * graph,Node * loop_switch_node,Node ** node_out)314 Status GetOrCreateBeforeEachIterationNode(Graph* graph, Node* loop_switch_node,
315                                           Node** node_out) {
316   // If while loop switch node already has a outgoing data to true brach
317   // of the switch op, then reuse that node.
318   for (const auto out_edge : loop_switch_node->out_edges()) {
319     if (out_edge->src_output() == 1) {
320       *node_out = out_edge->dst();
321       return Status::OK();
322     }
323   }
324 
325   // Create Identity node that represents execution at every loop iteration.
326   NodeDef at_loop_iteration_nodedef;
327   at_loop_iteration_nodedef.set_op("Identity");
328   DataType dtype;
329   TF_RETURN_IF_ERROR(GetNodeAttr(loop_switch_node->def(), "T", &dtype));
330 
331   AddNodeAttr("T", dtype, &at_loop_iteration_nodedef);
332   at_loop_iteration_nodedef.set_name(graph->NewName(strings::StrCat(
333       "TPUVariableReshard/before_iteration", "/_", internal::GetNodeId())));
334 
335   Status status;
336   Node* at_loop_iteration_node =
337       graph->AddNode(at_loop_iteration_nodedef, &status);
338   TF_RETURN_IF_ERROR(status);
339 
340   graph->AddEdge(loop_switch_node, 1, at_loop_iteration_node, 0);
341   *node_out = at_loop_iteration_node;
342   return Status::OK();
343 }
344 
345 // Injects NoOp node in that is executed after the very last iteration
346 // of the while loop but before the while loop exit node.
AddNoOpAfterLastIteration(Graph * graph,Node * loop_switch_node,Node ** node_out)347 Status AddNoOpAfterLastIteration(Graph* graph, Node* loop_switch_node,
348                                  Node** node_out) {
349   // Find the exit node from loop switch node.
350   Node* exit_node;
351   for (const auto out_node : loop_switch_node->out_nodes()) {
352     if (out_node->IsExit()) {
353       exit_node = out_node;
354       break;
355     }
356   }
357 
358   TF_RET_CHECK(exit_node != nullptr)
359       << "Cannot find exit node connected to switch node :"
360       << loop_switch_node->name();
361 
362   // Create NoOp that represents execution at the end of while loop
363   // last iteration.
364   NodeDef after_last_loop_iteration;
365   after_last_loop_iteration.set_op("Identity");
366   DataType dtype;
367   TF_RETURN_IF_ERROR(GetNodeAttr(loop_switch_node->def(), "T", &dtype));
368 
369   AddNodeAttr("T", dtype, &after_last_loop_iteration);
370   after_last_loop_iteration.set_name(graph->NewName(strings::StrCat(
371       "TPUVariableReshard/last_iteration", "/_", internal::GetNodeId())));
372 
373   Status status;
374   Node* after_last_iteration_node =
375       graph->AddNode(after_last_loop_iteration, &status);
376   TF_RETURN_IF_ERROR(status);
377 
378   // Newly created node must be executed once after last iteration of the while
379   // loop and before while loop exits.
380   graph->AddEdge(loop_switch_node, 0, after_last_iteration_node, 0);
381   graph->AddControlEdge(after_last_iteration_node, exit_node);
382   *node_out = after_last_iteration_node;
383   return Status::OK();
384 }
385 
386 }  // namespace
387 
DetectHostTrainingLoop(const std::string * current_function_name,const AttrValueMap * current_function_attr,const FunctionLibraryDefinition * library,Graph * graph,FunctionLibraryRuntime * flr,std::vector<HostTrainingLoopInfo> * host_training_loops_info)388 Status DetectHostTrainingLoop(
389     const std::string* current_function_name,
390     const AttrValueMap* current_function_attr,
391     const FunctionLibraryDefinition* library, Graph* graph,
392     FunctionLibraryRuntime* flr,
393     std::vector<HostTrainingLoopInfo>* host_training_loops_info) {
394   std::vector<AssociatedFunctionInfo> associated_function_list;
395   for (const auto* n : graph->nodes()) {
396     const auto associated_functions = GetAssociatedFunctions(*n, library);
397     if (associated_functions.empty()) continue;
398 
399     associated_function_list.insert(associated_function_list.end(),
400                                     associated_functions.begin(),
401                                     associated_functions.end());
402   }
403 
404   Status ret_status = Status::OK();
405   for (const auto& function : associated_function_list) {
406     if (function.type() != AssociatedFunctionInfo::kFunctionAttr) continue;
407 
408     // Convert the function to Graph.
409     FunctionLibraryRuntime::Handle handle;
410     TF_RETURN_IF_ERROR(flr->Instantiate(function.func_name(),
411                                         AttrSlice(&function.attrs()), &handle));
412     auto cleanup_handle = gtl::MakeCleanup([&]() {
413       auto s = flr->ReleaseHandle(handle);
414       if (!s.ok()) {
415         ret_status.Update(s);
416       }
417     });
418     const FunctionBody* body = flr->GetFunctionBody(handle);
419     Graph* function_graph = body->graph;
420     TF_RETURN_IF_ERROR(DetectHostTrainingLoop(
421         &function.func_name(), &function.attrs(), library, function_graph, flr,
422         host_training_loops_info));
423   }
424 
425   // BuildControlFlowInfo() requires that the graph's source node is connected
426   // to all source nodes in the graph. Many graphs violate this invariant.
427   // As so, add edges to source/sink nodes so that this invariant is kept.
428   FixupSourceAndSinkEdges(graph);
429   std::vector<ControlFlowInfo> cf_info;
430   TF_RETURN_IF_ERROR(
431       BuildControlFlowInfo(graph, &cf_info, /*unreachable_nodes=*/nullptr));
432 
433   std::unordered_map<string, WhileLoopFrame> frames;
434   TF_RETURN_IF_ERROR(ExtractWhileLoopFrames(cf_info, graph, &frames));
435   FindTPUCompileNodes(current_function_name, current_function_attr, frames,
436                       host_training_loops_info);
437   return ret_status;
438 }
439 
AddReshardOp(Graph * graph,const HostTrainingLoopInfo & host_loop_info)440 Status AddReshardOp(Graph* graph, const HostTrainingLoopInfo& host_loop_info) {
441   const auto& compile_node_name = host_loop_info.compile_node_name;
442   const auto node_name_map = graph->BuildNodeNameIndex();
443   const auto node_it = node_name_map.find(compile_node_name);
444   TF_RET_CHECK(node_it != node_name_map.end())
445       << "Unable to find compile node : " << compile_node_name;
446 
447   const auto compile_node = node_it->second;
448   std::vector<ExecuteNodeInfo> execute_nodes_info;
449 
450   Status status;
451   TPUCompileMetadataProto metadata;
452   status =
453       ExtractExecuteNodeInfo(compile_node, *graph, host_loop_info.loop_nodes,
454                              &execute_nodes_info, &metadata);
455   if (!status.ok()) {
456     LOG(ERROR) << "Encountered error when trying to extract execute nodes, "
457                   "skipping host loop optimization. Status: "
458                << status.ToString();
459     return Status::OK();
460   }
461 
462   if (execute_nodes_info.empty()) {
463     return Status::OK();
464   }
465 
466   // Update the TPUCompileMetadata such that sharding config of the
467   // sharded resource variable inputs is set to ALLOWED instead of
468   // TENTATIVE.
469   string new_metadata_string;
470   metadata.SerializeToString(&new_metadata_string);
471   compile_node->ClearAttr("metadata");
472   compile_node->AddAttr("metadata", new_metadata_string);
473 
474   // Unsharding of the model weight variables must happen only at the very
475   // last loop iteration. As so, add while loop condition predicate as an
476   // input to the sharding switch node. If loop condition is true, we do not
477   // unshard.
478   const auto& cond_node_name = host_loop_info.loop_cond_node_name;
479   auto loop_cond_node_it = node_name_map.find(cond_node_name);
480   TF_RET_CHECK(loop_cond_node_it != node_name_map.end())
481       << "Cannot find loop condition node : " << cond_node_name;
482   auto* loop_condition_node = loop_cond_node_it->second;
483 
484   // In order to make sure that shard/unshard operations are invoked
485   // at the start of every loop body and at the end of last iteration
486   // of the loop, respectively, traverse the graph and find a switch node
487   // of the host training loop.
488   TF_ASSIGN_OR_RETURN(Node * switch_node,
489                       GetLoopSwitchNode(*loop_condition_node));
490 
491   Node* after_last_iteration_node;
492   TF_RETURN_IF_ERROR(AddNoOpAfterLastIteration(graph, switch_node,
493                                                &after_last_iteration_node));
494 
495   Node* before_loop_iteration_node;
496   TF_RETURN_IF_ERROR(GetOrCreateBeforeEachIterationNode(
497       graph, switch_node, &before_loop_iteration_node));
498 
499   // Create const op that represents default sharding value
500   // (i.e. no-op sharding).
501   NodeDef default_sharding;
502   default_sharding.set_op("Const");
503   default_sharding.set_name(graph->NewName(strings::StrCat(
504       "TPUVariableReshard/default_shard_state", "/_", internal::GetNodeId())));
505   AddNodeAttr("dtype", DT_STRING, &default_sharding);
506 
507   Tensor t(DT_STRING, {3});
508   t.vec<tstring>()(0) = kDefaultShardingValue;
509   t.vec<tstring>()(1) = kDefaultShardingValue;
510   t.vec<tstring>()(2) = kDefaultShardingValue;
511   t.AsProtoTensorContent(
512       (*default_sharding.mutable_attr())["value"].mutable_tensor());
513 
514   Node* default_sharding_node = graph->AddNode(default_sharding, &status);
515   TF_RETURN_IF_ERROR(status);
516   // Add control edge between loop condition to make sure that
517   // default_sharding_node node is inside the while loop frame.
518   graph->AddControlEdge(loop_condition_node, default_sharding_node);
519 
520   // Build a no-op node used to add control edges after unshard nodes.
521   NodeDef after_unshard;
522   after_unshard.set_op("NoOp");
523   after_unshard.set_name(graph->NewName(strings::StrCat(
524       "TPUVariableReshard/last_iteration", "/_", internal::GetNodeId())));
525   auto after_unshard_node = graph->AddNode(after_unshard, &status);
526   TF_RETURN_IF_ERROR(status);
527 
528   for (auto info : execute_nodes_info) {
529     auto execute_node = info.execute_node;
530     // Create Reshard op that optionally shards model weight variables
531     // prior to program execution.
532     NodeDef reshard_node_def;
533     reshard_node_def.set_name(graph->NewName(strings::StrCat(
534         "TPUVariableReshard/reshard", "/_", internal::GetNodeId())));
535     reshard_node_def.set_op("TPUReshardVariables");
536     AddNodeAttr("N", static_cast<int>(info.var_inputs.size()),
537                 &reshard_node_def);
538     Node* reshard_op_node = graph->AddNode(reshard_node_def, &status);
539     if (!status.ok()) return status;
540 
541     reshard_op_node->set_assigned_device_name(
542         execute_node->assigned_device_name());
543 
544     // Reshard op must execute at every loop iteration prior to
545     // TPUExecute node.
546     graph->AddControlEdge(before_loop_iteration_node, reshard_op_node);
547     graph->AddControlEdge(reshard_op_node, execute_node);
548 
549     for (int i = 0; i < info.var_inputs.size(); ++i) {
550       const auto variable_edge = info.var_inputs[i];
551       graph->AddEdge(variable_edge->src(), variable_edge->src_output(),
552                      reshard_op_node, i);
553     }
554 
555     const int new_key_input = info.var_inputs.size();
556     // Add program input edge from the compiler(i.e. compilation key).
557     const auto compilation_key_edge =
558         FindEdgeConnecting(compile_node, execute_node);
559     graph->AddEdge(compile_node, compilation_key_edge->src_output(),
560                    reshard_op_node, new_key_input);
561 
562     // Create VarHandleOp to store sharding state. Sharding state holds string
563     // compilation key that identifies whether the graph is re-compiled and the
564     // variables need to be sharded again.
565     NodeDef var_handle_def;
566     var_handle_def.set_op("VarHandleOp");
567     var_handle_def.set_name(graph->NewName(strings::StrCat(
568         "TPUVariableReshard/reshard_state", "/_", internal::GetNodeId())));
569     AddNodeAttr("dtype", DT_STRING, &var_handle_def);
570     AddNodeAttr("shape", TensorShape({}), &var_handle_def);
571     Node* var_handle_node = graph->AddNode(var_handle_def, &status);
572     if (!status.ok()) return status;
573 
574     // Add control edge between `var_handle_def` node and while loop
575     // loop condition so that `var_handle_def` is inside the same while loop
576     // frame.
577     // TODO(hongjunchoi): Consider adding control edge from another node--such
578     // as input control node.
579     graph->AddControlEdge(loop_condition_node, var_handle_node);
580 
581     // Connect data edge between var handle op and reshard op.
582     const int format_state_input = new_key_input + 1;
583     graph->AddEdge(var_handle_node, 0, reshard_op_node, format_state_input);
584 
585     // Create Reshard op that represents unsharding after TPUExecute.
586     NodeDef unshard_node_def;
587     unshard_node_def.set_name(graph->NewName(strings::StrCat(
588         "TPUVariableReshard/unshard", "/_", internal::GetNodeId())));
589     unshard_node_def.set_op("TPUReshardVariables");
590     AddNodeAttr("N", static_cast<int>(info.var_inputs.size()),
591                 &unshard_node_def);
592     Node* unshard_op_node = graph->AddNode(unshard_node_def, &status);
593     TF_RETURN_IF_ERROR(status);
594 
595     unshard_op_node->set_assigned_device_name(
596         execute_node->assigned_device_name());
597 
598     for (int i = 0; i < info.var_inputs.size(); ++i) {
599       const auto variable_edge = info.var_inputs[i];
600       // Connect model weight resource variables to unshard op. Since unshard op
601       // must be only invoked after the very last loop iteration, for each while
602       // loop inputs, we traverse backwards to find the switch node of the host
603       // training loop and connect `output_false` field of the switch node with
604       // unshard op.
605       TF_ASSIGN_OR_RETURN(
606           Node * enter_node,
607           FindEnterNodeFromTPUExecuteNodeInput(variable_edge->src()));
608       graph->AddEdge(enter_node, 0, unshard_op_node, i);
609     }
610 
611     // Add control dependency before/after unshard node and the control nodes.
612     graph->AddControlEdge(after_last_iteration_node, unshard_op_node);
613     graph->AddControlEdge(unshard_op_node, after_unshard_node);
614 
615     graph->AddEdge(default_sharding_node, 0, unshard_op_node, new_key_input);
616 
617     // Add data edge from sharding state var handle op to unshard op.
618     graph->AddEdge(var_handle_node, 0, unshard_op_node, format_state_input);
619   }
620   // Add control dependency from after_unshard_node to all exits nodes. This is
621   // to make sure that the unshard ops will be executed as long as any of the
622   // exits are used.
623   for (auto exit : FindLoopExitNodes(*loop_condition_node)) {
624     graph->AddControlEdge(after_unshard_node, exit);
625   }
626   return Status::OK();
627 }
628 
629 }  // namespace tpu
630 }  // namespace tensorflow
631