1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/eval_const_tensor.h"
17 
18 #include <deque>
19 
20 #include "tensorflow/core/common_runtime/graph_runner.h"
21 #include "tensorflow/core/common_runtime/shape_refiner.h"
22 #include "tensorflow/core/framework/bounds_check.h"
23 #include "tensorflow/core/framework/node_def.pb.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/versions.pb.h"
26 #include "tensorflow/core/graph/graph.h"
27 
28 namespace tensorflow {
29 
30 using shape_inference::InferenceContext;
31 
32 namespace {
33 
34 // Tries to infer tensor output based on the input shapes of the node. In some
35 // cases, the shapes of the inputs are sufficient for inferring the contents of
36 // the output tensor. For example, a Shape op with fully defined input shapes
37 // can have its output tensor inferred.
TryToInferTensorOutputFromInputShapes(const Edge & edge,const ShapeRefiner & refiner,Tensor * output,bool * success)38 Status TryToInferTensorOutputFromInputShapes(const Edge& edge,
39                                              const ShapeRefiner& refiner,
40                                              Tensor* output, bool* success) {
41   *success = false;
42   const Node* node = edge.src();
43   InferenceContext* c = refiner.GetContext(node);
44   if (c == nullptr) {
45     return errors::FailedPrecondition("Node does not have context.");
46   }
47 
48   if (node->type_string() == "Shape") {
49     // If input shapes to the shape op are fully defined,
50     // we can infer the shape op's output tensor.
51     bool fully_defined_inputs = c->FullyDefined(c->input(0));
52     if (fully_defined_inputs) {
53       int input_rank = c->Rank(c->input(0));
54       Tensor t(node->output_type(0), TensorShape({input_rank}));
55       if (node->output_type(0) == DT_INT32) {
56         auto flat = t.flat<int>();
57         for (int i = 0; i < input_rank; i++) {
58           int64 dimension = c->Value(c->Dim(c->input(0), i));
59           if (!FastBoundsCheck(dimension, std::numeric_limits<int32>::max())) {
60             return errors::InvalidArgument(
61                 "Shape has output type int32, but dimension exceeds maximum "
62                 "int32 value");
63           }
64           flat(i) = static_cast<int32>(dimension);
65         }
66       } else if (node->output_type(0) == DT_INT64) {
67         auto flat = t.flat<int64>();
68         for (int i = 0; i < input_rank; i++) {
69           flat(i) = c->Value(c->Dim(c->input(0), i));
70         }
71       } else {
72         return errors::FailedPrecondition(
73             "Shape has output type that is not int32 or int64");
74       }
75       *output = t;
76       *success = true;
77     }
78   } else if (node->type_string() == "Rank") {
79     bool rank_known = c->RankKnown(c->input(0));
80     if (rank_known) {
81       int32 input_rank = c->Rank(c->input(0));
82       Tensor t(node->output_type(0), TensorShape({}));
83       t.flat<int32>()(0) = input_rank;
84       *output = t;
85       *success = true;
86     }
87   } else if (node->type_string() == "Size") {
88     bool fully_defined_inputs = c->FullyDefined(c->input(0));
89     if (fully_defined_inputs) {
90       int32 rank = c->Rank(c->input(0));
91       Tensor t(node->output_type(0), TensorShape({}));
92       int64 size = 1;
93       for (int i = 0; i < rank; i++) {
94         size *= c->Value(c->Dim(c->input(0), i));
95       }
96       if (node->output_type(0) == DT_INT32) {
97         if (!FastBoundsCheck(size, std::numeric_limits<int32>::max())) {
98           return errors::InvalidArgument(
99               "Size has output type int32, but size exceeds maximum int32 "
100               "value");
101         }
102         t.flat<int32>()(0) = static_cast<int32>(size);
103       } else if (node->output_type(0) == DT_INT64) {
104         t.flat<int64>()(0) = size;
105       } else {
106         return errors::FailedPrecondition(
107             "Size has output type that is not int32 or int64");
108       }
109       *output = t;
110       *success = true;
111     }
112   }
113   return Status::OK();
114 }
115 
116 // Returns true if 'node' has a registered CPU kernel.
HasCpuKernel(const Node & node)117 bool HasCpuKernel(const Node& node) {
118   return FindKernelDef(DeviceType(DEVICE_CPU), node.def(), /*def=*/nullptr,
119                        /*kernel_class_name=*/nullptr)
120       .ok();
121 }
122 
123 // Extracts the subgraph ending at 'target_node' that is statically computable
124 // and inserts into 'out_graph'. If statically computable, 'is_constant_graph'
125 // will be set to true.
ExtractConstantSubgraph(const Node & target_node,const ShapeRefiner & refiner,const std::unordered_map<string,Tensor> * cached_values,Graph * out_graph,bool * is_constant_graph,std::vector<std::pair<string,Tensor>> * const_inputs)126 Status ExtractConstantSubgraph(
127     const Node& target_node, const ShapeRefiner& refiner,
128     const std::unordered_map<string, Tensor>* cached_values, Graph* out_graph,
129     bool* is_constant_graph,
130     std::vector<std::pair<string, Tensor>>* const_inputs) {
131   *is_constant_graph = false;
132   std::unordered_set<string> const_inputs_added;
133 
134   if (target_node.op_def().is_stateful()) {
135     return Status::OK();
136   }
137 
138   if (IsMerge(&target_node)) {
139     return Status::OK();
140   }
141 
142   if (target_node.type_string() == "PlaceholderWithDefault") {
143     return Status::OK();
144   }
145 
146   // Since constant-folding runs on the CPU, do not attempt to constant-fold
147   // operators that have no CPU kernel.
148   if (!HasCpuKernel(target_node)) {
149     return Status::OK();
150   }
151 
152   // TODO(skyewm): should more of the filtering applied in input nodes below be
153   // applied to target_node here?
154 
155   // Identify the possibly constant subgraph by recursively iterating backwards
156   // through the inputs to 'target_node' until we either 1) find an already
157   // existing input to our subgraph 'const_inputs', 2) Discover our graph is not
158   // constant, or 3) Hit a root node.
159 
160   struct NodeAndRecursed {
161     Node* new_node = nullptr;
162     bool recursed = false;
163   };
164 
165   std::map<const Node*, NodeAndRecursed> old_to_new_and_recursed;
166   Node* target_node_copy = out_graph->CopyNode(&target_node);
167   old_to_new_and_recursed[&target_node].new_node = target_node_copy;
168   old_to_new_and_recursed[&target_node].recursed = true;
169 
170   // Add the target node's inputs to seed the recursion.
171   std::deque<const Edge*> edges_to_visit;
172   for (const Edge* e : target_node.in_edges()) {
173     // TODO(skyewm): control edges will be meaningful if/when we handle control
174     // flow (e.g. constants in cond branches are triggered via control edges).
175     if (e->IsControlEdge()) continue;
176     edges_to_visit.push_back(e);
177   }
178 
179   *is_constant_graph = true;
180 
181   // Iterate over the set of edges to visit (backwards).
182   while (!edges_to_visit.empty()) {
183     const Edge* current_edge = edges_to_visit.front();
184     edges_to_visit.pop_front();
185     Node* current_node = current_edge->src();
186 
187     // If the node is stateful, assume the graph is not constant.
188     if (current_node->op_def().is_stateful()) {
189       *is_constant_graph = false;
190       return Status::OK();
191     }
192 
193     // During construction or import from GraphConstructor, back edges may not
194     // be filled in. In addition, control flow constructs may depend on control
195     // edges which aren't handled by this method. Don't constant fold through
196     // merges at all for now.
197     if (IsMerge(current_node)) {
198       *is_constant_graph = false;
199       return Status::OK();
200     }
201 
202     // Don't constant fold enter/exit currently either, as it's easy to end
203     // up with a partial frame.
204     if (IsEnter(current_node) || IsExit(current_node)) {
205       *is_constant_graph = false;
206       return Status::OK();
207     }
208 
209     // Placeholders should never be constant folded because their outputs are
210     // fed by the user. Note that "Placeholder" nodes have no inputs so are
211     // handled below.
212     if (current_node->type_string() == "PlaceholderWithDefault") {
213       *is_constant_graph = false;
214       return Status::OK();
215     }
216 
217     if (!HasCpuKernel(*current_node)) {
218       *is_constant_graph = false;
219       return Status::OK();
220     }
221 
222     // If there is nothing more to recurse down, see if
223     // the generator node is a constant.
224     if (current_node->num_inputs() == 0) {
225       if (!current_node->IsConstant()) {
226         // Generator node is not a constant, so subgraph is not
227         // constant.
228         *is_constant_graph = false;
229         return Status::OK();
230       }
231     }
232 
233     // Either the node is a constant, or the node is a potential
234     // intermediate node on the path from a constant.
235     //
236     // Add a copy of its node and a new edge to the new subgraph.
237 
238     // Get or create the version of 'current_node' in the new graph.
239     Node* current_node_copy;
240     // This gets or creates the NodeAndRecursed entry for current_node.
241     NodeAndRecursed* node_and_recursed = &old_to_new_and_recursed[current_node];
242     if (node_and_recursed->new_node == nullptr) {
243       // First time processing this node.
244       current_node_copy = out_graph->CopyNode(current_node);
245       // Track the mapping from the original node to the new one.
246       node_and_recursed->new_node = current_node_copy;
247     } else {
248       current_node_copy = node_and_recursed->new_node;
249     }
250 
251     // Add the edge to the destination node.
252     {
253       auto it = old_to_new_and_recursed.find(current_edge->dst());
254       if (it == old_to_new_and_recursed.end()) {
255         return errors::Internal(
256             "Could not find mapping from old to new copy of destination node: ",
257             current_edge->dst()->name());
258       }
259       Node* dst_copy = it->second.new_node;
260 
261       out_graph->AddEdge(current_node_copy, current_edge->src_output(),
262                          dst_copy, current_edge->dst_input());
263     }
264 
265     const string& output_tensor_name =
266         strings::StrCat(current_node->name(), ":", current_edge->src_output());
267 
268     // Some tensor values can be inferred. For example, a shape op
269     // with input shapes fully defined can have its output tensor inferred.
270     Tensor tensor_inferred;
271     bool successfully_inferred_tensor = false;
272     TF_RETURN_IF_ERROR(TryToInferTensorOutputFromInputShapes(
273         *current_edge, refiner, &tensor_inferred,
274         &successfully_inferred_tensor));
275     if (successfully_inferred_tensor) {
276       const_inputs->emplace_back(output_tensor_name, tensor_inferred);
277       const_inputs_added.insert(output_tensor_name);
278       continue;
279     }
280 
281     // If we have a copy of the input tensor materialized already,
282     // then add to the list of inputs to feed and do not recurse further.
283     if (cached_values != nullptr) {
284       auto it = cached_values->find(output_tensor_name);
285       if (it != cached_values->end() &&
286           const_inputs_added.count(output_tensor_name) == 0) {
287         const_inputs->emplace_back(output_tensor_name, it->second);
288         const_inputs_added.insert(output_tensor_name);
289         continue;
290       }
291     }
292 
293     // If this node's inputs have not been processed already, do so now.
294     if (!node_and_recursed->recursed) {
295       node_and_recursed->recursed = true;
296       for (const Edge* e : current_node->in_edges()) {
297         if (e->IsControlEdge()) continue;
298         edges_to_visit.push_back(e);
299       }
300     }
301   }
302 
303   return Status::OK();
304 }
305 
306 }  // namespace
307 
EvaluateConstantTensor(OutputTensor tensor,const ShapeRefiner & refiner,const OpRegistryInterface & ops,int32 graph_def_version,bool * evaluated,Tensor * result,GraphRunner * graph_runner,std::unordered_map<string,Tensor> * cached_values,int64 max_cached_value_size,bool disable_constant_propagation)308 Status EvaluateConstantTensor(OutputTensor tensor, const ShapeRefiner& refiner,
309                               const OpRegistryInterface& ops,
310                               int32 graph_def_version, bool* evaluated,
311                               Tensor* result, GraphRunner* graph_runner,
312                               std::unordered_map<string, Tensor>* cached_values,
313                               int64 max_cached_value_size,
314                               bool disable_constant_propagation) {
315   *evaluated = false;
316   const Node* src = tensor.node;
317 
318   // Simple case: the source node is a constant
319   if (src->IsConstant()) {
320     if (result->FromProto(src->def().attr().at("value").tensor())) {
321       *evaluated = true;
322       return Status::OK();
323     }
324   }
325 
326   if (disable_constant_propagation) {
327     return Status::OK();
328   }
329 
330   bool is_constant_graph = false;
331   Graph subgraph(&ops);
332   auto versions = subgraph.versions();
333   versions.set_producer(graph_def_version);
334   subgraph.set_versions(versions);
335 
336   std::vector<std::pair<string, Tensor>> const_inputs;
337   TF_RETURN_IF_ERROR(ExtractConstantSubgraph(*src, refiner, cached_values,
338                                              &subgraph, &is_constant_graph,
339                                              &const_inputs));
340   if (!is_constant_graph) {
341     return Status::OK();
342   }
343   const string output_tensor_name =
344       strings::StrCat(src->name(), ":", tensor.index);
345   std::vector<Tensor> outputs;
346 
347   std::unique_ptr<GraphRunner> graph_runner_storage;
348   if (graph_runner == nullptr) {
349     // TODO(skyewm): Convert to std::make_unique when available.
350     graph_runner_storage.reset(new GraphRunner(Env::Default()));
351     graph_runner = graph_runner_storage.get();
352   }
353 
354   // NOTE; we should pass in a function library runtime if we want
355   // to support constant-expression evaluation on functions.
356   Status s = graph_runner->Run(&subgraph, nullptr /* function_library */,
357                                const_inputs, {output_tensor_name}, &outputs);
358 
359   // If all kernels in the constant graph are not registered
360   // in the process, GraphRunner::Run may fail, in which case
361   // we cannot propagate constants, so this is best-effort.
362   if (s.ok()) {
363     *result = outputs[0];
364     *evaluated = true;
365 
366     // We memoize (small) constants evaluated so far, so
367     // ExtractConstantSubgraph can avoid extracting the full
368     // subgraph.  As we build up large graphs, this avoids
369     // repeated computation of the early parts of a constant
370     // graph.
371     if (cached_values != nullptr &&
372         outputs[0].TotalBytes() <= max_cached_value_size) {
373       (*cached_values)[output_tensor_name] = outputs[0];
374     }
375   }
376   return Status::OK();
377 }
378 
379 }  // namespace tensorflow
380