1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/eval_const_tensor.h"
17
18 #include <deque>
19
20 #include "tensorflow/core/common_runtime/graph_runner.h"
21 #include "tensorflow/core/common_runtime/shape_refiner.h"
22 #include "tensorflow/core/framework/bounds_check.h"
23 #include "tensorflow/core/framework/node_def.pb.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/versions.pb.h"
26 #include "tensorflow/core/graph/graph.h"
27
28 namespace tensorflow {
29
30 using shape_inference::InferenceContext;
31
32 namespace {
33
34 // Tries to infer tensor output based on the input shapes of the node. In some
35 // cases, the shapes of the inputs are sufficient for inferring the contents of
36 // the output tensor. For example, a Shape op with fully defined input shapes
37 // can have its output tensor inferred.
TryToInferTensorOutputFromInputShapes(const Edge & edge,const ShapeRefiner & refiner,Tensor * output,bool * success)38 Status TryToInferTensorOutputFromInputShapes(const Edge& edge,
39 const ShapeRefiner& refiner,
40 Tensor* output, bool* success) {
41 *success = false;
42 const Node* node = edge.src();
43 InferenceContext* c = refiner.GetContext(node);
44 if (c == nullptr) {
45 return errors::FailedPrecondition("Node does not have context.");
46 }
47
48 if (node->type_string() == "Shape") {
49 // If input shapes to the shape op are fully defined,
50 // we can infer the shape op's output tensor.
51 bool fully_defined_inputs = c->FullyDefined(c->input(0));
52 if (fully_defined_inputs) {
53 int input_rank = c->Rank(c->input(0));
54 Tensor t(node->output_type(0), TensorShape({input_rank}));
55 if (node->output_type(0) == DT_INT32) {
56 auto flat = t.flat<int>();
57 for (int i = 0; i < input_rank; i++) {
58 int64 dimension = c->Value(c->Dim(c->input(0), i));
59 if (!FastBoundsCheck(dimension, std::numeric_limits<int32>::max())) {
60 return errors::InvalidArgument(
61 "Shape has output type int32, but dimension exceeds maximum "
62 "int32 value");
63 }
64 flat(i) = static_cast<int32>(dimension);
65 }
66 } else if (node->output_type(0) == DT_INT64) {
67 auto flat = t.flat<int64>();
68 for (int i = 0; i < input_rank; i++) {
69 flat(i) = c->Value(c->Dim(c->input(0), i));
70 }
71 } else {
72 return errors::FailedPrecondition(
73 "Shape has output type that is not int32 or int64");
74 }
75 *output = t;
76 *success = true;
77 }
78 } else if (node->type_string() == "Rank") {
79 bool rank_known = c->RankKnown(c->input(0));
80 if (rank_known) {
81 int32 input_rank = c->Rank(c->input(0));
82 Tensor t(node->output_type(0), TensorShape({}));
83 t.flat<int32>()(0) = input_rank;
84 *output = t;
85 *success = true;
86 }
87 } else if (node->type_string() == "Size") {
88 bool fully_defined_inputs = c->FullyDefined(c->input(0));
89 if (fully_defined_inputs) {
90 int32 rank = c->Rank(c->input(0));
91 Tensor t(node->output_type(0), TensorShape({}));
92 int64 size = 1;
93 for (int i = 0; i < rank; i++) {
94 size *= c->Value(c->Dim(c->input(0), i));
95 }
96 if (node->output_type(0) == DT_INT32) {
97 if (!FastBoundsCheck(size, std::numeric_limits<int32>::max())) {
98 return errors::InvalidArgument(
99 "Size has output type int32, but size exceeds maximum int32 "
100 "value");
101 }
102 t.flat<int32>()(0) = static_cast<int32>(size);
103 } else if (node->output_type(0) == DT_INT64) {
104 t.flat<int64>()(0) = size;
105 } else {
106 return errors::FailedPrecondition(
107 "Size has output type that is not int32 or int64");
108 }
109 *output = t;
110 *success = true;
111 }
112 }
113 return Status::OK();
114 }
115
116 // Returns true if 'node' has a registered CPU kernel.
HasCpuKernel(const Node & node)117 bool HasCpuKernel(const Node& node) {
118 return FindKernelDef(DeviceType(DEVICE_CPU), node.def(), /*def=*/nullptr,
119 /*kernel_class_name=*/nullptr)
120 .ok();
121 }
122
123 // Extracts the subgraph ending at 'target_node' that is statically computable
124 // and inserts into 'out_graph'. If statically computable, 'is_constant_graph'
125 // will be set to true.
ExtractConstantSubgraph(const Node & target_node,const ShapeRefiner & refiner,const std::unordered_map<string,Tensor> * cached_values,Graph * out_graph,bool * is_constant_graph,std::vector<std::pair<string,Tensor>> * const_inputs)126 Status ExtractConstantSubgraph(
127 const Node& target_node, const ShapeRefiner& refiner,
128 const std::unordered_map<string, Tensor>* cached_values, Graph* out_graph,
129 bool* is_constant_graph,
130 std::vector<std::pair<string, Tensor>>* const_inputs) {
131 *is_constant_graph = false;
132 std::unordered_set<string> const_inputs_added;
133
134 if (target_node.op_def().is_stateful()) {
135 return Status::OK();
136 }
137
138 if (IsMerge(&target_node)) {
139 return Status::OK();
140 }
141
142 if (target_node.type_string() == "PlaceholderWithDefault") {
143 return Status::OK();
144 }
145
146 // Since constant-folding runs on the CPU, do not attempt to constant-fold
147 // operators that have no CPU kernel.
148 if (!HasCpuKernel(target_node)) {
149 return Status::OK();
150 }
151
152 // TODO(skyewm): should more of the filtering applied in input nodes below be
153 // applied to target_node here?
154
155 // Identify the possibly constant subgraph by recursively iterating backwards
156 // through the inputs to 'target_node' until we either 1) find an already
157 // existing input to our subgraph 'const_inputs', 2) Discover our graph is not
158 // constant, or 3) Hit a root node.
159
160 struct NodeAndRecursed {
161 Node* new_node = nullptr;
162 bool recursed = false;
163 };
164
165 std::map<const Node*, NodeAndRecursed> old_to_new_and_recursed;
166 Node* target_node_copy = out_graph->CopyNode(&target_node);
167 old_to_new_and_recursed[&target_node].new_node = target_node_copy;
168 old_to_new_and_recursed[&target_node].recursed = true;
169
170 // Add the target node's inputs to seed the recursion.
171 std::deque<const Edge*> edges_to_visit;
172 for (const Edge* e : target_node.in_edges()) {
173 // TODO(skyewm): control edges will be meaningful if/when we handle control
174 // flow (e.g. constants in cond branches are triggered via control edges).
175 if (e->IsControlEdge()) continue;
176 edges_to_visit.push_back(e);
177 }
178
179 *is_constant_graph = true;
180
181 // Iterate over the set of edges to visit (backwards).
182 while (!edges_to_visit.empty()) {
183 const Edge* current_edge = edges_to_visit.front();
184 edges_to_visit.pop_front();
185 Node* current_node = current_edge->src();
186
187 // If the node is stateful, assume the graph is not constant.
188 if (current_node->op_def().is_stateful()) {
189 *is_constant_graph = false;
190 return Status::OK();
191 }
192
193 // During construction or import from GraphConstructor, back edges may not
194 // be filled in. In addition, control flow constructs may depend on control
195 // edges which aren't handled by this method. Don't constant fold through
196 // merges at all for now.
197 if (IsMerge(current_node)) {
198 *is_constant_graph = false;
199 return Status::OK();
200 }
201
202 // Don't constant fold enter/exit currently either, as it's easy to end
203 // up with a partial frame.
204 if (IsEnter(current_node) || IsExit(current_node)) {
205 *is_constant_graph = false;
206 return Status::OK();
207 }
208
209 // Placeholders should never be constant folded because their outputs are
210 // fed by the user. Note that "Placeholder" nodes have no inputs so are
211 // handled below.
212 if (current_node->type_string() == "PlaceholderWithDefault") {
213 *is_constant_graph = false;
214 return Status::OK();
215 }
216
217 if (!HasCpuKernel(*current_node)) {
218 *is_constant_graph = false;
219 return Status::OK();
220 }
221
222 // If there is nothing more to recurse down, see if
223 // the generator node is a constant.
224 if (current_node->num_inputs() == 0) {
225 if (!current_node->IsConstant()) {
226 // Generator node is not a constant, so subgraph is not
227 // constant.
228 *is_constant_graph = false;
229 return Status::OK();
230 }
231 }
232
233 // Either the node is a constant, or the node is a potential
234 // intermediate node on the path from a constant.
235 //
236 // Add a copy of its node and a new edge to the new subgraph.
237
238 // Get or create the version of 'current_node' in the new graph.
239 Node* current_node_copy;
240 // This gets or creates the NodeAndRecursed entry for current_node.
241 NodeAndRecursed* node_and_recursed = &old_to_new_and_recursed[current_node];
242 if (node_and_recursed->new_node == nullptr) {
243 // First time processing this node.
244 current_node_copy = out_graph->CopyNode(current_node);
245 // Track the mapping from the original node to the new one.
246 node_and_recursed->new_node = current_node_copy;
247 } else {
248 current_node_copy = node_and_recursed->new_node;
249 }
250
251 // Add the edge to the destination node.
252 {
253 auto it = old_to_new_and_recursed.find(current_edge->dst());
254 if (it == old_to_new_and_recursed.end()) {
255 return errors::Internal(
256 "Could not find mapping from old to new copy of destination node: ",
257 current_edge->dst()->name());
258 }
259 Node* dst_copy = it->second.new_node;
260
261 out_graph->AddEdge(current_node_copy, current_edge->src_output(),
262 dst_copy, current_edge->dst_input());
263 }
264
265 const string& output_tensor_name =
266 strings::StrCat(current_node->name(), ":", current_edge->src_output());
267
268 // Some tensor values can be inferred. For example, a shape op
269 // with input shapes fully defined can have its output tensor inferred.
270 Tensor tensor_inferred;
271 bool successfully_inferred_tensor = false;
272 TF_RETURN_IF_ERROR(TryToInferTensorOutputFromInputShapes(
273 *current_edge, refiner, &tensor_inferred,
274 &successfully_inferred_tensor));
275 if (successfully_inferred_tensor) {
276 const_inputs->emplace_back(output_tensor_name, tensor_inferred);
277 const_inputs_added.insert(output_tensor_name);
278 continue;
279 }
280
281 // If we have a copy of the input tensor materialized already,
282 // then add to the list of inputs to feed and do not recurse further.
283 if (cached_values != nullptr) {
284 auto it = cached_values->find(output_tensor_name);
285 if (it != cached_values->end() &&
286 const_inputs_added.count(output_tensor_name) == 0) {
287 const_inputs->emplace_back(output_tensor_name, it->second);
288 const_inputs_added.insert(output_tensor_name);
289 continue;
290 }
291 }
292
293 // If this node's inputs have not been processed already, do so now.
294 if (!node_and_recursed->recursed) {
295 node_and_recursed->recursed = true;
296 for (const Edge* e : current_node->in_edges()) {
297 if (e->IsControlEdge()) continue;
298 edges_to_visit.push_back(e);
299 }
300 }
301 }
302
303 return Status::OK();
304 }
305
306 } // namespace
307
EvaluateConstantTensor(OutputTensor tensor,const ShapeRefiner & refiner,const OpRegistryInterface & ops,int32 graph_def_version,bool * evaluated,Tensor * result,GraphRunner * graph_runner,std::unordered_map<string,Tensor> * cached_values,int64 max_cached_value_size,bool disable_constant_propagation)308 Status EvaluateConstantTensor(OutputTensor tensor, const ShapeRefiner& refiner,
309 const OpRegistryInterface& ops,
310 int32 graph_def_version, bool* evaluated,
311 Tensor* result, GraphRunner* graph_runner,
312 std::unordered_map<string, Tensor>* cached_values,
313 int64 max_cached_value_size,
314 bool disable_constant_propagation) {
315 *evaluated = false;
316 const Node* src = tensor.node;
317
318 // Simple case: the source node is a constant
319 if (src->IsConstant()) {
320 if (result->FromProto(src->def().attr().at("value").tensor())) {
321 *evaluated = true;
322 return Status::OK();
323 }
324 }
325
326 if (disable_constant_propagation) {
327 return Status::OK();
328 }
329
330 bool is_constant_graph = false;
331 Graph subgraph(&ops);
332 auto versions = subgraph.versions();
333 versions.set_producer(graph_def_version);
334 subgraph.set_versions(versions);
335
336 std::vector<std::pair<string, Tensor>> const_inputs;
337 TF_RETURN_IF_ERROR(ExtractConstantSubgraph(*src, refiner, cached_values,
338 &subgraph, &is_constant_graph,
339 &const_inputs));
340 if (!is_constant_graph) {
341 return Status::OK();
342 }
343 const string output_tensor_name =
344 strings::StrCat(src->name(), ":", tensor.index);
345 std::vector<Tensor> outputs;
346
347 std::unique_ptr<GraphRunner> graph_runner_storage;
348 if (graph_runner == nullptr) {
349 // TODO(skyewm): Convert to std::make_unique when available.
350 graph_runner_storage.reset(new GraphRunner(Env::Default()));
351 graph_runner = graph_runner_storage.get();
352 }
353
354 // NOTE; we should pass in a function library runtime if we want
355 // to support constant-expression evaluation on functions.
356 Status s = graph_runner->Run(&subgraph, nullptr /* function_library */,
357 const_inputs, {output_tensor_name}, &outputs);
358
359 // If all kernels in the constant graph are not registered
360 // in the process, GraphRunner::Run may fail, in which case
361 // we cannot propagate constants, so this is best-effort.
362 if (s.ok()) {
363 *result = outputs[0];
364 *evaluated = true;
365
366 // We memoize (small) constants evaluated so far, so
367 // ExtractConstantSubgraph can avoid extracting the full
368 // subgraph. As we build up large graphs, this avoids
369 // repeated computation of the early parts of a constant
370 // graph.
371 if (cached_values != nullptr &&
372 outputs[0].TotalBytes() <= max_cached_value_size) {
373 (*cached_values)[output_tensor_name] = outputs[0];
374 }
375 }
376 return Status::OK();
377 }
378
379 } // namespace tensorflow
380