1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/const_analysis.h"
17 
18 #include <unordered_map>
19 #include <unordered_set>
20 
21 #include "absl/algorithm/container.h"
22 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
23 #include "tensorflow/compiler/xla/status_macros.h"
24 #include "tensorflow/core/common_runtime/function.h"
25 #include "tensorflow/core/framework/attr_value.pb.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/node_def_util.h"
28 #include "tensorflow/core/graph/algorithm.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 
31 namespace tensorflow {
32 
33 Status GetCompileTimeConstInputs(const Node* node,
34                                  std::vector<int>* const_input_idxs,
35                                  FunctionLibraryRuntime* flib_runtime);
36 
37 // Backwards dataflow analysis that finds arguments to a graph that must be
38 // compile-time constants.
BackwardsConstAnalysis(const Graph & g,std::vector<bool> * compile_time_const_arg_indices,std::vector<bool> * compile_time_const_nodes,FunctionLibraryRuntime * flib_runtime,std::function<bool (const Edge &)> edge_filter)39 Status BackwardsConstAnalysis(const Graph& g,
40                               std::vector<bool>* compile_time_const_arg_indices,
41                               std::vector<bool>* compile_time_const_nodes,
42                               FunctionLibraryRuntime* flib_runtime,
43                               std::function<bool(const Edge&)> edge_filter) {
44   std::vector<bool> compile_time_const_nodes_impl;
45   if (compile_time_const_nodes) {
46     CHECK_EQ(compile_time_const_nodes->size(), g.num_node_ids());
47   } else {
48     compile_time_const_nodes_impl.resize(g.num_node_ids());
49     compile_time_const_nodes = &compile_time_const_nodes_impl;
50   }
51 
52   Status status;
53   auto visit = [&](Node* node) {
54     if (!status.ok()) return;
55 
56     // If this is a metadata-only op, don't propagate the const requirement.
57     if (XlaOpRegistry::IsMetadataOp(node->type_string())) {
58       return;
59     }
60 
61     // If this node must be const, and it isn't a metadata op, then all of its
62     // parents must be const.
63     if ((*compile_time_const_nodes)[node->id()]) {
64       if (node->type_string() == "_Arg") {
65         int index;
66         status = GetNodeAttr(node->attrs(), "index", &index);
67         if (!status.ok()) return;
68         if (compile_time_const_arg_indices) {
69           (*compile_time_const_arg_indices)[index] = true;
70         }
71         return;
72       }
73       for (const Edge* pred : node->in_edges()) {
74         if (!pred->IsControlEdge() && edge_filter(*pred)) {
75           // If the src node of the `pred` is an IdentityN do not mark it as a
76           // compile-time const. Only mark the corresponding input to the
77           // IdentityN node as a const.
78           // Note: XLA IdentityN op simply forwards its inputs so this is safe.
79           while (edge_filter(*pred) &&
80                  pred->src()->type_string() == "IdentityN") {
81             status = pred->src()->input_edge(pred->src_output(), &pred);
82             if (!status.ok()) return;
83           }
84           if (edge_filter(*pred)) {
85             (*compile_time_const_nodes)[pred->src()->id()] = true;
86           }
87         }
88       }
89       return;
90     }
91 
92     // Mark any compile-time constant operator arguments as const.
93     std::vector<int> const_input_idxs;
94     status = GetCompileTimeConstInputs(node, &const_input_idxs, flib_runtime);
95 
96     if (!status.ok()) {
97       return;
98     }
99 
100     for (Edge const* edge : node->in_edges()) {
101       if (!edge->IsControlEdge() &&
102           absl::c_binary_search(const_input_idxs, edge->dst_input()) &&
103           edge_filter(*edge)) {
104         // Do not mark IdentityN nodes as compile-time const.
105         // If the src node of the `pred` is an IdentityN do not mark it as a
106         // compile-time const. Only mark the corresponding input to the
107         // IdentityN node as a const.
108         // Note: XLA IdentityN op simply forwards its inputs so this is safe.
109         while (edge_filter(*edge) &&
110                edge->src()->type_string() == "IdentityN") {
111           status = edge->src()->input_edge(edge->src_output(), &edge);
112           if (!status.ok()) return;
113         }
114         if (edge_filter(*edge)) {
115           (*compile_time_const_nodes)[edge->src()->id()] = true;
116         }
117       }
118     }
119   };
120 
121   // Post-order traversal visits nodes in reverse topological order for an
122   // acyclic graph.
123   DFS(g, /*enter=*/{}, /*leave=*/visit, NodeComparatorName{},
124       [](const Edge& edge) { return !edge.src()->IsNextIteration(); });
125   return status;
126 }
127 
GetCompileTimeConstInputs(const Node * node,std::vector<int> * const_input_idxs,FunctionLibraryRuntime * flib_runtime)128 Status GetCompileTimeConstInputs(const Node* node,
129                                  std::vector<int>* const_input_idxs,
130                                  FunctionLibraryRuntime* flib_runtime) {
131   if (node->type_string() != "While") {
132     return XlaOpRegistry::CompileTimeConstantInputs(node->def(), node->op_def(),
133                                                     const_input_idxs);
134   }
135   // For While nodes, recurse into the body and cond graphs.
136   // TODO(b/124403063): Implement similar functionality for cond nodes and other
137   // functional ops.
138   NameAttrList cond_function;
139   TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "cond", &cond_function));
140   NameAttrList body_function;
141   TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "body", &body_function));
142   FunctionLibraryRuntime::Handle cond_handle;
143   FunctionLibraryRuntime::Handle body_handle;
144   TF_RETURN_IF_ERROR(flib_runtime->Instantiate(
145       cond_function.name(), AttrSlice(&cond_function.attr()), &cond_handle));
146   TF_RETURN_IF_ERROR(flib_runtime->Instantiate(
147       body_function.name(), AttrSlice(&body_function.attr()), &body_handle));
148   const FunctionBody* fcond = flib_runtime->GetFunctionBody(cond_handle);
149   const FunctionBody* fbody = flib_runtime->GetFunctionBody(body_handle);
150   TF_RET_CHECK(fcond);
151   TF_RET_CHECK(fbody);
152   int num_inputs = fbody->fdef.signature().input_arg_size();
153 
154   // Stores which of the loop inputs are expected to be compile time constants.
155   std::vector<bool> compile_time_const_arg_indices(num_inputs);
156   TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
157       *(fcond->graph), &compile_time_const_arg_indices,
158       /*compile_time_const_nodes=*/nullptr, flib_runtime));
159   TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
160       *(fbody->graph), &compile_time_const_arg_indices,
161       /*compile_time_const_nodes=*/nullptr, flib_runtime));
162   for (int i = 0; i < num_inputs; i++) {
163     if (compile_time_const_arg_indices[i]) {
164       // Check that this input is actually a loop invariant.
165       // NOTE(srbs): Ideally this should raise an error if the loop body
166       // requires the input at this index to be a compile time const but it is
167       // not a loop invariant. However, that causes problems because const
168       // analysis is performed for the entire graph (in the
169       // MarkForCompilationPass for example) and not just for the ops
170       // that will actually be run using XLA kernels. So we silently return here
171       // and let the error be raised during the actual compilation of the
172       // XLA graph.
173       Node* arg_i = fbody->arg_nodes[i];
174       Node* ret_i = fbody->ret_nodes[i];
175       const Node* ret_i_input_0;
176       TF_RETURN_IF_ERROR(ret_i->input_node(0, &ret_i_input_0));
177       if (ret_i_input_0->id() == arg_i->id()) {
178         const_input_idxs->push_back(i);
179       }
180     }
181   }
182   return Status::OK();
183 }
184 
185 }  // namespace tensorflow
186