1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_HOST_TRAINING_LOOP_OPTIMIZATION_UTIL_H_
17 #define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_HOST_TRAINING_LOOP_OPTIMIZATION_UTIL_H_
18 
19 #include <string>
20 #include <unordered_set>
21 #include <vector>
22 
23 #include "absl/types/optional.h"
24 #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
25 #include "tensorflow/core/common_runtime/function.h"
26 #include "tensorflow/core/graph/graph.h"
27 
28 namespace tensorflow {
29 namespace tpu {
30 
31 struct LoopArgInfo {
32   std::string enter_node_name;
33   // Exit nodes are optional for loop invariant while loop args.
34   absl::optional<std::string> exit_node_name;
35 };
36 
37 struct HostTrainingLoopInfo {
38   // Name and attribute information about the function in which
39   // host training loop is included. If host training loop is not
40   // inside a function call, then `function_name` and `function_attrs`
41   // are nullopt.
42   absl::optional<std::string> encapsulating_function_name;
43   absl::optional<AttrValueMap> encapsulating_function_attrs;
44 
45   // TPU Compile node as within a host training loop.
46   std::string compile_node_name;
47 
48   // Name of the while loop in which TPU compile op is located.
49   std::string while_loop_name;
50 
51   // Name of the node that represents loop condition.
52   std::string loop_cond_node_name;
53 
54   // Exit and Enter node names for each loop arguments.
55   std::vector<LoopArgInfo> loop_arguments;
56 
57   std::unordered_set<Node*> loop_nodes;  // NOLINT
58 };
59 
60 // Walks through the `graph`, recursively if functional nodes exist, and
61 // identifies all host training loops. Host training loops are the inner
62 // most while loops that encapsulates TPUCompileOp node. This would be
63 // later used/analyzed to inroduce host loop specific optimizations such
64 // as adding sharded weight update.
65 Status DetectHostTrainingLoop(
66     const std::string* current_function_name,
67     const AttrValueMap* current_function_attr,
68     const FunctionLibraryDefinition* library, Graph* graph,
69     FunctionLibraryRuntime* flr,
70     std::vector<HostTrainingLoopInfo>* host_training_loops_info);
71 
72 // Injects VariableReshardOps to before and after TPUExecute op inside
73 // host training loop body. This effectively applies sharded weight update
74 // on model weight variables.
75 Status AddReshardOp(Graph* graph, const HostTrainingLoopInfo& host_loop_info);
76 
77 }  // namespace tpu
78 }  // namespace tensorflow
79 
80 #endif  // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_HOST_TRAINING_LOOP_OPTIMIZATION_UTIL_H_
81